[sparse] allow unpack op to return 0-ranked tensor type. (#66269)
Many frontends canonicalize scalar into 0-ranked tensor, it change will hopefully make the operation easier to use for those cases.
This commit is contained in:
@@ -559,6 +559,18 @@ static ReassociationIndices getReassociationForFlattening(ShapedType srcTp) {
|
||||
return reassociation;
|
||||
}
|
||||
|
||||
static Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem,
|
||||
Type dstTp) {
|
||||
if (auto rtp = dstTp.dyn_cast<RankedTensorType>()) {
|
||||
// Scalars can only be converted to 0-ranked tensors.
|
||||
if (rtp.getRank() != 0)
|
||||
return nullptr;
|
||||
elem = genCast(builder, loc, elem, rtp.getElementType());
|
||||
return builder.create<tensor::FromElementsOp>(loc, rtp, elem);
|
||||
}
|
||||
return genCast(builder, loc, elem, dstTp);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Codegen rules.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -1324,7 +1336,8 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
|
||||
// consistent.
|
||||
retMem.insert(retMem.begin(), dst);
|
||||
Type valLenTp = op.getValLen().getType();
|
||||
retLen.insert(retLen.begin(), genCast(rewriter, loc, sz, valLenTp));
|
||||
retLen.insert(retLen.begin(),
|
||||
genScalarToTensor(rewriter, loc, sz, valLenTp));
|
||||
} else {
|
||||
assert(fKind == SparseTensorFieldKind::PosMemRef ||
|
||||
fKind == SparseTensorFieldKind::CrdMemRef);
|
||||
@@ -1337,7 +1350,7 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
|
||||
retMem.push_back(dst);
|
||||
// Retrieves the corresponding level length type.
|
||||
Type lvlLenTp = op.getLvlLens().getTypes()[retLen.size()];
|
||||
retLen.push_back(genCast(rewriter, loc, sz, lvlLenTp));
|
||||
retLen.push_back(genScalarToTensor(rewriter, loc, sz, lvlLenTp));
|
||||
}
|
||||
Value flatOut = dst;
|
||||
if (dst.getType().getRank() != 1) {
|
||||
|
||||
Reference in New Issue
Block a user