[mlir][sparse] avoid using mutable descriptor when unnecessary (NFC)
Use SparseTensorDescriptor whenever not calling setters, to avoid needing to create a temporal buffer for simple query purposes. Reviewed By: bixia, wrengr Differential Revision: https://reviews.llvm.org/D141953
This commit is contained in:
@@ -102,11 +102,9 @@ static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper,
|
||||
}
|
||||
|
||||
/// Gets the dimension size for the given sparse tensor at the given
|
||||
/// original dimension 'dim'. Returns std::nullopt if no sparse encoding is
|
||||
/// attached to the given tensor type.
|
||||
static std::optional<Value>
|
||||
sizeFromTensorAtDim(OpBuilder &builder, Location loc,
|
||||
const SparseTensorDescriptor &desc, unsigned dim) {
|
||||
/// original dimension 'dim'.
|
||||
static Value sizeFromTensorAtDim(OpBuilder &builder, Location loc,
|
||||
SparseTensorDescriptor desc, unsigned dim) {
|
||||
RankedTensorType rtp = desc.getTensorType();
|
||||
// Access into static dimension can query original type directly.
|
||||
// Note that this is typically already done by DimOp's folding.
|
||||
@@ -119,17 +117,12 @@ sizeFromTensorAtDim(OpBuilder &builder, Location loc,
|
||||
return desc.getDimSize(builder, loc, toStoredDim(rtp, dim));
|
||||
}
|
||||
|
||||
// Gets the dimension size at the given stored dimension 'd', either as a
|
||||
// Gets the dimension size at the given stored level 'lvl', either as a
|
||||
// constant for a static size, or otherwise dynamically through memSizes.
|
||||
Value sizeAtStoredDim(OpBuilder &builder, Location loc,
|
||||
MutSparseTensorDescriptor desc, unsigned d) {
|
||||
RankedTensorType rtp = desc.getTensorType();
|
||||
unsigned dim = toOrigDim(rtp, d);
|
||||
auto shape = rtp.getShape();
|
||||
if (!ShapedType::isDynamic(shape[dim]))
|
||||
return constantIndex(builder, loc, shape[dim]);
|
||||
|
||||
return desc.getDimSize(builder, loc, d);
|
||||
static Value sizeFromTensorAtLvl(OpBuilder &builder, Location loc,
|
||||
SparseTensorDescriptor desc, unsigned lvl) {
|
||||
return sizeFromTensorAtDim(builder, loc, desc,
|
||||
toOrigDim(desc.getTensorType(), lvl));
|
||||
}
|
||||
|
||||
static void createPushback(OpBuilder &builder, Location loc,
|
||||
@@ -174,7 +167,7 @@ static void allocSchemeForRank(OpBuilder &builder, Location loc,
|
||||
// at this level. We will eventually reach a compressed level or
|
||||
// otherwise the values array for the from-here "all-dense" case.
|
||||
assert(isDenseDim(rtp, r));
|
||||
Value size = sizeAtStoredDim(builder, loc, desc, r);
|
||||
Value size = sizeFromTensorAtLvl(builder, loc, desc, r);
|
||||
linear = builder.create<arith::MulIOp>(loc, linear, size);
|
||||
}
|
||||
// Reached values array so prepare for an insertion.
|
||||
@@ -436,7 +429,7 @@ static void genInsertBody(OpBuilder &builder, ModuleOp module,
|
||||
// Construct the new position as:
|
||||
// pos[d] = size * pos[d-1] + i[d]
|
||||
// <insert @ pos[d] at next dimension d + 1>
|
||||
Value size = sizeAtStoredDim(builder, loc, desc, d);
|
||||
Value size = sizeFromTensorAtLvl(builder, loc, desc, d);
|
||||
Value mult = builder.create<arith::MulIOp>(loc, size, pos);
|
||||
pos = builder.create<arith::AddIOp>(loc, mult, indices[d]);
|
||||
}
|
||||
@@ -517,7 +510,7 @@ static void genInsertionCallHelper(OpBuilder &builder,
|
||||
|
||||
/// Generations insertion finalization code.
|
||||
static void genEndInsert(OpBuilder &builder, Location loc,
|
||||
MutSparseTensorDescriptor desc) {
|
||||
SparseTensorDescriptor desc) {
|
||||
RankedTensorType rtp = desc.getTensorType();
|
||||
unsigned rank = rtp.getShape().size();
|
||||
for (unsigned d = 0; d < rank; d++) {
|
||||
@@ -654,10 +647,7 @@ public:
|
||||
auto desc = getDescriptorFromTensorTuple(adaptor.getSource());
|
||||
auto sz = sizeFromTensorAtDim(rewriter, op.getLoc(), desc, *index);
|
||||
|
||||
if (!sz)
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOp(op, *sz);
|
||||
rewriter.replaceOp(op, sz);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -727,8 +717,7 @@ public:
|
||||
|
||||
// Replace the sparse tensor deallocation with field deallocations.
|
||||
Location loc = op.getLoc();
|
||||
SmallVector<Value> fields;
|
||||
auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields);
|
||||
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
|
||||
for (auto input : desc.getMemRefFields())
|
||||
// Deallocate every buffer used to store the sparse tensor handler.
|
||||
rewriter.create<memref::DeallocOp>(loc, input);
|
||||
@@ -746,8 +735,7 @@ public:
|
||||
matchAndRewrite(LoadOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// Prepare descriptor.
|
||||
SmallVector<Value> fields;
|
||||
auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields);
|
||||
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
|
||||
// Generate optional insertion finalization code.
|
||||
if (op.getHasInserts())
|
||||
genEndInsert(rewriter, op.getLoc(), desc);
|
||||
@@ -780,11 +768,10 @@ public:
|
||||
// recursively rewrite the new DimOp on the **original** tensor.
|
||||
unsigned innerDim = toOrigDim(srcType, srcType.getRank() - 1);
|
||||
auto sz = sizeFromTensorAtDim(rewriter, loc, desc, innerDim);
|
||||
assert(sz); // This for sure is a sparse tensor
|
||||
// Generate a memref for `sz` elements of type `t`.
|
||||
auto genAlloc = [&](Type t) {
|
||||
auto memTp = MemRefType::get({ShapedType::kDynamic}, t);
|
||||
return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{*sz});
|
||||
return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz});
|
||||
};
|
||||
// Allocate temporary buffers for values/filled-switch and added.
|
||||
// We do not use stack buffers for this, since the expanded size may
|
||||
@@ -957,8 +944,7 @@ public:
|
||||
// Replace the requested pointer access with corresponding field.
|
||||
// The cast_op is inserted by type converter to intermix 1:N type
|
||||
// conversion.
|
||||
SmallVector<Value> fields;
|
||||
auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields);
|
||||
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
|
||||
rewriter.replaceOp(op, desc.getAOSMemRef());
|
||||
|
||||
return success();
|
||||
|
||||
Reference in New Issue
Block a user