[mlir][sparse] make UnpackOp return the actual filled length of unpacked memory
This might simplify frontend implementation by avoiding recomputation for the same value. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D154244
This commit is contained in:
@@ -1311,7 +1311,8 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
|
||||
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
|
||||
Location loc = op.getLoc();
|
||||
SmallVector<Value> retMem;
|
||||
desc.getLayout().foreachField([desc, loc, &rewriter, &op, &retMem](
|
||||
SmallVector<Value> retLen;
|
||||
desc.getLayout().foreachField([desc, loc, &rewriter, &op, &retMem, &retLen](
|
||||
FieldIndex fid,
|
||||
SparseTensorFieldKind fKind, Level lvl,
|
||||
DimLevelType dlt) -> bool {
|
||||
@@ -1329,6 +1330,7 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
|
||||
// TODO: maybe change unpack/pack operation instead to be
|
||||
// consistent.
|
||||
retMem.insert(retMem.begin(), dst);
|
||||
retLen.insert(retLen.begin(), sz);
|
||||
} else {
|
||||
assert(fKind == SparseTensorFieldKind::PosMemRef ||
|
||||
fKind == SparseTensorFieldKind::CrdMemRef);
|
||||
@@ -1339,6 +1341,7 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
|
||||
src = desc.getMemRefField(fid);
|
||||
dst = genToMemref(rewriter, loc, op.getOutLevels()[fid]);
|
||||
retMem.push_back(dst);
|
||||
retLen.push_back(sz);
|
||||
}
|
||||
Value flatOut = dst;
|
||||
if (dst.getType().getRank() != 1) {
|
||||
@@ -1352,12 +1355,13 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
|
||||
});
|
||||
|
||||
// Converts MemRefs back to Tensors.
|
||||
SmallVector<Value> retTensor = llvm::to_vector(
|
||||
SmallVector<Value> retValues = llvm::to_vector(
|
||||
llvm::map_range(retMem, [&rewriter, loc](Value v) -> Value {
|
||||
return rewriter.create<bufferization::ToTensorOp>(loc, v);
|
||||
}));
|
||||
|
||||
rewriter.replaceOp(op, retTensor);
|
||||
// Appends the actual memory length used in each buffer returned.
|
||||
retValues.append(retLen.begin(), retLen.end());
|
||||
rewriter.replaceOp(op, retValues);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user