[mlir][sparse] use a consistent order between [dis]assembleOp and sto… (#84079)

…rage layout.
This commit is contained in:
Peiming Liu
2024-03-06 09:57:41 -08:00
committed by GitHub
parent 8277e308c0
commit fc9f1d49aa
14 changed files with 168 additions and 175 deletions

View File

@@ -738,13 +738,7 @@ public:
auto stt = getSparseTensorType(op.getTensor());
SmallVector<Value> retVal;
SmallVector<Value> retLen;
// Get the values buffer first.
auto vals = genValuesCall(rewriter, loc, stt, adaptor.getTensor());
auto valLenTp = op.getValLen().getType();
auto valLen = linalg::createOrFoldDimOp(rewriter, loc, vals, 0);
retVal.push_back(vals);
retLen.push_back(genScalarToTensor(rewriter, loc, valLen, valLenTp));
// Then get the positions and coordinates buffers.
// Get the positions and coordinates buffers.
const Level lvlRank = stt.getLvlRank();
Level trailCOOLen = 0;
for (Level l = 0; l < lvlRank; l++) {
@@ -761,7 +755,7 @@ public:
auto poss =
genPositionsCall(rewriter, loc, stt, adaptor.getTensor(), l);
auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0);
auto posLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
auto posLenTp = op.getLvlLens().getTypes()[retLen.size()];
retVal.push_back(poss);
retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp));
}
@@ -769,7 +763,7 @@ public:
auto crds =
genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(), l);
auto crdLen = linalg::createOrFoldDimOp(rewriter, loc, crds, 0);
auto crdLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
auto crdLenTp = op.getLvlLens().getTypes()[retLen.size()];
retVal.push_back(crds);
retLen.push_back(genScalarToTensor(rewriter, loc, crdLen, crdLenTp));
}
@@ -784,14 +778,13 @@ public:
auto poss = genPositionsCall(rewriter, loc, stt, adaptor.getTensor(),
cooStartLvl);
auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0);
auto posLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
auto posLenTp = op.getLvlLens().getTypes()[retLen.size()];
retVal.push_back(poss);
retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp));
// Coordinates, copied over with:
// for (i = 0; i < crdLen; i++)
// buf[i][0] = crd0[i]; buf[i][1] = crd1[i];
auto buf =
genToMemref(rewriter, loc, op.getOutLevels()[retLen.size() - 1]);
auto buf = genToMemref(rewriter, loc, op.getOutLevels()[retLen.size()]);
auto crds0 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
cooStartLvl);
auto crds1 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
@@ -814,10 +807,17 @@ public:
args[1] = one;
rewriter.create<memref::StoreOp>(loc, c1, buf, args);
rewriter.setInsertionPointAfter(forOp);
auto bufLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
auto bufLenTp = op.getLvlLens().getTypes()[retLen.size()];
retVal.push_back(buf);
retLen.push_back(genScalarToTensor(rewriter, loc, bufLen, bufLenTp));
}
// Get the values buffer last.
auto vals = genValuesCall(rewriter, loc, stt, adaptor.getTensor());
auto valLenTp = op.getValLen().getType();
auto valLen = linalg::createOrFoldDimOp(rewriter, loc, vals, 0);
retVal.push_back(vals);
retLen.push_back(genScalarToTensor(rewriter, loc, valLen, valLenTp));
// Converts MemRefs back to Tensors.
assert(retVal.size() + retLen.size() == op.getNumResults());
for (unsigned i = 0, sz = retVal.size(); i < sz; i++) {
@@ -825,6 +825,7 @@ public:
retVal[i] =
rewriter.create<tensor::CastOp>(loc, op.getResultTypes()[i], tensor);
}
// Appends the actual memory length used in each buffer returned.
retVal.append(retLen.begin(), retLen.end());
rewriter.replaceOp(op, retVal);