//===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements logic for testing Linalg transformations. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/SetVector.h" using namespace mlir; using namespace mlir::linalg; namespace { struct TestLinalgTransforms : public PassWrapper { TestLinalgTransforms() = default; TestLinalgTransforms(const TestLinalgTransforms &pass) {} void getDependentDialects(DialectRegistry ®istry) const override { // clang-format off registry.insert(); // clang-format on } StringRef getArgument() const final { return "test-linalg-transform-patterns"; } StringRef getDescription() const final { return "Test Linalg transformation patterns by applying them greedily."; } void runOnFunction() override; Option testPatterns{*this, "test-patterns", llvm::cl::desc("Test a mixed set of patterns"), llvm::cl::init(false)}; Option testMatmulToVectorPatterns1dTiling{ *this, "test-matmul-to-vector-patterns-tile-1d", llvm::cl::desc( "Test a fused pass that applies patterns from matmul to vectors via " "1-d tiling"), llvm::cl::init(false)}; Option testMatmulToVectorPatterns2dTiling{ *this, "test-matmul-to-vector-patterns-tile-2d", llvm::cl::desc( "Test a fused pass that applies patterns from matmul to vectors via " "2-d tiling"), llvm::cl::init(false)}; Option testPromotionOptions{*this, "test-linalg-promotion-options", llvm::cl::desc("Test promotion options"), llvm::cl::init(false)}; Option testTileAndDistributionOptions{ *this, "test-tile-and-distribute-options", llvm::cl::desc("Test tile and distribute options"), llvm::cl::init(false)}; Option testVectorTransferForwardingPatterns{ *this, "test-vector-transfer-forwarding-patterns", llvm::cl::desc( "Test a fused pass that forwards linalg.copy to vector.transfer"), llvm::cl::init(false)}; Option testGenericToVectorPattern{ *this, "test-linalg-to-vector-patterns", llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction " "in vector.contract form"), llvm::cl::init(false)}; Option testTileAndPadPattern{ *this, "test-tile-and-pad-pattern", llvm::cl::desc("Test tile and pad pattern"), llvm::cl::init(false)}; Option testHoistPadding{*this, "test-hoist-padding", llvm::cl::desc("Test hoist padding"), llvm::cl::init(0)}; Option testTransformPadTensor{ *this, "test-transform-pad-tensor", llvm::cl::desc("Test transform pad tensor by copying with generic ops"), llvm::cl::init(false)}; Option testGeneralizePadTensor{ *this, "test-generalize-pad-tensor", llvm::cl::desc("Test transform pad tensor by copying with generic ops"), llvm::cl::init(false)}; Option testSwapSubTensorPadTensor{ *this, "test-swap-subtensor-padtensor", llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into " "pad_tensor(subtensor)"), llvm::cl::init(false)}; ListOption tileSizesForPadding{ *this, "tile-sizes-for-padding", llvm::cl::desc("Linalg tile sizes when tile+pad"), llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; ListOption testInterchangePattern{ *this, "test-interchange-pattern", llvm::cl::MiscFlags::CommaSeparated, llvm::cl::desc("Test the interchange pattern.")}; }; } // end anonymous namespace static void applyPatterns(FuncOp funcOp) { MLIRContext *ctx = funcOp.getContext(); RewritePatternSet patterns(ctx); //===--------------------------------------------------------------------===// // Linalg tiling patterns. //===--------------------------------------------------------------------===// patterns.add>( ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}), LinalgTransformationFilter(Identifier::get("MEM", ctx), Identifier::get("L3", ctx))); patterns.add>( ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}), LinalgTransformationFilter(Identifier::get("L3", ctx), Identifier::get("L2", ctx))); patterns.add>( ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}), LinalgTransformationFilter(Identifier::get("L2", ctx), Identifier::get("L1", ctx))); patterns.add>( ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}), LinalgTransformationFilter(Identifier::get("L1", ctx), Identifier::get("REG", ctx))); patterns.add>( ctx, LinalgTilingOptions().setTileSizes({5, 6}).setLoopType( LinalgTilingLoopType::ParallelLoops), LinalgTransformationFilter(ArrayRef{}, Identifier::get("L1", ctx))); patterns.add>( ctx, LinalgTilingOptions().setTileSizes(8000), LinalgTransformationFilter( ArrayRef{Identifier::get("MEM", ctx), Identifier::get("L3", ctx), Identifier::get("L2", ctx)}, Identifier::get("REG", ctx))); //===--------------------------------------------------------------------===// // Linalg tiling and permutation patterns. //===--------------------------------------------------------------------===// patterns.add>( ctx, LinalgTilingOptions() .setTileSizes({2000, 3000, 4000}) .setInterchange({1, 2, 0}), LinalgTransformationFilter(Identifier::get("__with_perm__", ctx), Identifier::get("L2__with_perm__", ctx))); patterns.add>( ctx, LinalgTilingOptions() .setTileSizes({200, 300, 400}) .setInterchange({1, 0, 2}), LinalgTransformationFilter(Identifier::get("L2__with_perm__", ctx), Identifier::get("L1__with_perm__", ctx))); patterns.add>( ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}), LinalgTransformationFilter(Identifier::get("L1__with_perm__", ctx), Identifier::get("REG__with_perm__", ctx))); patterns.add>( ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}), LinalgTransformationFilter(Identifier::get("__with_perm__", ctx), Identifier::get("L1__with_perm__", ctx))); patterns.add>( ctx, LinalgTilingOptions() .setTileSizes({16, 8, 4}) .setInterchange({1, 2, 0}) .setLoopType(LinalgTilingLoopType::ParallelLoops), LinalgTransformationFilter( Identifier::get("par__with_perm__", ctx), Identifier::get("after_par__with_perm__", ctx))); //===--------------------------------------------------------------------===// // Linalg to loops patterns. //===--------------------------------------------------------------------===// patterns.add>( ctx, /*loweringType=*/LinalgLoweringType::Loops, LinalgTransformationFilter(Identifier::get("REG", ctx))); //===--------------------------------------------------------------------===// // Linalg distribution patterns. //===--------------------------------------------------------------------===// LinalgLoopDistributionOptions distributionOptions; //===--------------------------------------------------------------------===// // Linalg to vector contraction patterns. //===--------------------------------------------------------------------===// patterns.add( ctx, LinalgTransformationFilter(Identifier::get("VECTORIZE", ctx)) .addOpFilter()); //===--------------------------------------------------------------------===// // Linalg generic interchange pattern. //===--------------------------------------------------------------------===// patterns.add( ctx, /*interchangeVector=*/ArrayRef{1, 2, 0}, LinalgTransformationFilter(ArrayRef{}, Identifier::get("PERMUTED", ctx))); //===--------------------------------------------------------------------===// // Linalg subview operands promotion. //===--------------------------------------------------------------------===// patterns.add>( ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), LinalgTransformationFilter(Identifier::get("_promote_views_", ctx), Identifier::get("_views_promoted_", ctx))); patterns.add>( ctx, LinalgPromotionOptions() .setOperandsToPromote({0}) .setUseFullTileBuffersByDefault(true), LinalgTransformationFilter( Identifier::get("_promote_first_view_", ctx), Identifier::get("_first_view_promoted_", ctx))); patterns.add>( ctx, LinalgPromotionOptions() .setOperandsToPromote({1}) .setUseFullTileBuffers({false, true}) .setAlignment(32), LinalgTransformationFilter( Identifier::get("_promote_views_aligned_", ctx), Identifier::get("_views_aligned_promoted_", ctx))); (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); // Drop the marker. funcOp.walk([](LinalgOp op) { op->removeAttr(LinalgTransforms::kLinalgTransformMarker); }); } static void fillL1TilingAndMatmulToVectorPatterns( FuncOp funcOp, StringRef startMarker, SmallVectorImpl &patternsVector) { MLIRContext *ctx = funcOp.getContext(); patternsVector.emplace_back( ctx, std::make_unique>( ctx, LinalgTilingOptions() .setTileSizes({8, 12, 16}) .setInterchange({1, 0, 2}), LinalgTransformationFilter(Identifier::get(startMarker, ctx), Identifier::get("L1", ctx)))); patternsVector.emplace_back( ctx, std::make_unique>( ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), LinalgTransformationFilter(Identifier::get("L1", ctx), Identifier::get("VEC", ctx)))); patternsVector.emplace_back( ctx, std::make_unique( MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(), LinalgTransformationFilter(Identifier::get("VEC", ctx)))); patternsVector.back().add( ctx, LinalgTransformationFilter().addFilter( [](Operation *op) { return success(isa(op)); })); } //===----------------------------------------------------------------------===// // Test promotion callbacks //===----------------------------------------------------------------------===// // Allocation call back static Optional allocCallBackFn(OpBuilder &b, memref::SubViewOp subView, ArrayRef boundingSubViewSize, DataLayout &layout) { SmallVector shape(boundingSubViewSize.size(), -1); return b .create( subView.getLoc(), MemRefType::get(shape, subView.getType().getElementType(), /*affineMapComposition =*/{}, 3), boundingSubViewSize) .getResult(); } // Deallocation callback static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) { b.create(buffer.getLoc(), buffer); return success(); } // Copy in call back static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst, bool isOutput) { auto floatType = src.getType().cast().getElementType(); if (!floatType.isa()) return failure(); if (!isOutput) { Value cst = b.create(src.getLoc(), FloatAttr::get(floatType, 42.0)); b.create(src.getLoc(), cst, dst); } b.create(src.getLoc(), src, dst); return success(); } static void fillPromotionCallBackPatterns(MLIRContext *ctx, RewritePatternSet &patterns) { patterns.add>( ctx, LinalgTilingOptions().setTileSizes({16, 16, 16}), LinalgTransformationFilter(Identifier::get("START", ctx), Identifier::get("PROMOTE", ctx))); patterns.add>( ctx, LinalgPromotionOptions() .setOperandsToPromote({0, 2}) .setUseFullTileBuffers({false, false}) .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn) .setCopyInOutFns( [](OpBuilder &b, Value src, Value dst) -> LogicalResult { return copyCallBackFn(b, src, dst, false); }, [](OpBuilder &b, Value src, Value dst) -> LogicalResult { return copyCallBackFn(b, src, dst, true); }), LinalgTransformationFilter(Identifier::get("PROMOTE", ctx))); } template static SmallVector getGpuProcIds(OpBuilder &b, Location loc, ArrayRef parallelLoopRanges) { size_t count = std::min(3, parallelLoopRanges.size()); SmallVector procInfo(count); const char *xyz[] = {"x", "y", "z"}; Type indexType = b.getIndexType(); for (unsigned i = 0; i < count; ++i) { procInfo[count - 1 - i] = { b.create(loc, indexType, b.getStringAttr(xyz[i])), b.create(loc, indexType, b.getStringAttr(xyz[i]))}; } return procInfo; } static void fillTileAndDistributePatterns(MLIRContext *context, RewritePatternSet &patterns) { { LinalgLoopDistributionOptions cyclicNprocsEqNiters; cyclicNprocsEqNiters.distributionMethod.resize( 2, DistributionMethod::CyclicNumProcsEqNumIters); cyclicNprocsEqNiters.procInfo = getGpuProcIds; patterns.add>( context, LinalgTilingOptions() .setTileSizes({8, 8, 4}) .setLoopType(LinalgTilingLoopType::ParallelLoops) .setDistributionOptions(cyclicNprocsEqNiters), LinalgTransformationFilter( Identifier::get("distribute1", context), Identifier::get("after_distribute1", context))); } { LinalgLoopDistributionOptions cyclicNprocsGeNiters; cyclicNprocsGeNiters.distributionMethod.resize( 2, DistributionMethod::CyclicNumProcsGeNumIters); cyclicNprocsGeNiters.procInfo = getGpuProcIds; patterns.add>( context, LinalgTilingOptions() .setTileSizes({8, 8, 4}) .setLoopType(LinalgTilingLoopType::ParallelLoops) .setDistributionOptions(cyclicNprocsGeNiters), LinalgTransformationFilter( Identifier::get("distribute2", context), Identifier::get("after_distribute2", context))); } { LinalgLoopDistributionOptions cyclicNprocsDefault; cyclicNprocsDefault.distributionMethod.resize(2, DistributionMethod::Cyclic); cyclicNprocsDefault.procInfo = getGpuProcIds; patterns.add>( context, LinalgTilingOptions() .setTileSizes({8, 8, 4}) .setLoopType(LinalgTilingLoopType::ParallelLoops) .setDistributionOptions(cyclicNprocsDefault), LinalgTransformationFilter( Identifier::get("distribute3", context), Identifier::get("after_distribute3", context))); } { LinalgLoopDistributionOptions cyclicNprocsMixed1; cyclicNprocsMixed1.distributionMethod = { DistributionMethod::CyclicNumProcsEqNumIters, DistributionMethod::CyclicNumProcsGeNumIters}; cyclicNprocsMixed1.procInfo = getGpuProcIds; patterns.add>( context, LinalgTilingOptions() .setTileSizes({8, 8, 4}) .setLoopType(LinalgTilingLoopType::ParallelLoops) .setDistributionOptions(cyclicNprocsMixed1), LinalgTransformationFilter( Identifier::get("distribute4", context), Identifier::get("after_distribute4", context))); } { LinalgLoopDistributionOptions cyclicNprocsMixed2; cyclicNprocsMixed2.distributionMethod = { DistributionMethod::CyclicNumProcsGeNumIters, DistributionMethod::Cyclic}; cyclicNprocsMixed2.procInfo = getGpuProcIds; patterns.add>( context, LinalgTilingOptions() .setTileSizes({8, 8, 4}) .setLoopType(LinalgTilingLoopType::ParallelLoops) .setDistributionOptions(cyclicNprocsMixed2), LinalgTransformationFilter( Identifier::get("distribute5", context), Identifier::get("after_distribute5", context))); } { LinalgLoopDistributionOptions cyclicNprocsMixed3; cyclicNprocsMixed3.distributionMethod = { DistributionMethod::Cyclic, DistributionMethod::CyclicNumProcsEqNumIters}; cyclicNprocsMixed3.procInfo = getGpuProcIds; patterns.add>( context, LinalgTilingOptions() .setTileSizes({8, 8, 4}) .setLoopType(LinalgTilingLoopType::ParallelLoops) .setDistributionOptions(cyclicNprocsMixed3), LinalgTransformationFilter( Identifier::get("distribute6", context), Identifier::get("after_distribute6", context))); } { LinalgLoopDistributionOptions cyclicNprocsEqNiters; cyclicNprocsEqNiters.distributionMethod.resize(2, DistributionMethod::Cyclic); cyclicNprocsEqNiters.procInfo = getGpuProcIds; patterns.add>( context, LinalgTilingOptions() .setTileSizes({8, 8, 4}) .setLoopType(LinalgTilingLoopType::Loops) .setDistributionOptions(cyclicNprocsEqNiters), LinalgTransformationFilter( Identifier::get("tensors_distribute1", context), Identifier::get("tensors_after_distribute1", context))); } } static void applyMatmulToVectorPatterns(FuncOp funcOp, bool testMatmulToVectorPatterns1dTiling, bool testMatmulToVectorPatterns2dTiling) { MLIRContext *ctx = funcOp.getContext(); SmallVector stage1Patterns; if (testMatmulToVectorPatterns1dTiling) { fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("START", ctx), stage1Patterns); } else if (testMatmulToVectorPatterns2dTiling) { stage1Patterns.emplace_back( ctx, std::make_unique>( ctx, LinalgTilingOptions() .setTileSizes({768, 264, 768}) .setInterchange({1, 2, 0}), LinalgTransformationFilter(Identifier::get("START", ctx), Identifier::get("L2", ctx)))); fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("L2", ctx), stage1Patterns); } SmallVector frozenStage1Patterns; llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns)); FrozenRewritePatternSet stage2Patterns = getLinalgTilingCanonicalizationPatterns(ctx); (void)applyStagedPatterns(funcOp, frozenStage1Patterns, std::move(stage2Patterns)); } static void applyVectorTransferForwardingPatterns(FuncOp funcOp) { RewritePatternSet forwardPattern(funcOp.getContext()); forwardPattern.add(funcOp.getContext()); forwardPattern.add(funcOp.getContext()); (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern)); } static void applyLinalgToVectorPatterns(FuncOp funcOp) { RewritePatternSet patterns(funcOp.getContext()); patterns.add( funcOp.getContext(), LinalgTransformationFilter() .addOpFilter()); populatePadTensorOpVectorizationPatterns(patterns); (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } static void applyPadTensorToGenericPatterns(FuncOp funcOp) { RewritePatternSet patterns(funcOp.getContext()); patterns.add(funcOp.getContext()); (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } static void applyGeneralizePadTensorPatterns(FuncOp funcOp) { RewritePatternSet patterns(funcOp.getContext()); patterns.add(funcOp.getContext()); (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) { RewritePatternSet patterns(funcOp.getContext()); patterns.add(funcOp.getContext()); (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } // For now, just assume it is the zero of type. // In the future, it should be the zero of type + op. static Value getNeutralOfLinalgOp(OpBuilder &b, OpOperand &op) { auto t = getElementTypeOrSelf(op.get()); return b.create(op.getOwner()->getLoc(), t, b.getZeroAttr(t)); } static void applyTileAndPadPattern(FuncOp funcOp, ArrayRef tileSizes) { MLIRContext *context = funcOp.getContext(); RewritePatternSet tilingPattern(context); auto linalgTilingOptions = linalg::LinalgTilingOptions() .setTileSizes(tileSizes) .setPaddingValueComputationFunction(getNeutralOfLinalgOp); tilingPattern.add, linalg::LinalgTilingPattern>( context, linalgTilingOptions, linalg::LinalgTransformationFilter( Identifier::get("tile-and-pad", context))); (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern)); } static void applyInterchangePattern(FuncOp funcOp, ArrayRef interchangeVector) { MLIRContext *context = funcOp.getContext(); RewritePatternSet interchangePattern(context); interchangePattern.add( context, interchangeVector, LinalgTransformationFilter(ArrayRef{}, Identifier::get("interchange", context))); (void)applyPatternsAndFoldGreedily(funcOp, std::move(interchangePattern)); } /// Apply transformations specified as patterns. void TestLinalgTransforms::runOnFunction() { auto lambda = [&](void *) { getFunction().walk([](LinalgOp op) { op->removeAttr(LinalgTransforms::kLinalgTransformMarker); }); }; std::unique_ptr cleanupGuard{(void *)1, lambda}; if (testPromotionOptions) { RewritePatternSet patterns(&getContext()); fillPromotionCallBackPatterns(&getContext(), patterns); (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); return; } if (testTileAndDistributionOptions) { RewritePatternSet patterns(&getContext()); fillTileAndDistributePatterns(&getContext(), patterns); (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); return; } if (testPatterns) return applyPatterns(getFunction()); if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling) return applyMatmulToVectorPatterns(getFunction(), testMatmulToVectorPatterns1dTiling, testMatmulToVectorPatterns2dTiling); if (testVectorTransferForwardingPatterns) return applyVectorTransferForwardingPatterns(getFunction()); if (testGenericToVectorPattern) return applyLinalgToVectorPatterns(getFunction()); if (testTransformPadTensor) return applyPadTensorToGenericPatterns(getFunction()); if (testGeneralizePadTensor) return applyGeneralizePadTensorPatterns(getFunction()); if (testSwapSubTensorPadTensor) return applyExtractSliceOfPadTensorSwapPattern(getFunction()); if (testTileAndPadPattern) return applyTileAndPadPattern(getFunction(), tileSizesForPadding); if (testHoistPadding) { getFunction().walk([&](linalg::PadTensorOp padTensorOp) { (void)linalg::hoistPaddingOnTensors(padTensorOp, testHoistPadding); }); } if (testInterchangePattern.hasValue()) return applyInterchangePattern(getFunction(), testInterchangePattern); } namespace mlir { namespace test { void registerTestLinalgTransforms() { PassRegistration(); } } // namespace test } // namespace mlir