From aeec94500a5dbd576e5d2d16895fe00fa0b1e154 Mon Sep 17 00:00:00 2001 From: junfengd-nv Date: Sat, 5 Apr 2025 13:56:55 -0700 Subject: [PATCH] [mlir][inliner] Add doClone and canHandleMultipleBlocks callbacks to Inliner Config (#131226) Current inliner disables inlining when the caller is in a region with single block trait, while the callee function contains multiple blocks. the SingleBlock trait is used in operations such as do/while loop, for example fir.do_loop, fir.iterate_while and fir.if. Typically, calls within loops are good candidates for inlining. However, functions with multiple blocks are also common. for example, any function with "if () then return" will result in multiple blocks in MLIR. This change gives the flexibility of a customized inliner to handle such cases. doClone: clones instructions and other information from the callee function into the caller function. . canHandleMultipleBlocks: checks if functions with multiple blocks can be inlined into a region with the SingleBlock trait. The default behavior of the inliner remains unchanged. --------- Co-authored-by: jeanPerier Co-authored-by: Mehdi Amini --- mlir/include/mlir/Transforms/Inliner.h | 36 +++++ mlir/include/mlir/Transforms/InliningUtils.h | 61 ++++--- mlir/lib/Transforms/Utils/Inliner.cpp | 31 ++-- mlir/lib/Transforms/Utils/InliningUtils.cpp | 116 +++++++------- .../Transforms/test-inlining-callback.mlir | 24 +++ mlir/test/lib/Transforms/CMakeLists.txt | 1 + mlir/test/lib/Transforms/TestInlining.cpp | 14 +- .../lib/Transforms/TestInliningCallback.cpp | 151 ++++++++++++++++++ mlir/tools/mlir-opt/mlir-opt.cpp | 2 + 9 files changed, 333 insertions(+), 103 deletions(-) create mode 100644 mlir/test/Transforms/test-inlining-callback.mlir create mode 100644 mlir/test/lib/Transforms/TestInliningCallback.cpp diff --git a/mlir/include/mlir/Transforms/Inliner.h b/mlir/include/mlir/Transforms/Inliner.h index ec77319d6ac8..506b4455af64 100644 --- a/mlir/include/mlir/Transforms/Inliner.h +++ b/mlir/include/mlir/Transforms/Inliner.h @@ -27,6 +27,11 @@ class InlinerConfig { public: using DefaultPipelineTy = std::function; using OpPipelinesTy = llvm::StringMap; + using CloneCallbackSigTy = void(OpBuilder &builder, Region *src, + Block *inlineBlock, Block *postInsertBlock, + IRMapping &mapper, + bool shouldCloneInlinedRegion); + using CloneCallbackTy = std::function; InlinerConfig() = default; InlinerConfig(DefaultPipelineTy defaultPipeline, @@ -39,6 +44,9 @@ public: } const OpPipelinesTy &getOpPipelines() const { return opPipelines; } unsigned getMaxInliningIterations() const { return maxInliningIterations; } + const CloneCallbackTy &getCloneCallback() const { return cloneCallback; } + bool getCanHandleMultipleBlocks() const { return canHandleMultipleBlocks; } + void setDefaultPipeline(DefaultPipelineTy pipeline) { defaultPipeline = std::move(pipeline); } @@ -46,6 +54,12 @@ public: opPipelines = std::move(pipelines); } void setMaxInliningIterations(unsigned max) { maxInliningIterations = max; } + void setCloneCallback(CloneCallbackTy callback) { + cloneCallback = std::move(callback); + } + void setCanHandleMultipleBlocks(bool value = true) { + canHandleMultipleBlocks = value; + } private: /// An optional function that constructs an optimization pipeline for @@ -60,6 +74,28 @@ private: /// For SCC-based inlining algorithms, specifies maximum number of iterations /// when inlining within an SCC. unsigned maxInliningIterations{0}; + /// Callback for cloning operations during inlining + CloneCallbackTy cloneCallback = [](OpBuilder &builder, Region *src, + Block *inlineBlock, Block *postInsertBlock, + IRMapping &mapper, + bool shouldCloneInlinedRegion) { + // Check to see if the region is being cloned, or moved inline. In + // either case, move the new blocks after the 'insertBlock' to improve + // IR readability. + Region *insertRegion = inlineBlock->getParent(); + if (shouldCloneInlinedRegion) + src->cloneInto(insertRegion, postInsertBlock->getIterator(), mapper); + else + insertRegion->getBlocks().splice(postInsertBlock->getIterator(), + src->getBlocks(), src->begin(), + src->end()); + }; + /// Determine if the inliner can inline a function containing multiple + /// blocks into a region that requires a single block. By default, it is + /// not allowed. If it is true, cloneCallback should perform the extra + /// transformation. see the example in + /// mlir/test/lib/Transforms/TestInliningCallback.cpp + bool canHandleMultipleBlocks{false}; }; /// This is an implementation of the inliner diff --git a/mlir/include/mlir/Transforms/InliningUtils.h b/mlir/include/mlir/Transforms/InliningUtils.h index becfe9b047ef..552030983d72 100644 --- a/mlir/include/mlir/Transforms/InliningUtils.h +++ b/mlir/include/mlir/Transforms/InliningUtils.h @@ -18,6 +18,7 @@ #include "mlir/IR/Location.h" #include "mlir/IR/Region.h" #include "mlir/IR/ValueRange.h" +#include "mlir/Transforms/Inliner.h" #include namespace mlir { @@ -253,33 +254,39 @@ public: /// provided, will be used to update the inlined operations' location /// information. 'shouldCloneInlinedRegion' corresponds to whether the source /// region should be cloned into the 'inlinePoint' or spliced directly. -LogicalResult inlineRegion(InlinerInterface &interface, Region *src, - Operation *inlinePoint, IRMapping &mapper, - ValueRange resultsToReplace, - TypeRange regionResultTypes, - std::optional inlineLoc = std::nullopt, - bool shouldCloneInlinedRegion = true); -LogicalResult inlineRegion(InlinerInterface &interface, Region *src, - Block *inlineBlock, Block::iterator inlinePoint, - IRMapping &mapper, ValueRange resultsToReplace, - TypeRange regionResultTypes, - std::optional inlineLoc = std::nullopt, - bool shouldCloneInlinedRegion = true); +LogicalResult +inlineRegion(InlinerInterface &interface, + function_ref cloneCallback, + Region *src, Operation *inlinePoint, IRMapping &mapper, + ValueRange resultsToReplace, TypeRange regionResultTypes, + std::optional inlineLoc = std::nullopt, + bool shouldCloneInlinedRegion = true); +LogicalResult +inlineRegion(InlinerInterface &interface, + function_ref cloneCallback, + Region *src, Block *inlineBlock, Block::iterator inlinePoint, + IRMapping &mapper, ValueRange resultsToReplace, + TypeRange regionResultTypes, + std::optional inlineLoc = std::nullopt, + bool shouldCloneInlinedRegion = true); /// This function is an overload of the above 'inlineRegion' that allows for /// providing the set of operands ('inlinedOperands') that should be used /// in-favor of the region arguments when inlining. -LogicalResult inlineRegion(InlinerInterface &interface, Region *src, - Operation *inlinePoint, ValueRange inlinedOperands, - ValueRange resultsToReplace, - std::optional inlineLoc = std::nullopt, - bool shouldCloneInlinedRegion = true); -LogicalResult inlineRegion(InlinerInterface &interface, Region *src, - Block *inlineBlock, Block::iterator inlinePoint, - ValueRange inlinedOperands, - ValueRange resultsToReplace, - std::optional inlineLoc = std::nullopt, - bool shouldCloneInlinedRegion = true); +LogicalResult +inlineRegion(InlinerInterface &interface, + function_ref cloneCallback, + Region *src, Operation *inlinePoint, ValueRange inlinedOperands, + ValueRange resultsToReplace, + std::optional inlineLoc = std::nullopt, + bool shouldCloneInlinedRegion = true); +LogicalResult +inlineRegion(InlinerInterface &interface, + function_ref cloneCallback, + Region *src, Block *inlineBlock, Block::iterator inlinePoint, + ValueRange inlinedOperands, ValueRange resultsToReplace, + std::optional inlineLoc = std::nullopt, + bool shouldCloneInlinedRegion = true); /// This function inlines a given region, 'src', of a callable operation, /// 'callable', into the location defined by the given call operation. This @@ -287,9 +294,11 @@ LogicalResult inlineRegion(InlinerInterface &interface, Region *src, /// failure, no changes are made to the module. 'shouldCloneInlinedRegion' /// corresponds to whether the source region should be cloned into the 'call' or /// spliced directly. -LogicalResult inlineCall(InlinerInterface &interface, CallOpInterface call, - CallableOpInterface callable, Region *src, - bool shouldCloneInlinedRegion = true); +LogicalResult +inlineCall(InlinerInterface &interface, + function_ref cloneCallback, + CallOpInterface call, CallableOpInterface callable, Region *src, + bool shouldCloneInlinedRegion = true); } // namespace mlir diff --git a/mlir/lib/Transforms/Utils/Inliner.cpp b/mlir/lib/Transforms/Utils/Inliner.cpp index f511504594cf..54b5c788a352 100644 --- a/mlir/lib/Transforms/Utils/Inliner.cpp +++ b/mlir/lib/Transforms/Utils/Inliner.cpp @@ -652,7 +652,7 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface, bool inlineInPlace = useList.hasOneUseAndDiscardable(it.targetNode); LogicalResult inlineResult = - inlineCall(inlinerIface, call, + inlineCall(inlinerIface, inliner.config.getCloneCallback(), call, cast(targetRegion->getParentOp()), targetRegion, /*shouldCloneInlinedRegion=*/!inlineInPlace); if (failed(inlineResult)) { @@ -730,19 +730,22 @@ bool Inliner::Impl::shouldInline(ResolvedCall &resolvedCall) { // Don't allow inlining if the callee has multiple blocks (unstructured // control flow) but we cannot be sure that the caller region supports that. - bool calleeHasMultipleBlocks = - llvm::hasNItemsOrMore(*callableRegion, /*N=*/2); - // If both parent ops have the same type, it is safe to inline. Otherwise, - // decide based on whether the op has the SingleBlock trait or not. - // Note: This check does currently not account for SizedRegion/MaxSizedRegion. - auto callerRegionSupportsMultipleBlocks = [&]() { - return callableRegion->getParentOp()->getName() == - resolvedCall.call->getParentOp()->getName() || - !resolvedCall.call->getParentOp() - ->mightHaveTrait(); - }; - if (calleeHasMultipleBlocks && !callerRegionSupportsMultipleBlocks()) - return false; + if (!inliner.config.getCanHandleMultipleBlocks()) { + bool calleeHasMultipleBlocks = + llvm::hasNItemsOrMore(*callableRegion, /*N=*/2); + // If both parent ops have the same type, it is safe to inline. Otherwise, + // decide based on whether the op has the SingleBlock trait or not. + // Note: This check does currently not account for + // SizedRegion/MaxSizedRegion. + auto callerRegionSupportsMultipleBlocks = [&]() { + return callableRegion->getParentOp()->getName() == + resolvedCall.call->getParentOp()->getName() || + !resolvedCall.call->getParentOp() + ->mightHaveTrait(); + }; + if (calleeHasMultipleBlocks && !callerRegionSupportsMultipleBlocks()) + return false; + } if (!inliner.isProfitableToInline(resolvedCall)) return false; diff --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp index e113389b26ae..3dd95d284571 100644 --- a/mlir/lib/Transforms/Utils/InliningUtils.cpp +++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Transforms/InliningUtils.h" +#include "mlir/Transforms/Inliner.h" #include "mlir/IR/Builders.h" #include "mlir/IR/IRMapping.h" @@ -266,10 +267,11 @@ static void handleResultImpl(InlinerInterface &interface, OpBuilder &builder, } static LogicalResult -inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock, - Block::iterator inlinePoint, IRMapping &mapper, - ValueRange resultsToReplace, TypeRange regionResultTypes, - std::optional inlineLoc, +inlineRegionImpl(InlinerInterface &interface, + function_ref cloneCallback, + Region *src, Block *inlineBlock, Block::iterator inlinePoint, + IRMapping &mapper, ValueRange resultsToReplace, + TypeRange regionResultTypes, std::optional inlineLoc, bool shouldCloneInlinedRegion, CallOpInterface call = {}) { assert(resultsToReplace.size() == regionResultTypes.size()); // We expect the region to have at least one block. @@ -296,16 +298,10 @@ inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock, if (call && callable) handleArgumentImpl(interface, builder, call, callable, mapper); - // Check to see if the region is being cloned, or moved inline. In either - // case, move the new blocks after the 'insertBlock' to improve IR - // readability. + // Clone the callee's source into the caller. Block *postInsertBlock = inlineBlock->splitBlock(inlinePoint); - if (shouldCloneInlinedRegion) - src->cloneInto(insertRegion, postInsertBlock->getIterator(), mapper); - else - insertRegion->getBlocks().splice(postInsertBlock->getIterator(), - src->getBlocks(), src->begin(), - src->end()); + cloneCallback(builder, src, inlineBlock, postInsertBlock, mapper, + shouldCloneInlinedRegion); // Get the range of newly inserted blocks. auto newBlocks = llvm::make_range(std::next(inlineBlock->getIterator()), @@ -374,9 +370,11 @@ inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock, } static LogicalResult -inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock, - Block::iterator inlinePoint, ValueRange inlinedOperands, - ValueRange resultsToReplace, std::optional inlineLoc, +inlineRegionImpl(InlinerInterface &interface, + function_ref cloneCallback, + Region *src, Block *inlineBlock, Block::iterator inlinePoint, + ValueRange inlinedOperands, ValueRange resultsToReplace, + std::optional inlineLoc, bool shouldCloneInlinedRegion, CallOpInterface call = {}) { // We expect the region to have at least one block. if (src->empty()) @@ -398,53 +396,54 @@ inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock, } // Call into the main region inliner function. - return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, mapper, - resultsToReplace, resultsToReplace.getTypes(), - inlineLoc, shouldCloneInlinedRegion, call); + return inlineRegionImpl(interface, cloneCallback, src, inlineBlock, + inlinePoint, mapper, resultsToReplace, + resultsToReplace.getTypes(), inlineLoc, + shouldCloneInlinedRegion, call); } -LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, - Operation *inlinePoint, IRMapping &mapper, - ValueRange resultsToReplace, - TypeRange regionResultTypes, - std::optional inlineLoc, - bool shouldCloneInlinedRegion) { - return inlineRegion(interface, src, inlinePoint->getBlock(), +LogicalResult mlir::inlineRegion( + InlinerInterface &interface, + function_ref cloneCallback, Region *src, + Operation *inlinePoint, IRMapping &mapper, ValueRange resultsToReplace, + TypeRange regionResultTypes, std::optional inlineLoc, + bool shouldCloneInlinedRegion) { + return inlineRegion(interface, cloneCallback, src, inlinePoint->getBlock(), ++inlinePoint->getIterator(), mapper, resultsToReplace, regionResultTypes, inlineLoc, shouldCloneInlinedRegion); } -LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, - Block *inlineBlock, - Block::iterator inlinePoint, IRMapping &mapper, - ValueRange resultsToReplace, - TypeRange regionResultTypes, - std::optional inlineLoc, - bool shouldCloneInlinedRegion) { - return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, mapper, - resultsToReplace, regionResultTypes, inlineLoc, - shouldCloneInlinedRegion); + +LogicalResult mlir::inlineRegion( + InlinerInterface &interface, + function_ref cloneCallback, Region *src, + Block *inlineBlock, Block::iterator inlinePoint, IRMapping &mapper, + ValueRange resultsToReplace, TypeRange regionResultTypes, + std::optional inlineLoc, bool shouldCloneInlinedRegion) { + return inlineRegionImpl( + interface, cloneCallback, src, inlineBlock, inlinePoint, mapper, + resultsToReplace, regionResultTypes, inlineLoc, shouldCloneInlinedRegion); } -LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, - Operation *inlinePoint, - ValueRange inlinedOperands, - ValueRange resultsToReplace, - std::optional inlineLoc, - bool shouldCloneInlinedRegion) { - return inlineRegion(interface, src, inlinePoint->getBlock(), +LogicalResult mlir::inlineRegion( + InlinerInterface &interface, + function_ref cloneCallback, Region *src, + Operation *inlinePoint, ValueRange inlinedOperands, + ValueRange resultsToReplace, std::optional inlineLoc, + bool shouldCloneInlinedRegion) { + return inlineRegion(interface, cloneCallback, src, inlinePoint->getBlock(), ++inlinePoint->getIterator(), inlinedOperands, resultsToReplace, inlineLoc, shouldCloneInlinedRegion); } -LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, - Block *inlineBlock, - Block::iterator inlinePoint, - ValueRange inlinedOperands, - ValueRange resultsToReplace, - std::optional inlineLoc, - bool shouldCloneInlinedRegion) { - return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, - inlinedOperands, resultsToReplace, inlineLoc, - shouldCloneInlinedRegion); + +LogicalResult mlir::inlineRegion( + InlinerInterface &interface, + function_ref cloneCallback, Region *src, + Block *inlineBlock, Block::iterator inlinePoint, ValueRange inlinedOperands, + ValueRange resultsToReplace, std::optional inlineLoc, + bool shouldCloneInlinedRegion) { + return inlineRegionImpl(interface, cloneCallback, src, inlineBlock, + inlinePoint, inlinedOperands, resultsToReplace, + inlineLoc, shouldCloneInlinedRegion); } /// Utility function used to generate a cast operation from the given interface, @@ -475,10 +474,11 @@ static Value materializeConversion(const DialectInlinerInterface *interface, /// failure, no changes are made to the module. 'shouldCloneInlinedRegion' /// corresponds to whether the source region should be cloned into the 'call' or /// spliced directly. -LogicalResult mlir::inlineCall(InlinerInterface &interface, - CallOpInterface call, - CallableOpInterface callable, Region *src, - bool shouldCloneInlinedRegion) { +LogicalResult +mlir::inlineCall(InlinerInterface &interface, + function_ref cloneCallback, + CallOpInterface call, CallableOpInterface callable, + Region *src, bool shouldCloneInlinedRegion) { // We expect the region to have at least one block. if (src->empty()) return failure(); @@ -552,7 +552,7 @@ LogicalResult mlir::inlineCall(InlinerInterface &interface, return cleanupState(); // Attempt to inline the call. - if (failed(inlineRegionImpl(interface, src, call->getBlock(), + if (failed(inlineRegionImpl(interface, cloneCallback, src, call->getBlock(), ++call->getIterator(), mapper, callResults, callableResultTypes, call.getLoc(), shouldCloneInlinedRegion, call))) diff --git a/mlir/test/Transforms/test-inlining-callback.mlir b/mlir/test/Transforms/test-inlining-callback.mlir new file mode 100644 index 000000000000..c012c31e7e49 --- /dev/null +++ b/mlir/test/Transforms/test-inlining-callback.mlir @@ -0,0 +1,24 @@ +// RUN: mlir-opt -allow-unregistered-dialect %s -test-inline-callback | FileCheck %s + +// Test inlining with multiple blocks and scf.execute_region transformation +// CHECK-LABEL: func @test_inline_multiple_blocks +func.func @test_inline_multiple_blocks(%arg0: i32) -> i32 { + // CHECK: %[[RES:.*]] = scf.execute_region -> i32 + // CHECK-NEXT: %[[ADD1:.*]] = arith.addi %arg0, %arg0 + // CHECK-NEXT: cf.br ^bb1(%[[ADD1]] : i32) + // CHECK: ^bb1(%[[ARG:.*]]: i32): + // CHECK-NEXT: %[[ADD2:.*]] = arith.addi %[[ARG]], %[[ARG]] + // CHECK-NEXT: scf.yield %[[ADD2]] + // CHECK: return %[[RES]] + %fn = "test.functional_region_op"() ({ + ^bb0(%a : i32): + %b = arith.addi %a, %a : i32 + cf.br ^bb1(%b: i32) + ^bb1(%c: i32): + %d = arith.addi %c, %c : i32 + "test.return"(%d) : (i32) -> () + }) : () -> ((i32) -> i32) + + %0 = call_indirect %fn(%arg0) : (i32) -> i32 + return %0 : i32 +} diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt index c053fd4b2047..76041cd6cd79 100644 --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -29,6 +29,7 @@ add_mlir_library(MLIRTestTransforms TestConstantFold.cpp TestControlFlowSink.cpp TestInlining.cpp + TestInliningCallback.cpp TestMakeIsolatedFromAbove.cpp TestTransformsOps.cpp ${MLIRTestTransformsPDLSrc} diff --git a/mlir/test/lib/Transforms/TestInlining.cpp b/mlir/test/lib/Transforms/TestInlining.cpp index 223cc78dd1e2..ae904a92a5d6 100644 --- a/mlir/test/lib/Transforms/TestInlining.cpp +++ b/mlir/test/lib/Transforms/TestInlining.cpp @@ -18,6 +18,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/IRMapping.h" #include "mlir/Pass/Pass.h" +#include "mlir/Transforms/Inliner.h" #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/StringSet.h" @@ -25,8 +26,9 @@ using namespace mlir; using namespace test; namespace { -struct Inliner : public PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(Inliner) +struct InlinerTest + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InlinerTest) StringRef getArgument() const final { return "test-inline"; } StringRef getDescription() const final { @@ -34,6 +36,8 @@ struct Inliner : public PassWrapper> { } void runOnOperation() override { + InlinerConfig config; + auto function = getOperation(); // Collect each of the direct function calls within the module. @@ -54,8 +58,8 @@ struct Inliner : public PassWrapper> { // Inline the functional region operation, but only clone the internal // region if there is more than one use. if (failed(inlineRegion( - interface, &callee.getBody(), caller, caller.getArgOperands(), - caller.getResults(), caller.getLoc(), + interface, config.getCloneCallback(), &callee.getBody(), caller, + caller.getArgOperands(), caller.getResults(), caller.getLoc(), /*shouldCloneInlinedRegion=*/!callee.getResult().hasOneUse()))) continue; @@ -71,6 +75,6 @@ struct Inliner : public PassWrapper> { namespace mlir { namespace test { -void registerInliner() { PassRegistration(); } +void registerInliner() { PassRegistration(); } } // namespace test } // namespace mlir diff --git a/mlir/test/lib/Transforms/TestInliningCallback.cpp b/mlir/test/lib/Transforms/TestInliningCallback.cpp new file mode 100644 index 000000000000..012d62b7b1b4 --- /dev/null +++ b/mlir/test/lib/Transforms/TestInliningCallback.cpp @@ -0,0 +1,151 @@ +//===- TestInliningCallback.cpp - Pass to inline calls in the test dialect +//--------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// This file implements a pass to test inlining callbacks including +// canHandleMultipleBlocks and doClone. +//===----------------------------------------------------------------------===// + +#include "TestDialect.h" +#include "TestOps.h" +#include "mlir/Analysis/CallGraph.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/Inliner.h" +#include "mlir/Transforms/InliningUtils.h" +#include "llvm/ADT/StringSet.h" + +using namespace mlir; +using namespace test; + +namespace { +struct InlinerCallback + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InlinerCallback) + + StringRef getArgument() const final { return "test-inline-callback"; } + StringRef getDescription() const final { + return "Test inlining region calls with call back functions"; + } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + static LogicalResult runPipelineHelper(Pass &pass, OpPassManager &pipeline, + Operation *op) { + return mlir::cast(pass).runPipeline(pipeline, op); + } + + // Customize the implementation of Inliner::doClone + // Wrap the callee into scf.execute_region operation + static void testDoClone(OpBuilder &builder, Region *src, Block *inlineBlock, + Block *postInsertBlock, IRMapping &mapper, + bool shouldCloneInlinedRegion) { + // Create a new scf.execute_region operation + mlir::Operation &call = inlineBlock->back(); + builder.setInsertionPointAfter(&call); + + auto executeRegionOp = builder.create( + call.getLoc(), call.getResultTypes()); + mlir::Region ®ion = executeRegionOp.getRegion(); + + // Move the inlined blocks into the region + src->cloneInto(®ion, mapper); + + // Split block before scf operation. + Block *continueBlock = + inlineBlock->splitBlock(executeRegionOp.getOperation()); + + // Replace all test.return with scf.yield + for (mlir::Block &block : region) { + + for (mlir::Operation &op : llvm::make_early_inc_range(block)) { + if (test::TestReturnOp returnOp = + llvm::dyn_cast(&op)) { + mlir::OpBuilder returnBuilder(returnOp); + returnBuilder.create(returnOp.getLoc(), + returnOp.getOperands()); + returnOp.erase(); + } + } + } + + // Add test.return after scf.execute_region + builder.setInsertionPointAfter(executeRegionOp); + builder.create(executeRegionOp.getLoc(), + executeRegionOp.getResults()); + } + + void runOnOperation() override { + InlinerConfig config; + CallGraph &cg = getAnalysis(); + + func::FuncOp function = getOperation(); + + // By default, assume that any inlining is profitable. + auto profitabilityCb = [&](const mlir::Inliner::ResolvedCall &call) { + return true; + }; + + // Set the clone callback in the config + config.setCloneCallback([](OpBuilder &builder, Region *src, + Block *inlineBlock, Block *postInsertBlock, + IRMapping &mapper, + bool shouldCloneInlinedRegion) { + return testDoClone(builder, src, inlineBlock, postInsertBlock, mapper, + shouldCloneInlinedRegion); + }); + + // Set canHandleMultipleBlocks to true in the config + config.setCanHandleMultipleBlocks(); + + // Get an instance of the inliner. + Inliner inliner(function, cg, *this, getAnalysisManager(), + runPipelineHelper, config, profitabilityCb); + + // Collect each of the direct function calls within the module. + SmallVector callers; + function.walk( + [&](func::CallIndirectOp caller) { callers.push_back(caller); }); + + // Build the inliner interface. + InlinerInterface interface(&getContext()); + + // Try to inline each of the call operations. + for (auto caller : callers) { + auto callee = dyn_cast_or_null( + caller.getCallee().getDefiningOp()); + if (!callee) + continue; + + // Inline the functional region operation, but only clone the internal + // region if there is more than one use. + if (failed(inlineRegion( + interface, config.getCloneCallback(), &callee.getBody(), caller, + caller.getArgOperands(), caller.getResults(), caller.getLoc(), + /*shouldCloneInlinedRegion=*/!callee.getResult().hasOneUse()))) + continue; + + // If the inlining was successful then erase the call and callee if + // possible. + caller.erase(); + if (callee.use_empty()) + callee.erase(); + } + } +}; +} // namespace + +namespace mlir { +namespace test { +void registerInlinerCallback() { PassRegistration(); } +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index d06ff8070e7c..ca4706e96787 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -73,6 +73,7 @@ void registerCommutativityUtils(); void registerConvertCallOpPass(); void registerConvertFuncOpPass(); void registerInliner(); +void registerInlinerCallback(); void registerMemRefBoundCheck(); void registerPatternsTestPass(); void registerSimpleParametricTilingPass(); @@ -215,6 +216,7 @@ void registerTestPasses() { mlir::test::registerConvertCallOpPass(); mlir::test::registerConvertFuncOpPass(); mlir::test::registerInliner(); + mlir::test::registerInlinerCallback(); mlir::test::registerMemRefBoundCheck(); mlir::test::registerPatternsTestPass(); mlir::test::registerSimpleParametricTilingPass();