[mlir][sparse] convert a sparse tensor slice to sparse tensor correctly.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D147074
This commit is contained in:
Peiming Liu
2023-03-28 19:54:34 +00:00
parent 9f15f1f0f3
commit 33267f4007
5 changed files with 51 additions and 2 deletions

View File

@@ -1058,9 +1058,14 @@ public:
SparseTensorEncodingAttr encDst = getSparseTensorEncoding(op.getType());
SparseTensorEncodingAttr encSrc =
getSparseTensorEncoding(op.getSource().getType());
// The output tensor can not be a slice and those cases should have been
// rejected by ConvertOp::verify() already.
assert(!encDst.isSlice() && "Cannot convert to a sparse tensor slices.");
// Different encoding (except for different bitwidth) should be handled by
// rewriting.
if (encDst.withoutBitWidths() != encSrc.withoutBitWidths()) {
// We need further rewrites if the input tensor is a slice too.
if (encDst.withoutBitWidths() != encSrc.withoutBitWidths() ||
encSrc.isSlice()) {
return failure();
}