[mlir][sparse] Making way for SparseTensorRuntime to support non-permutations
Systematically updates the SparseTensorRuntime to properly distinguish tensor-dimensions from storage-levels (and their associated ranks, shapes, sizes, indices, etc). With a few exceptions which are noted in the code, this ensures the runtime has all the **semantic** changes necessary to support non-permutations. (Whereas **operationally**, since we're still using `std::vector<uing64_t>` to represent the mappings, there's no way to pass in any interesting non-permutations. Changing the representation to `std::function` will be done in a separate differential.) Depends On D137680 Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D137681
This commit is contained in:
@@ -57,14 +57,14 @@ static func::CallOp replaceOpWithFuncCall(RewriterBase &rewriter, Operation *op,
|
||||
operands);
|
||||
}
|
||||
|
||||
/// Generates dimension size call.
|
||||
static Value genDimSizeCall(OpBuilder &builder, Location loc,
|
||||
/// Generates call to lookup a level-size.
|
||||
static Value genLvlSizeCall(OpBuilder &builder, Location loc,
|
||||
SparseTensorEncodingAttr &enc, Value src,
|
||||
uint64_t idx) {
|
||||
uint64_t lvl) {
|
||||
// Generate the call.
|
||||
StringRef name = "sparseDimSize";
|
||||
StringRef name = "sparseLvlSize";
|
||||
SmallVector<Value, 2> params{
|
||||
src, constantIndex(builder, loc, toStoredDim(enc, idx))};
|
||||
src, constantIndex(builder, loc, toStoredDim(enc, lvl))};
|
||||
Type iTp = builder.getIndexType();
|
||||
return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off)
|
||||
.getResult(0);
|
||||
@@ -72,13 +72,22 @@ static Value genDimSizeCall(OpBuilder &builder, Location loc,
|
||||
|
||||
/// Compute the size from type (for static sizes) or from an already-converted
|
||||
/// opaque pointer source (for dynamic sizes) at the given dimension.
|
||||
//
|
||||
// FIXME: Need to rename this function to match `genLvlSizeCall` and hence
|
||||
// match the naming convention used in the runtime library. However, it's
|
||||
// not entirely clear that all callsites of this function properly make the
|
||||
// "level"-vs-"dimension" distinction; so need to audit each callsite to
|
||||
// ensure this still does what they mean (possibly by having two separate
|
||||
// functions, one for levels and one for dimensions). That also means
|
||||
// renaming `sizesFromPtr`, `sizesFromType`, etc, to make clear whether
|
||||
// they mean to be referring to level-sizes vs dimension-sizes.
|
||||
static Value sizeFromPtrAtDim(OpBuilder &builder, Location loc,
|
||||
SparseTensorEncodingAttr &enc, ShapedType stp,
|
||||
Value src, unsigned dim) {
|
||||
Value src, unsigned i) {
|
||||
auto shape = stp.getShape();
|
||||
if (shape[dim] == ShapedType::kDynamicSize)
|
||||
return genDimSizeCall(builder, loc, enc, src, dim);
|
||||
return constantIndex(builder, loc, shape[dim]);
|
||||
if (shape[i] == ShapedType::kDynamicSize)
|
||||
return genLvlSizeCall(builder, loc, enc, src, i);
|
||||
return constantIndex(builder, loc, shape[i]);
|
||||
}
|
||||
|
||||
/// Populates given sizes array from type (for static sizes) and from
|
||||
@@ -225,17 +234,19 @@ public:
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr unsigned kNumStaticParams = 6;
|
||||
static constexpr unsigned kNumStaticParams = 8;
|
||||
static constexpr unsigned kNumDynamicParams = 2;
|
||||
static constexpr unsigned kNumParams = kNumStaticParams + kNumDynamicParams;
|
||||
static constexpr unsigned kParamLvlTypes = 0;
|
||||
static constexpr unsigned kParamDimSizes = 1;
|
||||
static constexpr unsigned kParamDim2Lvl = 2;
|
||||
static constexpr unsigned kParamPtrTp = 3;
|
||||
static constexpr unsigned kParamIndTp = 4;
|
||||
static constexpr unsigned kParamValTp = 5;
|
||||
static constexpr unsigned kParamAction = 6;
|
||||
static constexpr unsigned kParamPtr = 7;
|
||||
static constexpr unsigned kParamDimSizes = 0;
|
||||
static constexpr unsigned kParamLvlSizes = 1;
|
||||
static constexpr unsigned kParamLvlTypes = 2;
|
||||
static constexpr unsigned kParamLvl2Dim = 3;
|
||||
static constexpr unsigned kParamDim2Lvl = 4;
|
||||
static constexpr unsigned kParamPtrTp = 5;
|
||||
static constexpr unsigned kParamIndTp = 6;
|
||||
static constexpr unsigned kParamValTp = 7;
|
||||
static constexpr unsigned kParamAction = 8;
|
||||
static constexpr unsigned kParamPtr = 9;
|
||||
|
||||
OpBuilder &builder;
|
||||
Location loc;
|
||||
@@ -260,10 +271,17 @@ NewCallParams &NewCallParams::genBuffers(SparseTensorEncodingAttr enc,
|
||||
// verification of external data, or for construction of internal data.
|
||||
assert(dimSizes.size() == dimRank && "Dimension-rank mismatch");
|
||||
params[kParamDimSizes] = genBuffer(builder, loc, dimSizes);
|
||||
// The dimension-to-level mapping. We must preinitialize `dim2lvl`
|
||||
// so that the true branch below can perform random-access `operator[]`
|
||||
// assignment.
|
||||
// The level-sizes array must be passed as well, since for arbitrary
|
||||
// dim2lvl mappings it cannot be trivially reconstructed at runtime.
|
||||
// For now however, since we're still assuming permutations, we will
|
||||
// initialize this parameter alongside the `dim2lvl` and `lvl2dim`
|
||||
// parameters below. We preinitialize `lvlSizes` for code symmetry.
|
||||
SmallVector<Value, 4> lvlSizes(lvlRank);
|
||||
// The dimension-to-level mapping and its inverse. We must preinitialize
|
||||
// `dim2lvl` so that the true branch below can perform random-access
|
||||
// `operator[]` assignment. We preinitialize `lvl2dim` for code symmetry.
|
||||
SmallVector<Value, 4> dim2lvl(dimRank);
|
||||
SmallVector<Value, 4> lvl2dim(lvlRank);
|
||||
auto dimOrder = enc.getDimOrdering();
|
||||
if (dimOrder) {
|
||||
assert(dimOrder.isPermutation());
|
||||
@@ -271,13 +289,20 @@ NewCallParams &NewCallParams::genBuffers(SparseTensorEncodingAttr enc,
|
||||
// The `d`th source variable occurs in the `l`th result position.
|
||||
uint64_t d = dimOrder.getDimPosition(l);
|
||||
dim2lvl[d] = constantIndex(builder, loc, l);
|
||||
lvl2dim[l] = constantIndex(builder, loc, d);
|
||||
lvlSizes[l] = dimSizes[d];
|
||||
}
|
||||
} else {
|
||||
assert(dimRank == lvlRank && "Rank mismatch");
|
||||
for (unsigned i = 0; i < lvlRank; i++)
|
||||
dim2lvl[i] = constantIndex(builder, loc, i);
|
||||
for (unsigned i = 0; i < lvlRank; i++) {
|
||||
dim2lvl[i] = lvl2dim[i] = constantIndex(builder, loc, i);
|
||||
lvlSizes[i] = dimSizes[i];
|
||||
}
|
||||
}
|
||||
params[kParamDim2Lvl] = genBuffer(builder, loc, dim2lvl);
|
||||
params[kParamLvlSizes] = genBuffer(builder, loc, lvlSizes);
|
||||
params[kParamLvl2Dim] = genBuffer(builder, loc, lvl2dim);
|
||||
params[kParamDim2Lvl] =
|
||||
dimOrder ? genBuffer(builder, loc, dim2lvl) : params[kParamLvl2Dim];
|
||||
// Secondary and primary types encoding.
|
||||
setTemplateTypes(enc, stp);
|
||||
// Finally, make note that initialization is complete.
|
||||
@@ -316,9 +341,10 @@ static void genDelIteratorCall(OpBuilder &builder, Location loc, Type elemTp,
|
||||
/// if val != 0
|
||||
/// t->add(&val, [i1,..,ik], [p1,..,pk]);
|
||||
static void genAddEltCall(OpBuilder &builder, Location loc, Type eltType,
|
||||
Value ptr, Value valPtr, Value ind, Value perm) {
|
||||
Value lvlCOO, Value valPtr, Value dimInd,
|
||||
Value dim2lvl) {
|
||||
SmallString<9> name{"addElt", primaryTypeFunctionSuffix(eltType)};
|
||||
SmallVector<Value, 4> params{ptr, valPtr, ind, perm};
|
||||
SmallVector<Value, 4> params{lvlCOO, valPtr, dimInd, dim2lvl};
|
||||
Type pTp = getOpaquePointerType(builder);
|
||||
createFuncCall(builder, loc, name, pTp, params, EmitCInterface::On);
|
||||
}
|
||||
@@ -632,7 +658,7 @@ public:
|
||||
Value src = adaptor.getOperands()[0];
|
||||
int64_t idx = *index;
|
||||
rewriter.replaceOp(op,
|
||||
genDimSizeCall(rewriter, op->getLoc(), enc, src, idx));
|
||||
genLvlSizeCall(rewriter, op->getLoc(), enc, src, idx));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user