[mlir][sparse] infer returned type for sparse_tensor.to_[buffer] ops (#83343)

The sparse structure buffers might not always be memrefs with rank == 1
with the presence of batch levels.
This commit is contained in:
Peiming Liu
2024-02-28 16:10:20 -08:00
committed by GitHub
parent 43b7dfcc1d
commit 6bc7c9df7f
5 changed files with 129 additions and 160 deletions

View File

@@ -1058,17 +1058,9 @@ public:
// Replace the requested coordinates access with corresponding field.
// The cast_op is inserted by type converter to intermix 1:N type
// conversion.
Location loc = op.getLoc();
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
Value field = desc.getCrdMemRefOrView(rewriter, loc, op.getLevel());
// Insert a cast to bridge the actual type to the user expected type. If the
// actual type and the user expected type aren't compatible, the compiler or
// the runtime will issue an error.
Type resType = op.getResult().getType();
if (resType != field.getType())
field = rewriter.create<memref::CastOp>(loc, resType, field);
rewriter.replaceOp(op, field);
rewriter.replaceOp(
op, desc.getCrdMemRefOrView(rewriter, op.getLoc(), op.getLevel()));
return success();
}