[sparse] allow unpack op to return any integer type. (#66161)
This commit is contained in:
@@ -1323,7 +1323,8 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
|
||||
// TODO: maybe change unpack/pack operation instead to be
|
||||
// consistent.
|
||||
retMem.insert(retMem.begin(), dst);
|
||||
retLen.insert(retLen.begin(), sz);
|
||||
Type valLenTp = op.getValLen().getType();
|
||||
retLen.insert(retLen.begin(), genCast(rewriter, loc, sz, valLenTp));
|
||||
} else {
|
||||
assert(fKind == SparseTensorFieldKind::PosMemRef ||
|
||||
fKind == SparseTensorFieldKind::CrdMemRef);
|
||||
@@ -1334,7 +1335,9 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
|
||||
src = desc.getMemRefField(fid);
|
||||
dst = genToMemref(rewriter, loc, op.getOutLevels()[fid]);
|
||||
retMem.push_back(dst);
|
||||
retLen.push_back(sz);
|
||||
// Retrieves the corresponding level length type.
|
||||
Type lvlLenTp = op.getLvlLens().getTypes()[retLen.size()];
|
||||
retLen.push_back(genCast(rewriter, loc, sz, lvlLenTp));
|
||||
}
|
||||
Value flatOut = dst;
|
||||
if (dst.getType().getRank() != 1) {
|
||||
|
||||
Reference in New Issue
Block a user