[mlir][sparse] Combining dimOrdering+higherOrdering fields into dimToLvl
This is a major step along the way towards the new STEA design. While a great deal of this patch is simple renaming, there are several significant changes as well. I've done my best to ensure that this patch retains the previous behavior and error-conditions, even though those are at odds with the eventual intended semantics of the `dimToLvl` mapping. Since the majority of the compiler does not yet support non-permutations, I've also added explicit assertions in places that previously had implicitly assumed it was dealing with permutations. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D151505
This commit is contained in:
@@ -481,7 +481,7 @@ public:
|
||||
nameOstream << sh << "_";
|
||||
// Permutation information is also used in generating insertion.
|
||||
if (!stt.isIdentity())
|
||||
nameOstream << stt.getDimToLvlMap() << "_";
|
||||
nameOstream << stt.getDimToLvl() << "_";
|
||||
nameOstream << stt.getElementType() << "_";
|
||||
nameOstream << stt.getCrdWidth() << "_" << stt.getPosWidth();
|
||||
return nameOstream.str().str();
|
||||
@@ -1139,8 +1139,7 @@ public:
|
||||
if (!srcEnc || !dstEnc || !dstEnc.isSlice())
|
||||
return failure();
|
||||
assert(srcEnc.getLvlTypes() == dstEnc.getLvlTypes());
|
||||
assert(srcEnc.getDimOrdering() == dstEnc.getDimOrdering());
|
||||
assert(srcEnc.getHigherOrdering() == dstEnc.getHigherOrdering());
|
||||
assert(srcEnc.getDimToLvl() == dstEnc.getDimToLvl());
|
||||
assert(srcEnc.getPosWidth() == dstEnc.getPosWidth());
|
||||
assert(srcEnc.getCrdWidth() == dstEnc.getCrdWidth());
|
||||
|
||||
@@ -1168,7 +1167,7 @@ public:
|
||||
// FIXME: we need to distinguish level sizes and dimension size for slices
|
||||
// here. Maybe we should store slice level sizes in a different array
|
||||
// instead of reusing it.
|
||||
assert(srcEnc.hasIdDimOrdering());
|
||||
assert(srcEnc.isIdentity());
|
||||
desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::LvlSize, dim,
|
||||
sizeV);
|
||||
desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::DimStride,
|
||||
@@ -1428,26 +1427,26 @@ struct SparseNewOpConverter : public OpConversionPattern<NewOp> {
|
||||
fields, nse);
|
||||
MutSparseTensorDescriptor desc(dstTp, fields);
|
||||
|
||||
// Construct the `dim2lvl` buffer for handing off to the runtime library.
|
||||
// Construct the `dimToLvl` buffer for handing off to the runtime library.
|
||||
// FIXME: This code is (mostly) copied from the SparseTensorConversion.cpp
|
||||
// handling of `NewOp`, and only handles permutations. Fixing this
|
||||
// requires waiting for wrengr to finish redoing the CL that handles
|
||||
// all dim<->lvl stuff more robustly.
|
||||
SmallVector<Value> dim2lvlValues(dimRank);
|
||||
SmallVector<Value> dimToLvlValues(dimRank);
|
||||
if (!dstTp.isIdentity()) {
|
||||
const auto dimOrder = dstTp.getDimToLvlMap();
|
||||
assert(dimOrder.isPermutation() && "Got non-permutation");
|
||||
const auto dimToLvl = dstTp.getDimToLvl();
|
||||
assert(dimToLvl.isPermutation() && "Got non-permutation");
|
||||
for (Level l = 0; l < lvlRank; l++) {
|
||||
const Dimension d = dimOrder.getDimPosition(l);
|
||||
dim2lvlValues[d] = constantIndex(rewriter, loc, l);
|
||||
const Dimension d = dimToLvl.getDimPosition(l);
|
||||
dimToLvlValues[d] = constantIndex(rewriter, loc, l);
|
||||
}
|
||||
} else {
|
||||
// The `SparseTensorType` ctor already ensures `dimRank == lvlRank`
|
||||
// when `isIdentity`; so no need to re-assert it here.
|
||||
for (Dimension d = 0; d < dimRank; d++)
|
||||
dim2lvlValues[d] = constantIndex(rewriter, loc, d);
|
||||
dimToLvlValues[d] = constantIndex(rewriter, loc, d);
|
||||
}
|
||||
Value dim2lvl = allocaBuffer(rewriter, loc, dim2lvlValues);
|
||||
Value dimToLvl = allocaBuffer(rewriter, loc, dimToLvlValues);
|
||||
|
||||
// Read the COO tensor data.
|
||||
Value xs = desc.getAOSMemRef();
|
||||
@@ -1463,7 +1462,7 @@ struct SparseNewOpConverter : public OpConversionPattern<NewOp> {
|
||||
primaryTypeFunctionSuffix(elemTp)};
|
||||
Value isSorted =
|
||||
createFuncCall(rewriter, loc, readToBuffersFuncName, {boolTp},
|
||||
{reader, dim2lvl, xs, ys}, EmitCInterface::On)
|
||||
{reader, dimToLvl, xs, ys}, EmitCInterface::On)
|
||||
.getResult(0);
|
||||
|
||||
// If the destination tensor is a sorted COO, we need to sort the COO tensor
|
||||
|
||||
Reference in New Issue
Block a user