[mlir][sparse] implement sparse_tensor.lvl operation. (#69993)
This commit is contained in:
@@ -97,32 +97,6 @@ static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper,
|
||||
return forOp;
|
||||
}
|
||||
|
||||
/// Gets the dimension size for the given sparse tensor at the given
|
||||
/// original dimension 'dim'.
|
||||
static Value sizeFromTensorAtDim(OpBuilder &builder, Location loc,
|
||||
SparseTensorDescriptor desc, Dimension dim) {
|
||||
const SparseTensorType stt(desc.getRankedTensorType());
|
||||
// Access into static dimension can query original type directly.
|
||||
// Note that this is typically already done by DimOp's folding.
|
||||
if (auto sz = stt.getStaticDimSize(dim))
|
||||
return constantIndex(builder, loc, *sz);
|
||||
|
||||
// Any other query can consult the dimSizes array at field DimSizesIdx,
|
||||
// accounting for the reordering applied to the sparse storage.
|
||||
// FIXME: `toStoredDim` is deprecated.
|
||||
const Level lvl = toStoredDim(stt, dim);
|
||||
return desc.getLvlSize(builder, loc, lvl);
|
||||
}
|
||||
|
||||
// Gets the dimension size at the given stored level 'lvl', either as a
|
||||
// constant for a static size, or otherwise dynamically through memSizes.
|
||||
static Value sizeFromTensorAtLvl(OpBuilder &builder, Location loc,
|
||||
SparseTensorDescriptor desc, Level lvl) {
|
||||
// FIXME: `toOrigDim` is deprecated.
|
||||
return sizeFromTensorAtDim(builder, loc, desc,
|
||||
toOrigDim(desc.getRankedTensorType(), lvl));
|
||||
}
|
||||
|
||||
static void createPushback(OpBuilder &builder, Location loc,
|
||||
MutSparseTensorDescriptor desc,
|
||||
SparseTensorFieldKind kind, std::optional<Level> lvl,
|
||||
@@ -164,7 +138,7 @@ static void allocSchemeForRank(OpBuilder &builder, Location loc,
|
||||
// at this level. We will eventually reach a compressed level or
|
||||
// otherwise the values array for the from-here "all-dense" case.
|
||||
assert(isDenseDLT(dlt));
|
||||
Value size = sizeFromTensorAtLvl(builder, loc, desc, l);
|
||||
Value size = desc.getLvlSize(builder, loc, l);
|
||||
linear = builder.create<arith::MulIOp>(loc, linear, size);
|
||||
}
|
||||
// Reached values array so prepare for an insertion.
|
||||
@@ -448,7 +422,7 @@ public:
|
||||
// Construct the new position as:
|
||||
// positions[l] = size * positions[l-1] + coords[l]
|
||||
// <insert @ positions[l] at next level l + 1>
|
||||
Value size = sizeFromTensorAtLvl(builder, loc, desc, l);
|
||||
Value size = desc.getLvlSize(builder, loc, l);
|
||||
Value mult = builder.create<arith::MulIOp>(loc, size, parentPos);
|
||||
parentPos = builder.create<arith::AddIOp>(loc, mult, coords[l]);
|
||||
}
|
||||
@@ -658,19 +632,19 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
/// Sparse codegen rule for dimension accesses.
|
||||
class SparseDimOpConverter : public OpConversionPattern<tensor::DimOp> {
|
||||
/// Sparse codegen rule for level accesses.
|
||||
class SparseLvlOpConverter : public OpConversionPattern<LvlOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor,
|
||||
matchAndRewrite(LvlOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
std::optional<int64_t> dim = op.getConstantIndex();
|
||||
if (!dim || !getSparseTensorEncoding(adaptor.getSource().getType()))
|
||||
std::optional<int64_t> lvl = op.getConstantLvlIndex();
|
||||
if (!lvl || !getSparseTensorEncoding(adaptor.getSource().getType()))
|
||||
return failure();
|
||||
|
||||
auto desc = getDescriptorFromTensorTuple(adaptor.getSource());
|
||||
auto sz = sizeFromTensorAtDim(rewriter, op.getLoc(), desc, *dim);
|
||||
auto sz = desc.getLvlSize(rewriter, op.getLoc(), *lvl);
|
||||
|
||||
rewriter.replaceOp(op, sz);
|
||||
return success();
|
||||
@@ -922,12 +896,10 @@ public:
|
||||
Type idxType = rewriter.getIndexType();
|
||||
// All initialization should be done on entry of the loop nest.
|
||||
rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp());
|
||||
|
||||
// Determine the size for access expansion (always the innermost stored
|
||||
// level size, translated back to original dimension). Note that we
|
||||
// recursively rewrite the new DimOp on the **original** tensor.
|
||||
// FIXME: `toOrigDim` is deprecated.
|
||||
const Dimension innerDim = toOrigDim(srcType, srcType.getLvlRank() - 1);
|
||||
const auto sz = sizeFromTensorAtDim(rewriter, loc, desc, innerDim);
|
||||
// level size).
|
||||
const auto sz = desc.getLvlSize(rewriter, loc, srcType.getLvlRank() - 1);
|
||||
// Generate a memref for `sz` elements of type `t`.
|
||||
const auto genAlloc = [&](Type t) {
|
||||
const auto memTp = MemRefType::get({ShapedType::kDynamic}, t);
|
||||
@@ -1588,7 +1560,7 @@ void mlir::populateSparseTensorCodegenPatterns(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
bool createSparseDeallocs, bool enableBufferInitialization) {
|
||||
patterns.add<SparseAssembleOpConverter, SparseDisassembleOpConverter,
|
||||
SparseReturnConverter, SparseCallConverter, SparseDimOpConverter,
|
||||
SparseReturnConverter, SparseCallConverter, SparseLvlOpConverter,
|
||||
SparseCastConverter, SparseExtractSliceConverter,
|
||||
SparseTensorLoadConverter, SparseExpandConverter,
|
||||
SparseCompressConverter, SparseInsertConverter,
|
||||
|
||||
Reference in New Issue
Block a user