[mlir][sparse] factoring out getRankedTensorType helper function

Reviewed By: Peiming

Differential Revision: https://reviews.llvm.org/D142074
This commit is contained in:
wren romano
2023-01-18 19:11:48 -08:00
parent eaabc1bbea
commit 255c3f1159
6 changed files with 27 additions and 25 deletions

View File

@@ -756,8 +756,7 @@ public:
return failure();
Location loc = op->getLoc();
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
RankedTensorType srcType =
op.getTensor().getType().cast<RankedTensorType>();
auto srcType = getRankedTensorType(op.getTensor());
Type eltType = srcType.getElementType();
Type boolType = rewriter.getIntegerType(1);
Type idxType = rewriter.getIndexType();