[mlir][sparse] fix bugs when computing the memory size when lowering pack op.
Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D151481
This commit is contained in:
@@ -1242,10 +1242,11 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
|
||||
});
|
||||
|
||||
MutSparseTensorDescriptor desc(stt, fields);
|
||||
Value c0 = constantIndex(rewriter, loc, 0);
|
||||
Value c1 = constantIndex(rewriter, loc, 1);
|
||||
Value c2 = constantIndex(rewriter, loc, 2);
|
||||
Value posBack = c1; // index to the last value in the postion array
|
||||
Value memSize = c2; // memory size for current array
|
||||
Value posBack = c0; // index to the last value in the postion array
|
||||
Value memSize = c1; // memory size for current array
|
||||
|
||||
Level trailCOOStart = getCOOStart(stt.getEncoding());
|
||||
Level trailCOORank = stt.getLvlRank() - trailCOOStart;
|
||||
@@ -1266,7 +1267,7 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
|
||||
DimLevelType dlt = stt.getLvlType(lvl);
|
||||
// Simply forwards the position index when this is a dense level.
|
||||
if (isDenseDLT(dlt)) {
|
||||
memSize = rewriter.create<arith::MulIOp>(loc, lvlSize, posBack);
|
||||
memSize = rewriter.create<arith::MulIOp>(loc, lvlSize, memSize);
|
||||
posBack = rewriter.create<arith::SubIOp>(loc, memSize, c1);
|
||||
continue;
|
||||
}
|
||||
@@ -1276,6 +1277,10 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
|
||||
if (isCompressedWithHiDLT(dlt)) {
|
||||
memSize = rewriter.create<arith::MulIOp>(loc, memSize, c2);
|
||||
posBack = rewriter.create<arith::SubIOp>(loc, memSize, c1);
|
||||
} else {
|
||||
assert(isCompressedDLT(dlt));
|
||||
posBack = memSize;
|
||||
memSize = rewriter.create<arith::AddIOp>(loc, memSize, c1);
|
||||
}
|
||||
desc.setPosMemSize(rewriter, loc, lvl, memSize);
|
||||
// The last value in position array is the memory size for next level.
|
||||
|
||||
Reference in New Issue
Block a user