[mlir][sparse] Generate AOS subviews on-demand.
Previously, we generate AOS subviews for indices buffers when constructing an immutable sparse tensor descriptor. We now only generate such subviews when getIdxMemRefOrView is requested. Reviewed By: Peiming Differential Revision: https://reviews.llvm.org/D141325
This commit is contained in:
@@ -295,11 +295,10 @@ static Value genCompressed(OpBuilder &builder, Location loc,
|
||||
unsigned idxIndex;
|
||||
unsigned idxStride;
|
||||
std::tie(idxIndex, idxStride) = desc.getIdxMemRefIndexAndStride(d);
|
||||
unsigned ptrIndex = desc.getPtrMemRefIndex(d);
|
||||
Value one = constantIndex(builder, loc, 1);
|
||||
Value pp1 = builder.create<arith::AddIOp>(loc, pos, one);
|
||||
Value plo = genLoad(builder, loc, desc.getMemRefField(ptrIndex), pos);
|
||||
Value phi = genLoad(builder, loc, desc.getMemRefField(ptrIndex), pp1);
|
||||
Value plo = genLoad(builder, loc, desc.getPtrMemRef(d), pos);
|
||||
Value phi = genLoad(builder, loc, desc.getPtrMemRef(d), pp1);
|
||||
Value msz = desc.getIdxMemSize(builder, loc, d);
|
||||
Value idxStrideC;
|
||||
if (idxStride > 1) {
|
||||
@@ -325,7 +324,7 @@ static Value genCompressed(OpBuilder &builder, Location loc,
|
||||
builder.create<scf::YieldOp>(loc, eq);
|
||||
builder.setInsertionPointToStart(&ifOp1.getElseRegion().front());
|
||||
if (d > 0)
|
||||
genStore(builder, loc, msz, desc.getMemRefField(ptrIndex), pos);
|
||||
genStore(builder, loc, msz, desc.getPtrMemRef(d), pos);
|
||||
builder.create<scf::YieldOp>(loc, constantI1(builder, loc, false));
|
||||
builder.setInsertionPointAfter(ifOp1);
|
||||
Value p = ifOp1.getResult(0);
|
||||
@@ -352,7 +351,7 @@ static Value genCompressed(OpBuilder &builder, Location loc,
|
||||
// If !present (changes fields, update next).
|
||||
builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
|
||||
Value mszp1 = builder.create<arith::AddIOp>(loc, msz, one);
|
||||
genStore(builder, loc, mszp1, desc.getMemRefField(ptrIndex), pp1);
|
||||
genStore(builder, loc, mszp1, desc.getPtrMemRef(d), pp1);
|
||||
createPushback(builder, loc, desc, SparseTensorFieldKind::IdxMemRef, d,
|
||||
indices[d]);
|
||||
// Prepare the next dimension "as needed".
|
||||
@@ -638,10 +637,8 @@ public:
|
||||
if (!index || !getSparseTensorEncoding(adaptor.getSource().getType()))
|
||||
return failure();
|
||||
|
||||
Location loc = op.getLoc();
|
||||
auto desc =
|
||||
getDescriptorFromTensorTuple(rewriter, loc, adaptor.getSource());
|
||||
auto sz = sizeFromTensorAtDim(rewriter, loc, desc, *index);
|
||||
auto desc = getDescriptorFromTensorTuple(adaptor.getSource());
|
||||
auto sz = sizeFromTensorAtDim(rewriter, op.getLoc(), desc, *index);
|
||||
|
||||
if (!sz)
|
||||
return failure();
|
||||
@@ -756,8 +753,7 @@ public:
|
||||
if (!getSparseTensorEncoding(op.getTensor().getType()))
|
||||
return failure();
|
||||
Location loc = op->getLoc();
|
||||
auto desc =
|
||||
getDescriptorFromTensorTuple(rewriter, loc, adaptor.getTensor());
|
||||
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
|
||||
RankedTensorType srcType =
|
||||
op.getTensor().getType().cast<RankedTensorType>();
|
||||
Type eltType = srcType.getElementType();
|
||||
@@ -900,8 +896,7 @@ public:
|
||||
// Replace the requested pointer access with corresponding field.
|
||||
// The cast_op is inserted by type converter to intermix 1:N type
|
||||
// conversion.
|
||||
auto desc = getDescriptorFromTensorTuple(rewriter, op.getLoc(),
|
||||
adaptor.getTensor());
|
||||
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
|
||||
uint64_t dim = op.getDimension().getZExtValue();
|
||||
rewriter.replaceOp(op, desc.getPtrMemRef(dim));
|
||||
return success();
|
||||
@@ -919,17 +914,17 @@ public:
|
||||
// Replace the requested pointer access with corresponding field.
|
||||
// The cast_op is inserted by type converter to intermix 1:N type
|
||||
// conversion.
|
||||
auto desc = getDescriptorFromTensorTuple(rewriter, op.getLoc(),
|
||||
adaptor.getTensor());
|
||||
Location loc = op.getLoc();
|
||||
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
|
||||
uint64_t dim = op.getDimension().getZExtValue();
|
||||
Value field = desc.getIdxMemRef(dim);
|
||||
Value field = desc.getIdxMemRefOrView(rewriter, loc, dim);
|
||||
|
||||
// Insert a cast to bridge the actual type to the user expected type. If the
|
||||
// actual type and the user expected type aren't compatible, the compiler or
|
||||
// the runtime will issue an error.
|
||||
Type resType = op.getResult().getType();
|
||||
if (resType != field.getType())
|
||||
field = rewriter.create<memref::CastOp>(op.getLoc(), resType, field);
|
||||
field = rewriter.create<memref::CastOp>(loc, resType, field);
|
||||
rewriter.replaceOp(op, field);
|
||||
|
||||
return success();
|
||||
@@ -967,8 +962,7 @@ public:
|
||||
// Replace the requested pointer access with corresponding field.
|
||||
// The cast_op is inserted by type converter to intermix 1:N type
|
||||
// conversion.
|
||||
auto desc = getDescriptorFromTensorTuple(rewriter, op.getLoc(),
|
||||
adaptor.getTensor());
|
||||
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
|
||||
rewriter.replaceOp(op, desc.getValMemRef());
|
||||
return success();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user