[mlir][sparse] Factoring out NewCallParams class in SparseTensorConversion.cpp

The new class helps encapsulate the arguments to `_mlir_ciface_newSparseTensor` so that client code doesn't depend on the details of the API.  (This makes way for the next differential which significantly alters the API.)

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D137680
This commit is contained in:
wren romano
2022-11-08 16:43:44 -08:00
parent 3b70e8b012
commit f5ce99afa7

View File

@@ -70,16 +70,6 @@ static Value genDimSizeCall(OpBuilder &builder, Location loc,
.getResult(0);
}
/// Generates a call into the "swiss army knife" method of the sparse runtime
/// support library for materializing sparse tensors into the computation.
static Value genNewCall(OpBuilder &builder, Location loc,
ArrayRef<Value> params) {
StringRef name = "newSparseTensor";
Type pTp = getOpaquePointerType(builder);
return createFuncCall(builder, loc, name, pTp, params, EmitCInterface::On)
.getResult(0);
}
/// Compute the size from type (for static sizes) or from an already-converted
/// opaque pointer source (for dynamic sizes) at the given dimension.
static Value sizeFromPtrAtDim(OpBuilder &builder, Location loc,
@@ -168,41 +158,132 @@ static Value genBuffer(OpBuilder &builder, Location loc, ValueRange values) {
return buffer;
}
/// Populates parameters required to call the "swiss army knife" method of the
/// sparse runtime support library for materializing sparse tensors into the
/// computation.
static void newParams(OpBuilder &builder, SmallVector<Value, 8> &params,
Location loc, ShapedType stp,
SparseTensorEncodingAttr &enc, Action action,
ValueRange szs, Value ptr = Value()) {
ArrayRef<DimLevelType> dlt = enc.getDimLevelType();
unsigned sz = dlt.size();
/// This class abstracts over the API of `_mlir_ciface_newSparseTensor`:
/// the "swiss army knife" method of the sparse runtime support library
/// for materializing sparse tensors into the computation. This abstraction
/// reduces the need to make modifications to client code whenever that
/// API changes.
class NewCallParams final {
public:
/// Allocates the `ValueRange` for the `func::CallOp` parameters,
/// but does not initialize them.
NewCallParams(OpBuilder &builder, Location loc)
: builder(builder), loc(loc), pTp(getOpaquePointerType(builder)) {}
/// Initializes all static parameters (i.e., those which indicate
/// type-level information such as the encoding and sizes), generating
/// MLIR buffers as needed, and returning `this` for method chaining.
/// This method does not set the action and pointer arguments, since
/// those are handled by `genNewCall` instead.
NewCallParams &genBuffers(SparseTensorEncodingAttr enc, ValueRange sizes,
ShapedType stp);
/// (Re)sets the C++ template type parameters, and returns `this`
/// for method chaining. This is already done as part of `genBuffers`,
/// but is factored out so that it can also be called independently
/// whenever subsequent `genNewCall` calls want to reuse the same
/// buffers but different type parameters.
//
// TODO: This is only ever used by sparse2sparse-viaCOO `ConvertOp`;
// is there a better way to handle that than this one-off setter method?
NewCallParams &setTemplateTypes(SparseTensorEncodingAttr enc,
ShapedType stp) {
params[kParamPtrTp] = constantPointerTypeEncoding(builder, loc, enc);
params[kParamIndTp] = constantIndexTypeEncoding(builder, loc, enc);
params[kParamValTp] =
constantPrimaryTypeEncoding(builder, loc, stp.getElementType());
return *this;
}
/// Checks whether all the static parameters have been initialized.
bool isInitialized() const {
for (unsigned i = 0; i < kNumStaticParams; ++i)
if (!params[i])
return false;
return true;
}
/// Gets the dimension-to-level mapping.
//
// TODO: This is only ever used for passing into `genAddEltCall`;
// is there a better way to encapsulate that pattern (both to avoid
// this one-off getter, and to avoid potential mixups)?
Value getDim2LvlMap() const {
assert(isInitialized() && "Must initialize before getDim2LvlMap");
return params[kParamDim2Lvl];
}
/// Generates a function call, with the current static parameters
/// and the given dynamic arguments.
Value genNewCall(Action action, Value ptr = Value()) {
assert(isInitialized() && "Must initialize before genNewCall");
StringRef name = "newSparseTensor";
params[kParamAction] = constantAction(builder, loc, action);
params[kParamPtr] = ptr ? ptr : builder.create<LLVM::NullOp>(loc, pTp);
return createFuncCall(builder, loc, name, pTp, params, EmitCInterface::On)
.getResult(0);
}
private:
static constexpr unsigned kNumStaticParams = 6;
static constexpr unsigned kNumDynamicParams = 2;
static constexpr unsigned kNumParams = kNumStaticParams + kNumDynamicParams;
static constexpr unsigned kParamLvlTypes = 0;
static constexpr unsigned kParamDimSizes = 1;
static constexpr unsigned kParamDim2Lvl = 2;
static constexpr unsigned kParamPtrTp = 3;
static constexpr unsigned kParamIndTp = 4;
static constexpr unsigned kParamValTp = 5;
static constexpr unsigned kParamAction = 6;
static constexpr unsigned kParamPtr = 7;
OpBuilder &builder;
Location loc;
Type pTp;
Value params[kNumParams];
};
// TODO: see the note at `_mlir_ciface_newSparseTensor` about how
// the meaning of the various arguments (e.g., "sizes" vs "shapes")
// is inconsistent between the different actions.
NewCallParams &NewCallParams::genBuffers(SparseTensorEncodingAttr enc,
ValueRange dimSizes, ShapedType stp) {
const unsigned lvlRank = enc.getDimLevelType().size();
const unsigned dimRank = stp.getRank();
// Sparsity annotations.
SmallVector<Value, 4> attrs;
for (unsigned i = 0; i < sz; i++)
attrs.push_back(constantDimLevelTypeEncoding(builder, loc, dlt[i]));
params.push_back(genBuffer(builder, loc, attrs));
// Dimension sizes array of the enveloping tensor. Useful for either
SmallVector<Value, 4> lvlTypes;
for (auto dlt : enc.getDimLevelType())
lvlTypes.push_back(constantDimLevelTypeEncoding(builder, loc, dlt));
assert(lvlTypes.size() == lvlRank && "Level-rank mismatch");
params[kParamLvlTypes] = genBuffer(builder, loc, lvlTypes);
// Dimension-sizes array of the enveloping tensor. Useful for either
// verification of external data, or for construction of internal data.
params.push_back(genBuffer(builder, loc, szs));
// Dimension order permutation array. This is the "identity" permutation by
// default, or otherwise the "reverse" permutation of a given ordering, so
// that indices can be mapped quickly to the right position.
SmallVector<Value, 4> rev(sz);
for (unsigned i = 0; i < sz; i++)
rev[toOrigDim(enc, i)] = constantIndex(builder, loc, i);
params.push_back(genBuffer(builder, loc, rev));
assert(dimSizes.size() == dimRank && "Dimension-rank mismatch");
params[kParamDimSizes] = genBuffer(builder, loc, dimSizes);
// The dimension-to-level mapping. We must preinitialize `dim2lvl`
// so that the true branch below can perform random-access `operator[]`
// assignment.
SmallVector<Value, 4> dim2lvl(dimRank);
auto dimOrder = enc.getDimOrdering();
if (dimOrder) {
assert(dimOrder.isPermutation());
for (unsigned l = 0; l < lvlRank; l++) {
// The `d`th source variable occurs in the `l`th result position.
uint64_t d = dimOrder.getDimPosition(l);
dim2lvl[d] = constantIndex(builder, loc, l);
}
} else {
assert(dimRank == lvlRank && "Rank mismatch");
for (unsigned i = 0; i < lvlRank; i++)
dim2lvl[i] = constantIndex(builder, loc, i);
}
params[kParamDim2Lvl] = genBuffer(builder, loc, dim2lvl);
// Secondary and primary types encoding.
Type elemTp = stp.getElementType();
params.push_back(constantPointerTypeEncoding(builder, loc, enc));
params.push_back(constantIndexTypeEncoding(builder, loc, enc));
params.push_back(constantPrimaryTypeEncoding(builder, loc, elemTp));
// User action.
params.push_back(constantAction(builder, loc, action));
// Payload pointer.
if (!ptr)
ptr = builder.create<LLVM::NullOp>(loc, getOpaquePointerType(builder));
params.push_back(ptr);
setTemplateTypes(enc, stp);
// Finally, make note that initialization is complete.
assert(isInitialized() && "Initialization failed");
// And return `this` for method chaining.
return *this;
}
/// Generates a call to obtain the values array.
@@ -387,14 +468,12 @@ genSparse2SparseReshape(ReshapeOp op, typename ReshapeOp::Adaptor adaptor,
op->getContext(), encSrc.getDimLevelType(), AffineMap(), AffineMap(),
encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
SmallVector<Value, 4> srcSizes;
SmallVector<Value, 8> params;
sizesFromPtr(rewriter, srcSizes, loc, encSrc, srcTp, adaptor.getSrc());
newParams(rewriter, params, loc, srcTp, noPerm, Action::kToIterator, srcSizes,
adaptor.getSrc());
Value iter = genNewCall(rewriter, loc, params);
NewCallParams params(rewriter, loc);
Value iter = params.genBuffers(noPerm, srcSizes, srcTp)
.genNewCall(Action::kToIterator, adaptor.getSrc());
// Start a new COO for the destination tensor.
SmallVector<Value, 4> dstSizes;
params.clear();
if (dstTp.hasStaticShape()) {
sizesFromType(rewriter, dstSizes, loc, dstTp);
} else {
@@ -402,9 +481,9 @@ genSparse2SparseReshape(ReshapeOp op, typename ReshapeOp::Adaptor adaptor,
genReshapeDstShape(loc, rewriter, dstSizes, srcSizes, dstShape,
op.getReassociationIndices());
}
newParams(rewriter, params, loc, dstTp, encDst, Action::kEmptyCOO, dstSizes);
Value coo = genNewCall(rewriter, loc, params);
Value dstPerm = params[2];
Value coo =
params.genBuffers(encDst, dstSizes, dstTp).genNewCall(Action::kEmptyCOO);
Value dstPerm = params.getDim2LvlMap();
// Construct a while loop over the iterator.
Value srcIdx = genAlloca(rewriter, loc, srcRank, rewriter.getIndexType());
Value dstIdx = genAlloca(rewriter, loc, dstRank, rewriter.getIndexType());
@@ -426,9 +505,7 @@ genSparse2SparseReshape(ReshapeOp op, typename ReshapeOp::Adaptor adaptor,
rewriter.create<scf::YieldOp>(loc);
// Final call to construct sparse tensor storage and free temporary resources.
rewriter.setInsertionPointAfter(whileOp);
params[6] = constantAction(rewriter, loc, Action::kFromCOO);
params[7] = coo;
Value dst = genNewCall(rewriter, loc, params);
Value dst = params.genNewCall(Action::kFromCOO, coo);
genDelCOOCall(rewriter, loc, elemTp, coo);
genDelIteratorCall(rewriter, loc, elemTp, iter);
rewriter.replaceOp(op, dst);
@@ -458,11 +535,10 @@ static void genSparseCOOIterationLoop(
rewriter.getContext(), enc.getDimLevelType(), AffineMap(), AffineMap(),
enc.getPointerBitWidth(), enc.getIndexBitWidth());
SmallVector<Value, 4> sizes;
SmallVector<Value, 8> params;
sizesFromPtr(rewriter, sizes, loc, noPerm, tensorTp, t);
newParams(rewriter, params, loc, tensorTp, noPerm, Action::kToIterator, sizes,
t);
Value iter = genNewCall(rewriter, loc, params);
Value iter = NewCallParams(rewriter, loc)
.genBuffers(noPerm, sizes, tensorTp)
.genNewCall(Action::kToIterator, t);
// Construct a while loop over the iterator.
Value srcIdx = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
@@ -611,12 +687,12 @@ public:
// Generate the call to construct tensor from ptr. The sizes are
// inferred from the result type of the new operator.
SmallVector<Value, 4> sizes;
SmallVector<Value, 8> params;
ShapedType stp = resType.cast<ShapedType>();
sizesFromType(rewriter, sizes, loc, stp);
Value ptr = adaptor.getOperands()[0];
newParams(rewriter, params, loc, stp, enc, Action::kFromFile, sizes, ptr);
rewriter.replaceOp(op, genNewCall(rewriter, loc, params));
rewriter.replaceOp(op, NewCallParams(rewriter, loc)
.genBuffers(enc, sizes, stp)
.genNewCall(Action::kFromFile, ptr));
return success();
}
};
@@ -650,10 +726,10 @@ public:
}
// Generate the call to construct empty tensor. The sizes are
// explicitly defined by the arguments to the alloc operator.
SmallVector<Value, 8> params;
ShapedType stp = resType.cast<ShapedType>();
newParams(rewriter, params, loc, stp, enc, Action::kEmpty, sizes);
rewriter.replaceOp(op, genNewCall(rewriter, loc, params));
rewriter.replaceOp(op,
NewCallParams(rewriter, loc)
.genBuffers(enc, sizes, resType.cast<ShapedType>())
.genNewCall(Action::kEmpty));
return success();
}
};
@@ -690,7 +766,7 @@ public:
return success();
}
SmallVector<Value, 4> sizes;
SmallVector<Value, 8> params;
NewCallParams params(rewriter, loc);
ShapedType stp = srcType.cast<ShapedType>();
sizesFromPtr(rewriter, sizes, loc, encSrc, stp, src);
bool useDirectConversion;
@@ -708,9 +784,8 @@ public:
break;
}
if (useDirectConversion) {
newParams(rewriter, params, loc, stp, encDst, Action::kSparseToSparse,
sizes, src);
rewriter.replaceOp(op, genNewCall(rewriter, loc, params));
rewriter.replaceOp(op, params.genBuffers(encDst, sizes, stp)
.genNewCall(Action::kSparseToSparse, src));
} else { // use via-COO conversion.
// Set up encoding with right mix of src and dst so that the two
// method calls can share most parameters, while still providing
@@ -719,13 +794,13 @@ public:
op->getContext(), encDst.getDimLevelType(), encDst.getDimOrdering(),
encDst.getHigherOrdering(), encSrc.getPointerBitWidth(),
encSrc.getIndexBitWidth());
newParams(rewriter, params, loc, stp, enc, Action::kToCOO, sizes, src);
Value coo = genNewCall(rewriter, loc, params);
params[3] = constantPointerTypeEncoding(rewriter, loc, encDst);
params[4] = constantIndexTypeEncoding(rewriter, loc, encDst);
params[6] = constantAction(rewriter, loc, Action::kFromCOO);
params[7] = coo;
Value dst = genNewCall(rewriter, loc, params);
// TODO: This is the only place where `kToCOO` (or `kToIterator`)
// is called with a non-identity permutation. Is there any clean
// way to push the permutation over to the `kFromCOO` side instead?
Value coo =
params.genBuffers(enc, sizes, stp).genNewCall(Action::kToCOO, src);
Value dst = params.setTemplateTypes(encDst, stp)
.genNewCall(Action::kFromCOO, coo);
genDelCOOCall(rewriter, loc, stp.getElementType(), coo);
rewriter.replaceOp(op, dst);
}
@@ -743,7 +818,7 @@ public:
RankedTensorType srcTensorTp = srcType.cast<RankedTensorType>();
unsigned rank = dstTensorTp.getRank();
Type elemTp = dstTensorTp.getElementType();
// Fabricate a no-permutation encoding for newParams().
// Fabricate a no-permutation encoding for NewCallParams
// The pointer/index types must be those of `src`.
// The dimLevelTypes aren't actually used by Action::kToIterator.
encDst = SparseTensorEncodingAttr::get(
@@ -751,11 +826,10 @@ public:
SmallVector<DimLevelType>(rank, DimLevelType::Dense), AffineMap(),
AffineMap(), encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
SmallVector<Value, 4> sizes;
SmallVector<Value, 8> params;
sizesFromPtr(rewriter, sizes, loc, encSrc, srcTensorTp, src);
newParams(rewriter, params, loc, dstTensorTp, encDst, Action::kToIterator,
sizes, src);
Value iter = genNewCall(rewriter, loc, params);
Value iter = NewCallParams(rewriter, loc)
.genBuffers(encDst, sizes, dstTensorTp)
.genNewCall(Action::kToIterator, src);
Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
Value elemPtr = genAllocaScalar(rewriter, loc, elemTp);
Block *insertionBlock = rewriter.getInsertionBlock();
@@ -817,12 +891,12 @@ public:
ShapedType stp = resType.cast<ShapedType>();
unsigned rank = stp.getRank();
SmallVector<Value, 4> sizes;
SmallVector<Value, 8> params;
sizesFromSrc(rewriter, sizes, loc, src);
newParams(rewriter, params, loc, stp, encDst, Action::kEmptyCOO, sizes);
Value coo = genNewCall(rewriter, loc, params);
NewCallParams params(rewriter, loc);
Value coo =
params.genBuffers(encDst, sizes, stp).genNewCall(Action::kEmptyCOO);
Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
Value perm = params[2];
Value perm = params.getDim2LvlMap();
Type eltType = stp.getElementType();
Value elemPtr = genAllocaScalar(rewriter, loc, eltType);
genDenseTensorOrSparseConstantIterLoop(
@@ -836,9 +910,7 @@ public:
genAddEltCall(builder, loc, eltType, coo, elemPtr, ind, perm);
});
// Final call to construct sparse tensor storage.
params[6] = constantAction(rewriter, loc, Action::kFromCOO);
params[7] = coo;
Value dst = genNewCall(rewriter, loc, params);
Value dst = params.genNewCall(Action::kFromCOO, coo);
genDelCOOCall(rewriter, loc, eltType, coo);
rewriter.replaceOp(op, dst);
return success();
@@ -1117,15 +1189,15 @@ public:
Value offset = constantIndex(rewriter, loc, 0);
SmallVector<Value, 4> sizes;
SmallVector<Value, 8> params;
NewCallParams params(rewriter, loc);
concatSizesFromInputs(rewriter, sizes, loc, dstTp, op.getInputs(),
concatDim);
if (encDst) {
// Start a new COO for the destination tensor.
newParams(rewriter, params, loc, dstTp, encDst, Action::kEmptyCOO, sizes);
dst = genNewCall(rewriter, loc, params);
dstPerm = params[2];
dst =
params.genBuffers(encDst, sizes, dstTp).genNewCall(Action::kEmptyCOO);
dstPerm = params.getDim2LvlMap();
elemPtr = genAllocaScalar(rewriter, loc, elemTp);
dstIdx = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
} else {
@@ -1188,11 +1260,9 @@ public:
offset = rewriter.create<arith::AddIOp>(loc, offset, curDim);
}
if (encDst) {
params[6] = constantAction(rewriter, loc, Action::kFromCOO);
// In sparse output case, the destination holds the COO.
Value coo = dst;
params[7] = coo;
dst = genNewCall(rewriter, loc, params);
dst = params.genNewCall(Action::kFromCOO, coo);
// Release resources.
genDelCOOCall(rewriter, loc, elemTp, coo);
rewriter.replaceOp(op, dst);
@@ -1216,27 +1286,25 @@ public:
Value src = adaptor.getOperands()[0];
auto encSrc = getSparseTensorEncoding(srcType);
SmallVector<Value, 4> sizes;
SmallVector<Value, 8> params;
sizesFromPtr(rewriter, sizes, loc, encSrc, srcType, src);
auto enc = SparseTensorEncodingAttr::get(
op->getContext(), encSrc.getDimLevelType(), AffineMap(), AffineMap(),
encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
newParams(rewriter, params, loc, srcType, enc, Action::kToCOO, sizes, src);
Value coo = genNewCall(rewriter, loc, params);
Value coo = NewCallParams(rewriter, loc)
.genBuffers(enc, sizes, srcType)
.genNewCall(Action::kToCOO, src);
// Then output the tensor to external file with indices in the externally
// visible lexicographic index order. A sort is required if the source was
// not in that order yet (note that the sort can be dropped altogether if
// external format does not care about the order at all, but here we assume
// it does).
bool sort =
encSrc.getDimOrdering() && !encSrc.getDimOrdering().isIdentity();
params.clear();
params.push_back(coo);
params.push_back(adaptor.getOperands()[1]);
params.push_back(constantI1(rewriter, loc, sort));
Value sort = constantI1(rewriter, loc,
encSrc.getDimOrdering() &&
!encSrc.getDimOrdering().isIdentity());
SmallVector<Value, 3> outParams{coo, adaptor.getOperands()[1], sort};
Type eltType = srcType.getElementType();
SmallString<18> name{"outSparseTensor", primaryTypeFunctionSuffix(eltType)};
createFuncCall(rewriter, loc, name, {}, params, EmitCInterface::Off);
createFuncCall(rewriter, loc, name, {}, outParams, EmitCInterface::Off);
genDelCOOCall(rewriter, loc, eltType, coo);
rewriter.eraseOp(op);
return success();