[mlir][sparse] introduce vectorization pass for sparse loops
This brings back previous SIMD functionality, but in a separate pass. The idea is to improve this new pass incrementally, going beyond for-loops to while-loops for co-iteration as welll (masking), while introducing new abstractions to make the lowering more progressive. The separation of sparsification and vectorization is a very good first step on this journey. Also brings back ArmSVE support Still to be fine-tuned: + use of "index" in SIMD loop (viz. a[i] = i) + check that all ops really have SIMD support + check all forms of reductions + chain reduction SIMD values Reviewed By: dcaballe Differential Revision: https://reviews.llvm.org/D138236
This commit is contained in:
@@ -27,6 +27,7 @@ namespace mlir {
|
||||
#define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
|
||||
#define GEN_PASS_DEF_SPARSETENSORCODEGEN
|
||||
#define GEN_PASS_DEF_SPARSEBUFFERREWRITE
|
||||
#define GEN_PASS_DEF_SPARSEVECTORIZATION
|
||||
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
|
||||
} // namespace mlir
|
||||
|
||||
@@ -67,10 +68,9 @@ struct SparsificationPass
|
||||
auto *ctx = &getContext();
|
||||
// Translate strategy flags to strategy options.
|
||||
SparsificationOptions options(parallelization);
|
||||
// Apply sparsification and vector cleanup rewriting.
|
||||
// Apply sparsification and cleanup rewriting.
|
||||
RewritePatternSet patterns(ctx);
|
||||
populateSparsificationPatterns(patterns, options);
|
||||
vector::populateVectorToVectorCanonicalizationPatterns(patterns);
|
||||
scf::ForOp::getCanonicalizationPatterns(patterns, ctx);
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
}
|
||||
@@ -250,6 +250,27 @@ struct SparseBufferRewritePass
|
||||
}
|
||||
};
|
||||
|
||||
struct SparseVectorizationPass
|
||||
: public impl::SparseVectorizationBase<SparseVectorizationPass> {
|
||||
|
||||
SparseVectorizationPass() = default;
|
||||
SparseVectorizationPass(const SparseVectorizationPass &pass) = default;
|
||||
SparseVectorizationPass(unsigned vl, bool vla, bool sidx32) {
|
||||
vectorLength = vl;
|
||||
enableVLAVectorization = vla;
|
||||
enableSIMDIndex32 = sidx32;
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
auto *ctx = &getContext();
|
||||
RewritePatternSet patterns(ctx);
|
||||
populateSparseVectorizationPatterns(
|
||||
patterns, vectorLength, enableVLAVectorization, enableSIMDIndex32);
|
||||
vector::populateVectorToVectorCanonicalizationPatterns(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -322,3 +343,15 @@ std::unique_ptr<Pass>
|
||||
mlir::createSparseBufferRewritePass(bool enableBufferInitialization) {
|
||||
return std::make_unique<SparseBufferRewritePass>(enableBufferInitialization);
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> mlir::createSparseVectorizationPass() {
|
||||
return std::make_unique<SparseVectorizationPass>();
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass>
|
||||
mlir::createSparseVectorizationPass(unsigned vectorLength,
|
||||
bool enableVLAVectorization,
|
||||
bool enableSIMDIndex32) {
|
||||
return std::make_unique<SparseVectorizationPass>(
|
||||
vectorLength, enableVLAVectorization, enableSIMDIndex32);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user