[mlir][sparse] force a properly sized view on pos/crd/val under codegen (#91288)

Codegen "vectors" for pos/crd/val use the capacity as memref size, not
the actual used size. Although the sparsifier itself always uses just
the defined pos/crd/val parts, printing these and passing them back to a
runtime environment could benefit from wrapping the basic pos/crd/val
getters into a proper memref view that sets the right size.
This commit is contained in:
Aart Bik
2024-05-07 09:20:56 -07:00
committed by GitHub
parent 8bcb073705
commit 5c5116556f
10 changed files with 240 additions and 239 deletions

View File

@@ -1050,10 +1050,14 @@ public:
matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Replace the requested position access with corresponding field.
// The cast_op is inserted by type converter to intermix 1:N type
// conversion.
// The view is restricted to the actual size to ensure clients
// of this operation truly observe size, not capacity!
Location loc = op.getLoc();
Level lvl = op.getLevel();
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
rewriter.replaceOp(op, desc.getPosMemRef(op.getLevel()));
auto mem = desc.getPosMemRef(lvl);
auto size = desc.getPosMemSize(rewriter, loc, lvl);
rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
return success();
}
};
@@ -1068,12 +1072,17 @@ public:
matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Replace the requested coordinates access with corresponding field.
// The cast_op is inserted by type converter to intermix 1:N type
// conversion.
// The view is restricted to the actual size to ensure clients
// of this operation truly observe size, not capacity!
Location loc = op.getLoc();
Level lvl = op.getLevel();
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
rewriter.replaceOp(
op, desc.getCrdMemRefOrView(rewriter, op.getLoc(), op.getLevel()));
auto mem = desc.getCrdMemRefOrView(rewriter, loc, lvl);
if (lvl < getSparseTensorType(op.getTensor()).getAoSCOOStart()) {
auto size = desc.getCrdMemSize(rewriter, loc, lvl);
mem = genSliceToSize(rewriter, loc, mem, size);
}
rewriter.replaceOp(op, mem);
return success();
}
};
@@ -1088,11 +1097,14 @@ public:
matchAndRewrite(ToCoordinatesBufferOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Replace the requested coordinates access with corresponding field.
// The cast_op is inserted by type converter to intermix 1:N type
// conversion.
// The view is restricted to the actual size to ensure clients
// of this operation truly observe size, not capacity!
Location loc = op.getLoc();
Level lvl = getSparseTensorType(op.getTensor()).getAoSCOOStart();
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
rewriter.replaceOp(op, desc.getAOSMemRef());
auto mem = desc.getAOSMemRef();
auto size = desc.getCrdMemSize(rewriter, loc, lvl);
rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
return success();
}
};
@@ -1106,10 +1118,13 @@ public:
matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Replace the requested values access with corresponding field.
// The cast_op is inserted by type converter to intermix 1:N type
// conversion.
// The view is restricted to the actual size to ensure clients
// of this operation truly observe size, not capacity!
Location loc = op.getLoc();
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
rewriter.replaceOp(op, desc.getValMemRef());
auto mem = desc.getValMemRef();
auto size = desc.getValMemSize(rewriter, loc);
rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
return success();
}
};