[mlir][sparse] force a properly sized view on pos/crd/val under codegen (#91288)
Codegen "vectors" for pos/crd/val use the capacity as memref size, not the actual used size. Although the sparsifier itself always uses just the defined pos/crd/val parts, printing these and passing them back to a runtime environment could benefit from wrapping the basic pos/crd/val getters into a proper memref view that sets the right size.
This commit is contained in:
@@ -1050,10 +1050,14 @@ public:
|
||||
matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// Replace the requested position access with corresponding field.
|
||||
// The cast_op is inserted by type converter to intermix 1:N type
|
||||
// conversion.
|
||||
// The view is restricted to the actual size to ensure clients
|
||||
// of this operation truly observe size, not capacity!
|
||||
Location loc = op.getLoc();
|
||||
Level lvl = op.getLevel();
|
||||
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
|
||||
rewriter.replaceOp(op, desc.getPosMemRef(op.getLevel()));
|
||||
auto mem = desc.getPosMemRef(lvl);
|
||||
auto size = desc.getPosMemSize(rewriter, loc, lvl);
|
||||
rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -1068,12 +1072,17 @@ public:
|
||||
matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// Replace the requested coordinates access with corresponding field.
|
||||
// The cast_op is inserted by type converter to intermix 1:N type
|
||||
// conversion.
|
||||
// The view is restricted to the actual size to ensure clients
|
||||
// of this operation truly observe size, not capacity!
|
||||
Location loc = op.getLoc();
|
||||
Level lvl = op.getLevel();
|
||||
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
|
||||
rewriter.replaceOp(
|
||||
op, desc.getCrdMemRefOrView(rewriter, op.getLoc(), op.getLevel()));
|
||||
|
||||
auto mem = desc.getCrdMemRefOrView(rewriter, loc, lvl);
|
||||
if (lvl < getSparseTensorType(op.getTensor()).getAoSCOOStart()) {
|
||||
auto size = desc.getCrdMemSize(rewriter, loc, lvl);
|
||||
mem = genSliceToSize(rewriter, loc, mem, size);
|
||||
}
|
||||
rewriter.replaceOp(op, mem);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -1088,11 +1097,14 @@ public:
|
||||
matchAndRewrite(ToCoordinatesBufferOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// Replace the requested coordinates access with corresponding field.
|
||||
// The cast_op is inserted by type converter to intermix 1:N type
|
||||
// conversion.
|
||||
// The view is restricted to the actual size to ensure clients
|
||||
// of this operation truly observe size, not capacity!
|
||||
Location loc = op.getLoc();
|
||||
Level lvl = getSparseTensorType(op.getTensor()).getAoSCOOStart();
|
||||
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
|
||||
rewriter.replaceOp(op, desc.getAOSMemRef());
|
||||
|
||||
auto mem = desc.getAOSMemRef();
|
||||
auto size = desc.getCrdMemSize(rewriter, loc, lvl);
|
||||
rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -1106,10 +1118,13 @@ public:
|
||||
matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// Replace the requested values access with corresponding field.
|
||||
// The cast_op is inserted by type converter to intermix 1:N type
|
||||
// conversion.
|
||||
// The view is restricted to the actual size to ensure clients
|
||||
// of this operation truly observe size, not capacity!
|
||||
Location loc = op.getLoc();
|
||||
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
|
||||
rewriter.replaceOp(op, desc.getValMemRef());
|
||||
auto mem = desc.getValMemRef();
|
||||
auto size = desc.getValMemSize(rewriter, loc);
|
||||
rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user