[mlir][sparse] Factoring out NewCallParams class in SparseTensorConversion.cpp
The new class helps encapsulate the arguments to `_mlir_ciface_newSparseTensor` so that client code doesn't depend on the details of the API. (This makes way for the next differential which significantly alters the API.) Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D137680
This commit is contained in:
@@ -70,16 +70,6 @@ static Value genDimSizeCall(OpBuilder &builder, Location loc,
|
||||
.getResult(0);
|
||||
}
|
||||
|
||||
/// Generates a call into the "swiss army knife" method of the sparse runtime
|
||||
/// support library for materializing sparse tensors into the computation.
|
||||
static Value genNewCall(OpBuilder &builder, Location loc,
|
||||
ArrayRef<Value> params) {
|
||||
StringRef name = "newSparseTensor";
|
||||
Type pTp = getOpaquePointerType(builder);
|
||||
return createFuncCall(builder, loc, name, pTp, params, EmitCInterface::On)
|
||||
.getResult(0);
|
||||
}
|
||||
|
||||
/// Compute the size from type (for static sizes) or from an already-converted
|
||||
/// opaque pointer source (for dynamic sizes) at the given dimension.
|
||||
static Value sizeFromPtrAtDim(OpBuilder &builder, Location loc,
|
||||
@@ -168,41 +158,132 @@ static Value genBuffer(OpBuilder &builder, Location loc, ValueRange values) {
|
||||
return buffer;
|
||||
}
|
||||
|
||||
/// Populates parameters required to call the "swiss army knife" method of the
|
||||
/// sparse runtime support library for materializing sparse tensors into the
|
||||
/// computation.
|
||||
static void newParams(OpBuilder &builder, SmallVector<Value, 8> ¶ms,
|
||||
Location loc, ShapedType stp,
|
||||
SparseTensorEncodingAttr &enc, Action action,
|
||||
ValueRange szs, Value ptr = Value()) {
|
||||
ArrayRef<DimLevelType> dlt = enc.getDimLevelType();
|
||||
unsigned sz = dlt.size();
|
||||
/// This class abstracts over the API of `_mlir_ciface_newSparseTensor`:
|
||||
/// the "swiss army knife" method of the sparse runtime support library
|
||||
/// for materializing sparse tensors into the computation. This abstraction
|
||||
/// reduces the need to make modifications to client code whenever that
|
||||
/// API changes.
|
||||
class NewCallParams final {
|
||||
public:
|
||||
/// Allocates the `ValueRange` for the `func::CallOp` parameters,
|
||||
/// but does not initialize them.
|
||||
NewCallParams(OpBuilder &builder, Location loc)
|
||||
: builder(builder), loc(loc), pTp(getOpaquePointerType(builder)) {}
|
||||
|
||||
/// Initializes all static parameters (i.e., those which indicate
|
||||
/// type-level information such as the encoding and sizes), generating
|
||||
/// MLIR buffers as needed, and returning `this` for method chaining.
|
||||
/// This method does not set the action and pointer arguments, since
|
||||
/// those are handled by `genNewCall` instead.
|
||||
NewCallParams &genBuffers(SparseTensorEncodingAttr enc, ValueRange sizes,
|
||||
ShapedType stp);
|
||||
|
||||
/// (Re)sets the C++ template type parameters, and returns `this`
|
||||
/// for method chaining. This is already done as part of `genBuffers`,
|
||||
/// but is factored out so that it can also be called independently
|
||||
/// whenever subsequent `genNewCall` calls want to reuse the same
|
||||
/// buffers but different type parameters.
|
||||
//
|
||||
// TODO: This is only ever used by sparse2sparse-viaCOO `ConvertOp`;
|
||||
// is there a better way to handle that than this one-off setter method?
|
||||
NewCallParams &setTemplateTypes(SparseTensorEncodingAttr enc,
|
||||
ShapedType stp) {
|
||||
params[kParamPtrTp] = constantPointerTypeEncoding(builder, loc, enc);
|
||||
params[kParamIndTp] = constantIndexTypeEncoding(builder, loc, enc);
|
||||
params[kParamValTp] =
|
||||
constantPrimaryTypeEncoding(builder, loc, stp.getElementType());
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Checks whether all the static parameters have been initialized.
|
||||
bool isInitialized() const {
|
||||
for (unsigned i = 0; i < kNumStaticParams; ++i)
|
||||
if (!params[i])
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Gets the dimension-to-level mapping.
|
||||
//
|
||||
// TODO: This is only ever used for passing into `genAddEltCall`;
|
||||
// is there a better way to encapsulate that pattern (both to avoid
|
||||
// this one-off getter, and to avoid potential mixups)?
|
||||
Value getDim2LvlMap() const {
|
||||
assert(isInitialized() && "Must initialize before getDim2LvlMap");
|
||||
return params[kParamDim2Lvl];
|
||||
}
|
||||
|
||||
/// Generates a function call, with the current static parameters
|
||||
/// and the given dynamic arguments.
|
||||
Value genNewCall(Action action, Value ptr = Value()) {
|
||||
assert(isInitialized() && "Must initialize before genNewCall");
|
||||
StringRef name = "newSparseTensor";
|
||||
params[kParamAction] = constantAction(builder, loc, action);
|
||||
params[kParamPtr] = ptr ? ptr : builder.create<LLVM::NullOp>(loc, pTp);
|
||||
return createFuncCall(builder, loc, name, pTp, params, EmitCInterface::On)
|
||||
.getResult(0);
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr unsigned kNumStaticParams = 6;
|
||||
static constexpr unsigned kNumDynamicParams = 2;
|
||||
static constexpr unsigned kNumParams = kNumStaticParams + kNumDynamicParams;
|
||||
static constexpr unsigned kParamLvlTypes = 0;
|
||||
static constexpr unsigned kParamDimSizes = 1;
|
||||
static constexpr unsigned kParamDim2Lvl = 2;
|
||||
static constexpr unsigned kParamPtrTp = 3;
|
||||
static constexpr unsigned kParamIndTp = 4;
|
||||
static constexpr unsigned kParamValTp = 5;
|
||||
static constexpr unsigned kParamAction = 6;
|
||||
static constexpr unsigned kParamPtr = 7;
|
||||
|
||||
OpBuilder &builder;
|
||||
Location loc;
|
||||
Type pTp;
|
||||
Value params[kNumParams];
|
||||
};
|
||||
|
||||
// TODO: see the note at `_mlir_ciface_newSparseTensor` about how
|
||||
// the meaning of the various arguments (e.g., "sizes" vs "shapes")
|
||||
// is inconsistent between the different actions.
|
||||
NewCallParams &NewCallParams::genBuffers(SparseTensorEncodingAttr enc,
|
||||
ValueRange dimSizes, ShapedType stp) {
|
||||
const unsigned lvlRank = enc.getDimLevelType().size();
|
||||
const unsigned dimRank = stp.getRank();
|
||||
// Sparsity annotations.
|
||||
SmallVector<Value, 4> attrs;
|
||||
for (unsigned i = 0; i < sz; i++)
|
||||
attrs.push_back(constantDimLevelTypeEncoding(builder, loc, dlt[i]));
|
||||
params.push_back(genBuffer(builder, loc, attrs));
|
||||
// Dimension sizes array of the enveloping tensor. Useful for either
|
||||
SmallVector<Value, 4> lvlTypes;
|
||||
for (auto dlt : enc.getDimLevelType())
|
||||
lvlTypes.push_back(constantDimLevelTypeEncoding(builder, loc, dlt));
|
||||
assert(lvlTypes.size() == lvlRank && "Level-rank mismatch");
|
||||
params[kParamLvlTypes] = genBuffer(builder, loc, lvlTypes);
|
||||
// Dimension-sizes array of the enveloping tensor. Useful for either
|
||||
// verification of external data, or for construction of internal data.
|
||||
params.push_back(genBuffer(builder, loc, szs));
|
||||
// Dimension order permutation array. This is the "identity" permutation by
|
||||
// default, or otherwise the "reverse" permutation of a given ordering, so
|
||||
// that indices can be mapped quickly to the right position.
|
||||
SmallVector<Value, 4> rev(sz);
|
||||
for (unsigned i = 0; i < sz; i++)
|
||||
rev[toOrigDim(enc, i)] = constantIndex(builder, loc, i);
|
||||
params.push_back(genBuffer(builder, loc, rev));
|
||||
assert(dimSizes.size() == dimRank && "Dimension-rank mismatch");
|
||||
params[kParamDimSizes] = genBuffer(builder, loc, dimSizes);
|
||||
// The dimension-to-level mapping. We must preinitialize `dim2lvl`
|
||||
// so that the true branch below can perform random-access `operator[]`
|
||||
// assignment.
|
||||
SmallVector<Value, 4> dim2lvl(dimRank);
|
||||
auto dimOrder = enc.getDimOrdering();
|
||||
if (dimOrder) {
|
||||
assert(dimOrder.isPermutation());
|
||||
for (unsigned l = 0; l < lvlRank; l++) {
|
||||
// The `d`th source variable occurs in the `l`th result position.
|
||||
uint64_t d = dimOrder.getDimPosition(l);
|
||||
dim2lvl[d] = constantIndex(builder, loc, l);
|
||||
}
|
||||
} else {
|
||||
assert(dimRank == lvlRank && "Rank mismatch");
|
||||
for (unsigned i = 0; i < lvlRank; i++)
|
||||
dim2lvl[i] = constantIndex(builder, loc, i);
|
||||
}
|
||||
params[kParamDim2Lvl] = genBuffer(builder, loc, dim2lvl);
|
||||
// Secondary and primary types encoding.
|
||||
Type elemTp = stp.getElementType();
|
||||
params.push_back(constantPointerTypeEncoding(builder, loc, enc));
|
||||
params.push_back(constantIndexTypeEncoding(builder, loc, enc));
|
||||
params.push_back(constantPrimaryTypeEncoding(builder, loc, elemTp));
|
||||
// User action.
|
||||
params.push_back(constantAction(builder, loc, action));
|
||||
// Payload pointer.
|
||||
if (!ptr)
|
||||
ptr = builder.create<LLVM::NullOp>(loc, getOpaquePointerType(builder));
|
||||
params.push_back(ptr);
|
||||
setTemplateTypes(enc, stp);
|
||||
// Finally, make note that initialization is complete.
|
||||
assert(isInitialized() && "Initialization failed");
|
||||
// And return `this` for method chaining.
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Generates a call to obtain the values array.
|
||||
@@ -387,14 +468,12 @@ genSparse2SparseReshape(ReshapeOp op, typename ReshapeOp::Adaptor adaptor,
|
||||
op->getContext(), encSrc.getDimLevelType(), AffineMap(), AffineMap(),
|
||||
encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
|
||||
SmallVector<Value, 4> srcSizes;
|
||||
SmallVector<Value, 8> params;
|
||||
sizesFromPtr(rewriter, srcSizes, loc, encSrc, srcTp, adaptor.getSrc());
|
||||
newParams(rewriter, params, loc, srcTp, noPerm, Action::kToIterator, srcSizes,
|
||||
adaptor.getSrc());
|
||||
Value iter = genNewCall(rewriter, loc, params);
|
||||
NewCallParams params(rewriter, loc);
|
||||
Value iter = params.genBuffers(noPerm, srcSizes, srcTp)
|
||||
.genNewCall(Action::kToIterator, adaptor.getSrc());
|
||||
// Start a new COO for the destination tensor.
|
||||
SmallVector<Value, 4> dstSizes;
|
||||
params.clear();
|
||||
if (dstTp.hasStaticShape()) {
|
||||
sizesFromType(rewriter, dstSizes, loc, dstTp);
|
||||
} else {
|
||||
@@ -402,9 +481,9 @@ genSparse2SparseReshape(ReshapeOp op, typename ReshapeOp::Adaptor adaptor,
|
||||
genReshapeDstShape(loc, rewriter, dstSizes, srcSizes, dstShape,
|
||||
op.getReassociationIndices());
|
||||
}
|
||||
newParams(rewriter, params, loc, dstTp, encDst, Action::kEmptyCOO, dstSizes);
|
||||
Value coo = genNewCall(rewriter, loc, params);
|
||||
Value dstPerm = params[2];
|
||||
Value coo =
|
||||
params.genBuffers(encDst, dstSizes, dstTp).genNewCall(Action::kEmptyCOO);
|
||||
Value dstPerm = params.getDim2LvlMap();
|
||||
// Construct a while loop over the iterator.
|
||||
Value srcIdx = genAlloca(rewriter, loc, srcRank, rewriter.getIndexType());
|
||||
Value dstIdx = genAlloca(rewriter, loc, dstRank, rewriter.getIndexType());
|
||||
@@ -426,9 +505,7 @@ genSparse2SparseReshape(ReshapeOp op, typename ReshapeOp::Adaptor adaptor,
|
||||
rewriter.create<scf::YieldOp>(loc);
|
||||
// Final call to construct sparse tensor storage and free temporary resources.
|
||||
rewriter.setInsertionPointAfter(whileOp);
|
||||
params[6] = constantAction(rewriter, loc, Action::kFromCOO);
|
||||
params[7] = coo;
|
||||
Value dst = genNewCall(rewriter, loc, params);
|
||||
Value dst = params.genNewCall(Action::kFromCOO, coo);
|
||||
genDelCOOCall(rewriter, loc, elemTp, coo);
|
||||
genDelIteratorCall(rewriter, loc, elemTp, iter);
|
||||
rewriter.replaceOp(op, dst);
|
||||
@@ -458,11 +535,10 @@ static void genSparseCOOIterationLoop(
|
||||
rewriter.getContext(), enc.getDimLevelType(), AffineMap(), AffineMap(),
|
||||
enc.getPointerBitWidth(), enc.getIndexBitWidth());
|
||||
SmallVector<Value, 4> sizes;
|
||||
SmallVector<Value, 8> params;
|
||||
sizesFromPtr(rewriter, sizes, loc, noPerm, tensorTp, t);
|
||||
newParams(rewriter, params, loc, tensorTp, noPerm, Action::kToIterator, sizes,
|
||||
t);
|
||||
Value iter = genNewCall(rewriter, loc, params);
|
||||
Value iter = NewCallParams(rewriter, loc)
|
||||
.genBuffers(noPerm, sizes, tensorTp)
|
||||
.genNewCall(Action::kToIterator, t);
|
||||
|
||||
// Construct a while loop over the iterator.
|
||||
Value srcIdx = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
|
||||
@@ -611,12 +687,12 @@ public:
|
||||
// Generate the call to construct tensor from ptr. The sizes are
|
||||
// inferred from the result type of the new operator.
|
||||
SmallVector<Value, 4> sizes;
|
||||
SmallVector<Value, 8> params;
|
||||
ShapedType stp = resType.cast<ShapedType>();
|
||||
sizesFromType(rewriter, sizes, loc, stp);
|
||||
Value ptr = adaptor.getOperands()[0];
|
||||
newParams(rewriter, params, loc, stp, enc, Action::kFromFile, sizes, ptr);
|
||||
rewriter.replaceOp(op, genNewCall(rewriter, loc, params));
|
||||
rewriter.replaceOp(op, NewCallParams(rewriter, loc)
|
||||
.genBuffers(enc, sizes, stp)
|
||||
.genNewCall(Action::kFromFile, ptr));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -650,10 +726,10 @@ public:
|
||||
}
|
||||
// Generate the call to construct empty tensor. The sizes are
|
||||
// explicitly defined by the arguments to the alloc operator.
|
||||
SmallVector<Value, 8> params;
|
||||
ShapedType stp = resType.cast<ShapedType>();
|
||||
newParams(rewriter, params, loc, stp, enc, Action::kEmpty, sizes);
|
||||
rewriter.replaceOp(op, genNewCall(rewriter, loc, params));
|
||||
rewriter.replaceOp(op,
|
||||
NewCallParams(rewriter, loc)
|
||||
.genBuffers(enc, sizes, resType.cast<ShapedType>())
|
||||
.genNewCall(Action::kEmpty));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -690,7 +766,7 @@ public:
|
||||
return success();
|
||||
}
|
||||
SmallVector<Value, 4> sizes;
|
||||
SmallVector<Value, 8> params;
|
||||
NewCallParams params(rewriter, loc);
|
||||
ShapedType stp = srcType.cast<ShapedType>();
|
||||
sizesFromPtr(rewriter, sizes, loc, encSrc, stp, src);
|
||||
bool useDirectConversion;
|
||||
@@ -708,9 +784,8 @@ public:
|
||||
break;
|
||||
}
|
||||
if (useDirectConversion) {
|
||||
newParams(rewriter, params, loc, stp, encDst, Action::kSparseToSparse,
|
||||
sizes, src);
|
||||
rewriter.replaceOp(op, genNewCall(rewriter, loc, params));
|
||||
rewriter.replaceOp(op, params.genBuffers(encDst, sizes, stp)
|
||||
.genNewCall(Action::kSparseToSparse, src));
|
||||
} else { // use via-COO conversion.
|
||||
// Set up encoding with right mix of src and dst so that the two
|
||||
// method calls can share most parameters, while still providing
|
||||
@@ -719,13 +794,13 @@ public:
|
||||
op->getContext(), encDst.getDimLevelType(), encDst.getDimOrdering(),
|
||||
encDst.getHigherOrdering(), encSrc.getPointerBitWidth(),
|
||||
encSrc.getIndexBitWidth());
|
||||
newParams(rewriter, params, loc, stp, enc, Action::kToCOO, sizes, src);
|
||||
Value coo = genNewCall(rewriter, loc, params);
|
||||
params[3] = constantPointerTypeEncoding(rewriter, loc, encDst);
|
||||
params[4] = constantIndexTypeEncoding(rewriter, loc, encDst);
|
||||
params[6] = constantAction(rewriter, loc, Action::kFromCOO);
|
||||
params[7] = coo;
|
||||
Value dst = genNewCall(rewriter, loc, params);
|
||||
// TODO: This is the only place where `kToCOO` (or `kToIterator`)
|
||||
// is called with a non-identity permutation. Is there any clean
|
||||
// way to push the permutation over to the `kFromCOO` side instead?
|
||||
Value coo =
|
||||
params.genBuffers(enc, sizes, stp).genNewCall(Action::kToCOO, src);
|
||||
Value dst = params.setTemplateTypes(encDst, stp)
|
||||
.genNewCall(Action::kFromCOO, coo);
|
||||
genDelCOOCall(rewriter, loc, stp.getElementType(), coo);
|
||||
rewriter.replaceOp(op, dst);
|
||||
}
|
||||
@@ -743,7 +818,7 @@ public:
|
||||
RankedTensorType srcTensorTp = srcType.cast<RankedTensorType>();
|
||||
unsigned rank = dstTensorTp.getRank();
|
||||
Type elemTp = dstTensorTp.getElementType();
|
||||
// Fabricate a no-permutation encoding for newParams().
|
||||
// Fabricate a no-permutation encoding for NewCallParams
|
||||
// The pointer/index types must be those of `src`.
|
||||
// The dimLevelTypes aren't actually used by Action::kToIterator.
|
||||
encDst = SparseTensorEncodingAttr::get(
|
||||
@@ -751,11 +826,10 @@ public:
|
||||
SmallVector<DimLevelType>(rank, DimLevelType::Dense), AffineMap(),
|
||||
AffineMap(), encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
|
||||
SmallVector<Value, 4> sizes;
|
||||
SmallVector<Value, 8> params;
|
||||
sizesFromPtr(rewriter, sizes, loc, encSrc, srcTensorTp, src);
|
||||
newParams(rewriter, params, loc, dstTensorTp, encDst, Action::kToIterator,
|
||||
sizes, src);
|
||||
Value iter = genNewCall(rewriter, loc, params);
|
||||
Value iter = NewCallParams(rewriter, loc)
|
||||
.genBuffers(encDst, sizes, dstTensorTp)
|
||||
.genNewCall(Action::kToIterator, src);
|
||||
Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
|
||||
Value elemPtr = genAllocaScalar(rewriter, loc, elemTp);
|
||||
Block *insertionBlock = rewriter.getInsertionBlock();
|
||||
@@ -817,12 +891,12 @@ public:
|
||||
ShapedType stp = resType.cast<ShapedType>();
|
||||
unsigned rank = stp.getRank();
|
||||
SmallVector<Value, 4> sizes;
|
||||
SmallVector<Value, 8> params;
|
||||
sizesFromSrc(rewriter, sizes, loc, src);
|
||||
newParams(rewriter, params, loc, stp, encDst, Action::kEmptyCOO, sizes);
|
||||
Value coo = genNewCall(rewriter, loc, params);
|
||||
NewCallParams params(rewriter, loc);
|
||||
Value coo =
|
||||
params.genBuffers(encDst, sizes, stp).genNewCall(Action::kEmptyCOO);
|
||||
Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
|
||||
Value perm = params[2];
|
||||
Value perm = params.getDim2LvlMap();
|
||||
Type eltType = stp.getElementType();
|
||||
Value elemPtr = genAllocaScalar(rewriter, loc, eltType);
|
||||
genDenseTensorOrSparseConstantIterLoop(
|
||||
@@ -836,9 +910,7 @@ public:
|
||||
genAddEltCall(builder, loc, eltType, coo, elemPtr, ind, perm);
|
||||
});
|
||||
// Final call to construct sparse tensor storage.
|
||||
params[6] = constantAction(rewriter, loc, Action::kFromCOO);
|
||||
params[7] = coo;
|
||||
Value dst = genNewCall(rewriter, loc, params);
|
||||
Value dst = params.genNewCall(Action::kFromCOO, coo);
|
||||
genDelCOOCall(rewriter, loc, eltType, coo);
|
||||
rewriter.replaceOp(op, dst);
|
||||
return success();
|
||||
@@ -1117,15 +1189,15 @@ public:
|
||||
Value offset = constantIndex(rewriter, loc, 0);
|
||||
|
||||
SmallVector<Value, 4> sizes;
|
||||
SmallVector<Value, 8> params;
|
||||
NewCallParams params(rewriter, loc);
|
||||
concatSizesFromInputs(rewriter, sizes, loc, dstTp, op.getInputs(),
|
||||
concatDim);
|
||||
|
||||
if (encDst) {
|
||||
// Start a new COO for the destination tensor.
|
||||
newParams(rewriter, params, loc, dstTp, encDst, Action::kEmptyCOO, sizes);
|
||||
dst = genNewCall(rewriter, loc, params);
|
||||
dstPerm = params[2];
|
||||
dst =
|
||||
params.genBuffers(encDst, sizes, dstTp).genNewCall(Action::kEmptyCOO);
|
||||
dstPerm = params.getDim2LvlMap();
|
||||
elemPtr = genAllocaScalar(rewriter, loc, elemTp);
|
||||
dstIdx = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
|
||||
} else {
|
||||
@@ -1188,11 +1260,9 @@ public:
|
||||
offset = rewriter.create<arith::AddIOp>(loc, offset, curDim);
|
||||
}
|
||||
if (encDst) {
|
||||
params[6] = constantAction(rewriter, loc, Action::kFromCOO);
|
||||
// In sparse output case, the destination holds the COO.
|
||||
Value coo = dst;
|
||||
params[7] = coo;
|
||||
dst = genNewCall(rewriter, loc, params);
|
||||
dst = params.genNewCall(Action::kFromCOO, coo);
|
||||
// Release resources.
|
||||
genDelCOOCall(rewriter, loc, elemTp, coo);
|
||||
rewriter.replaceOp(op, dst);
|
||||
@@ -1216,27 +1286,25 @@ public:
|
||||
Value src = adaptor.getOperands()[0];
|
||||
auto encSrc = getSparseTensorEncoding(srcType);
|
||||
SmallVector<Value, 4> sizes;
|
||||
SmallVector<Value, 8> params;
|
||||
sizesFromPtr(rewriter, sizes, loc, encSrc, srcType, src);
|
||||
auto enc = SparseTensorEncodingAttr::get(
|
||||
op->getContext(), encSrc.getDimLevelType(), AffineMap(), AffineMap(),
|
||||
encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
|
||||
newParams(rewriter, params, loc, srcType, enc, Action::kToCOO, sizes, src);
|
||||
Value coo = genNewCall(rewriter, loc, params);
|
||||
Value coo = NewCallParams(rewriter, loc)
|
||||
.genBuffers(enc, sizes, srcType)
|
||||
.genNewCall(Action::kToCOO, src);
|
||||
// Then output the tensor to external file with indices in the externally
|
||||
// visible lexicographic index order. A sort is required if the source was
|
||||
// not in that order yet (note that the sort can be dropped altogether if
|
||||
// external format does not care about the order at all, but here we assume
|
||||
// it does).
|
||||
bool sort =
|
||||
encSrc.getDimOrdering() && !encSrc.getDimOrdering().isIdentity();
|
||||
params.clear();
|
||||
params.push_back(coo);
|
||||
params.push_back(adaptor.getOperands()[1]);
|
||||
params.push_back(constantI1(rewriter, loc, sort));
|
||||
Value sort = constantI1(rewriter, loc,
|
||||
encSrc.getDimOrdering() &&
|
||||
!encSrc.getDimOrdering().isIdentity());
|
||||
SmallVector<Value, 3> outParams{coo, adaptor.getOperands()[1], sort};
|
||||
Type eltType = srcType.getElementType();
|
||||
SmallString<18> name{"outSparseTensor", primaryTypeFunctionSuffix(eltType)};
|
||||
createFuncCall(rewriter, loc, name, {}, params, EmitCInterface::Off);
|
||||
createFuncCall(rewriter, loc, name, {}, outParams, EmitCInterface::Off);
|
||||
genDelCOOCall(rewriter, loc, eltType, coo);
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
|
||||
Reference in New Issue
Block a user