[mlir][sparse] Cleaning up the dim/lvl distinction in SparseTensorConversion
This change cleans up the conversion pass re the "dim"-vs-"lvl" and "sizes"-vs-"shape" distinctions of the runtime. A quick synopsis includes:
* Adds new `SparseTensorStorageBase::getDimSize` method, with `sparseDimSize` wrapper in SparseTensorRuntime.h, and `genDimSizeCall` generator in SparseTensorConversion.cpp
* Changes `genLvlSizeCall` to perform no logic, just generate the function call.
* Adds `createOrFold{Dim,Lvl}Call` functions to handle the logic of replacing `gen{Dim,Lvl}SizeCall` with constants whenever possible. The `createOrFoldDimCall` function replaces the old `sizeFromPtrAtDim`.
* Adds `{get,fill}DimSizes` functions for iterating `createOrFoldDimCall` across the whole type. These functions replace the old `sizesFromPtr`.
* Adds `{get,fill}DimShape` functions for lowering a `ShapedType` into constants. These functions replace the old `sizesFromType`.
* Changes the `DimOp` rewrite to do the right thing.
* Changes the `ExpandOp` rewrite to compute the proper expansion size.
Depends On D138365
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D139165
This commit is contained in:
@@ -57,62 +57,111 @@ static func::CallOp replaceOpWithFuncCall(RewriterBase &rewriter, Operation *op,
|
||||
operands);
|
||||
}
|
||||
|
||||
/// Generates call to lookup a level-size.
|
||||
static Value genLvlSizeCall(OpBuilder &builder, Location loc,
|
||||
SparseTensorEncodingAttr &enc, Value src,
|
||||
/// Generates call to lookup a level-size. N.B., this only generates
|
||||
/// the raw function call, and therefore (intentionally) does not perform
|
||||
/// any dim<->lvl conversion or other logic.
|
||||
static Value genLvlSizeCall(OpBuilder &builder, Location loc, Value tensor,
|
||||
uint64_t lvl) {
|
||||
// Generate the call.
|
||||
StringRef name = "sparseLvlSize";
|
||||
SmallVector<Value, 2> params{ // just two
|
||||
src, constantIndex(builder, loc, toStoredDim(enc, lvl))};
|
||||
SmallVector<Value, 2> params{tensor, constantIndex(builder, loc, lvl)};
|
||||
Type iTp = builder.getIndexType();
|
||||
return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off)
|
||||
.getResult(0);
|
||||
}
|
||||
|
||||
/// Compute the size from type (for static sizes) or from an already-converted
|
||||
/// opaque pointer source (for dynamic sizes) at the given dimension.
|
||||
//
|
||||
// FIXME: Need to rename this function to match `genLvlSizeCall` and hence
|
||||
// match the naming convention used in the runtime library. However, it's
|
||||
// not entirely clear that all callsites of this function properly make the
|
||||
// "level"-vs-"dimension" distinction; so need to audit each callsite to
|
||||
// ensure this still does what they mean (possibly by having two separate
|
||||
// functions, one for levels and one for dimensions). That also means
|
||||
// renaming `sizesFromPtr`, `sizesFromType`, etc, to make clear whether
|
||||
// they mean to be referring to level-sizes vs dimension-sizes.
|
||||
static Value sizeFromPtrAtDim(OpBuilder &builder, Location loc,
|
||||
SparseTensorEncodingAttr &enc, ShapedType stp,
|
||||
Value src, unsigned i) {
|
||||
auto shape = stp.getShape();
|
||||
if (shape[i] == ShapedType::kDynamic)
|
||||
return genLvlSizeCall(builder, loc, enc, src, i);
|
||||
return constantIndex(builder, loc, shape[i]);
|
||||
/// Generates call to lookup a dimension-size. N.B., this only generates
|
||||
/// the raw function call, and therefore (intentionally) does not perform
|
||||
/// any dim<->lvl conversion or other logic.
|
||||
static Value genDimSizeCall(OpBuilder &builder, Location loc, Value tensor,
|
||||
uint64_t dim) {
|
||||
StringRef name = "sparseDimSize";
|
||||
SmallVector<Value, 2> params{tensor, constantIndex(builder, loc, dim)};
|
||||
Type iTp = builder.getIndexType();
|
||||
return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off)
|
||||
.getResult(0);
|
||||
}
|
||||
|
||||
/// Populates given sizes array from type (for static sizes) and from
|
||||
/// an already-converted opaque pointer source (for dynamic sizes).
|
||||
static void sizesFromPtr(OpBuilder &builder, SmallVectorImpl<Value> &sizes,
|
||||
Location loc, SparseTensorEncodingAttr &enc,
|
||||
ShapedType stp, Value src) {
|
||||
unsigned rank = stp.getRank();
|
||||
sizes.reserve(rank);
|
||||
for (unsigned i = 0; i < rank; i++)
|
||||
sizes.push_back(sizeFromPtrAtDim(builder, loc, enc, stp, src, i));
|
||||
/// Looks up a level-size by returning a statically-computed constant
|
||||
/// (when possible), or by calling `genLvlSizeCall` (when dynamic).
|
||||
static Value createOrFoldLvlCall(OpBuilder &builder, Location loc,
|
||||
SparseTensorEncodingAttr &enc, ShapedType stp,
|
||||
Value tensor, unsigned lvl) {
|
||||
// Only sparse tensors have "levels" to query.
|
||||
assert(enc);
|
||||
auto dimOrder = enc.getDimOrdering();
|
||||
// TODO: The following implementation only handles permutations;
|
||||
// we'll need to generalize this to handle arbitrary AffineExpr.
|
||||
//
|
||||
// There's no need to assert `isPermutation` here: because
|
||||
// `getDimPosition` checks that the expr isa `AffineDimExpr`,
|
||||
// which is all we care about (for supporting permutations).
|
||||
unsigned dim = dimOrder ? dimOrder.getDimPosition(lvl) : lvl;
|
||||
auto s = stp.getShape()[dim];
|
||||
if (s != ShapedType::kDynamic)
|
||||
return constantIndex(builder, loc, s);
|
||||
// If we cannot statically compute the size from the shape, then we
|
||||
// must dynamically query it. (In principle we could also dynamically
|
||||
// compute it, but since we already did so to construct the `tensor`
|
||||
// in the first place, we might as well query rather than recompute.)
|
||||
return genLvlSizeCall(builder, loc, tensor, lvl);
|
||||
}
|
||||
|
||||
/// Populates given sizes array from type.
|
||||
static void sizesFromType(OpBuilder &builder, SmallVectorImpl<Value> &sizes,
|
||||
Location loc, ShapedType stp) {
|
||||
/// Looks up a dimension-size by returning a constant from the shape
|
||||
/// (for static sizes), or by calling `genDimSizeCall` (for dynamic sizes
|
||||
/// of sparse tensors) or `linalg::createOrFoldDimOp` (for dynamic sizes
|
||||
/// of dense tensors).
|
||||
static Value createOrFoldDimCall(OpBuilder &builder, Location loc,
|
||||
SparseTensorEncodingAttr &enc, ShapedType stp,
|
||||
Value tensor, unsigned dim) {
|
||||
auto s = stp.getShape()[dim];
|
||||
if (s != ShapedType::kDynamic)
|
||||
return constantIndex(builder, loc, s);
|
||||
if (enc)
|
||||
return genDimSizeCall(builder, loc, tensor, dim);
|
||||
return linalg::createOrFoldDimOp(builder, loc, tensor, dim);
|
||||
}
|
||||
|
||||
/// Populates the array with the dimension-sizes of the given tensor.
|
||||
static void fillDimSizes(OpBuilder &builder, Location loc,
|
||||
SparseTensorEncodingAttr &enc, ShapedType stp,
|
||||
Value tensor, SmallVectorImpl<Value> &out) {
|
||||
unsigned dimRank = stp.getRank();
|
||||
out.reserve(dimRank);
|
||||
for (unsigned d = 0; d < dimRank; d++)
|
||||
out.push_back(createOrFoldDimCall(builder, loc, enc, stp, tensor, d));
|
||||
}
|
||||
|
||||
/// Returns an array with the dimension-sizes of the given tensor.
|
||||
static SmallVector<Value> getDimSizes(OpBuilder &builder, Location loc,
|
||||
SparseTensorEncodingAttr &enc,
|
||||
ShapedType stp, Value tensor) {
|
||||
SmallVector<Value> out;
|
||||
fillDimSizes(builder, loc, enc, stp, tensor, out);
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Populates the array with the dimension-shape of the given `ShapedType`,
|
||||
/// where dynamic sizes are represented by zero.
|
||||
static void fillDimShape(OpBuilder &builder, Location loc, ShapedType stp,
|
||||
SmallVectorImpl<Value> &out) {
|
||||
auto shape = stp.getShape();
|
||||
unsigned rank = stp.getRank();
|
||||
sizes.reserve(rank);
|
||||
for (unsigned i = 0; i < rank; i++) {
|
||||
uint64_t s = shape[i] == ShapedType::kDynamic ? 0 : shape[i];
|
||||
sizes.push_back(constantIndex(builder, loc, s));
|
||||
unsigned dimRank = stp.getRank();
|
||||
out.reserve(dimRank);
|
||||
for (unsigned d = 0; d < dimRank; d++) {
|
||||
auto s = shape[d] == ShapedType::kDynamic ? 0 : shape[d];
|
||||
out.push_back(constantIndex(builder, loc, s));
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns an array with the dimension-shape of the given `ShapedType`,
|
||||
/// where dynamic sizes are represented by zero.
|
||||
static SmallVector<Value> getDimShape(OpBuilder &builder, Location loc,
|
||||
ShapedType stp) {
|
||||
SmallVector<Value> out;
|
||||
fillDimShape(builder, loc, stp, out);
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Populates the given sizes array for concatenation from type (for static
|
||||
/// sizes) and from an already-converted opaque pointer source (for dynamic
|
||||
/// sizes).
|
||||
@@ -128,7 +177,7 @@ static void concatSizesFromInputs(OpBuilder &builder,
|
||||
// compute the size of the concatenation dimension if necessary.
|
||||
if (srcEnc)
|
||||
// Reuses sizes from an arbitrary input tensor is fine.
|
||||
sizesFromPtr(builder, sizes, loc, srcEnc, srcTp, srcs[0]);
|
||||
fillDimSizes(builder, loc, srcEnc, srcTp, srcs[0], sizes);
|
||||
else
|
||||
sizesFromSrc(builder, sizes, loc, srcs[0]);
|
||||
|
||||
@@ -142,8 +191,7 @@ static void concatSizesFromInputs(OpBuilder &builder,
|
||||
auto srcTp = srcs[i].getType().cast<ShapedType>();
|
||||
auto encSrc = getSparseTensorEncoding(srcTp);
|
||||
Value srcSz =
|
||||
encSrc ? sizeFromPtrAtDim(builder, loc, encSrc, srcTp, srcs[i], dim)
|
||||
: linalg::createOrFoldDimOp(builder, loc, srcs[i], dim);
|
||||
createOrFoldDimCall(builder, loc, encSrc, srcTp, srcs[i], dim);
|
||||
// Sum up all the sizes.
|
||||
sizes[dim] = builder.create<arith::AddIOp>(loc, sizes[dim], srcSz);
|
||||
}
|
||||
@@ -489,9 +537,6 @@ genSparse2SparseReshape(ReshapeOp op, typename ReshapeOp::Adaptor adaptor,
|
||||
auto encDst = getSparseTensorEncoding(dstTp);
|
||||
if (!encDst || !encSrc)
|
||||
return failure();
|
||||
|
||||
unsigned srcRank = srcTp.getRank();
|
||||
unsigned dstRank = dstTp.getRank();
|
||||
Type elemTp = srcTp.getElementType();
|
||||
assert(elemTp == dstTp.getElementType() &&
|
||||
"reshape should not change element type");
|
||||
@@ -499,26 +544,26 @@ genSparse2SparseReshape(ReshapeOp op, typename ReshapeOp::Adaptor adaptor,
|
||||
auto noPerm = SparseTensorEncodingAttr::get(
|
||||
op->getContext(), encSrc.getDimLevelType(), AffineMap(), AffineMap(),
|
||||
encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
|
||||
SmallVector<Value> srcSizes;
|
||||
sizesFromPtr(rewriter, srcSizes, loc, encSrc, srcTp, adaptor.getSrc());
|
||||
SmallVector<Value> srcDimSizes =
|
||||
getDimSizes(rewriter, loc, encSrc, srcTp, adaptor.getSrc());
|
||||
NewCallParams params(rewriter, loc);
|
||||
Value iter = params.genBuffers(noPerm, srcSizes, srcTp)
|
||||
Value iter = params.genBuffers(noPerm, srcDimSizes, srcTp)
|
||||
.genNewCall(Action::kToIterator, adaptor.getSrc());
|
||||
// Start a new COO for the destination tensor.
|
||||
SmallVector<Value> dstSizes;
|
||||
if (dstTp.hasStaticShape()) {
|
||||
sizesFromType(rewriter, dstSizes, loc, dstTp);
|
||||
} else {
|
||||
ArrayRef<int64_t> dstShape = dstTp.getShape();
|
||||
genReshapeDstShape(loc, rewriter, dstSizes, srcSizes, dstShape,
|
||||
op.getReassociationIndices());
|
||||
}
|
||||
Value coo =
|
||||
params.genBuffers(encDst, dstSizes, dstTp).genNewCall(Action::kEmptyCOO);
|
||||
SmallVector<Value> dstDimSizes;
|
||||
if (dstTp.hasStaticShape())
|
||||
// Static "shapes" are in fact "sizes".
|
||||
fillDimShape(rewriter, loc, dstTp, dstDimSizes);
|
||||
else
|
||||
genReshapeDstShape(loc, rewriter, dstDimSizes, srcDimSizes,
|
||||
dstTp.getShape(), op.getReassociationIndices());
|
||||
Value coo = params.genBuffers(encDst, dstDimSizes, dstTp)
|
||||
.genNewCall(Action::kEmptyCOO);
|
||||
Value dstPerm = params.getDim2LvlMap();
|
||||
// Construct a while loop over the iterator.
|
||||
Value srcIdx = genAlloca(rewriter, loc, srcRank, rewriter.getIndexType());
|
||||
Value dstIdx = genAlloca(rewriter, loc, dstRank, rewriter.getIndexType());
|
||||
Type iTp = rewriter.getIndexType();
|
||||
Value srcIdx = genAlloca(rewriter, loc, srcTp.getRank(), iTp);
|
||||
Value dstIdx = genAlloca(rewriter, loc, dstTp.getRank(), iTp);
|
||||
Value elemPtr = genAllocaScalar(rewriter, loc, elemTp);
|
||||
SmallVector<Value> noArgs;
|
||||
SmallVector<Type> noTypes;
|
||||
@@ -532,7 +577,7 @@ genSparse2SparseReshape(ReshapeOp op, typename ReshapeOp::Adaptor adaptor,
|
||||
Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, noTypes);
|
||||
rewriter.setInsertionPointToStart(after);
|
||||
translateIndices(loc, rewriter, op.getReassociationIndices(), dstTp, srcTp,
|
||||
dstIdx, srcIdx, dstSizes, srcSizes);
|
||||
dstIdx, srcIdx, dstDimSizes, srcDimSizes);
|
||||
genAddEltCall(rewriter, loc, elemTp, coo, elemPtr, dstIdx, dstPerm);
|
||||
rewriter.create<scf::YieldOp>(loc);
|
||||
// Final call to construct sparse tensor storage and free temporary resources.
|
||||
@@ -566,10 +611,9 @@ static void genSparseCOOIterationLoop(
|
||||
auto noPerm = SparseTensorEncodingAttr::get(
|
||||
rewriter.getContext(), enc.getDimLevelType(), AffineMap(), AffineMap(),
|
||||
enc.getPointerBitWidth(), enc.getIndexBitWidth());
|
||||
SmallVector<Value> sizes;
|
||||
sizesFromPtr(rewriter, sizes, loc, noPerm, tensorTp, t);
|
||||
SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, noPerm, tensorTp, t);
|
||||
Value iter = NewCallParams(rewriter, loc)
|
||||
.genBuffers(noPerm, sizes, tensorTp)
|
||||
.genBuffers(noPerm, dimSizes, tensorTp)
|
||||
.genNewCall(Action::kToIterator, t);
|
||||
|
||||
// Construct a while loop over the iterator.
|
||||
@@ -664,7 +708,7 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
/// Sparse conversion rule for dimension accesses.
|
||||
/// Sparse conversion rule for accessing dimension-sizes.
|
||||
class SparseTensorToDimSizeConverter
|
||||
: public OpConversionPattern<tensor::DimOp> {
|
||||
public:
|
||||
@@ -672,18 +716,19 @@ public:
|
||||
LogicalResult
|
||||
matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// Only rewrite annotated DimOp with constant index.
|
||||
auto enc = getSparseTensorEncoding(op.getSource().getType());
|
||||
auto stp = op.getSource().getType().cast<ShapedType>();
|
||||
// Only rewrite sparse DimOp.
|
||||
auto enc = getSparseTensorEncoding(stp);
|
||||
if (!enc)
|
||||
return failure();
|
||||
Optional<int64_t> index = op.getConstantIndex();
|
||||
if (!index)
|
||||
// Only rewrite DimOp with constant index.
|
||||
Optional<int64_t> dim = op.getConstantIndex();
|
||||
if (!dim)
|
||||
return failure();
|
||||
// Generate the call.
|
||||
Value src = adaptor.getOperands()[0];
|
||||
int64_t idx = *index;
|
||||
rewriter.replaceOp(op,
|
||||
genLvlSizeCall(rewriter, op->getLoc(), enc, src, idx));
|
||||
rewriter.replaceOp(
|
||||
op, createOrFoldDimCall(rewriter, op->getLoc(), enc, stp, src, *dim));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -734,8 +779,7 @@ public:
|
||||
const unsigned lvlRank = enc.getDimLevelType().size();
|
||||
// Construct the dimShape.
|
||||
const auto dimShape = stp.getShape();
|
||||
SmallVector<Value> dimShapeValues;
|
||||
sizesFromType(rewriter, dimShapeValues, loc, stp);
|
||||
SmallVector<Value> dimShapeValues = getDimShape(rewriter, loc, stp);
|
||||
Value dimShapeBuffer = genBuffer(rewriter, loc, dimShapeValues);
|
||||
// Allocate `SparseTensorReader` and perform all initial setup that
|
||||
// does not depend on lvlSizes (nor dim2lvl, lvl2dim, etc).
|
||||
@@ -890,10 +934,10 @@ public:
|
||||
rewriter.replaceOp(op, adaptor.getOperands()); // hidden nop cast
|
||||
return success();
|
||||
}
|
||||
SmallVector<Value> sizes;
|
||||
NewCallParams params(rewriter, loc);
|
||||
ShapedType stp = srcType.cast<ShapedType>();
|
||||
sizesFromPtr(rewriter, sizes, loc, encSrc, stp, src);
|
||||
SmallVector<Value> dimSizes =
|
||||
getDimSizes(rewriter, loc, encSrc, stp, src);
|
||||
bool useDirectConversion;
|
||||
switch (options.sparseToSparseStrategy) {
|
||||
case SparseToSparseConversionStrategy::kViaCOO:
|
||||
@@ -909,7 +953,7 @@ public:
|
||||
break;
|
||||
}
|
||||
if (useDirectConversion) {
|
||||
rewriter.replaceOp(op, params.genBuffers(encDst, sizes, stp)
|
||||
rewriter.replaceOp(op, params.genBuffers(encDst, dimSizes, stp)
|
||||
.genNewCall(Action::kSparseToSparse, src));
|
||||
} else { // use via-COO conversion.
|
||||
// Set up encoding with right mix of src and dst so that the two
|
||||
@@ -922,8 +966,8 @@ public:
|
||||
// TODO: This is the only place where `kToCOO` (or `kToIterator`)
|
||||
// is called with a non-identity permutation. Is there any clean
|
||||
// way to push the permutation over to the `kFromCOO` side instead?
|
||||
Value coo =
|
||||
params.genBuffers(enc, sizes, stp).genNewCall(Action::kToCOO, src);
|
||||
Value coo = params.genBuffers(enc, dimSizes, stp)
|
||||
.genNewCall(Action::kToCOO, src);
|
||||
Value dst = params.setTemplateTypes(encDst, stp)
|
||||
.genNewCall(Action::kFromCOO, coo);
|
||||
genDelCOOCall(rewriter, loc, stp.getElementType(), coo);
|
||||
@@ -950,17 +994,17 @@ public:
|
||||
op->getContext(),
|
||||
SmallVector<DimLevelType>(rank, DimLevelType::Dense), AffineMap(),
|
||||
AffineMap(), encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
|
||||
SmallVector<Value> sizes;
|
||||
sizesFromPtr(rewriter, sizes, loc, encSrc, srcTensorTp, src);
|
||||
SmallVector<Value> dimSizes =
|
||||
getDimSizes(rewriter, loc, encSrc, srcTensorTp, src);
|
||||
Value iter = NewCallParams(rewriter, loc)
|
||||
.genBuffers(encDst, sizes, dstTensorTp)
|
||||
.genBuffers(encDst, dimSizes, dstTensorTp)
|
||||
.genNewCall(Action::kToIterator, src);
|
||||
Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
|
||||
Value elemPtr = genAllocaScalar(rewriter, loc, elemTp);
|
||||
Block *insertionBlock = rewriter.getInsertionBlock();
|
||||
// TODO: Dense buffers should be allocated/deallocated via the callback
|
||||
// in BufferizationOptions.
|
||||
Value dst = allocDenseTensor(rewriter, loc, dstTensorTp, sizes);
|
||||
Value dst = allocDenseTensor(rewriter, loc, dstTensorTp, dimSizes);
|
||||
SmallVector<Value> noArgs;
|
||||
SmallVector<Type> noTypes;
|
||||
auto whileOp = rewriter.create<scf::WhileOp>(loc, noTypes, noArgs);
|
||||
@@ -1196,12 +1240,12 @@ public:
|
||||
Type idxType = rewriter.getIndexType();
|
||||
// All initialization should be done on entry of the loop nest.
|
||||
rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp());
|
||||
// Determine the size for access expansion (always the innermost stored
|
||||
// dimension size, translated back to original dimension).
|
||||
auto enc = getSparseTensorEncoding(srcType);
|
||||
unsigned innerDim = toOrigDim(srcType, srcType.getRank() - 1);
|
||||
auto sz = sizeFromPtrAtDim(rewriter, loc, enc, srcType, adaptor.getTensor(),
|
||||
innerDim);
|
||||
// Get the cardinality of valid coordinates for the innermost level.
|
||||
auto srcEnc = getSparseTensorEncoding(srcType);
|
||||
unsigned lvlRank =
|
||||
srcEnc ? srcEnc.getDimLevelType().size() : srcType.getRank();
|
||||
Value sz = createOrFoldLvlCall(rewriter, loc, srcEnc, srcType,
|
||||
adaptor.getTensor(), lvlRank - 1);
|
||||
// Allocate temporary buffers for values, filled-switch, and indices.
|
||||
// We do not use stack buffers for this, since the expanded size may
|
||||
// be rather large (as it envelops a single expanded dense dimension).
|
||||
@@ -1377,10 +1421,8 @@ public:
|
||||
}
|
||||
// Accumulate offset.
|
||||
// TODO: avoid calling sparseDimSize multiple times by caching the result!
|
||||
Value curDim = encSrc ? sizeFromPtrAtDim(rewriter, loc, encSrc, srcTp,
|
||||
adaptedOp, concatDim)
|
||||
: linalg::createOrFoldDimOp(rewriter, loc,
|
||||
adaptedOp, concatDim);
|
||||
Value curDim = createOrFoldDimCall(rewriter, loc, encSrc, srcTp,
|
||||
adaptedOp, concatDim);
|
||||
|
||||
offset = rewriter.create<arith::AddIOp>(loc, offset, curDim);
|
||||
}
|
||||
@@ -1410,13 +1452,13 @@ public:
|
||||
// Convert to default permuted COO.
|
||||
Value src = adaptor.getOperands()[0];
|
||||
auto encSrc = getSparseTensorEncoding(srcType);
|
||||
SmallVector<Value> sizes;
|
||||
sizesFromPtr(rewriter, sizes, loc, encSrc, srcType, src);
|
||||
SmallVector<Value> dimSizes =
|
||||
getDimSizes(rewriter, loc, encSrc, srcType, src);
|
||||
auto enc = SparseTensorEncodingAttr::get(
|
||||
op->getContext(), encSrc.getDimLevelType(), AffineMap(), AffineMap(),
|
||||
encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
|
||||
Value coo = NewCallParams(rewriter, loc)
|
||||
.genBuffers(enc, sizes, srcType)
|
||||
.genBuffers(enc, dimSizes, srcType)
|
||||
.genNewCall(Action::kToCOO, src);
|
||||
// Then output the tensor to external file with indices in the externally
|
||||
// visible lexicographic index order. A sort is required if the source was
|
||||
|
||||
Reference in New Issue
Block a user