[mlir][sparse] Introducing options for the SparseTensorConversion pass

This is work towards: https://github.com/llvm/llvm-project/issues/51652

This differential sets up the options and threads them through everywhere, but doesn't actually use them yet.  The differential that finally makes use of them is D122061, which is the final differential in the chain that fixes bug 51652.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D122054
This commit is contained in:
wren romano
2022-03-18 19:10:40 -07:00
parent 110295ebb7
commit c7e24db412
6 changed files with 124 additions and 15 deletions

View File

@@ -73,6 +73,13 @@ public:
struct SparseTensorConversionPass
: public SparseTensorConversionBase<SparseTensorConversionPass> {
SparseTensorConversionPass() = default;
SparseTensorConversionPass(const SparseTensorConversionPass &pass) = default;
SparseTensorConversionPass(const SparseTensorConversionOptions &options) {
sparseToSparse = static_cast<int32_t>(options.sparseToSparseStrategy);
}
void runOnOperation() override {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
@@ -106,11 +113,14 @@ struct SparseTensorConversionPass
target
.addLegalDialect<bufferization::BufferizationDialect, LLVM::LLVMDialect,
memref::MemRefDialect, scf::SCFDialect>();
// Translate strategy flags to strategy options.
SparseTensorConversionOptions options(
sparseToSparseConversionStrategy(sparseToSparse));
// Populate with rules and apply rewriting rules.
populateFunctionOpInterfaceTypeConversionPattern<FuncOp>(patterns,
converter);
populateCallOpTypeConversionPattern(patterns, converter);
populateSparseTensorConversionPatterns(converter, patterns);
populateSparseTensorConversionPatterns(converter, patterns, options);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
@@ -146,6 +156,18 @@ SparseVectorizationStrategy mlir::sparseVectorizationStrategy(int32_t flag) {
}
}
SparseToSparseConversionStrategy
mlir::sparseToSparseConversionStrategy(int32_t flag) {
switch (flag) {
default:
return SparseToSparseConversionStrategy::kAuto;
case 1:
return SparseToSparseConversionStrategy::kViaCOO;
case 2:
return SparseToSparseConversionStrategy::kDirect;
}
}
std::unique_ptr<Pass> mlir::createSparsificationPass() {
return std::make_unique<SparsificationPass>();
}
@@ -158,3 +180,8 @@ mlir::createSparsificationPass(const SparsificationOptions &options) {
std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() {
return std::make_unique<SparseTensorConversionPass>();
}
std::unique_ptr<Pass> mlir::createSparseTensorConversionPass(
const SparseTensorConversionOptions &options) {
return std::make_unique<SparseTensorConversionPass>(options);
}