[mlir][sparse] provide an AoS "view" into sparse runtime support lib (#87116)
Note that even though the sparse runtime support lib always uses SoA storage for COO storage (and provides correct codegen by means of views into this storage), in some rare cases we need the true physical SoA storage as a coordinate buffer. This PR provides that functionality by means of a (costly) coordinate buffer call. Since this is currently only used for testing/debugging by means of the sparse_tensor.print method, this solution is acceptable. If we ever want a performing version of this, we should truly support AoS storage of COO in addition to the SoA used right now.
This commit is contained in:
@@ -275,7 +275,7 @@ static Value genPositionsCall(OpBuilder &builder, Location loc,
|
||||
.getResult(0);
|
||||
}
|
||||
|
||||
/// Generates a call to obtain the coordindates array.
|
||||
/// Generates a call to obtain the coordinates array.
|
||||
static Value genCoordinatesCall(OpBuilder &builder, Location loc,
|
||||
SparseTensorType stt, Value ptr, Level l) {
|
||||
Type crdTp = stt.getCrdType();
|
||||
@@ -287,6 +287,20 @@ static Value genCoordinatesCall(OpBuilder &builder, Location loc,
|
||||
.getResult(0);
|
||||
}
|
||||
|
||||
/// Generates a call to obtain the coordinates array (AoS view).
|
||||
static Value genCoordinatesBufferCall(OpBuilder &builder, Location loc,
|
||||
SparseTensorType stt, Value ptr,
|
||||
Level l) {
|
||||
Type crdTp = stt.getCrdType();
|
||||
auto resTp = MemRefType::get({ShapedType::kDynamic}, crdTp);
|
||||
Value lvl = constantIndex(builder, loc, l);
|
||||
SmallString<25> name{"sparseCoordinatesBuffer",
|
||||
overheadTypeFunctionSuffix(crdTp)};
|
||||
return createFuncCall(builder, loc, name, resTp, {ptr, lvl},
|
||||
EmitCInterface::On)
|
||||
.getResult(0);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Conversion rules.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -518,13 +532,35 @@ public:
|
||||
LogicalResult
|
||||
matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
const Location loc = op.getLoc();
|
||||
auto stt = getSparseTensorType(op.getTensor());
|
||||
auto crds = genCoordinatesCall(rewriter, op.getLoc(), stt,
|
||||
adaptor.getTensor(), op.getLevel());
|
||||
auto crds = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
|
||||
op.getLevel());
|
||||
// Cast the MemRef type to the type expected by the users, though these
|
||||
// two types should be compatible at runtime.
|
||||
if (op.getType() != crds.getType())
|
||||
crds = rewriter.create<memref::CastOp>(op.getLoc(), op.getType(), crds);
|
||||
crds = rewriter.create<memref::CastOp>(loc, op.getType(), crds);
|
||||
rewriter.replaceOp(op, crds);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Sparse conversion rule for coordinate accesses (AoS style).
|
||||
class SparseToCoordinatesBufferConverter
|
||||
: public OpConversionPattern<ToCoordinatesBufferOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(ToCoordinatesBufferOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
const Location loc = op.getLoc();
|
||||
auto stt = getSparseTensorType(op.getTensor());
|
||||
auto crds = genCoordinatesBufferCall(
|
||||
rewriter, loc, stt, adaptor.getTensor(), stt.getAoSCOOStart());
|
||||
// Cast the MemRef type to the type expected by the users, though these
|
||||
// two types should be compatible at runtime.
|
||||
if (op.getType() != crds.getType())
|
||||
crds = rewriter.create<memref::CastOp>(loc, op.getType(), crds);
|
||||
rewriter.replaceOp(op, crds);
|
||||
return success();
|
||||
}
|
||||
@@ -878,10 +914,10 @@ void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
|
||||
SparseTensorAllocConverter, SparseTensorEmptyConverter,
|
||||
SparseTensorDeallocConverter, SparseTensorReorderCOOConverter,
|
||||
SparseTensorToPositionsConverter, SparseTensorToCoordinatesConverter,
|
||||
SparseTensorToValuesConverter, SparseNumberOfEntriesConverter,
|
||||
SparseTensorLoadConverter, SparseTensorInsertConverter,
|
||||
SparseTensorExpandConverter, SparseTensorCompressConverter,
|
||||
SparseTensorAssembleConverter, SparseTensorDisassembleConverter,
|
||||
SparseHasRuntimeLibraryConverter>(typeConverter,
|
||||
patterns.getContext());
|
||||
SparseToCoordinatesBufferConverter, SparseTensorToValuesConverter,
|
||||
SparseNumberOfEntriesConverter, SparseTensorLoadConverter,
|
||||
SparseTensorInsertConverter, SparseTensorExpandConverter,
|
||||
SparseTensorCompressConverter, SparseTensorAssembleConverter,
|
||||
SparseTensorDisassembleConverter, SparseHasRuntimeLibraryConverter>(
|
||||
typeConverter, patterns.getContext());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user