//===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements logic for testing Linalg transformations. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "llvm/ADT/SetVector.h" using namespace mlir; using namespace mlir::linalg; namespace { struct TestLinalgTransforms : public PassWrapper { TestLinalgTransforms() = default; TestLinalgTransforms(const TestLinalgTransforms &pass) {} void runOnFunction() override; Option testPatterns{*this, "test-patterns", llvm::cl::desc("Test a mixed set of patterns"), llvm::cl::init(false)}; }; } // end anonymous namespace static void applyPatterns(FuncOp funcOp) { MLIRContext *ctx = funcOp.getContext(); OwningRewritePatternList patterns; //===--------------------------------------------------------------------===// // Linalg tiling patterns. //===--------------------------------------------------------------------===// patterns.insert>( ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}), LinalgMarker({"MEM", {}}, "L3")); patterns.insert>( ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}), LinalgMarker({"L3"}, "L2")); patterns.insert>( ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}), LinalgMarker({"L2"}, "L1")); patterns.insert>( ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}), LinalgMarker({"L1"}, "REG")); patterns.insert>( ctx, LinalgTilingOptions().setTileSizes({5, 6}).setLoopType( LinalgTilingLoopType::ParallelLoops), LinalgMarker({}, "L1")); patterns.insert>( ctx, LinalgTilingOptions().setTileSizes(8000), LinalgMarker({"MEM", "L3", "L2", {}}, "REG")); //===--------------------------------------------------------------------===// // Linalg tiling and permutation patterns. //===--------------------------------------------------------------------===// patterns.insert>( ctx, LinalgTilingOptions() .setTileSizes({2000, 3000, 4000}) .setInterchange({1, 2, 0}), LinalgMarker({"__with_perm__"}, "L2__with_perm__")); patterns.insert>( ctx, LinalgTilingOptions() .setTileSizes({200, 300, 400}) .setInterchange({1, 0, 2}), LinalgMarker({"L2__with_perm__"}, "L1__with_perm__")); patterns.insert>( ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}), LinalgMarker({"L1__with_perm__"}, "REG__with_perm__")); patterns.insert>( ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}), LinalgMarker({"__with_perm__"}, "L1__with_perm__")); //===--------------------------------------------------------------------===// // Linalg to loops patterns. //===--------------------------------------------------------------------===// patterns.insert>( ctx, /*loweringType=*/LinalgLoweringType::Loops, LinalgMarker({"REG"})); //===--------------------------------------------------------------------===// // Linalg to vector contraction patterns. //===--------------------------------------------------------------------===// patterns.insert, LinalgVectorizationPattern, LinalgVectorizationPattern>( ctx, LinalgMarker({"VECTORIZE"})); //===--------------------------------------------------------------------===// // Linalg generic permutation patterns. //===--------------------------------------------------------------------===// patterns.insert>( ctx, /*interchangeVector=*/ArrayRef{1, 2, 0}, LinalgMarker({}, "PERMUTED")); patterns.insert>( ctx, /*interchangeVector=*/ArrayRef{1, 2, 0}, LinalgMarker({}, "PERMUTED")); //===--------------------------------------------------------------------===// // Linalg subview operands promotion. //===--------------------------------------------------------------------===// patterns.insert>( ctx, LinalgMarker({"_promote_views_"}, "_views_promoted_")); patterns.insert>( ctx, /*operandsToPromote=*/ArrayRef{0}, LinalgMarker({"_promote_first_view_"}, "_first_view_promoted_")); patterns.insert>( ctx, /*operandsToPromote=*/ArrayRef{0}, /*alignment=*/32, LinalgMarker({"_promote_views_aligned_"}, "_views_aligned_promoted_")); applyPatternsAndFoldGreedily(funcOp, patterns); // Drop the marker. funcOp.walk([](LinalgOp op) { op.removeAttr(LinalgTransforms::kLinalgTransformMarker); }); } /// Apply transformations specified as patterns. void TestLinalgTransforms::runOnFunction() { if (testPatterns) return applyPatterns(getFunction()); } namespace mlir { void registerTestLinalgTransforms() { PassRegistration testTransformPatternsPass( "test-linalg-transform-patterns", "Test Linalg transformation patterns by applying them greedily."); } } // namespace mlir