[mlir][sparse] introduce MapRef, unify conversion/codegen for reader (#68360)

This revision introduces a MapRef, which will support a future
generalization beyond permutations (e.g. block sparsity). This revision
also unifies the conversion/codegen paths for the sparse_tensor.new
operation from file (eg. the readers). Note that more unification is
planned as well as general affine dim2lvl and lvl2dim (all marked with
TODOs).
This commit is contained in:
Aart Bik
2023-10-06 13:42:01 -07:00
committed by GitHub
parent f045f2c26d
commit d3af65358d
14 changed files with 438 additions and 484 deletions

View File

@@ -1428,7 +1428,7 @@ struct SparseDisassembleOpConverter
}
};
struct SparseNewOpConverter : public OpConversionPattern<NewOp> {
struct SparseNewConverter : public OpConversionPattern<NewOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(NewOp op, OpAdaptor adaptor,
@@ -1440,7 +1440,7 @@ struct SparseNewOpConverter : public OpConversionPattern<NewOp> {
if (!dstTp.hasEncoding() || getCOOStart(dstTp.getEncoding()) != 0)
return failure();
// Implement the NewOp(filename) as follows:
// Implement as follows:
// %reader = @createCheckedSparseTensorReader(%filename)
// %nse = @getSparseTensorNSE(%reader)
// %coo = bufferization.alloc_tensor an ordered COO with
@@ -1451,74 +1451,39 @@ struct SparseNewOpConverter : public OpConversionPattern<NewOp> {
// if (! %isSorted) sparse_tensor.sort_coo(%nse, %coordinates, %values)
// update storage specifier
// @delSparseTensorReader(%reader)
SmallVector<Value> dimShapesValues;
Value dimSizesBuffer;
Value reader = genReader(rewriter, loc, dstTp, adaptor.getOperands()[0],
dimShapesValues, dimSizesBuffer);
// Allocate `SparseTensorReader` and perform all initial setup that
// does not depend on lvlSizes (nor dimToLvl, lvlToDim, etc).
const Type opaqueTp = getOpaquePointerType(rewriter);
const Value fileName = op.getSource();
SmallVector<Value> dimShapeValues;
for (const DynSize sh : dstTp.getDimShape()) {
const auto s = ShapedType::isDynamic(sh) ? 0 : sh;
dimShapeValues.push_back(constantIndex(rewriter, loc, s));
}
Value dimShapeBuffer = allocaBuffer(rewriter, loc, dimShapeValues);
Value valTp =
constantPrimaryTypeEncoding(rewriter, loc, dstTp.getElementType());
Value reader =
createFuncCall(rewriter, loc, "createCheckedSparseTensorReader",
opaqueTp, {fileName, dimShapeBuffer, valTp},
EmitCInterface::On)
.getResult(0);
// Get the number of stored entries.
const Type indexTp = rewriter.getIndexType();
const Dimension dimRank = dstTp.getDimRank();
const Level lvlRank = dstTp.getLvlRank();
Value nse = createFuncCall(rewriter, loc, "getSparseTensorReaderNSE",
{indexTp}, {reader}, EmitCInterface::Off)
.getResult(0);
// If the result tensor has dynamic dimensions, get the dynamic sizes from
// the sparse tensor reader.
// Construct allocation for each field.
SmallVector<Value> dynSizes;
if (dstTp.hasDynamicDimShape()) {
auto memTp = MemRefType::get({ShapedType::kDynamic}, indexTp);
Value dimSizesBuffer =
createFuncCall(rewriter, loc, "getSparseTensorReaderDimSizes", memTp,
reader, EmitCInterface::On)
.getResult(0);
for (const auto &d : llvm::enumerate(dstTp.getDimShape()))
if (ShapedType::isDynamic(d.value()))
dynSizes.push_back(rewriter.create<memref::LoadOp>(
loc, dimSizesBuffer, constantIndex(rewriter, loc, d.index())));
}
// Get the number of stored entries.
Value nse = createFuncCall(rewriter, loc, "getSparseTensorReaderNSE",
{indexTp}, {reader}, EmitCInterface::Off)
.getResult(0);
// Construct allocation for each field.
SmallVector<Value> fields;
createAllocFields(rewriter, loc, dstTp, dynSizes, /*enableInit=*/false,
fields, nse);
MutSparseTensorDescriptor desc(dstTp, fields);
// Construct the `dimToLvl` buffer for handing off to the runtime library.
SmallVector<Value> dimToLvlValues(dimRank);
if (!dstTp.isIdentity()) {
const auto dimToLvl = dstTp.getDimToLvl();
assert(dimToLvl.isPermutation() && "Got non-permutation");
for (Level l = 0; l < lvlRank; l++) {
const Dimension d = dimToLvl.getDimPosition(l);
dimToLvlValues[d] = constantIndex(rewriter, loc, l);
}
} else {
// The `SparseTensorType` ctor already ensures `dimRank == lvlRank`
// when `isIdentity`; so no need to re-assert it here.
for (Dimension d = 0; d < dimRank; d++)
dimToLvlValues[d] = constantIndex(rewriter, loc, d);
}
Value dimToLvl = allocaBuffer(rewriter, loc, dimToLvlValues);
// Now construct the dim2lvl and lvl2dim buffers.
Value dim2lvlBuffer;
Value lvl2dimBuffer;
genReaderBuffers(rewriter, loc, dstTp, dimShapesValues, dimSizesBuffer,
dim2lvlBuffer, lvl2dimBuffer);
// Read the COO tensor data.
Value xs = desc.getAOSMemRef();
Value ys = desc.getValMemRef();
const Type boolTp = rewriter.getIntegerType(1);
const Type elemTp = dstTp.getElementType();
const Type crdTp = dstTp.getCrdType();
@@ -1527,11 +1492,13 @@ struct SparseNewOpConverter : public OpConversionPattern<NewOp> {
primaryTypeFunctionSuffix(elemTp)};
Value isSorted =
createFuncCall(rewriter, loc, readToBuffersFuncName, {boolTp},
{reader, dimToLvl, xs, ys}, EmitCInterface::On)
{reader, dim2lvlBuffer, lvl2dimBuffer, xs, ys},
EmitCInterface::On)
.getResult(0);
// If the destination tensor is a sorted COO, we need to sort the COO tensor
// data if the input elements aren't sorted yet.
const Level lvlRank = dstTp.getLvlRank();
if (dstTp.isOrderedLvl(lvlRank - 1)) {
Value kFalse = constantI1(rewriter, loc, false);
Value notSorted = rewriter.create<arith::CmpIOp>(
@@ -1593,7 +1560,7 @@ void mlir::populateSparseTensorCodegenPatterns(
StorageSpecifierKind::DimStride>,
SparseToPositionsConverter, SparseToCoordinatesConverter,
SparseToCoordinatesBufferConverter, SparseToValuesConverter,
SparseConvertConverter, SparseNewOpConverter,
SparseConvertConverter, SparseNewConverter,
SparseNumberOfEntriesConverter>(typeConverter,
patterns.getContext());
patterns.add<SparseTensorDeallocConverter>(