[mlir][sparse] Cleaning up the dim/lvl distinction in SparseTensorConversion

This change cleans up the conversion pass re the "dim"-vs-"lvl" and "sizes"-vs-"shape" distinctions of the runtime. A quick synopsis includes:

* Adds new `SparseTensorStorageBase::getDimSize` method, with `sparseDimSize` wrapper in SparseTensorRuntime.h, and `genDimSizeCall` generator in SparseTensorConversion.cpp
* Changes `genLvlSizeCall` to perform no logic, just generate the function call.
* Adds `createOrFold{Dim,Lvl}Call` functions to handle the logic of replacing `gen{Dim,Lvl}SizeCall` with constants whenever possible. The `createOrFoldDimCall` function replaces the old `sizeFromPtrAtDim`.
* Adds `{get,fill}DimSizes` functions for iterating `createOrFoldDimCall` across the whole type. These functions replace the old `sizesFromPtr`.
* Adds `{get,fill}DimShape` functions for lowering a `ShapedType` into constants. These functions replace the old `sizesFromType`.
* Changes the `DimOp` rewrite to do the right thing.
* Changes the `ExpandOp` rewrite to compute the proper expansion size.

Depends On D138365

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D139165
This commit is contained in:
wren romano
2022-12-01 19:08:45 -08:00
parent 33bcb3dc79
commit 86f91e45a2
8 changed files with 196 additions and 137 deletions

View File

