[mlir][sparse] implement sparse_tensor.lvl operation. (#69993)

This commit is contained in:
Peiming Liu
2023-10-24 13:23:28 -07:00
committed by GitHub
parent 260dbb45ac
commit c780352de9
15 changed files with 132 additions and 106 deletions

View File

@@ -97,32 +97,6 @@ static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper,
return forOp;
}
/// Gets the dimension size for the given sparse tensor at the given
/// original dimension 'dim'.
static Value sizeFromTensorAtDim(OpBuilder &builder, Location loc,
SparseTensorDescriptor desc, Dimension dim) {
const SparseTensorType stt(desc.getRankedTensorType());
// Access into static dimension can query original type directly.
// Note that this is typically already done by DimOp's folding.
if (auto sz = stt.getStaticDimSize(dim))
return constantIndex(builder, loc, *sz);
// Any other query can consult the dimSizes array at field DimSizesIdx,
// accounting for the reordering applied to the sparse storage.
// FIXME: `toStoredDim` is deprecated.
const Level lvl = toStoredDim(stt, dim);
return desc.getLvlSize(builder, loc, lvl);
}
// Gets the dimension size at the given stored level 'lvl', either as a
// constant for a static size, or otherwise dynamically through memSizes.
static Value sizeFromTensorAtLvl(OpBuilder &builder, Location loc,
SparseTensorDescriptor desc, Level lvl) {
// FIXME: `toOrigDim` is deprecated.
return sizeFromTensorAtDim(builder, loc, desc,
toOrigDim(desc.getRankedTensorType(), lvl));
}
static void createPushback(OpBuilder &builder, Location loc,
MutSparseTensorDescriptor desc,
SparseTensorFieldKind kind, std::optional<Level> lvl,
@@ -164,7 +138,7 @@ static void allocSchemeForRank(OpBuilder &builder, Location loc,
// at this level. We will eventually reach a compressed level or
// otherwise the values array for the from-here "all-dense" case.
assert(isDenseDLT(dlt));
Value size = sizeFromTensorAtLvl(builder, loc, desc, l);
Value size = desc.getLvlSize(builder, loc, l);
linear = builder.create<arith::MulIOp>(loc, linear, size);
}
// Reached values array so prepare for an insertion.
@@ -448,7 +422,7 @@ public:
// Construct the new position as:
// positions[l] = size * positions[l-1] + coords[l]
// <insert @ positions[l] at next level l + 1>
Value size = sizeFromTensorAtLvl(builder, loc, desc, l);
Value size = desc.getLvlSize(builder, loc, l);
Value mult = builder.create<arith::MulIOp>(loc, size, parentPos);
parentPos = builder.create<arith::AddIOp>(loc, mult, coords[l]);
}
@@ -658,19 +632,19 @@ public:
}
};
/// Sparse codegen rule for dimension accesses.
class SparseDimOpConverter : public OpConversionPattern<tensor::DimOp> {
/// Sparse codegen rule for level accesses.
class SparseLvlOpConverter : public OpConversionPattern<LvlOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor,
matchAndRewrite(LvlOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
std::optional<int64_t> dim = op.getConstantIndex();
if (!dim || !getSparseTensorEncoding(adaptor.getSource().getType()))
std::optional<int64_t> lvl = op.getConstantLvlIndex();
if (!lvl || !getSparseTensorEncoding(adaptor.getSource().getType()))
return failure();
auto desc = getDescriptorFromTensorTuple(adaptor.getSource());
auto sz = sizeFromTensorAtDim(rewriter, op.getLoc(), desc, *dim);
auto sz = desc.getLvlSize(rewriter, op.getLoc(), *lvl);
rewriter.replaceOp(op, sz);
return success();
@@ -922,12 +896,10 @@ public:
Type idxType = rewriter.getIndexType();
// All initialization should be done on entry of the loop nest.
rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp());
// Determine the size for access expansion (always the innermost stored
// level size, translated back to original dimension). Note that we
// recursively rewrite the new DimOp on the **original** tensor.
// FIXME: `toOrigDim` is deprecated.
const Dimension innerDim = toOrigDim(srcType, srcType.getLvlRank() - 1);
const auto sz = sizeFromTensorAtDim(rewriter, loc, desc, innerDim);
// level size).
const auto sz = desc.getLvlSize(rewriter, loc, srcType.getLvlRank() - 1);
// Generate a memref for `sz` elements of type `t`.
const auto genAlloc = [&](Type t) {
const auto memTp = MemRefType::get({ShapedType::kDynamic}, t);
@@ -1588,7 +1560,7 @@ void mlir::populateSparseTensorCodegenPatterns(
TypeConverter &typeConverter, RewritePatternSet &patterns,
bool createSparseDeallocs, bool enableBufferInitialization) {
patterns.add<SparseAssembleOpConverter, SparseDisassembleOpConverter,
SparseReturnConverter, SparseCallConverter, SparseDimOpConverter,
SparseReturnConverter, SparseCallConverter, SparseLvlOpConverter,
SparseCastConverter, SparseExtractSliceConverter,
SparseTensorLoadConverter, SparseExpandConverter,
SparseCompressConverter, SparseInsertConverter,