[mlir][sparse] Breaking up openSparseTensor to better support non-permutations

This commit updates how the `SparseTensorConversion` pass handles `NewOp`.  It breaks up the underlying `openSparseTensor` function into two parts (`SparseTensorReader::create` and `SparseTensorReader::readSparseTensor`) so that the pass can inject code for constructing `lvlSizes` between those two parts.  Migrating the construction of `lvlSizes` out of the runtime and into the pass is a necessary first step toward fully supporting non-permutations.  (The alternative would be for the pass to generate a `FuncOp` for performing the construction and then passing that to the runtime; which doesn't seem to have any benefits over the design of this commit.)  And since the pass now generates the code to call these two functions, this change also removes the `Action::kFromFile` value from the enum used by `_mlir_ciface_newSparseTensor`.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D138363
This commit is contained in:
wren romano
2022-12-01 18:18:33 -08:00
parent ca23b7ca47
commit 2af2e4dbb7
10 changed files with 447 additions and 149 deletions

View File

@@ -95,7 +95,9 @@ static Value sizeFromPtrAtDim(OpBuilder &builder, Location loc,
static void sizesFromPtr(OpBuilder &builder, SmallVectorImpl<Value> &sizes,
Location loc, SparseTensorEncodingAttr &enc,
ShapedType stp, Value src) {
for (unsigned i = 0, rank = stp.getRank(); i < rank; i++)
unsigned rank = stp.getRank();
sizes.reserve(rank);
for (unsigned i = 0; i < rank; i++)
sizes.push_back(sizeFromPtrAtDim(builder, loc, enc, stp, src, i));
}
@@ -103,7 +105,9 @@ static void sizesFromPtr(OpBuilder &builder, SmallVectorImpl<Value> &sizes,
static void sizesFromType(OpBuilder &builder, SmallVectorImpl<Value> &sizes,
Location loc, ShapedType stp) {
auto shape = stp.getShape();
for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) {
unsigned rank = stp.getRank();
sizes.reserve(rank);
for (unsigned i = 0; i < rank; i++) {
uint64_t s = shape[i] == ShapedType::kDynamic ? 0 : shape[i];
sizes.push_back(constantIndex(builder, loc, s));
}
@@ -167,6 +171,17 @@ static Value genBuffer(OpBuilder &builder, Location loc, ValueRange values) {
return buffer;
}
/// Generates a temporary buffer for the level-types of the given encoding.
static Value genLvlTypesBuffer(OpBuilder &builder, Location loc,
SparseTensorEncodingAttr enc) {
SmallVector<Value> lvlTypes;
auto dlts = enc.getDimLevelType();
lvlTypes.reserve(dlts.size());
for (auto dlt : dlts)
lvlTypes.push_back(constantDimLevelTypeEncoding(builder, loc, dlt));
return genBuffer(builder, loc, lvlTypes);
}
/// This class abstracts over the API of `_mlir_ciface_newSparseTensor`:
/// the "swiss army knife" method of the sparse runtime support library
/// for materializing sparse tensors into the computation. This abstraction
@@ -262,11 +277,7 @@ NewCallParams &NewCallParams::genBuffers(SparseTensorEncodingAttr enc,
const unsigned lvlRank = enc.getDimLevelType().size();
const unsigned dimRank = stp.getRank();
// Sparsity annotations.
SmallVector<Value> lvlTypes;
for (auto dlt : enc.getDimLevelType())
lvlTypes.push_back(constantDimLevelTypeEncoding(builder, loc, dlt));
assert(lvlTypes.size() == lvlRank && "Level-rank mismatch");
params[kParamLvlTypes] = genBuffer(builder, loc, lvlTypes);
params[kParamLvlTypes] = genLvlTypesBuffer(builder, loc, enc);
// Dimension-sizes array of the enveloping tensor. Useful for either
// verification of external data, or for construction of internal data.
assert(dimSizes.size() == dimRank && "Dimension-rank mismatch");
@@ -715,19 +726,98 @@ public:
matchAndRewrite(NewOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Type resType = op.getType();
auto enc = getSparseTensorEncoding(resType);
auto stp = op.getType().cast<ShapedType>();
auto enc = getSparseTensorEncoding(stp);
if (!enc)
return failure();
// Generate the call to construct tensor from ptr. The sizes are
// inferred from the result type of the new operator.
SmallVector<Value> sizes;
ShapedType stp = resType.cast<ShapedType>();
sizesFromType(rewriter, sizes, loc, stp);
Value ptr = adaptor.getOperands()[0];
rewriter.replaceOp(op, NewCallParams(rewriter, loc)
.genBuffers(enc, sizes, stp)
.genNewCall(Action::kFromFile, ptr));
const unsigned dimRank = stp.getRank();
const unsigned lvlRank = enc.getDimLevelType().size();
// Construct the dimShape.
const auto dimShape = stp.getShape();
SmallVector<Value> dimShapeValues;
sizesFromType(rewriter, dimShapeValues, loc, stp);
Value dimShapeBuffer = genBuffer(rewriter, loc, dimShapeValues);
// Allocate `SparseTensorReader` and perform all initial setup that
// does not depend on lvlSizes (nor dim2lvl, lvl2dim, etc).
Type opaqueTp = getOpaquePointerType(rewriter);
Value valTp =
constantPrimaryTypeEncoding(rewriter, loc, stp.getElementType());
Value reader =
createFuncCall(rewriter, loc, "createCheckedSparseTensorReader",
opaqueTp,
{adaptor.getOperands()[0], dimShapeBuffer, valTp},
EmitCInterface::On)
.getResult(0);
// Construct the lvlSizes. If the dimShape is static, then it's
// identical to dimSizes: so we can compute lvlSizes entirely at
// compile-time. If dimShape is dynamic, then we'll need to generate
// code for computing lvlSizes from the `reader`'s actual dimSizes.
//
// TODO: For now we're still assuming `dim2lvl` is a permutation.
// But since we're computing lvlSizes here (rather than in the runtime),
// we can easily generalize that simply by adjusting this code.
//
// FIXME: reduce redundancy vs `NewCallParams::genBuffers`.
Value dimSizesBuffer;
if (!stp.hasStaticShape()) {
Type indexTp = rewriter.getIndexType();
auto memTp = MemRefType::get({ShapedType::kDynamic}, indexTp);
dimSizesBuffer =
createFuncCall(rewriter, loc, "getSparseTensorReaderDimSizes", memTp,
reader, EmitCInterface::On)
.getResult(0);
}
Value lvlSizesBuffer;
Value lvl2dimBuffer;
Value dim2lvlBuffer;
if (auto dimOrder = enc.getDimOrdering()) {
assert(dimOrder.isPermutation() && "Got non-permutation");
// We preinitialize `dim2lvlValues` since we need random-access writing.
// And we preinitialize the others for stylistic consistency.
SmallVector<Value> lvlSizeValues(lvlRank);
SmallVector<Value> lvl2dimValues(lvlRank);
SmallVector<Value> dim2lvlValues(dimRank);
for (unsigned l = 0; l < lvlRank; l++) {
// The `d`th source variable occurs in the `l`th result position.
uint64_t d = dimOrder.getDimPosition(l);
Value lvl = constantIndex(rewriter, loc, l);
Value dim = constantIndex(rewriter, loc, d);
dim2lvlValues[d] = lvl;
lvl2dimValues[l] = dim;
lvlSizeValues[l] =
(dimShape[d] == ShapedType::kDynamic)
? rewriter.create<memref::LoadOp>(loc, dimSizesBuffer, dim)
: dimShapeValues[d];
}
lvlSizesBuffer = genBuffer(rewriter, loc, lvlSizeValues);
lvl2dimBuffer = genBuffer(rewriter, loc, lvl2dimValues);
dim2lvlBuffer = genBuffer(rewriter, loc, dim2lvlValues);
} else {
assert(dimRank == lvlRank && "Rank mismatch");
SmallVector<Value> iotaValues;
iotaValues.reserve(lvlRank);
for (unsigned i = 0; i < lvlRank; i++)
iotaValues.push_back(constantIndex(rewriter, loc, i));
lvlSizesBuffer = dimSizesBuffer ? dimSizesBuffer : dimShapeBuffer;
dim2lvlBuffer = lvl2dimBuffer = genBuffer(rewriter, loc, iotaValues);
}
// Use the `reader` to parse the file.
SmallVector<Value, 8> params{
reader,
lvlSizesBuffer,
genLvlTypesBuffer(rewriter, loc, enc),
lvl2dimBuffer,
dim2lvlBuffer,
constantPointerTypeEncoding(rewriter, loc, enc),
constantIndexTypeEncoding(rewriter, loc, enc),
valTp};
Value tensor = createFuncCall(rewriter, loc, "newSparseTensorFromReader",
opaqueTp, params, EmitCInterface::On)
.getResult(0);
// Free the memory for `reader`.
createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader},
EmitCInterface::Off);
rewriter.replaceOp(op, tensor);
return success();
}
};