[mlir][sparse] fix bugs when computing the memory size when lowering pack op.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D151481
This commit is contained in:
Peiming Liu
2023-05-25 18:35:28 +00:00
parent 5c082e7e15
commit f7b8b005ff
3 changed files with 41 additions and 22 deletions

View File

@@ -1242,10 +1242,11 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
});
MutSparseTensorDescriptor desc(stt, fields);
Value c0 = constantIndex(rewriter, loc, 0);
Value c1 = constantIndex(rewriter, loc, 1);
Value c2 = constantIndex(rewriter, loc, 2);
Value posBack = c1; // index to the last value in the postion array
Value memSize = c2; // memory size for current array
Value posBack = c0; // index to the last value in the postion array
Value memSize = c1; // memory size for current array
Level trailCOOStart = getCOOStart(stt.getEncoding());
Level trailCOORank = stt.getLvlRank() - trailCOOStart;
@@ -1266,7 +1267,7 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
DimLevelType dlt = stt.getLvlType(lvl);
// Simply forwards the position index when this is a dense level.
if (isDenseDLT(dlt)) {
memSize = rewriter.create<arith::MulIOp>(loc, lvlSize, posBack);
memSize = rewriter.create<arith::MulIOp>(loc, lvlSize, memSize);
posBack = rewriter.create<arith::SubIOp>(loc, memSize, c1);
continue;
}
@@ -1276,6 +1277,10 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
if (isCompressedWithHiDLT(dlt)) {
memSize = rewriter.create<arith::MulIOp>(loc, memSize, c2);
posBack = rewriter.create<arith::SubIOp>(loc, memSize, c1);
} else {
assert(isCompressedDLT(dlt));
posBack = memSize;
memSize = rewriter.create<arith::AddIOp>(loc, memSize, c1);
}
desc.setPosMemSize(rewriter, loc, lvl, memSize);
// The last value in position array is the memory size for next level.