[mlir][sparse] Using SparseTensorType in SparsePackOpConverter

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D147465
This commit is contained in:
wren romano
2023-04-03 12:55:59 -07:00
parent b0ba8fe6ba
commit 34c9c59ce4

View File

@@ -1235,29 +1235,28 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
matchAndRewrite(PackOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
const auto rtp = getRankedTensorType(op.getResult());
assert(isUniqueCOOType(rtp));
const auto stt = getSparseTensorType(op.getResult());
assert(isUniqueCOOType(stt));
SmallVector<Value> fields;
Location loc = op.getLoc();
foreachFieldAndTypeInSparseTensor(
rtp,
[&rewriter, &fields, &op, rtp,
stt,
[&rewriter, &fields, &op, stt,
loc](Type fType, FieldIndex fIdx, SparseTensorFieldKind fKind,
Level /*lvl*/, DimLevelType /*dlt*/) -> bool {
assert(fields.size() == fIdx);
auto enc = getSparseTensorEncoding(rtp);
Value field;
switch (fKind) {
case SparseTensorFieldKind::StorageSpec:
field = SparseTensorSpecifier::getInitValue(rewriter, loc, rtp);
field = SparseTensorSpecifier::getInitValue(rewriter, loc, stt);
break;
case SparseTensorFieldKind::PosMemRef: {
// TACO-style COO starts with a PosBuffer
// By creating a constant value for it, we avoid the complexity of
// memory management.
const auto posTp = enc.getPosType();
const auto posTp = stt.getPosType();
auto tensorType = RankedTensorType::get({2}, posTp);
auto memrefType = MemRefType::get(tensorType.getShape(),
tensorType.getElementType());
@@ -1306,13 +1305,11 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
return true;
});
MutSparseTensorDescriptor desc(rtp, fields);
MutSparseTensorDescriptor desc(stt, fields);
auto noe = linalg::createOrFoldDimOp(rewriter, loc, op.getValues(), 0);
// FIXME: should use `SparseTensorType::getLvlRank` in lieu of
// `RankedTensorType::getRank`, because the latter introduces dim/lvl
// ambiguity.
for (Level lvl = 0, lvlRank = rtp.getRank(); lvl < lvlRank; lvl++) {
const auto sh = rtp.getShape()[lvl];
for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
// FIXME: dim/lvl confusion!
const auto sh = stt.getDimShape()[lvl];
assert(!ShapedType::isDynamic(sh));
desc.setLvlSize(rewriter, loc, lvl, constantIndex(rewriter, loc, sh));
if (lvl == 0)