[mlir][sparse] Making way for SparseTensorRuntime to support non-permutations

Systematically updates the SparseTensorRuntime to properly distinguish tensor-dimensions from storage-levels (and their associated ranks, shapes, sizes, indices, etc).  With a few exceptions which are noted in the code, this ensures the runtime has all the **semantic** changes necessary to support non-permutations.

(Whereas **operationally**, since we're still using `std::vector<uing64_t>` to represent the mappings, there's no way to pass in any interesting non-permutations.  Changing the representation to `std::function` will be done in a separate differential.)

Depends On D137680

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D137681
This commit is contained in:
wren romano
2022-11-09 13:33:01 -08:00
parent 65f9992865
commit c518745bba
22 changed files with 1734 additions and 1089 deletions

View File

@@ -57,14 +57,14 @@ static func::CallOp replaceOpWithFuncCall(RewriterBase &rewriter, Operation *op,
operands);
}
/// Generates dimension size call.
static Value genDimSizeCall(OpBuilder &builder, Location loc,
/// Generates call to lookup a level-size.
static Value genLvlSizeCall(OpBuilder &builder, Location loc,
SparseTensorEncodingAttr &enc, Value src,
uint64_t idx) {
uint64_t lvl) {
// Generate the call.
StringRef name = "sparseDimSize";
StringRef name = "sparseLvlSize";
SmallVector<Value, 2> params{
src, constantIndex(builder, loc, toStoredDim(enc, idx))};
src, constantIndex(builder, loc, toStoredDim(enc, lvl))};
Type iTp = builder.getIndexType();
return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off)
.getResult(0);
@@ -72,13 +72,22 @@ static Value genDimSizeCall(OpBuilder &builder, Location loc,
/// Compute the size from type (for static sizes) or from an already-converted
/// opaque pointer source (for dynamic sizes) at the given dimension.
//
// FIXME: Need to rename this function to match `genLvlSizeCall` and hence
// match the naming convention used in the runtime library. However, it's
// not entirely clear that all callsites of this function properly make the
// "level"-vs-"dimension" distinction; so need to audit each callsite to
// ensure this still does what they mean (possibly by having two separate
// functions, one for levels and one for dimensions). That also means
// renaming `sizesFromPtr`, `sizesFromType`, etc, to make clear whether
// they mean to be referring to level-sizes vs dimension-sizes.
static Value sizeFromPtrAtDim(OpBuilder &builder, Location loc,
SparseTensorEncodingAttr &enc, ShapedType stp,
Value src, unsigned dim) {
Value src, unsigned i) {
auto shape = stp.getShape();
if (shape[dim] == ShapedType::kDynamicSize)
return genDimSizeCall(builder, loc, enc, src, dim);
return constantIndex(builder, loc, shape[dim]);
if (shape[i] == ShapedType::kDynamicSize)
return genLvlSizeCall(builder, loc, enc, src, i);
return constantIndex(builder, loc, shape[i]);
}
/// Populates given sizes array from type (for static sizes) and from
@@ -225,17 +234,19 @@ public:
}
private:
static constexpr unsigned kNumStaticParams = 6;
static constexpr unsigned kNumStaticParams = 8;
static constexpr unsigned kNumDynamicParams = 2;
static constexpr unsigned kNumParams = kNumStaticParams + kNumDynamicParams;
static constexpr unsigned kParamLvlTypes = 0;
static constexpr unsigned kParamDimSizes = 1;
static constexpr unsigned kParamDim2Lvl = 2;
static constexpr unsigned kParamPtrTp = 3;
static constexpr unsigned kParamIndTp = 4;
static constexpr unsigned kParamValTp = 5;
static constexpr unsigned kParamAction = 6;
static constexpr unsigned kParamPtr = 7;
static constexpr unsigned kParamDimSizes = 0;
static constexpr unsigned kParamLvlSizes = 1;
static constexpr unsigned kParamLvlTypes = 2;
static constexpr unsigned kParamLvl2Dim = 3;
static constexpr unsigned kParamDim2Lvl = 4;
static constexpr unsigned kParamPtrTp = 5;
static constexpr unsigned kParamIndTp = 6;
static constexpr unsigned kParamValTp = 7;
static constexpr unsigned kParamAction = 8;
static constexpr unsigned kParamPtr = 9;
OpBuilder &builder;
Location loc;
@@ -260,10 +271,17 @@ NewCallParams &NewCallParams::genBuffers(SparseTensorEncodingAttr enc,
// verification of external data, or for construction of internal data.
assert(dimSizes.size() == dimRank && "Dimension-rank mismatch");
params[kParamDimSizes] = genBuffer(builder, loc, dimSizes);
// The dimension-to-level mapping. We must preinitialize `dim2lvl`
// so that the true branch below can perform random-access `operator[]`
// assignment.
// The level-sizes array must be passed as well, since for arbitrary
// dim2lvl mappings it cannot be trivially reconstructed at runtime.
// For now however, since we're still assuming permutations, we will
// initialize this parameter alongside the `dim2lvl` and `lvl2dim`
// parameters below. We preinitialize `lvlSizes` for code symmetry.
SmallVector<Value, 4> lvlSizes(lvlRank);
// The dimension-to-level mapping and its inverse. We must preinitialize
// `dim2lvl` so that the true branch below can perform random-access
// `operator[]` assignment. We preinitialize `lvl2dim` for code symmetry.
SmallVector<Value, 4> dim2lvl(dimRank);
SmallVector<Value, 4> lvl2dim(lvlRank);
auto dimOrder = enc.getDimOrdering();
if (dimOrder) {
assert(dimOrder.isPermutation());
@@ -271,13 +289,20 @@ NewCallParams &NewCallParams::genBuffers(SparseTensorEncodingAttr enc,
// The `d`th source variable occurs in the `l`th result position.
uint64_t d = dimOrder.getDimPosition(l);
dim2lvl[d] = constantIndex(builder, loc, l);
lvl2dim[l] = constantIndex(builder, loc, d);
lvlSizes[l] = dimSizes[d];
}
} else {
assert(dimRank == lvlRank && "Rank mismatch");
for (unsigned i = 0; i < lvlRank; i++)
dim2lvl[i] = constantIndex(builder, loc, i);
for (unsigned i = 0; i < lvlRank; i++) {
dim2lvl[i] = lvl2dim[i] = constantIndex(builder, loc, i);
lvlSizes[i] = dimSizes[i];
}
}
params[kParamDim2Lvl] = genBuffer(builder, loc, dim2lvl);
params[kParamLvlSizes] = genBuffer(builder, loc, lvlSizes);
params[kParamLvl2Dim] = genBuffer(builder, loc, lvl2dim);
params[kParamDim2Lvl] =
dimOrder ? genBuffer(builder, loc, dim2lvl) : params[kParamLvl2Dim];
// Secondary and primary types encoding.
setTemplateTypes(enc, stp);
// Finally, make note that initialization is complete.
@@ -316,9 +341,10 @@ static void genDelIteratorCall(OpBuilder &builder, Location loc, Type elemTp,
/// if val != 0
/// t->add(&val, [i1,..,ik], [p1,..,pk]);
static void genAddEltCall(OpBuilder &builder, Location loc, Type eltType,
Value ptr, Value valPtr, Value ind, Value perm) {
Value lvlCOO, Value valPtr, Value dimInd,
Value dim2lvl) {
SmallString<9> name{"addElt", primaryTypeFunctionSuffix(eltType)};
SmallVector<Value, 4> params{ptr, valPtr, ind, perm};
SmallVector<Value, 4> params{lvlCOO, valPtr, dimInd, dim2lvl};
Type pTp = getOpaquePointerType(builder);
createFuncCall(builder, loc, name, pTp, params, EmitCInterface::On);
}
@@ -632,7 +658,7 @@ public:
Value src = adaptor.getOperands()[0];
int64_t idx = *index;
rewriter.replaceOp(op,
genDimSizeCall(rewriter, op->getLoc(), enc, src, idx));
genLvlSizeCall(rewriter, op->getLoc(), enc, src, idx));
return success();
}
};