[mlir][sparse] Add AOS optimization.
Use an array of structures to represent the indices for the tailing COO region of a sparse tensor. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D140870
This commit is contained in:
@@ -122,7 +122,7 @@ static std::optional<Value> sizeFromTensorAtDim(OpBuilder &builder,
|
||||
// Gets the dimension size at the given stored dimension 'd', either as a
|
||||
// constant for a static size, or otherwise dynamically through memSizes.
|
||||
Value sizeAtStoredDim(OpBuilder &builder, Location loc,
|
||||
SparseTensorDescriptor desc, unsigned d) {
|
||||
MutSparseTensorDescriptor desc, unsigned d) {
|
||||
RankedTensorType rtp = desc.getTensorType();
|
||||
unsigned dim = toOrigDim(rtp, d);
|
||||
auto shape = rtp.getShape();
|
||||
@@ -293,15 +293,20 @@ static Value genCompressed(OpBuilder &builder, Location loc,
|
||||
SmallVector<Type> types;
|
||||
Type indexType = builder.getIndexType();
|
||||
Type boolType = builder.getIntegerType(1);
|
||||
unsigned idxIndex = desc.getIdxMemRefIndex(d);
|
||||
unsigned idxIndex;
|
||||
unsigned idxStride;
|
||||
std::tie(idxIndex, idxStride) = desc.getIdxMemRefIndexAndStride(d);
|
||||
unsigned ptrIndex = desc.getPtrMemRefIndex(d);
|
||||
Value one = constantIndex(builder, loc, 1);
|
||||
Value pp1 = builder.create<arith::AddIOp>(loc, pos, one);
|
||||
Value plo = genLoad(builder, loc, desc.getMemRefField(ptrIndex), pos);
|
||||
Value phi = genLoad(builder, loc, desc.getMemRefField(ptrIndex), pp1);
|
||||
Value msz = desc.getIdxMemSize(builder, loc, d);
|
||||
// Value msz = desc.getMemSize(builder, loc, getFieldMemSizesIndex(idxIndex));
|
||||
|
||||
Value idxStrideC;
|
||||
if (idxStride > 1) {
|
||||
idxStrideC = constantIndex(builder, loc, idxStride);
|
||||
msz = builder.create<arith::DivUIOp>(loc, msz, idxStrideC);
|
||||
}
|
||||
Value phim1 = builder.create<arith::SubIOp>(
|
||||
loc, toType(builder, loc, phi, indexType), one);
|
||||
// Conditional expression.
|
||||
@@ -311,7 +316,10 @@ static Value genCompressed(OpBuilder &builder, Location loc,
|
||||
scf::IfOp ifOp1 = builder.create<scf::IfOp>(loc, types, lt, /*else*/ true);
|
||||
types.pop_back();
|
||||
builder.setInsertionPointToStart(&ifOp1.getThenRegion().front());
|
||||
Value crd = genLoad(builder, loc, desc.getMemRefField(idxIndex), phim1);
|
||||
Value crd = genLoad(
|
||||
builder, loc, desc.getMemRefField(idxIndex),
|
||||
idxStride > 1 ? builder.create<arith::MulIOp>(loc, phim1, idxStrideC)
|
||||
: phim1);
|
||||
Value eq = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
|
||||
toType(builder, loc, crd, indexType),
|
||||
indices[d]);
|
||||
@@ -631,8 +639,10 @@ public:
|
||||
if (!index || !getSparseTensorEncoding(adaptor.getSource().getType()))
|
||||
return failure();
|
||||
|
||||
auto desc = getDescriptorFromTensorTuple(adaptor.getSource());
|
||||
auto sz = sizeFromTensorAtDim(rewriter, op.getLoc(), desc, *index);
|
||||
Location loc = op.getLoc();
|
||||
auto desc =
|
||||
getDescriptorFromTensorTuple(rewriter, loc, adaptor.getSource());
|
||||
auto sz = sizeFromTensorAtDim(rewriter, loc, desc, *index);
|
||||
|
||||
if (!sz)
|
||||
return failure();
|
||||
@@ -707,7 +717,8 @@ public:
|
||||
|
||||
// Replace the sparse tensor deallocation with field deallocations.
|
||||
Location loc = op.getLoc();
|
||||
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
|
||||
SmallVector<Value> fields;
|
||||
auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields);
|
||||
for (auto input : desc.getMemRefFields())
|
||||
// Deallocate every buffer used to store the sparse tensor handler.
|
||||
rewriter.create<memref::DeallocOp>(loc, input);
|
||||
@@ -746,7 +757,8 @@ public:
|
||||
if (!getSparseTensorEncoding(op.getTensor().getType()))
|
||||
return failure();
|
||||
Location loc = op->getLoc();
|
||||
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
|
||||
auto desc =
|
||||
getDescriptorFromTensorTuple(rewriter, loc, adaptor.getTensor());
|
||||
RankedTensorType srcType =
|
||||
op.getTensor().getType().cast<RankedTensorType>();
|
||||
Type eltType = srcType.getElementType();
|
||||
@@ -889,7 +901,8 @@ public:
|
||||
// Replace the requested pointer access with corresponding field.
|
||||
// The cast_op is inserted by type converter to intermix 1:N type
|
||||
// conversion.
|
||||
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
|
||||
auto desc = getDescriptorFromTensorTuple(rewriter, op.getLoc(),
|
||||
adaptor.getTensor());
|
||||
uint64_t dim = op.getDimension().getZExtValue();
|
||||
rewriter.replaceOp(op, desc.getPtrMemRef(dim));
|
||||
return success();
|
||||
@@ -907,7 +920,8 @@ public:
|
||||
// Replace the requested pointer access with corresponding field.
|
||||
// The cast_op is inserted by type converter to intermix 1:N type
|
||||
// conversion.
|
||||
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
|
||||
auto desc = getDescriptorFromTensorTuple(rewriter, op.getLoc(),
|
||||
adaptor.getTensor());
|
||||
uint64_t dim = op.getDimension().getZExtValue();
|
||||
Value field = desc.getIdxMemRef(dim);
|
||||
|
||||
@@ -934,7 +948,8 @@ public:
|
||||
// Replace the requested pointer access with corresponding field.
|
||||
// The cast_op is inserted by type converter to intermix 1:N type
|
||||
// conversion.
|
||||
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
|
||||
auto desc = getDescriptorFromTensorTuple(rewriter, op.getLoc(),
|
||||
adaptor.getTensor());
|
||||
rewriter.replaceOp(op, desc.getValMemRef());
|
||||
return success();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user