//===- TestConvVectorization.cpp - Vectorization of Conv ops --------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/SCF/Transforms.h" #include "mlir/Dialect/Vector/VectorTransforms.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/LoopUtils.h" #include "mlir/Transforms/Passes.h" using namespace mlir; using namespace vector; namespace { /// A pass converting MLIR Linalg ops into Vector ops. class TestConvVectorization : public PassWrapper> { public: StringRef getArgument() const final { return "test-conv-vectorization"; } StringRef getDescription() const final { return "Test vectorization of convolutions"; } TestConvVectorization() = default; TestConvVectorization(const TestConvVectorization &) {} explicit TestConvVectorization(ArrayRef tileSizesParam) { tileSizes = tileSizesParam; } void runOnOperation() override; void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); registry.insert(); registry.insert(); registry.insert(); registry.insert(); registry.insert(); } ListOption tileSizes{ *this, "tile-sizes", llvm::cl::desc("Vectorization sizes."), llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; }; } // namespace void TestConvVectorization::runOnOperation() { MLIRContext *context = &getContext(); ModuleOp module = getOperation(); ConversionTarget target(*context); target.addLegalDialect(); target.addLegalOp(); target.addLegalOp(); SmallVector stage1Patterns; linalg::populateConvVectorizationPatterns(context, stage1Patterns, tileSizes); SmallVector frozenStage1Patterns; llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns)); RewritePatternSet stage2Patterns = linalg::getLinalgTilingCanonicalizationPatterns(context); scf::populateSCFForLoopCanonicalizationPatterns(stage2Patterns); auto stage3Transforms = [](Operation *op) { PassManager pm(op->getContext()); pm.addPass(createLoopInvariantCodeMotionPass()); if (failed(pm.run(cast(op)))) llvm_unreachable("Unexpected failure in cleanup pass pipeline."); op->walk([](FuncOp func) { promoteSingleIterationLoops(func); linalg::hoistRedundantVectorTransfers(func); }); return success(); }; (void)linalg::applyStagedPatterns(module, frozenStage1Patterns, std::move(stage2Patterns), stage3Transforms); //===--------------------------------------------------------------------===// // Post staged patterns transforms //===--------------------------------------------------------------------===// VectorTransformsOptions vectorTransformsOptions{ VectorContractLowering::Dot, VectorTransposeLowering::EltWise}; RewritePatternSet vectorTransferPatterns(context); // Pattern is not applied because rank-reducing vector transfer is not yet // supported as can be seen in splitFullAndPartialTransferPrecondition, // VectorTransforms.cpp vectorTransferPatterns.add( context, vectorTransformsOptions); (void)applyPatternsAndFoldGreedily(module, std::move(vectorTransferPatterns)); // Programmatic controlled lowering of linalg.copy and linalg.fill. PassManager pm(context); pm.addNestedPass(createConvertLinalgToLoopsPass()); if (failed(pm.run(module))) llvm_unreachable("Unexpected failure in linalg to loops pass."); // Programmatic controlled lowering of vector.contract only. RewritePatternSet vectorContractLoweringPatterns(context); populateVectorContractLoweringPatterns(vectorContractLoweringPatterns, vectorTransformsOptions); populateVectorTransposeLoweringPatterns(vectorContractLoweringPatterns, vectorTransformsOptions); (void)applyPatternsAndFoldGreedily(module, std::move(vectorContractLoweringPatterns)); // Programmatic controlled lowering of vector.transfer only. RewritePatternSet vectorToLoopsPatterns(context); populateVectorToSCFConversionPatterns(vectorToLoopsPatterns, VectorTransferToSCFOptions()); (void)applyPatternsAndFoldGreedily(module, std::move(vectorToLoopsPatterns)); // Ensure we drop the marker in the end. module.walk([](linalg::LinalgOp op) { op->removeAttr(linalg::LinalgTransforms::kLinalgTransformMarker); }); } namespace mlir { namespace test { void registerTestConvVectorization() { PassRegistration(); } } // namespace test } // namespace mlir