[mlir][sparse] use a consistent order between [dis]assembleOp and sto… (#84079)
…rage layout.
This commit is contained in:
@@ -738,13 +738,7 @@ public:
|
||||
auto stt = getSparseTensorType(op.getTensor());
|
||||
SmallVector<Value> retVal;
|
||||
SmallVector<Value> retLen;
|
||||
// Get the values buffer first.
|
||||
auto vals = genValuesCall(rewriter, loc, stt, adaptor.getTensor());
|
||||
auto valLenTp = op.getValLen().getType();
|
||||
auto valLen = linalg::createOrFoldDimOp(rewriter, loc, vals, 0);
|
||||
retVal.push_back(vals);
|
||||
retLen.push_back(genScalarToTensor(rewriter, loc, valLen, valLenTp));
|
||||
// Then get the positions and coordinates buffers.
|
||||
// Get the positions and coordinates buffers.
|
||||
const Level lvlRank = stt.getLvlRank();
|
||||
Level trailCOOLen = 0;
|
||||
for (Level l = 0; l < lvlRank; l++) {
|
||||
@@ -761,7 +755,7 @@ public:
|
||||
auto poss =
|
||||
genPositionsCall(rewriter, loc, stt, adaptor.getTensor(), l);
|
||||
auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0);
|
||||
auto posLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
|
||||
auto posLenTp = op.getLvlLens().getTypes()[retLen.size()];
|
||||
retVal.push_back(poss);
|
||||
retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp));
|
||||
}
|
||||
@@ -769,7 +763,7 @@ public:
|
||||
auto crds =
|
||||
genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(), l);
|
||||
auto crdLen = linalg::createOrFoldDimOp(rewriter, loc, crds, 0);
|
||||
auto crdLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
|
||||
auto crdLenTp = op.getLvlLens().getTypes()[retLen.size()];
|
||||
retVal.push_back(crds);
|
||||
retLen.push_back(genScalarToTensor(rewriter, loc, crdLen, crdLenTp));
|
||||
}
|
||||
@@ -784,14 +778,13 @@ public:
|
||||
auto poss = genPositionsCall(rewriter, loc, stt, adaptor.getTensor(),
|
||||
cooStartLvl);
|
||||
auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0);
|
||||
auto posLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
|
||||
auto posLenTp = op.getLvlLens().getTypes()[retLen.size()];
|
||||
retVal.push_back(poss);
|
||||
retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp));
|
||||
// Coordinates, copied over with:
|
||||
// for (i = 0; i < crdLen; i++)
|
||||
// buf[i][0] = crd0[i]; buf[i][1] = crd1[i];
|
||||
auto buf =
|
||||
genToMemref(rewriter, loc, op.getOutLevels()[retLen.size() - 1]);
|
||||
auto buf = genToMemref(rewriter, loc, op.getOutLevels()[retLen.size()]);
|
||||
auto crds0 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
|
||||
cooStartLvl);
|
||||
auto crds1 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
|
||||
@@ -814,10 +807,17 @@ public:
|
||||
args[1] = one;
|
||||
rewriter.create<memref::StoreOp>(loc, c1, buf, args);
|
||||
rewriter.setInsertionPointAfter(forOp);
|
||||
auto bufLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
|
||||
auto bufLenTp = op.getLvlLens().getTypes()[retLen.size()];
|
||||
retVal.push_back(buf);
|
||||
retLen.push_back(genScalarToTensor(rewriter, loc, bufLen, bufLenTp));
|
||||
}
|
||||
// Get the values buffer last.
|
||||
auto vals = genValuesCall(rewriter, loc, stt, adaptor.getTensor());
|
||||
auto valLenTp = op.getValLen().getType();
|
||||
auto valLen = linalg::createOrFoldDimOp(rewriter, loc, vals, 0);
|
||||
retVal.push_back(vals);
|
||||
retLen.push_back(genScalarToTensor(rewriter, loc, valLen, valLenTp));
|
||||
|
||||
// Converts MemRefs back to Tensors.
|
||||
assert(retVal.size() + retLen.size() == op.getNumResults());
|
||||
for (unsigned i = 0, sz = retVal.size(); i < sz; i++) {
|
||||
@@ -825,6 +825,7 @@ public:
|
||||
retVal[i] =
|
||||
rewriter.create<tensor::CastOp>(loc, op.getResultTypes()[i], tensor);
|
||||
}
|
||||
|
||||
// Appends the actual memory length used in each buffer returned.
|
||||
retVal.append(retLen.begin(), retLen.end());
|
||||
rewriter.replaceOp(op, retVal);
|
||||
|
||||
Reference in New Issue
Block a user