[mlir][sparse] support sparse bufferization.alloc_tensor with copy argument.
Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D147358
This commit is contained in:
@@ -749,11 +749,29 @@ public:
|
||||
const auto resType = getSparseTensorType(op);
|
||||
if (!resType.hasEncoding())
|
||||
return failure();
|
||||
if (op.getCopy())
|
||||
return rewriter.notifyMatchFailure(op, "tensor copy not implemented");
|
||||
|
||||
// Construct allocation for each field.
|
||||
const Location loc = op.getLoc();
|
||||
if (op.getCopy()) {
|
||||
auto desc = getDescriptorFromTensorTuple(adaptor.getCopy());
|
||||
SmallVector<Value> fields;
|
||||
fields.reserve(desc.getNumFields());
|
||||
// Memcpy on memref fields.
|
||||
for (auto field : desc.getMemRefFields()) {
|
||||
auto memrefTp = field.getType().cast<MemRefType>();
|
||||
auto size = rewriter.create<memref::DimOp>(loc, field, 0);
|
||||
auto copied =
|
||||
rewriter.create<memref::AllocOp>(loc, memrefTp, ValueRange{size});
|
||||
rewriter.create<memref::CopyOp>(loc, field, copied);
|
||||
fields.push_back(copied);
|
||||
}
|
||||
// Reuses specifier.
|
||||
fields.push_back(desc.getSpecifier());
|
||||
assert(fields.size() == desc.getNumFields());
|
||||
rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
|
||||
return success();
|
||||
}
|
||||
|
||||
const Value sizeHint = op.getSizeHint();
|
||||
const ValueRange dynSizes = adaptor.getDynamicSizes();
|
||||
const size_t found = dynSizes.size();
|
||||
|
||||
Reference in New Issue
Block a user