[mlir][sparse] provide an AoS "view" into sparse runtime support lib (#87116)

Note that even though the sparse runtime support lib always uses SoA
storage for COO storage (and provides correct codegen by means of views
into this storage), in some rare cases we need the true physical SoA
storage as a coordinate buffer. This PR provides that functionality by
means of a (costly) coordinate buffer call.

Since this is currently only used for testing/debugging by means of the
sparse_tensor.print method, this solution is acceptable. If we ever want
a performing version of this, we should truly support AoS storage of COO
in addition to the SoA used right now.
This commit is contained in:
Aart Bik
2024-03-29 15:30:36 -07:00
committed by GitHub
parent 038e66fe59
commit dc4cfdbb8f
7 changed files with 152 additions and 23 deletions

View File

@@ -275,7 +275,7 @@ static Value genPositionsCall(OpBuilder &builder, Location loc,
.getResult(0);
}
/// Generates a call to obtain the coordindates array.
/// Generates a call to obtain the coordinates array.
static Value genCoordinatesCall(OpBuilder &builder, Location loc,
SparseTensorType stt, Value ptr, Level l) {
Type crdTp = stt.getCrdType();
@@ -287,6 +287,20 @@ static Value genCoordinatesCall(OpBuilder &builder, Location loc,
.getResult(0);
}
/// Generates a call to obtain the coordinates array (AoS view).
static Value genCoordinatesBufferCall(OpBuilder &builder, Location loc,
SparseTensorType stt, Value ptr,
Level l) {
Type crdTp = stt.getCrdType();
auto resTp = MemRefType::get({ShapedType::kDynamic}, crdTp);
Value lvl = constantIndex(builder, loc, l);
SmallString<25> name{"sparseCoordinatesBuffer",
overheadTypeFunctionSuffix(crdTp)};
return createFuncCall(builder, loc, name, resTp, {ptr, lvl},
EmitCInterface::On)
.getResult(0);
}
//===----------------------------------------------------------------------===//
// Conversion rules.
//===----------------------------------------------------------------------===//
@@ -518,13 +532,35 @@ public:
LogicalResult
matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
const Location loc = op.getLoc();
auto stt = getSparseTensorType(op.getTensor());
auto crds = genCoordinatesCall(rewriter, op.getLoc(), stt,
adaptor.getTensor(), op.getLevel());
auto crds = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
op.getLevel());
// Cast the MemRef type to the type expected by the users, though these
// two types should be compatible at runtime.
if (op.getType() != crds.getType())
crds = rewriter.create<memref::CastOp>(op.getLoc(), op.getType(), crds);
crds = rewriter.create<memref::CastOp>(loc, op.getType(), crds);
rewriter.replaceOp(op, crds);
return success();
}
};
/// Sparse conversion rule for coordinate accesses (AoS style).
class SparseToCoordinatesBufferConverter
: public OpConversionPattern<ToCoordinatesBufferOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(ToCoordinatesBufferOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
const Location loc = op.getLoc();
auto stt = getSparseTensorType(op.getTensor());
auto crds = genCoordinatesBufferCall(
rewriter, loc, stt, adaptor.getTensor(), stt.getAoSCOOStart());
// Cast the MemRef type to the type expected by the users, though these
// two types should be compatible at runtime.
if (op.getType() != crds.getType())
crds = rewriter.create<memref::CastOp>(loc, op.getType(), crds);
rewriter.replaceOp(op, crds);
return success();
}
@@ -878,10 +914,10 @@ void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
SparseTensorAllocConverter, SparseTensorEmptyConverter,
SparseTensorDeallocConverter, SparseTensorReorderCOOConverter,
SparseTensorToPositionsConverter, SparseTensorToCoordinatesConverter,
SparseTensorToValuesConverter, SparseNumberOfEntriesConverter,
SparseTensorLoadConverter, SparseTensorInsertConverter,
SparseTensorExpandConverter, SparseTensorCompressConverter,
SparseTensorAssembleConverter, SparseTensorDisassembleConverter,
SparseHasRuntimeLibraryConverter>(typeConverter,
patterns.getContext());
SparseToCoordinatesBufferConverter, SparseTensorToValuesConverter,
SparseNumberOfEntriesConverter, SparseTensorLoadConverter,
SparseTensorInsertConverter, SparseTensorExpandConverter,
SparseTensorCompressConverter, SparseTensorAssembleConverter,
SparseTensorDisassembleConverter, SparseHasRuntimeLibraryConverter>(
typeConverter, patterns.getContext());
}