[mlir][sparse] Generate AOS subviews on-demand.

Previously, we generate AOS subviews for indices buffers when constructing an
immutable sparse tensor descriptor. We now only generate such subviews when
getIdxMemRefOrView is requested.

Reviewed By: Peiming

Differential Revision: https://reviews.llvm.org/D141325
This commit is contained in:
bixia1
2023-01-10 12:33:10 -08:00
parent a18fe67b9f
commit 52028c1a48
4 changed files with 88 additions and 154 deletions

View File

@@ -295,11 +295,10 @@ static Value genCompressed(OpBuilder &builder, Location loc,
unsigned idxIndex;
unsigned idxStride;
std::tie(idxIndex, idxStride) = desc.getIdxMemRefIndexAndStride(d);
unsigned ptrIndex = desc.getPtrMemRefIndex(d);
Value one = constantIndex(builder, loc, 1);
Value pp1 = builder.create<arith::AddIOp>(loc, pos, one);
Value plo = genLoad(builder, loc, desc.getMemRefField(ptrIndex), pos);
Value phi = genLoad(builder, loc, desc.getMemRefField(ptrIndex), pp1);
Value plo = genLoad(builder, loc, desc.getPtrMemRef(d), pos);
Value phi = genLoad(builder, loc, desc.getPtrMemRef(d), pp1);
Value msz = desc.getIdxMemSize(builder, loc, d);
Value idxStrideC;
if (idxStride > 1) {
@@ -325,7 +324,7 @@ static Value genCompressed(OpBuilder &builder, Location loc,
builder.create<scf::YieldOp>(loc, eq);
builder.setInsertionPointToStart(&ifOp1.getElseRegion().front());
if (d > 0)
genStore(builder, loc, msz, desc.getMemRefField(ptrIndex), pos);
genStore(builder, loc, msz, desc.getPtrMemRef(d), pos);
builder.create<scf::YieldOp>(loc, constantI1(builder, loc, false));
builder.setInsertionPointAfter(ifOp1);
Value p = ifOp1.getResult(0);
@@ -352,7 +351,7 @@ static Value genCompressed(OpBuilder &builder, Location loc,
// If !present (changes fields, update next).
builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
Value mszp1 = builder.create<arith::AddIOp>(loc, msz, one);
genStore(builder, loc, mszp1, desc.getMemRefField(ptrIndex), pp1);
genStore(builder, loc, mszp1, desc.getPtrMemRef(d), pp1);
createPushback(builder, loc, desc, SparseTensorFieldKind::IdxMemRef, d,
indices[d]);
// Prepare the next dimension "as needed".
@@ -638,10 +637,8 @@ public:
if (!index || !getSparseTensorEncoding(adaptor.getSource().getType()))
return failure();
Location loc = op.getLoc();
auto desc =
getDescriptorFromTensorTuple(rewriter, loc, adaptor.getSource());
auto sz = sizeFromTensorAtDim(rewriter, loc, desc, *index);
auto desc = getDescriptorFromTensorTuple(adaptor.getSource());
auto sz = sizeFromTensorAtDim(rewriter, op.getLoc(), desc, *index);
if (!sz)
return failure();
@@ -756,8 +753,7 @@ public:
if (!getSparseTensorEncoding(op.getTensor().getType()))
return failure();
Location loc = op->getLoc();
auto desc =
getDescriptorFromTensorTuple(rewriter, loc, adaptor.getTensor());
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
RankedTensorType srcType =
op.getTensor().getType().cast<RankedTensorType>();
Type eltType = srcType.getElementType();
@@ -900,8 +896,7 @@ public:
// Replace the requested pointer access with corresponding field.
// The cast_op is inserted by type converter to intermix 1:N type
// conversion.
auto desc = getDescriptorFromTensorTuple(rewriter, op.getLoc(),
adaptor.getTensor());
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
uint64_t dim = op.getDimension().getZExtValue();
rewriter.replaceOp(op, desc.getPtrMemRef(dim));
return success();
@@ -919,17 +914,17 @@ public:
// Replace the requested pointer access with corresponding field.
// The cast_op is inserted by type converter to intermix 1:N type
// conversion.
auto desc = getDescriptorFromTensorTuple(rewriter, op.getLoc(),
adaptor.getTensor());
Location loc = op.getLoc();
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
uint64_t dim = op.getDimension().getZExtValue();
Value field = desc.getIdxMemRef(dim);
Value field = desc.getIdxMemRefOrView(rewriter, loc, dim);
// Insert a cast to bridge the actual type to the user expected type. If the
// actual type and the user expected type aren't compatible, the compiler or
// the runtime will issue an error.
Type resType = op.getResult().getType();
if (resType != field.getType())
field = rewriter.create<memref::CastOp>(op.getLoc(), resType, field);
field = rewriter.create<memref::CastOp>(loc, resType, field);
rewriter.replaceOp(op, field);
return success();
@@ -967,8 +962,7 @@ public:
// Replace the requested pointer access with corresponding field.
// The cast_op is inserted by type converter to intermix 1:N type
// conversion.
auto desc = getDescriptorFromTensorTuple(rewriter, op.getLoc(),
adaptor.getTensor());
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
rewriter.replaceOp(op, desc.getValMemRef());
return success();
}