[mlir][sparse] avoid using mutable descriptor when unnecessary (NFC)

Use SparseTensorDescriptor whenever not calling setters, to avoid needing to create a temporal buffer for simple query purposes.

Reviewed By: bixia, wrengr

Differential Revision: https://reviews.llvm.org/D141953
This commit is contained in:
Peiming Liu
2023-01-17 19:25:40 +00:00
parent bf1ba6bb52
commit 83a50839b7
3 changed files with 77 additions and 89 deletions

View File

@@ -102,11 +102,9 @@ static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper,
}
/// Gets the dimension size for the given sparse tensor at the given
/// original dimension 'dim'. Returns std::nullopt if no sparse encoding is
/// attached to the given tensor type.
static std::optional<Value>
sizeFromTensorAtDim(OpBuilder &builder, Location loc,
const SparseTensorDescriptor &desc, unsigned dim) {
/// original dimension 'dim'.
static Value sizeFromTensorAtDim(OpBuilder &builder, Location loc,
SparseTensorDescriptor desc, unsigned dim) {
RankedTensorType rtp = desc.getTensorType();
// Access into static dimension can query original type directly.
// Note that this is typically already done by DimOp's folding.
@@ -119,17 +117,12 @@ sizeFromTensorAtDim(OpBuilder &builder, Location loc,
return desc.getDimSize(builder, loc, toStoredDim(rtp, dim));
}
// Gets the dimension size at the given stored dimension 'd', either as a
// Gets the dimension size at the given stored level 'lvl', either as a
// constant for a static size, or otherwise dynamically through memSizes.
Value sizeAtStoredDim(OpBuilder &builder, Location loc,
MutSparseTensorDescriptor desc, unsigned d) {
RankedTensorType rtp = desc.getTensorType();
unsigned dim = toOrigDim(rtp, d);
auto shape = rtp.getShape();
if (!ShapedType::isDynamic(shape[dim]))
return constantIndex(builder, loc, shape[dim]);
return desc.getDimSize(builder, loc, d);
static Value sizeFromTensorAtLvl(OpBuilder &builder, Location loc,
SparseTensorDescriptor desc, unsigned lvl) {
return sizeFromTensorAtDim(builder, loc, desc,
toOrigDim(desc.getTensorType(), lvl));
}
static void createPushback(OpBuilder &builder, Location loc,
@@ -174,7 +167,7 @@ static void allocSchemeForRank(OpBuilder &builder, Location loc,
// at this level. We will eventually reach a compressed level or
// otherwise the values array for the from-here "all-dense" case.
assert(isDenseDim(rtp, r));
Value size = sizeAtStoredDim(builder, loc, desc, r);
Value size = sizeFromTensorAtLvl(builder, loc, desc, r);
linear = builder.create<arith::MulIOp>(loc, linear, size);
}
// Reached values array so prepare for an insertion.
@@ -436,7 +429,7 @@ static void genInsertBody(OpBuilder &builder, ModuleOp module,
// Construct the new position as:
// pos[d] = size * pos[d-1] + i[d]
// <insert @ pos[d] at next dimension d + 1>
Value size = sizeAtStoredDim(builder, loc, desc, d);
Value size = sizeFromTensorAtLvl(builder, loc, desc, d);
Value mult = builder.create<arith::MulIOp>(loc, size, pos);
pos = builder.create<arith::AddIOp>(loc, mult, indices[d]);
}
@@ -517,7 +510,7 @@ static void genInsertionCallHelper(OpBuilder &builder,
/// Generations insertion finalization code.
static void genEndInsert(OpBuilder &builder, Location loc,
MutSparseTensorDescriptor desc) {
SparseTensorDescriptor desc) {
RankedTensorType rtp = desc.getTensorType();
unsigned rank = rtp.getShape().size();
for (unsigned d = 0; d < rank; d++) {
@@ -654,10 +647,7 @@ public:
auto desc = getDescriptorFromTensorTuple(adaptor.getSource());
auto sz = sizeFromTensorAtDim(rewriter, op.getLoc(), desc, *index);
if (!sz)
return failure();
rewriter.replaceOp(op, *sz);
rewriter.replaceOp(op, sz);
return success();
}
};
@@ -727,8 +717,7 @@ public:
// Replace the sparse tensor deallocation with field deallocations.
Location loc = op.getLoc();
SmallVector<Value> fields;
auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields);
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
for (auto input : desc.getMemRefFields())
// Deallocate every buffer used to store the sparse tensor handler.
rewriter.create<memref::DeallocOp>(loc, input);
@@ -746,8 +735,7 @@ public:
matchAndRewrite(LoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Prepare descriptor.
SmallVector<Value> fields;
auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields);
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
// Generate optional insertion finalization code.
if (op.getHasInserts())
genEndInsert(rewriter, op.getLoc(), desc);
@@ -780,11 +768,10 @@ public:
// recursively rewrite the new DimOp on the **original** tensor.
unsigned innerDim = toOrigDim(srcType, srcType.getRank() - 1);
auto sz = sizeFromTensorAtDim(rewriter, loc, desc, innerDim);
assert(sz); // This for sure is a sparse tensor
// Generate a memref for `sz` elements of type `t`.
auto genAlloc = [&](Type t) {
auto memTp = MemRefType::get({ShapedType::kDynamic}, t);
return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{*sz});
return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz});
};
// Allocate temporary buffers for values/filled-switch and added.
// We do not use stack buffers for this, since the expanded size may
@@ -957,8 +944,7 @@ public:
// Replace the requested pointer access with corresponding field.
// The cast_op is inserted by type converter to intermix 1:N type
// conversion.
SmallVector<Value> fields;
auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields);
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
rewriter.replaceOp(op, desc.getAOSMemRef());
return success();