[mlir][sparse] infer returned type for sparse_tensor.to_[buffer] ops (#83343)
The sparse structure buffers might not always be memrefs with rank == 1 with the presence of batch levels.
This commit is contained in:
@@ -1058,17 +1058,9 @@ public:
|
||||
// Replace the requested coordinates access with corresponding field.
|
||||
// The cast_op is inserted by type converter to intermix 1:N type
|
||||
// conversion.
|
||||
Location loc = op.getLoc();
|
||||
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
|
||||
Value field = desc.getCrdMemRefOrView(rewriter, loc, op.getLevel());
|
||||
|
||||
// Insert a cast to bridge the actual type to the user expected type. If the
|
||||
// actual type and the user expected type aren't compatible, the compiler or
|
||||
// the runtime will issue an error.
|
||||
Type resType = op.getResult().getType();
|
||||
if (resType != field.getType())
|
||||
field = rewriter.create<memref::CastOp>(loc, resType, field);
|
||||
rewriter.replaceOp(op, field);
|
||||
rewriter.replaceOp(
|
||||
op, desc.getCrdMemRefOrView(rewriter, op.getLoc(), op.getLevel()));
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user