[mlir][sparse] make UnpackOp return the actual filled length of unpacked memory

This might simplify frontend implementation by avoiding recomputation for the same value.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D154244
This commit is contained in:
Peiming Liu
2023-06-30 18:07:21 +00:00
parent ab345bde81
commit a63d6a0014
7 changed files with 50 additions and 35 deletions

View File

@@ -1311,7 +1311,8 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
Location loc = op.getLoc();
SmallVector<Value> retMem;
desc.getLayout().foreachField([desc, loc, &rewriter, &op, &retMem](
SmallVector<Value> retLen;
desc.getLayout().foreachField([desc, loc, &rewriter, &op, &retMem, &retLen](
FieldIndex fid,
SparseTensorFieldKind fKind, Level lvl,
DimLevelType dlt) -> bool {
@@ -1329,6 +1330,7 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
// TODO: maybe change unpack/pack operation instead to be
// consistent.
retMem.insert(retMem.begin(), dst);
retLen.insert(retLen.begin(), sz);
} else {
assert(fKind == SparseTensorFieldKind::PosMemRef ||
fKind == SparseTensorFieldKind::CrdMemRef);
@@ -1339,6 +1341,7 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
src = desc.getMemRefField(fid);
dst = genToMemref(rewriter, loc, op.getOutLevels()[fid]);
retMem.push_back(dst);
retLen.push_back(sz);
}
Value flatOut = dst;
if (dst.getType().getRank() != 1) {
@@ -1352,12 +1355,13 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
});
// Converts MemRefs back to Tensors.
SmallVector<Value> retTensor = llvm::to_vector(
SmallVector<Value> retValues = llvm::to_vector(
llvm::map_range(retMem, [&rewriter, loc](Value v) -> Value {
return rewriter.create<bufferization::ToTensorOp>(loc, v);
}));
rewriter.replaceOp(op, retTensor);
// Appends the actual memory length used in each buffer returned.
retValues.append(retLen.begin(), retLen.end());
rewriter.replaceOp(op, retValues);
return success();
}
};