[mlir][sparse] Introducing options for the SparseTensorConversion pass
This is work towards: https://github.com/llvm/llvm-project/issues/51652 This differential sets up the options and threads them through everywhere, but doesn't actually use them yet. The differential that finally makes use of them is D122061, which is the final differential in the chain that fixes bug 51652. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D122054
This commit is contained in:
@@ -73,6 +73,13 @@ public:
|
||||
|
||||
struct SparseTensorConversionPass
|
||||
: public SparseTensorConversionBase<SparseTensorConversionPass> {
|
||||
|
||||
SparseTensorConversionPass() = default;
|
||||
SparseTensorConversionPass(const SparseTensorConversionPass &pass) = default;
|
||||
SparseTensorConversionPass(const SparseTensorConversionOptions &options) {
|
||||
sparseToSparse = static_cast<int32_t>(options.sparseToSparseStrategy);
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
auto *ctx = &getContext();
|
||||
RewritePatternSet patterns(ctx);
|
||||
@@ -106,11 +113,14 @@ struct SparseTensorConversionPass
|
||||
target
|
||||
.addLegalDialect<bufferization::BufferizationDialect, LLVM::LLVMDialect,
|
||||
memref::MemRefDialect, scf::SCFDialect>();
|
||||
// Translate strategy flags to strategy options.
|
||||
SparseTensorConversionOptions options(
|
||||
sparseToSparseConversionStrategy(sparseToSparse));
|
||||
// Populate with rules and apply rewriting rules.
|
||||
populateFunctionOpInterfaceTypeConversionPattern<FuncOp>(patterns,
|
||||
converter);
|
||||
populateCallOpTypeConversionPattern(patterns, converter);
|
||||
populateSparseTensorConversionPatterns(converter, patterns);
|
||||
populateSparseTensorConversionPatterns(converter, patterns, options);
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
signalPassFailure();
|
||||
@@ -146,6 +156,18 @@ SparseVectorizationStrategy mlir::sparseVectorizationStrategy(int32_t flag) {
|
||||
}
|
||||
}
|
||||
|
||||
SparseToSparseConversionStrategy
|
||||
mlir::sparseToSparseConversionStrategy(int32_t flag) {
|
||||
switch (flag) {
|
||||
default:
|
||||
return SparseToSparseConversionStrategy::kAuto;
|
||||
case 1:
|
||||
return SparseToSparseConversionStrategy::kViaCOO;
|
||||
case 2:
|
||||
return SparseToSparseConversionStrategy::kDirect;
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> mlir::createSparsificationPass() {
|
||||
return std::make_unique<SparsificationPass>();
|
||||
}
|
||||
@@ -158,3 +180,8 @@ mlir::createSparsificationPass(const SparsificationOptions &options) {
|
||||
std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() {
|
||||
return std::make_unique<SparseTensorConversionPass>();
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> mlir::createSparseTensorConversionPass(
|
||||
const SparseTensorConversionOptions &options) {
|
||||
return std::make_unique<SparseTensorConversionPass>(options);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user