[sparse] allow unpack op to return any integer type. (#66161)

This commit is contained in:
Peiming Liu
2023-09-12 17:27:51 -07:00
committed by GitHub
parent 749ec26d83
commit 64df1c08d0
3 changed files with 11 additions and 8 deletions

View File

@@ -1323,7 +1323,8 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
// TODO: maybe change unpack/pack operation instead to be
// consistent.
retMem.insert(retMem.begin(), dst);
retLen.insert(retLen.begin(), sz);
Type valLenTp = op.getValLen().getType();
retLen.insert(retLen.begin(), genCast(rewriter, loc, sz, valLenTp));
} else {
assert(fKind == SparseTensorFieldKind::PosMemRef ||
fKind == SparseTensorFieldKind::CrdMemRef);
@@ -1334,7 +1335,9 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
src = desc.getMemRefField(fid);
dst = genToMemref(rewriter, loc, op.getOutLevels()[fid]);
retMem.push_back(dst);
retLen.push_back(sz);
// Retrieves the corresponding level length type.
Type lvlLenTp = op.getLvlLens().getTypes()[retLen.size()];
retLen.push_back(genCast(rewriter, loc, sz, lvlLenTp));
}
Value flatOut = dst;
if (dst.getType().getRank() != 1) {