[mlir][sparse] introduce MapRef, unify conversion/codegen for reader (#68360)
This revision introduces a MapRef, which will support a future generalization beyond permutations (e.g. block sparsity). This revision also unifies the conversion/codegen paths for the sparse_tensor.new operation from file (eg. the readers). Note that more unification is planned as well as general affine dim2lvl and lvl2dim (all marked with TODOs).
This commit is contained in:
@@ -1428,7 +1428,7 @@ struct SparseDisassembleOpConverter
|
||||
}
|
||||
};
|
||||
|
||||
struct SparseNewOpConverter : public OpConversionPattern<NewOp> {
|
||||
struct SparseNewConverter : public OpConversionPattern<NewOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(NewOp op, OpAdaptor adaptor,
|
||||
@@ -1440,7 +1440,7 @@ struct SparseNewOpConverter : public OpConversionPattern<NewOp> {
|
||||
if (!dstTp.hasEncoding() || getCOOStart(dstTp.getEncoding()) != 0)
|
||||
return failure();
|
||||
|
||||
// Implement the NewOp(filename) as follows:
|
||||
// Implement as follows:
|
||||
// %reader = @createCheckedSparseTensorReader(%filename)
|
||||
// %nse = @getSparseTensorNSE(%reader)
|
||||
// %coo = bufferization.alloc_tensor an ordered COO with
|
||||
@@ -1451,74 +1451,39 @@ struct SparseNewOpConverter : public OpConversionPattern<NewOp> {
|
||||
// if (! %isSorted) sparse_tensor.sort_coo(%nse, %coordinates, %values)
|
||||
// update storage specifier
|
||||
// @delSparseTensorReader(%reader)
|
||||
SmallVector<Value> dimShapesValues;
|
||||
Value dimSizesBuffer;
|
||||
Value reader = genReader(rewriter, loc, dstTp, adaptor.getOperands()[0],
|
||||
dimShapesValues, dimSizesBuffer);
|
||||
|
||||
// Allocate `SparseTensorReader` and perform all initial setup that
|
||||
// does not depend on lvlSizes (nor dimToLvl, lvlToDim, etc).
|
||||
const Type opaqueTp = getOpaquePointerType(rewriter);
|
||||
const Value fileName = op.getSource();
|
||||
SmallVector<Value> dimShapeValues;
|
||||
for (const DynSize sh : dstTp.getDimShape()) {
|
||||
const auto s = ShapedType::isDynamic(sh) ? 0 : sh;
|
||||
dimShapeValues.push_back(constantIndex(rewriter, loc, s));
|
||||
}
|
||||
Value dimShapeBuffer = allocaBuffer(rewriter, loc, dimShapeValues);
|
||||
Value valTp =
|
||||
constantPrimaryTypeEncoding(rewriter, loc, dstTp.getElementType());
|
||||
Value reader =
|
||||
createFuncCall(rewriter, loc, "createCheckedSparseTensorReader",
|
||||
opaqueTp, {fileName, dimShapeBuffer, valTp},
|
||||
EmitCInterface::On)
|
||||
.getResult(0);
|
||||
// Get the number of stored entries.
|
||||
const Type indexTp = rewriter.getIndexType();
|
||||
const Dimension dimRank = dstTp.getDimRank();
|
||||
const Level lvlRank = dstTp.getLvlRank();
|
||||
Value nse = createFuncCall(rewriter, loc, "getSparseTensorReaderNSE",
|
||||
{indexTp}, {reader}, EmitCInterface::Off)
|
||||
.getResult(0);
|
||||
|
||||
// If the result tensor has dynamic dimensions, get the dynamic sizes from
|
||||
// the sparse tensor reader.
|
||||
// Construct allocation for each field.
|
||||
SmallVector<Value> dynSizes;
|
||||
if (dstTp.hasDynamicDimShape()) {
|
||||
auto memTp = MemRefType::get({ShapedType::kDynamic}, indexTp);
|
||||
Value dimSizesBuffer =
|
||||
createFuncCall(rewriter, loc, "getSparseTensorReaderDimSizes", memTp,
|
||||
reader, EmitCInterface::On)
|
||||
.getResult(0);
|
||||
for (const auto &d : llvm::enumerate(dstTp.getDimShape()))
|
||||
if (ShapedType::isDynamic(d.value()))
|
||||
dynSizes.push_back(rewriter.create<memref::LoadOp>(
|
||||
loc, dimSizesBuffer, constantIndex(rewriter, loc, d.index())));
|
||||
}
|
||||
|
||||
// Get the number of stored entries.
|
||||
Value nse = createFuncCall(rewriter, loc, "getSparseTensorReaderNSE",
|
||||
{indexTp}, {reader}, EmitCInterface::Off)
|
||||
.getResult(0);
|
||||
// Construct allocation for each field.
|
||||
SmallVector<Value> fields;
|
||||
createAllocFields(rewriter, loc, dstTp, dynSizes, /*enableInit=*/false,
|
||||
fields, nse);
|
||||
MutSparseTensorDescriptor desc(dstTp, fields);
|
||||
|
||||
// Construct the `dimToLvl` buffer for handing off to the runtime library.
|
||||
SmallVector<Value> dimToLvlValues(dimRank);
|
||||
if (!dstTp.isIdentity()) {
|
||||
const auto dimToLvl = dstTp.getDimToLvl();
|
||||
assert(dimToLvl.isPermutation() && "Got non-permutation");
|
||||
for (Level l = 0; l < lvlRank; l++) {
|
||||
const Dimension d = dimToLvl.getDimPosition(l);
|
||||
dimToLvlValues[d] = constantIndex(rewriter, loc, l);
|
||||
}
|
||||
} else {
|
||||
// The `SparseTensorType` ctor already ensures `dimRank == lvlRank`
|
||||
// when `isIdentity`; so no need to re-assert it here.
|
||||
for (Dimension d = 0; d < dimRank; d++)
|
||||
dimToLvlValues[d] = constantIndex(rewriter, loc, d);
|
||||
}
|
||||
Value dimToLvl = allocaBuffer(rewriter, loc, dimToLvlValues);
|
||||
// Now construct the dim2lvl and lvl2dim buffers.
|
||||
Value dim2lvlBuffer;
|
||||
Value lvl2dimBuffer;
|
||||
genReaderBuffers(rewriter, loc, dstTp, dimShapesValues, dimSizesBuffer,
|
||||
dim2lvlBuffer, lvl2dimBuffer);
|
||||
|
||||
// Read the COO tensor data.
|
||||
Value xs = desc.getAOSMemRef();
|
||||
Value ys = desc.getValMemRef();
|
||||
|
||||
const Type boolTp = rewriter.getIntegerType(1);
|
||||
const Type elemTp = dstTp.getElementType();
|
||||
const Type crdTp = dstTp.getCrdType();
|
||||
@@ -1527,11 +1492,13 @@ struct SparseNewOpConverter : public OpConversionPattern<NewOp> {
|
||||
primaryTypeFunctionSuffix(elemTp)};
|
||||
Value isSorted =
|
||||
createFuncCall(rewriter, loc, readToBuffersFuncName, {boolTp},
|
||||
{reader, dimToLvl, xs, ys}, EmitCInterface::On)
|
||||
{reader, dim2lvlBuffer, lvl2dimBuffer, xs, ys},
|
||||
EmitCInterface::On)
|
||||
.getResult(0);
|
||||
|
||||
// If the destination tensor is a sorted COO, we need to sort the COO tensor
|
||||
// data if the input elements aren't sorted yet.
|
||||
const Level lvlRank = dstTp.getLvlRank();
|
||||
if (dstTp.isOrderedLvl(lvlRank - 1)) {
|
||||
Value kFalse = constantI1(rewriter, loc, false);
|
||||
Value notSorted = rewriter.create<arith::CmpIOp>(
|
||||
@@ -1593,7 +1560,7 @@ void mlir::populateSparseTensorCodegenPatterns(
|
||||
StorageSpecifierKind::DimStride>,
|
||||
SparseToPositionsConverter, SparseToCoordinatesConverter,
|
||||
SparseToCoordinatesBufferConverter, SparseToValuesConverter,
|
||||
SparseConvertConverter, SparseNewOpConverter,
|
||||
SparseConvertConverter, SparseNewConverter,
|
||||
SparseNumberOfEntriesConverter>(typeConverter,
|
||||
patterns.getContext());
|
||||
patterns.add<SparseTensorDeallocConverter>(
|
||||
|
||||
Reference in New Issue
Block a user