@@ -57,62 +57,111 @@ static func::CallOp replaceOpWithFuncCall(RewriterBase &rewriter, Operation *op,
operands);
}
/// Generates call to lookup a level-size.
static Value genLvlSizeCall(OpBuilder &builder, Location loc,
SparseTensorEncodingAttr &enc, Value src,
/// Generates call to lookup a level-size. N.B., this only generates
/// the raw function call, and therefore (intentionally) does not perform
/// any dim<->lvl conversion or other logic.
static Value genLvlSizeCall(OpBuilder &builder, Location loc, Value tensor,
uint64_t lvl) {
// Generate the call.
StringRef name = "sparseLvlSize";
SmallVector<Value, 2> params{ // just two
src, constantIndex(builder, loc, toStoredDim(enc, lvl))};
SmallVector<Value, 2> params{tensor, constantIndex(builder, loc, lvl)};
Type iTp = builder.getIndexType();
return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off)
.getResult(0);
}
/// Compute the size from type (for static sizes) or from an already-converted
/// opaque pointer source (for dynamic sizes) at the given dimension.
//
// FIXME: Need to rename this function to match `genLvlSizeCall` and hence
// match the naming convention used in the runtime library. However, it's
// not entirely clear that all callsites of this function properly make the
// "level"-vs-"dimension" distinction; so need to audit each callsite to
// ensure this still does what they mean (possibly by having two separate
// functions, one for levels and one for dimensions). That also means
// renaming `sizesFromPtr`, `sizesFromType`, etc, to make clear whether
// they mean to be referring to level-sizes vs dimension-sizes.
static Value sizeFromPtrAtDim(OpBuilder &builder, Location loc,
SparseTensorEncodingAttr &enc, ShapedType stp,
Value src, unsigned i) {
auto shape = stp.getShape();
if (shape[i] == ShapedType::kDynamic)
return genLvlSizeCall(builder, loc, enc, src, i);
return constantIndex(builder, loc, shape[i]);
/// Generates call to lookup a dimension-size. N.B., this only generates
/// the raw function call, and therefore (intentionally) does not perform
/// any dim<->lvl conversion or other logic.
static Value genDimSizeCall(OpBuilder &builder, Location loc, Value tensor,
uint64_t dim) {
StringRef name = "sparseDimSize";
SmallVector<Value, 2> params{tensor, constantIndex(builder, loc, dim)};
Type iTp = builder.getIndexType();
return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off)
.getResult(0);
}
/// Populates given sizes array from type (for static sizes) and from
/// an already-converted opaque pointer source (for dynamic sizes).
static void sizesFromPtr(OpBuilder &builder, SmallVectorImpl<Value> &sizes,
Location loc, SparseTensorEncodingAttr &enc,
ShapedType stp, Value src) {
unsigned rank = stp.getRank();
sizes.reserve(rank);
for (unsigned i = 0; i < rank; i++)
sizes.push_back(sizeFromPtrAtDim(builder, loc, enc, stp, src, i));
/// Looks up a level-size by returning a statically-computed constant
/// (when possible), or by calling `genLvlSizeCall` (when dynamic).
static Value createOrFoldLvlCall(OpBuilder &builder, Location loc,
SparseTensorEncodingAttr &enc, ShapedType stp,
Value tensor, unsigned lvl) {
// Only sparse tensors have "levels" to query.
assert(enc);
auto dimOrder = enc.getDimOrdering();
// TODO: The following implementation only handles permutations;
// we'll need to generalize this to handle arbitrary AffineExpr.
//
// There's no need to assert `isPermutation` here: because
// `getDimPosition` checks that the expr isa `AffineDimExpr`,
// which is all we care about (for supporting permutations).
unsigned dim = dimOrder ? dimOrder.getDimPosition(lvl) : lvl;
auto s = stp.getShape()[dim];
if (s != ShapedType::kDynamic)
return constantIndex(builder, loc, s);
// If we cannot statically compute the size from the shape, then we
// must dynamically query it. (In principle we could also dynamically
// compute it, but since we already did so to construct the `tensor`
// in the first place, we might as well query rather than recompute.)
return genLvlSizeCall(builder, loc, tensor, lvl);
}
/// Populates given sizes array from type.
static void sizesFromType(OpBuilder &builder, SmallVectorImpl<Value> &sizes,
Location loc, ShapedType stp) {
/// Looks up a dimension-size by returning a constant from the shape
/// (for static sizes), or by calling `genDimSizeCall` (for dynamic sizes
/// of sparse tensors) or `linalg::createOrFoldDimOp` (for dynamic sizes
/// of dense tensors).
static Value createOrFoldDimCall(OpBuilder &builder, Location loc,
SparseTensorEncodingAttr &enc, ShapedType stp,
Value tensor, unsigned dim) {
auto s = stp.getShape()[dim];
if (s != ShapedType::kDynamic)
return constantIndex(builder, loc, s);
if (enc)
return genDimSizeCall(builder, loc, tensor, dim);
return linalg::createOrFoldDimOp(builder, loc, tensor, dim);
}
/// Populates the array with the dimension-sizes of the given tensor.
static void fillDimSizes(OpBuilder &builder, Location loc,
SparseTensorEncodingAttr &enc, ShapedType stp,
Value tensor, SmallVectorImpl<Value> &out) {
unsigned dimRank = stp.getRank();
out.reserve(dimRank);
for (unsigned d = 0; d < dimRank; d++)
out.push_back(createOrFoldDimCall(builder, loc, enc, stp, tensor, d));
}
/// Returns an array with the dimension-sizes of the given tensor.
static SmallVector<Value> getDimSizes(OpBuilder &builder, Location loc,
SparseTensorEncodingAttr &enc,
ShapedType stp, Value tensor) {
SmallVector<Value> out;
fillDimSizes(builder, loc, enc, stp, tensor, out);
return out;
}
/// Populates the array with the dimension-shape of the given `ShapedType`,
/// where dynamic sizes are represented by zero.
static void fillDimShape(OpBuilder &builder, Location loc, ShapedType stp,
SmallVectorImpl<Value> &out) {
auto shape = stp.getShape();
unsigned rank = stp.getRank();
sizes.reserve(rank);
for (unsigned i = 0; i < rank; i++) {
uint64_t s = shape[i] == ShapedType::kDynamic ? 0 : shape[i];
sizes.push_back(constantIndex(builder, loc, s));
unsigned dimRank = stp.getRank();
out.reserve(dimRank);
for (unsigned d = 0; d < dimRank; d++) {
auto s = shape[d] == ShapedType::kDynamic ? 0 : shape[d];
out.push_back(constantIndex(builder, loc, s));
}
}
/// Returns an array with the dimension-shape of the given `ShapedType`,
/// where dynamic sizes are represented by zero.
static SmallVector<Value> getDimShape(OpBuilder &builder, Location loc,
ShapedType stp) {
SmallVector<Value> out;
fillDimShape(builder, loc, stp, out);
return out;
}
/// Populates the given sizes array for concatenation from type (for static
/// sizes) and from an already-converted opaque pointer source (for dynamic
/// sizes).
@@ -128,7 +177,7 @@ static void concatSizesFromInputs(OpBuilder &builder,
// compute the size of the concatenation dimension if necessary.
if (srcEnc)
// Reuses sizes from an arbitrary input tensor is fine.
sizesFromPtr(builder, sizes, loc, srcEnc, srcTp, srcs[0]);
fillDimSizes(builder, loc, srcEnc, srcTp, srcs[0], sizes);
else
sizesFromSrc(builder, sizes, loc, srcs[0]);
@@ -142,8 +191,7 @@ static void concatSizesFromInputs(OpBuilder &builder,
auto srcTp = srcs[i].getType().cast<ShapedType>();
auto encSrc = getSparseTensorEncoding(srcTp);
Value srcSz =
encSrc ? sizeFromPtrAtDim(builder, loc, encSrc, srcTp, srcs[i], dim)
: linalg::createOrFoldDimOp(builder, loc, srcs[i], dim);
createOrFoldDimCall(builder, loc, encSrc, srcTp, srcs[i], dim);
// Sum up all the sizes.
sizes[dim] = builder.create<arith::AddIOp>(loc, sizes[dim], srcSz);
}
@@ -489,9 +537,6 @@ genSparse2SparseReshape(ReshapeOp op, typename ReshapeOp::Adaptor adaptor,
auto encDst = getSparseTensorEncoding(dstTp);
if (!encDst || !encSrc)
return failure();
unsigned srcRank = srcTp.getRank();
unsigned dstRank = dstTp.getRank();
Type elemTp = srcTp.getElementType();
assert(elemTp == dstTp.getElementType() &&
"reshape should not change element type");
@@ -499,26 +544,26 @@ genSparse2SparseReshape(ReshapeOp op, typename ReshapeOp::Adaptor adaptor,
auto noPerm = SparseTensorEncodingAttr::get(
op->getContext(), encSrc.getDimLevelType(), AffineMap(), AffineMap(),
encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
SmallVector<Value> srcSizes;
sizesFromPtr(rewriter, srcSizes, loc, encSrc, srcTp, adaptor.getSrc());
SmallVector<Value> srcDimSizes =
getDimSizes(rewriter, loc, encSrc, srcTp, adaptor.getSrc());
NewCallParams params(rewriter, loc);
Value iter = params.genBuffers(noPerm, srcSizes, srcTp)
Value iter = params.genBuffers(noPerm, srcDimSizes, srcTp)
.genNewCall(Action::kToIterator, adaptor.getSrc());
// Start a new COO for the destination tensor.
SmallVector<Value> dstSizes;
if (dstTp.hasStaticShape()) {
sizesFromType(rewriter, dstSizes, loc, dstTp);
} else {
ArrayRef<int64_t> dstShape = dstTp.getShape();
genReshapeDstShape(loc, rewriter, dstSizes, srcSizes, dstShape,
op.getReassociationIndices());
}
Value coo =
params.genBuffers(encDst, dstSizes, dstTp).genNewCall(Action::kEmptyCOO);
SmallVector<Value> dstDimSizes;
if (dstTp.hasStaticShape())
// Static "shapes" are in fact "sizes".
fillDimShape(rewriter, loc, dstTp, dstDimSizes);
else
genReshapeDstShape(loc, rewriter, dstDimSizes, srcDimSizes,
dstTp.getShape(), op.getReassociationIndices());
Value coo = params.genBuffers(encDst, dstDimSizes, dstTp)
.genNewCall(Action::kEmptyCOO);
Value dstPerm = params.getDim2LvlMap();
// Construct a while loop over the iterator.
Value srcIdx = genAlloca(rewriter, loc, srcRank, rewriter.getIndexType());
Value dstIdx = genAlloca(rewriter, loc, dstRank, rewriter.getIndexType());
Type iTp = rewriter.getIndexType();
Value srcIdx = genAlloca(rewriter, loc, srcTp.getRank(), iTp);
Value dstIdx = genAlloca(rewriter, loc, dstTp.getRank(), iTp);
Value elemPtr = genAllocaScalar(rewriter, loc, elemTp);
SmallVector<Value> noArgs;
SmallVector<Type> noTypes;
@@ -532,7 +577,7 @@ genSparse2SparseReshape(ReshapeOp op, typename ReshapeOp::Adaptor adaptor,
Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, noTypes);
rewriter.setInsertionPointToStart(after);
translateIndices(loc, rewriter, op.getReassociationIndices(), dstTp, srcTp,
dstIdx, srcIdx, dstSizes, srcSizes);
dstIdx, srcIdx, dstDimSizes, srcDimSizes);
genAddEltCall(rewriter, loc, elemTp, coo, elemPtr, dstIdx, dstPerm);
rewriter.create<scf::YieldOp>(loc);
// Final call to construct sparse tensor storage and free temporary resources.
@@ -566,10 +611,9 @@ static void genSparseCOOIterationLoop(
auto noPerm = SparseTensorEncodingAttr::get(
rewriter.getContext(), enc.getDimLevelType(), AffineMap(), AffineMap(),
enc.getPointerBitWidth(), enc.getIndexBitWidth());
SmallVector<Value> sizes;
sizesFromPtr(rewriter, sizes, loc, noPerm, tensorTp, t);
SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, noPerm, tensorTp, t);
Value iter = NewCallParams(rewriter, loc)
.genBuffers(noPerm, sizes, tensorTp)
.genBuffers(noPerm, dimSizes, tensorTp)
.genNewCall(Action::kToIterator, t);
// Construct a while loop over the iterator.
@@ -664,7 +708,7 @@ public:
}
};
/// Sparse conversion rule for dimension accesses.
/// Sparse conversion rule for accessing dimension-sizes.
class SparseTensorToDimSizeConverter
: public OpConversionPattern<tensor::DimOp> {
public:
@@ -672,18 +716,19 @@ public:
LogicalResult
matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Only rewrite annotated DimOp with constant index.
auto enc = getSparseTensorEncoding(op.getSource().getType());
auto stp = op.getSource().getType().cast<ShapedType>();
// Only rewrite sparse DimOp.
auto enc = getSparseTensorEncoding(stp);
if (!enc)
return failure();
Optional<int64_t> index = op.getConstantIndex();
if (!index)
// Only rewrite DimOp with constant index.
Optional<int64_t> dim = op.getConstantIndex();
if (!dim)
return failure();
// Generate the call.
Value src = adaptor.getOperands()[0];
int64_t idx = *index;
rewriter.replaceOp(op,
genLvlSizeCall(rewriter, op->getLoc(), enc, src, idx));
rewriter.replaceOp(
op, createOrFoldDimCall(rewriter, op->getLoc(), enc, stp, src, *dim));
return success();
}
};
@@ -734,8 +779,7 @@ public:
const unsigned lvlRank = enc.getDimLevelType().size();
// Construct the dimShape.
const auto dimShape = stp.getShape();
SmallVector<Value> dimShapeValues;
sizesFromType(rewriter, dimShapeValues, loc, stp);
SmallVector<Value> dimShapeValues = getDimShape(rewriter, loc, stp);
Value dimShapeBuffer = genBuffer(rewriter, loc, dimShapeValues);
// Allocate `SparseTensorReader` and perform all initial setup that
// does not depend on lvlSizes (nor dim2lvl, lvl2dim, etc).
@@ -890,10 +934,10 @@ public:
rewriter.replaceOp(op, adaptor.getOperands()); // hidden nop cast
return success();
}
SmallVector<Value> sizes;
NewCallParams params(rewriter, loc);
ShapedType stp = srcType.cast<ShapedType>();
sizesFromPtr(rewriter, sizes, loc, encSrc, stp, src);
SmallVector<Value> dimSizes =
getDimSizes(rewriter, loc, encSrc, stp, src);
bool useDirectConversion;
switch (options.sparseToSparseStrategy) {
case SparseToSparseConversionStrategy::kViaCOO:
@@ -909,7 +953,7 @@ public:
break;
}
if (useDirectConversion) {
rewriter.replaceOp(op, params.genBuffers(encDst, sizes, stp)
rewriter.replaceOp(op, params.genBuffers(encDst, dimSizes, stp)
.genNewCall(Action::kSparseToSparse, src));
} else { // use via-COO conversion.
// Set up encoding with right mix of src and dst so that the two
@@ -922,8 +966,8 @@ public:
// TODO: This is the only place where `kToCOO` (or `kToIterator`)
// is called with a non-identity permutation. Is there any clean
// way to push the permutation over to the `kFromCOO` side instead?
Value coo =
params.genBuffers(enc, sizes, stp).genNewCall(Action::kToCOO, src);
Value coo = params.genBuffers(enc, dimSizes, stp)
.genNewCall(Action::kToCOO, src);
Value dst = params.setTemplateTypes(encDst, stp)
.genNewCall(Action::kFromCOO, coo);
genDelCOOCall(rewriter, loc, stp.getElementType(), coo);
@@ -950,17 +994,17 @@ public:
op->getContext(),
SmallVector<DimLevelType>(rank, DimLevelType::Dense), AffineMap(),
AffineMap(), encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
SmallVector<Value> sizes;
sizesFromPtr(rewriter, sizes, loc, encSrc, srcTensorTp, src);
SmallVector<Value> dimSizes =
getDimSizes(rewriter, loc, encSrc, srcTensorTp, src);
Value iter = NewCallParams(rewriter, loc)
.genBuffers(encDst, sizes, dstTensorTp)
.genBuffers(encDst, dimSizes, dstTensorTp)
.genNewCall(Action::kToIterator, src);
Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
Value elemPtr = genAllocaScalar(rewriter, loc, elemTp);
Block *insertionBlock = rewriter.getInsertionBlock();
// TODO: Dense buffers should be allocated/deallocated via the callback
// in BufferizationOptions.
Value dst = allocDenseTensor(rewriter, loc, dstTensorTp, sizes);
Value dst = allocDenseTensor(rewriter, loc, dstTensorTp, dimSizes);
SmallVector<Value> noArgs;
SmallVector<Type> noTypes;
auto whileOp = rewriter.create<scf::WhileOp>(loc, noTypes, noArgs);
@@ -1196,12 +1240,12 @@ public:
Type idxType = rewriter.getIndexType();
// All initialization should be done on entry of the loop nest.
rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp());
// Determine the size for access expansion (always the innermost stored
// dimension size, translated back to original dimension).
auto enc = getSparseTensorEncoding(srcType);
unsigned innerDim = toOrigDim(srcType, srcType.getRank() - 1);
auto sz = sizeFromPtrAtDim(rewriter, loc, enc, srcType, adaptor.getTensor(),
innerDim);
// Get the cardinality of valid coordinates for the innermost level.
auto srcEnc = getSparseTensorEncoding(srcType);
unsigned lvlRank =
srcEnc ? srcEnc.getDimLevelType().size() : srcType.getRank();
Value sz = createOrFoldLvlCall(rewriter, loc, srcEnc, srcType,
adaptor.getTensor(), lvlRank - 1);
// Allocate temporary buffers for values, filled-switch, and indices.
// We do not use stack buffers for this, since the expanded size may
// be rather large (as it envelops a single expanded dense dimension).
@@ -1377,10 +1421,8 @@ public:
}
// Accumulate offset.
// TODO: avoid calling sparseDimSize multiple times by caching the result!
Value curDim = encSrc ? sizeFromPtrAtDim(rewriter, loc, encSrc, srcTp,
adaptedOp, concatDim)
: linalg::createOrFoldDimOp(rewriter, loc,
adaptedOp, concatDim);
Value curDim = createOrFoldDimCall(rewriter, loc, encSrc, srcTp,
adaptedOp, concatDim);
offset = rewriter.create<arith::AddIOp>(loc, offset, curDim);
}
@@ -1410,13 +1452,13 @@ public:
// Convert to default permuted COO.
Value src = adaptor.getOperands()[0];
auto encSrc = getSparseTensorEncoding(srcType);
SmallVector<Value> sizes;
sizesFromPtr(rewriter, sizes, loc, encSrc, srcType, src);
SmallVector<Value> dimSizes =
getDimSizes(rewriter, loc, encSrc, srcType, src);
auto enc = SparseTensorEncodingAttr::get(
op->getContext(), encSrc.getDimLevelType(), AffineMap(), AffineMap(),
encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
Value coo = NewCallParams(rewriter, loc)
.genBuffers(enc, sizes, srcType)
.genBuffers(enc, dimSizes, srcType)
.genNewCall(Action::kToCOO, src);
// Then output the tensor to external file with indices in the externally
// visible lexicographic index order. A sort is required if the source was