[mlir][sparse] introduce vectorization pass for sparse loops

This brings back previous SIMD functionality, but in a separate pass.
The idea is to improve this new pass incrementally, going beyond for-loops
to while-loops for co-iteration as welll (masking), while introducing new
abstractions to make the lowering more progressive. The separation of
sparsification and vectorization is a very good first step on this journey.

Also brings back ArmSVE support

Still to be fine-tuned:
  + use of "index" in SIMD loop (viz. a[i] = i)
  + check that all ops really have SIMD support
  + check all forms of reductions
  + chain reduction SIMD values

Reviewed By: dcaballe

Differential Revision: https://reviews.llvm.org/D138236
This commit is contained in:
Aart Bik
2022-11-18 12:18:00 -08:00
parent 9df8ba631d
commit 99b3849d89
7 changed files with 1016 additions and 89 deletions

View File

@@ -27,6 +27,7 @@ namespace mlir {
#define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
#define GEN_PASS_DEF_SPARSETENSORCODEGEN
#define GEN_PASS_DEF_SPARSEBUFFERREWRITE
#define GEN_PASS_DEF_SPARSEVECTORIZATION
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
} // namespace mlir
@@ -67,10 +68,9 @@ struct SparsificationPass
auto *ctx = &getContext();
// Translate strategy flags to strategy options.
SparsificationOptions options(parallelization);
// Apply sparsification and vector cleanup rewriting.
// Apply sparsification and cleanup rewriting.
RewritePatternSet patterns(ctx);
populateSparsificationPatterns(patterns, options);
vector::populateVectorToVectorCanonicalizationPatterns(patterns);
scf::ForOp::getCanonicalizationPatterns(patterns, ctx);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
@@ -250,6 +250,27 @@ struct SparseBufferRewritePass
}
};
struct SparseVectorizationPass
: public impl::SparseVectorizationBase<SparseVectorizationPass> {
SparseVectorizationPass() = default;
SparseVectorizationPass(const SparseVectorizationPass &pass) = default;
SparseVectorizationPass(unsigned vl, bool vla, bool sidx32) {
vectorLength = vl;
enableVLAVectorization = vla;
enableSIMDIndex32 = sidx32;
}
void runOnOperation() override {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
populateSparseVectorizationPatterns(
patterns, vectorLength, enableVLAVectorization, enableSIMDIndex32);
vector::populateVectorToVectorCanonicalizationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
} // namespace
//===----------------------------------------------------------------------===//
@@ -322,3 +343,15 @@ std::unique_ptr<Pass>
mlir::createSparseBufferRewritePass(bool enableBufferInitialization) {
return std::make_unique<SparseBufferRewritePass>(enableBufferInitialization);
}
std::unique_ptr<Pass> mlir::createSparseVectorizationPass() {
return std::make_unique<SparseVectorizationPass>();
}
std::unique_ptr<Pass>
mlir::createSparseVectorizationPass(unsigned vectorLength,
bool enableVLAVectorization,
bool enableSIMDIndex32) {
return std::make_unique<SparseVectorizationPass>(
vectorLength, enableVLAVectorization, enableSIMDIndex32);
}