[mlir][sparse] Add AOS optimization.

Use an array of structures to represent the indices for the tailing COO region
of a sparse tensor.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D140870
This commit is contained in:
bixia1
2023-01-04 13:08:06 -08:00
parent ff66d410fd
commit 3fdd85da06
8 changed files with 370 additions and 127 deletions

View File

@@ -122,7 +122,7 @@ static std::optional<Value> sizeFromTensorAtDim(OpBuilder &builder,
// Gets the dimension size at the given stored dimension 'd', either as a
// constant for a static size, or otherwise dynamically through memSizes.
Value sizeAtStoredDim(OpBuilder &builder, Location loc,
SparseTensorDescriptor desc, unsigned d) {
MutSparseTensorDescriptor desc, unsigned d) {
RankedTensorType rtp = desc.getTensorType();
unsigned dim = toOrigDim(rtp, d);
auto shape = rtp.getShape();
@@ -293,15 +293,20 @@ static Value genCompressed(OpBuilder &builder, Location loc,
SmallVector<Type> types;
Type indexType = builder.getIndexType();
Type boolType = builder.getIntegerType(1);
unsigned idxIndex = desc.getIdxMemRefIndex(d);
unsigned idxIndex;
unsigned idxStride;
std::tie(idxIndex, idxStride) = desc.getIdxMemRefIndexAndStride(d);
unsigned ptrIndex = desc.getPtrMemRefIndex(d);
Value one = constantIndex(builder, loc, 1);
Value pp1 = builder.create<arith::AddIOp>(loc, pos, one);
Value plo = genLoad(builder, loc, desc.getMemRefField(ptrIndex), pos);
Value phi = genLoad(builder, loc, desc.getMemRefField(ptrIndex), pp1);
Value msz = desc.getIdxMemSize(builder, loc, d);
// Value msz = desc.getMemSize(builder, loc, getFieldMemSizesIndex(idxIndex));
Value idxStrideC;
if (idxStride > 1) {
idxStrideC = constantIndex(builder, loc, idxStride);
msz = builder.create<arith::DivUIOp>(loc, msz, idxStrideC);
}
Value phim1 = builder.create<arith::SubIOp>(
loc, toType(builder, loc, phi, indexType), one);
// Conditional expression.
@@ -311,7 +316,10 @@ static Value genCompressed(OpBuilder &builder, Location loc,
scf::IfOp ifOp1 = builder.create<scf::IfOp>(loc, types, lt, /*else*/ true);
types.pop_back();
builder.setInsertionPointToStart(&ifOp1.getThenRegion().front());
Value crd = genLoad(builder, loc, desc.getMemRefField(idxIndex), phim1);
Value crd = genLoad(
builder, loc, desc.getMemRefField(idxIndex),
idxStride > 1 ? builder.create<arith::MulIOp>(loc, phim1, idxStrideC)
: phim1);
Value eq = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
toType(builder, loc, crd, indexType),
indices[d]);
@@ -631,8 +639,10 @@ public:
if (!index || !getSparseTensorEncoding(adaptor.getSource().getType()))
return failure();
auto desc = getDescriptorFromTensorTuple(adaptor.getSource());
auto sz = sizeFromTensorAtDim(rewriter, op.getLoc(), desc, *index);
Location loc = op.getLoc();
auto desc =
getDescriptorFromTensorTuple(rewriter, loc, adaptor.getSource());
auto sz = sizeFromTensorAtDim(rewriter, loc, desc, *index);
if (!sz)
return failure();
@@ -707,7 +717,8 @@ public:
// Replace the sparse tensor deallocation with field deallocations.
Location loc = op.getLoc();
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
SmallVector<Value> fields;
auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields);
for (auto input : desc.getMemRefFields())
// Deallocate every buffer used to store the sparse tensor handler.
rewriter.create<memref::DeallocOp>(loc, input);
@@ -746,7 +757,8 @@ public:
if (!getSparseTensorEncoding(op.getTensor().getType()))
return failure();
Location loc = op->getLoc();
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
auto desc =
getDescriptorFromTensorTuple(rewriter, loc, adaptor.getTensor());
RankedTensorType srcType =
op.getTensor().getType().cast<RankedTensorType>();
Type eltType = srcType.getElementType();
@@ -889,7 +901,8 @@ public:
// Replace the requested pointer access with corresponding field.
// The cast_op is inserted by type converter to intermix 1:N type
// conversion.
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
auto desc = getDescriptorFromTensorTuple(rewriter, op.getLoc(),
adaptor.getTensor());
uint64_t dim = op.getDimension().getZExtValue();
rewriter.replaceOp(op, desc.getPtrMemRef(dim));
return success();
@@ -907,7 +920,8 @@ public:
// Replace the requested pointer access with corresponding field.
// The cast_op is inserted by type converter to intermix 1:N type
// conversion.
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
auto desc = getDescriptorFromTensorTuple(rewriter, op.getLoc(),
adaptor.getTensor());
uint64_t dim = op.getDimension().getZExtValue();
Value field = desc.getIdxMemRef(dim);
@@ -934,7 +948,8 @@ public:
// Replace the requested pointer access with corresponding field.
// The cast_op is inserted by type converter to intermix 1:N type
// conversion.
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
auto desc = getDescriptorFromTensorTuple(rewriter, op.getLoc(),
adaptor.getTensor());
rewriter.replaceOp(op, desc.getValMemRef());
return success();
}