[sparse] allow unpack op to return 0-ranked tensor type. (#66269)

Many frontends canonicalize scalar into 0-ranked tensor, it change will
hopefully make the operation easier to use for those cases.
This commit is contained in:
Peiming Liu
2023-09-13 11:33:01 -07:00
committed by GitHub
parent 372115fadd
commit 098f46dce3
5 changed files with 26 additions and 6 deletions

View File

@@ -559,6 +559,18 @@ static ReassociationIndices getReassociationForFlattening(ShapedType srcTp) {
return reassociation;
}
static Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem,
Type dstTp) {
if (auto rtp = dstTp.dyn_cast<RankedTensorType>()) {
// Scalars can only be converted to 0-ranked tensors.
if (rtp.getRank() != 0)
return nullptr;
elem = genCast(builder, loc, elem, rtp.getElementType());
return builder.create<tensor::FromElementsOp>(loc, rtp, elem);
}
return genCast(builder, loc, elem, dstTp);
}
//===----------------------------------------------------------------------===//
// Codegen rules.
//===----------------------------------------------------------------------===//
@@ -1324,7 +1336,8 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
// consistent.
retMem.insert(retMem.begin(), dst);
Type valLenTp = op.getValLen().getType();
retLen.insert(retLen.begin(), genCast(rewriter, loc, sz, valLenTp));
retLen.insert(retLen.begin(),
genScalarToTensor(rewriter, loc, sz, valLenTp));
} else {
assert(fKind == SparseTensorFieldKind::PosMemRef ||
fKind == SparseTensorFieldKind::CrdMemRef);
@@ -1337,7 +1350,7 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
retMem.push_back(dst);
// Retrieves the corresponding level length type.
Type lvlLenTp = op.getLvlLens().getTypes()[retLen.size()];
retLen.push_back(genCast(rewriter, loc, sz, lvlLenTp));
retLen.push_back(genScalarToTensor(rewriter, loc, sz, lvlLenTp));
}
Value flatOut = dst;
if (dst.getType().getRank() != 1) {