[mlir][sparse] Combining dimOrdering+higherOrdering fields into dimToLvl

This is a major step along the way towards the new STEA design.  While a great deal of this patch is simple renaming, there are several significant changes as well.  I've done my best to ensure that this patch retains the previous behavior and error-conditions, even though those are at odds with the eventual intended semantics of the `dimToLvl` mapping.  Since the majority of the compiler does not yet support non-permutations, I've also added explicit assertions in places that previously had implicitly assumed it was dealing with permutations.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D151505
This commit is contained in:
wren romano
2023-05-30 13:16:29 -07:00
parent 510f4168cf
commit 76647fce13
62 changed files with 486 additions and 442 deletions

View File

@@ -481,7 +481,7 @@ public:
nameOstream << sh << "_";
// Permutation information is also used in generating insertion.
if (!stt.isIdentity())
nameOstream << stt.getDimToLvlMap() << "_";
nameOstream << stt.getDimToLvl() << "_";
nameOstream << stt.getElementType() << "_";
nameOstream << stt.getCrdWidth() << "_" << stt.getPosWidth();
return nameOstream.str().str();
@@ -1139,8 +1139,7 @@ public:
if (!srcEnc || !dstEnc || !dstEnc.isSlice())
return failure();
assert(srcEnc.getLvlTypes() == dstEnc.getLvlTypes());
assert(srcEnc.getDimOrdering() == dstEnc.getDimOrdering());
assert(srcEnc.getHigherOrdering() == dstEnc.getHigherOrdering());
assert(srcEnc.getDimToLvl() == dstEnc.getDimToLvl());
assert(srcEnc.getPosWidth() == dstEnc.getPosWidth());
assert(srcEnc.getCrdWidth() == dstEnc.getCrdWidth());
@@ -1168,7 +1167,7 @@ public:
// FIXME: we need to distinguish level sizes and dimension size for slices
// here. Maybe we should store slice level sizes in a different array
// instead of reusing it.
assert(srcEnc.hasIdDimOrdering());
assert(srcEnc.isIdentity());
desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::LvlSize, dim,
sizeV);
desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::DimStride,
@@ -1428,26 +1427,26 @@ struct SparseNewOpConverter : public OpConversionPattern<NewOp> {
fields, nse);
MutSparseTensorDescriptor desc(dstTp, fields);
// Construct the `dim2lvl` buffer for handing off to the runtime library.
// Construct the `dimToLvl` buffer for handing off to the runtime library.
// FIXME: This code is (mostly) copied from the SparseTensorConversion.cpp
// handling of `NewOp`, and only handles permutations. Fixing this
// requires waiting for wrengr to finish redoing the CL that handles
// all dim<->lvl stuff more robustly.
SmallVector<Value> dim2lvlValues(dimRank);
SmallVector<Value> dimToLvlValues(dimRank);
if (!dstTp.isIdentity()) {
const auto dimOrder = dstTp.getDimToLvlMap();
assert(dimOrder.isPermutation() && "Got non-permutation");
const auto dimToLvl = dstTp.getDimToLvl();
assert(dimToLvl.isPermutation() && "Got non-permutation");
for (Level l = 0; l < lvlRank; l++) {
const Dimension d = dimOrder.getDimPosition(l);
dim2lvlValues[d] = constantIndex(rewriter, loc, l);
const Dimension d = dimToLvl.getDimPosition(l);
dimToLvlValues[d] = constantIndex(rewriter, loc, l);
}
} else {
// The `SparseTensorType` ctor already ensures `dimRank == lvlRank`
// when `isIdentity`; so no need to re-assert it here.
for (Dimension d = 0; d < dimRank; d++)
dim2lvlValues[d] = constantIndex(rewriter, loc, d);
dimToLvlValues[d] = constantIndex(rewriter, loc, d);
}
Value dim2lvl = allocaBuffer(rewriter, loc, dim2lvlValues);
Value dimToLvl = allocaBuffer(rewriter, loc, dimToLvlValues);
// Read the COO tensor data.
Value xs = desc.getAOSMemRef();
@@ -1463,7 +1462,7 @@ struct SparseNewOpConverter : public OpConversionPattern<NewOp> {
primaryTypeFunctionSuffix(elemTp)};
Value isSorted =
createFuncCall(rewriter, loc, readToBuffersFuncName, {boolTp},
{reader, dim2lvl, xs, ys}, EmitCInterface::On)
{reader, dimToLvl, xs, ys}, EmitCInterface::On)
.getResult(0);
// If the destination tensor is a sorted COO, we need to sort the COO tensor