//===- TestLinalgFusionTransforms.cpp - Test Linalg fusion patterns -------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements logic for testing Linalg fusion patterns. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/SCF/Transforms.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" using namespace mlir; using namespace mlir::linalg; template static void fillFusionPatterns(MLIRContext *context, const LinalgDependenceGraph &dependenceGraph, RewritePatternSet &patterns) { patterns.add, LinalgTileAndFusePattern>( context, dependenceGraph, LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType), LinalgFusionOptions().setIndicesToFuse({2}), LinalgTransformationFilter( Identifier::get("basic_fusion", context), Identifier::get("after_basic_fusion", context)), LinalgTransformationFilter( ArrayRef(), Identifier::get("after_basic_fusion_producer", context)), LinalgTransformationFilter( ArrayRef(), Identifier::get("after_basic_fusion_original", context))); patterns.add>( context, dependenceGraph, LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType), LinalgFusionOptions().setIndicesToFuse({0}), LinalgTransformationFilter(Identifier::get("lhs_fusion", context), Identifier::get("after_lhs_fusion", context)), LinalgTransformationFilter( ArrayRef(), Identifier::get("after_lhs_fusion_producer", context)), LinalgTransformationFilter( ArrayRef(), Identifier::get("after_lhs_fusion_original", context))); patterns.add>( context, dependenceGraph, LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType), LinalgFusionOptions().setIndicesToFuse({2}), LinalgTransformationFilter(Identifier::get("out_fusion", context), Identifier::get("after_out_fusion", context)), LinalgTransformationFilter( ArrayRef(), Identifier::get("after_out_fusion_producer", context)), LinalgTransformationFilter( ArrayRef(), Identifier::get("after_out_fusion_original", context))); patterns.add>( context, dependenceGraph, LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType), LinalgFusionOptions().setIndicesToFuse({1}), LinalgTransformationFilter(Identifier::get("rhs_fusion", context), Identifier::get("after_rhs_fusion", context)), LinalgTransformationFilter( ArrayRef(), Identifier::get("after_rhs_fusion_producer", context)), LinalgTransformationFilter( ArrayRef(), Identifier::get("after_rhs_fusion_original", context))); patterns.add>( context, dependenceGraph, LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType), LinalgFusionOptions().setIndicesToFuse({0, 2}), LinalgTransformationFilter( Identifier::get("two_operand_fusion", context), Identifier::get("after_two_operand_fusion", context)), LinalgTransformationFilter( ArrayRef(), Identifier::get("after_two_operand_fusion_producer", context)), LinalgTransformationFilter( ArrayRef(), Identifier::get("after_two_operand_fusion_original", context))); patterns.add>( context, dependenceGraph, LinalgTilingOptions().setTileSizes({32, 64}).setLoopType(LoopType), LinalgFusionOptions().setIndicesToFuse({0, 1}), LinalgTransformationFilter( Identifier::get("transpose_fusion", context), Identifier::get("after_transpose_fusion", context)), LinalgTransformationFilter( ArrayRef(), Identifier::get("after_transpose_fusion_producer", context)), LinalgTransformationFilter( ArrayRef(), Identifier::get("after_transpose_fusion_original", context))); } namespace { template struct TestLinalgFusionTransforms : public PassWrapper, FunctionPass> { void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } TestLinalgFusionTransforms() = default; TestLinalgFusionTransforms(const TestLinalgFusionTransforms &pass) {} void runOnFunction() override { MLIRContext *context = &this->getContext(); FuncOp funcOp = this->getFunction(); RewritePatternSet fusionPatterns(context); Aliases alias; LinalgDependenceGraph dependenceGraph = LinalgDependenceGraph::buildDependenceGraph(alias, funcOp); fillFusionPatterns(context, dependenceGraph, fusionPatterns); (void)applyPatternsAndFoldGreedily(funcOp, std::move(fusionPatterns)); } }; struct TestLinalgFusionTransformsParallelLoops : public TestLinalgFusionTransforms { StringRef getArgument() const final { return "test-linalg-fusion-transform-patterns"; } StringRef getDescription() const final { return "Test Linalg fusion transformation patterns by applying them " "greedily."; } }; struct TestLinalgFusionTransformsLoops : public TestLinalgFusionTransforms { StringRef getArgument() const final { return "test-linalg-tensor-fusion-transform-patterns"; } StringRef getDescription() const final { return "Test Linalg on tensor fusion transformation " "patterns by applying them greedily."; } }; struct TestLinalgFusionTransformsTiledLoops : public TestLinalgFusionTransforms { StringRef getArgument() const final { return "test-linalg-tiled-loop-fusion-transform-patterns"; } StringRef getDescription() const final { return "Test Linalg on tensor fusion transformation " "patterns by applying them greedily."; } }; } // namespace static LogicalResult fuseLinalgOpsGreedily(FuncOp f) { OpBuilder b(f); DenseSet eraseSet; // Save original Linalg ops, we only want to make a pass over those. SmallVector linalgOps; f.walk([&](LinalgOp op) { // TODO: support multi-results. if (op->getNumResults() <= 1) linalgOps.push_back(op); }); // Tile and Fuse for tensors inputs (TODO: all tensor operands). bool changed = false; for (LinalgOp linalgOp : llvm::reverse(linalgOps)) { for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { if (opOperand->get().getType().isa()) { // TODO: LinalgDependenceGraph should be able to update itself. // The current naive and expensive reconstruction of the graph should be // removed. linalg::Aliases aliases; linalg::LinalgDependenceGraph graph(aliases, linalgOps); if (auto info = fuseProducerOfBuffer(b, *opOperand, graph)) { auto *originalOp = info->originalProducer.getOperation(); eraseSet.insert(originalOp); auto *originalOpInLinalgOpsVector = std::find(linalgOps.begin(), linalgOps.end(), originalOp); *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); changed = true; } } else if (opOperand->get().getType().isa()) { // Tile and Fuse tensor input. if (opOperand->getOperandNumber() >= linalgOp.getNumInputs()) continue; if (auto info = fuseProducerOfTensor(b, *opOperand)) { auto *originalOp = info->originalProducer.getOperation(); auto *originalOpInLinalgOpsVector = std::find(linalgOps.begin(), linalgOps.end(), originalOp); *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); // Don't mark for erasure in the tensor case, let DCE handle this. changed = true; } } } } // The `fuseProducerOfBuffer` function performs structural checks and in // particular that no covering read or write exist between the consumer and // the producer. As a consequence, the only fusions that may occur preserve // subsequent dependences and are guaranteed by construction to produce the // whole view. We may thus erase the producer once it is fused. for (auto *e : eraseSet) e->erase(); return changed ? success() : failure(); } namespace { struct TestLinalgGreedyFusion : public PassWrapper { void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } StringRef getArgument() const final { return "test-linalg-greedy-fusion"; } StringRef getDescription() const final { return "Test Linalg fusion by applying a greedy test transformation."; } void runOnFunction() override { MLIRContext *context = &getContext(); RewritePatternSet patterns = linalg::getLinalgTilingCanonicalizationPatterns(context); patterns.add(context); scf::populateSCFForLoopCanonicalizationPatterns(patterns); FrozenRewritePatternSet frozenPatterns(std::move(patterns)); do { (void)applyPatternsAndFoldGreedily(getFunction(), frozenPatterns); PassManager pm(context); pm.addPass(createLoopInvariantCodeMotionPass()); pm.addPass(createCanonicalizerPass()); pm.addPass(createCSEPass()); LogicalResult res = pm.run(getFunction()->getParentOfType()); if (failed(res)) this->signalPassFailure(); } while (succeeded(fuseLinalgOpsGreedily(getFunction()))); } }; /// Pass to test tile and fuse of sequence of operations. Intended only for /// testing. struct TestLinalgTileAndFuseSequencePass : public PassWrapper { StringRef getArgument() const final { return "test-linalg-tile-and-fuse"; } StringRef getDescription() const final { return "Test Linalg tiling and fusion of a sequence of Linalg operations."; } TestLinalgTileAndFuseSequencePass() = default; TestLinalgTileAndFuseSequencePass( const TestLinalgTileAndFuseSequencePass &pass){}; ListOption tileSizes{ *this, "tile-sizes", llvm::cl::desc("Tile sizes to use for ops"), llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } void runOnFunction() override { FuncOp funcOp = getOperation(); auto &blocks = funcOp.getBody().getBlocks(); if (!llvm::hasSingleElement(blocks)) { return; } SmallVector linalgOps = llvm::to_vector<2>(blocks.front().getOps()); Aliases aliases; LinalgDependenceGraph dependenceGraph(aliases, linalgOps); OpBuilder builder(funcOp.getContext()); linalg::LinalgTilingLoopType loopType = LinalgTilingLoopType::ParallelLoops; if (llvm::any_of(linalgOps, [](LinalgOp linalgOp) { return linalgOp.hasTensorSemantics(); })) loopType = LinalgTilingLoopType::Loops; Optional tileAndFuseOps = tileAndFuseLinalgOps( builder, linalgOps, dependenceGraph, LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(loopType)); if (!tileAndFuseOps) return signalPassFailure(); if (linalgOps.back().hasTensorSemantics()) { linalgOps.back().getOperation()->replaceAllUsesWith( tileAndFuseOps->fusedLoops.front()); } for (auto op : linalgOps) if (op.hasBufferSemantics()) op.erase(); } }; } // namespace namespace mlir { namespace test { void registerTestLinalgFusionTransforms() { PassRegistration(); } void registerTestLinalgTensorFusionTransforms() { PassRegistration(); } void registerTestLinalgTiledLoopFusionTransforms() { PassRegistration(); } void registerTestLinalgGreedyFusion() { PassRegistration(); } void registerTestLinalgTileAndFuseSequencePass() { PassRegistration(); } } // namespace test } // namespace mlir