diff --git a/mlir/include/mlir/Transforms/Inliner.h b/mlir/include/mlir/Transforms/Inliner.h index ec77319d6ac8..506b4455af64 100644 --- a/mlir/include/mlir/Transforms/Inliner.h +++ b/mlir/include/mlir/Transforms/Inliner.h @@ -27,6 +27,11 @@ class InlinerConfig { public: using DefaultPipelineTy = std::function; using OpPipelinesTy = llvm::StringMap; + using CloneCallbackSigTy = void(OpBuilder &builder, Region *src, + Block *inlineBlock, Block *postInsertBlock, + IRMapping &mapper, + bool shouldCloneInlinedRegion); + using CloneCallbackTy = std::function; InlinerConfig() = default; InlinerConfig(DefaultPipelineTy defaultPipeline, @@ -39,6 +44,9 @@ public: } const OpPipelinesTy &getOpPipelines() const { return opPipelines; } unsigned getMaxInliningIterations() const { return maxInliningIterations; } + const CloneCallbackTy &getCloneCallback() const { return cloneCallback; } + bool getCanHandleMultipleBlocks() const { return canHandleMultipleBlocks; } + void setDefaultPipeline(DefaultPipelineTy pipeline) { defaultPipeline = std::move(pipeline); } @@ -46,6 +54,12 @@ public: opPipelines = std::move(pipelines); } void setMaxInliningIterations(unsigned max) { maxInliningIterations = max; } + void setCloneCallback(CloneCallbackTy callback) { + cloneCallback = std::move(callback); + } + void setCanHandleMultipleBlocks(bool value = true) { + canHandleMultipleBlocks = value; + } private: /// An optional function that constructs an optimization pipeline for @@ -60,6 +74,28 @@ private: /// For SCC-based inlining algorithms, specifies maximum number of iterations /// when inlining within an SCC. unsigned maxInliningIterations{0}; + /// Callback for cloning operations during inlining + CloneCallbackTy cloneCallback = [](OpBuilder &builder, Region *src, + Block *inlineBlock, Block *postInsertBlock, + IRMapping &mapper, + bool shouldCloneInlinedRegion) { + // Check to see if the region is being cloned, or moved inline. In + // either case, move the new blocks after the 'insertBlock' to improve + // IR readability. + Region *insertRegion = inlineBlock->getParent(); + if (shouldCloneInlinedRegion) + src->cloneInto(insertRegion, postInsertBlock->getIterator(), mapper); + else + insertRegion->getBlocks().splice(postInsertBlock->getIterator(), + src->getBlocks(), src->begin(), + src->end()); + }; + /// Determine if the inliner can inline a function containing multiple + /// blocks into a region that requires a single block. By default, it is + /// not allowed. If it is true, cloneCallback should perform the extra + /// transformation. see the example in + /// mlir/test/lib/Transforms/TestInliningCallback.cpp + bool canHandleMultipleBlocks{false}; }; /// This is an implementation of the inliner diff --git a/mlir/include/mlir/Transforms/InliningUtils.h b/mlir/include/mlir/Transforms/InliningUtils.h index becfe9b047ef..552030983d72 100644 --- a/mlir/include/mlir/Transforms/InliningUtils.h +++ b/mlir/include/mlir/Transforms/InliningUtils.h @@ -18,6 +18,7 @@ #include "mlir/IR/Location.h" #include "mlir/IR/Region.h" #include "mlir/IR/ValueRange.h" +#include "mlir/Transforms/Inliner.h" #include namespace mlir { @@ -253,33 +254,39 @@ public: /// provided, will be used to update the inlined operations' location /// information. 'shouldCloneInlinedRegion' corresponds to whether the source /// region should be cloned into the 'inlinePoint' or spliced directly. -LogicalResult inlineRegion(InlinerInterface &interface, Region *src, - Operation *inlinePoint, IRMapping &mapper, - ValueRange resultsToReplace, - TypeRange regionResultTypes, - std::optional inlineLoc = std::nullopt, - bool shouldCloneInlinedRegion = true); -LogicalResult inlineRegion(InlinerInterface &interface, Region *src, - Block *inlineBlock, Block::iterator inlinePoint, - IRMapping &mapper, ValueRange resultsToReplace, - TypeRange regionResultTypes, - std::optional inlineLoc = std::nullopt, - bool shouldCloneInlinedRegion = true); +LogicalResult +inlineRegion(InlinerInterface &interface, + function_ref cloneCallback, + Region *src, Operation *inlinePoint, IRMapping &mapper, + ValueRange resultsToReplace, TypeRange regionResultTypes, + std::optional inlineLoc = std::nullopt, + bool shouldCloneInlinedRegion = true); +LogicalResult +inlineRegion(InlinerInterface &interface, + function_ref cloneCallback, + Region *src, Block *inlineBlock, Block::iterator inlinePoint, + IRMapping &mapper, ValueRange resultsToReplace, + TypeRange regionResultTypes, + std::optional inlineLoc = std::nullopt, + bool shouldCloneInlinedRegion = true); /// This function is an overload of the above 'inlineRegion' that allows for /// providing the set of operands ('inlinedOperands') that should be used /// in-favor of the region arguments when inlining. -LogicalResult inlineRegion(InlinerInterface &interface, Region *src, - Operation *inlinePoint, ValueRange inlinedOperands, - ValueRange resultsToReplace, - std::optional inlineLoc = std::nullopt, - bool shouldCloneInlinedRegion = true); -LogicalResult inlineRegion(InlinerInterface &interface, Region *src, - Block *inlineBlock, Block::iterator inlinePoint, - ValueRange inlinedOperands, - ValueRange resultsToReplace, - std::optional inlineLoc = std::nullopt, - bool shouldCloneInlinedRegion = true); +LogicalResult +inlineRegion(InlinerInterface &interface, + function_ref cloneCallback, + Region *src, Operation *inlinePoint, ValueRange inlinedOperands, + ValueRange resultsToReplace, + std::optional inlineLoc = std::nullopt, + bool shouldCloneInlinedRegion = true); +LogicalResult +inlineRegion(InlinerInterface &interface, + function_ref cloneCallback, + Region *src, Block *inlineBlock, Block::iterator inlinePoint, + ValueRange inlinedOperands, ValueRange resultsToReplace, + std::optional inlineLoc = std::nullopt, + bool shouldCloneInlinedRegion = true); /// This function inlines a given region, 'src', of a callable operation, /// 'callable', into the location defined by the given call operation. This @@ -287,9 +294,11 @@ LogicalResult inlineRegion(InlinerInterface &interface, Region *src, /// failure, no changes are made to the module. 'shouldCloneInlinedRegion' /// corresponds to whether the source region should be cloned into the 'call' or /// spliced directly. -LogicalResult inlineCall(InlinerInterface &interface, CallOpInterface call, - CallableOpInterface callable, Region *src, - bool shouldCloneInlinedRegion = true); +LogicalResult +inlineCall(InlinerInterface &interface, + function_ref cloneCallback, + CallOpInterface call, CallableOpInterface callable, Region *src, + bool shouldCloneInlinedRegion = true); } // namespace mlir diff --git a/mlir/lib/Transforms/Utils/Inliner.cpp b/mlir/lib/Transforms/Utils/Inliner.cpp index f511504594cf..54b5c788a352 100644 --- a/mlir/lib/Transforms/Utils/Inliner.cpp +++ b/mlir/lib/Transforms/Utils/Inliner.cpp @@ -652,7 +652,7 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface, bool inlineInPlace = useList.hasOneUseAndDiscardable(it.targetNode); LogicalResult inlineResult = - inlineCall(inlinerIface, call, + inlineCall(inlinerIface, inliner.config.getCloneCallback(), call, cast(targetRegion->getParentOp()), targetRegion, /*shouldCloneInlinedRegion=*/!inlineInPlace); if (failed(inlineResult)) { @@ -730,19 +730,22 @@ bool Inliner::Impl::shouldInline(ResolvedCall &resolvedCall) { // Don't allow inlining if the callee has multiple blocks (unstructured // control flow) but we cannot be sure that the caller region supports that. - bool calleeHasMultipleBlocks = - llvm::hasNItemsOrMore(*callableRegion, /*N=*/2); - // If both parent ops have the same type, it is safe to inline. Otherwise, - // decide based on whether the op has the SingleBlock trait or not. - // Note: This check does currently not account for SizedRegion/MaxSizedRegion. - auto callerRegionSupportsMultipleBlocks = [&]() { - return callableRegion->getParentOp()->getName() == - resolvedCall.call->getParentOp()->getName() || - !resolvedCall.call->getParentOp() - ->mightHaveTrait(); - }; - if (calleeHasMultipleBlocks && !callerRegionSupportsMultipleBlocks()) - return false; + if (!inliner.config.getCanHandleMultipleBlocks()) { + bool calleeHasMultipleBlocks = + llvm::hasNItemsOrMore(*callableRegion, /*N=*/2); + // If both parent ops have the same type, it is safe to inline. Otherwise, + // decide based on whether the op has the SingleBlock trait or not. + // Note: This check does currently not account for + // SizedRegion/MaxSizedRegion. + auto callerRegionSupportsMultipleBlocks = [&]() { + return callableRegion->getParentOp()->getName() == + resolvedCall.call->getParentOp()->getName() || + !resolvedCall.call->getParentOp() + ->mightHaveTrait(); + }; + if (calleeHasMultipleBlocks && !callerRegionSupportsMultipleBlocks()) + return false; + } if (!inliner.isProfitableToInline(resolvedCall)) return false; diff --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp index e113389b26ae..3dd95d284571 100644 --- a/mlir/lib/Transforms/Utils/InliningUtils.cpp +++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Transforms/InliningUtils.h" +#include "mlir/Transforms/Inliner.h" #include "mlir/IR/Builders.h" #include "mlir/IR/IRMapping.h" @@ -266,10 +267,11 @@ static void handleResultImpl(InlinerInterface &interface, OpBuilder &builder, } static LogicalResult -inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock, - Block::iterator inlinePoint, IRMapping &mapper, - ValueRange resultsToReplace, TypeRange regionResultTypes, - std::optional inlineLoc, +inlineRegionImpl(InlinerInterface &interface, + function_ref cloneCallback, + Region *src, Block *inlineBlock, Block::iterator inlinePoint, + IRMapping &mapper, ValueRange resultsToReplace, + TypeRange regionResultTypes, std::optional inlineLoc, bool shouldCloneInlinedRegion, CallOpInterface call = {}) { assert(resultsToReplace.size() == regionResultTypes.size()); // We expect the region to have at least one block. @@ -296,16 +298,10 @@ inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock, if (call && callable) handleArgumentImpl(interface, builder, call, callable, mapper); - // Check to see if the region is being cloned, or moved inline. In either - // case, move the new blocks after the 'insertBlock' to improve IR - // readability. + // Clone the callee's source into the caller. Block *postInsertBlock = inlineBlock->splitBlock(inlinePoint); - if (shouldCloneInlinedRegion) - src->cloneInto(insertRegion, postInsertBlock->getIterator(), mapper); - else - insertRegion->getBlocks().splice(postInsertBlock->getIterator(), - src->getBlocks(), src->begin(), - src->end()); + cloneCallback(builder, src, inlineBlock, postInsertBlock, mapper, + shouldCloneInlinedRegion); // Get the range of newly inserted blocks. auto newBlocks = llvm::make_range(std::next(inlineBlock->getIterator()), @@ -374,9 +370,11 @@ inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock, } static LogicalResult -inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock, - Block::iterator inlinePoint, ValueRange inlinedOperands, - ValueRange resultsToReplace, std::optional inlineLoc, +inlineRegionImpl(InlinerInterface &interface, + function_ref cloneCallback, + Region *src, Block *inlineBlock, Block::iterator inlinePoint, + ValueRange inlinedOperands, ValueRange resultsToReplace, + std::optional inlineLoc, bool shouldCloneInlinedRegion, CallOpInterface call = {}) { // We expect the region to have at least one block. if (src->empty()) @@ -398,53 +396,54 @@ inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock, } // Call into the main region inliner function. - return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, mapper, - resultsToReplace, resultsToReplace.getTypes(), - inlineLoc, shouldCloneInlinedRegion, call); + return inlineRegionImpl(interface, cloneCallback, src, inlineBlock, + inlinePoint, mapper, resultsToReplace, + resultsToReplace.getTypes(), inlineLoc, + shouldCloneInlinedRegion, call); } -LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, - Operation *inlinePoint, IRMapping &mapper, - ValueRange resultsToReplace, - TypeRange regionResultTypes, - std::optional inlineLoc, - bool shouldCloneInlinedRegion) { - return inlineRegion(interface, src, inlinePoint->getBlock(), +LogicalResult mlir::inlineRegion( + InlinerInterface &interface, + function_ref cloneCallback, Region *src, + Operation *inlinePoint, IRMapping &mapper, ValueRange resultsToReplace, + TypeRange regionResultTypes, std::optional inlineLoc, + bool shouldCloneInlinedRegion) { + return inlineRegion(interface, cloneCallback, src, inlinePoint->getBlock(), ++inlinePoint->getIterator(), mapper, resultsToReplace, regionResultTypes, inlineLoc, shouldCloneInlinedRegion); } -LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, - Block *inlineBlock, - Block::iterator inlinePoint, IRMapping &mapper, - ValueRange resultsToReplace, - TypeRange regionResultTypes, - std::optional inlineLoc, - bool shouldCloneInlinedRegion) { - return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, mapper, - resultsToReplace, regionResultTypes, inlineLoc, - shouldCloneInlinedRegion); + +LogicalResult mlir::inlineRegion( + InlinerInterface &interface, + function_ref cloneCallback, Region *src, + Block *inlineBlock, Block::iterator inlinePoint, IRMapping &mapper, + ValueRange resultsToReplace, TypeRange regionResultTypes, + std::optional inlineLoc, bool shouldCloneInlinedRegion) { + return inlineRegionImpl( + interface, cloneCallback, src, inlineBlock, inlinePoint, mapper, + resultsToReplace, regionResultTypes, inlineLoc, shouldCloneInlinedRegion); } -LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, - Operation *inlinePoint, - ValueRange inlinedOperands, - ValueRange resultsToReplace, - std::optional inlineLoc, - bool shouldCloneInlinedRegion) { - return inlineRegion(interface, src, inlinePoint->getBlock(), +LogicalResult mlir::inlineRegion( + InlinerInterface &interface, + function_ref cloneCallback, Region *src, + Operation *inlinePoint, ValueRange inlinedOperands, + ValueRange resultsToReplace, std::optional inlineLoc, + bool shouldCloneInlinedRegion) { + return inlineRegion(interface, cloneCallback, src, inlinePoint->getBlock(), ++inlinePoint->getIterator(), inlinedOperands, resultsToReplace, inlineLoc, shouldCloneInlinedRegion); } -LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, - Block *inlineBlock, - Block::iterator inlinePoint, - ValueRange inlinedOperands, - ValueRange resultsToReplace, - std::optional inlineLoc, - bool shouldCloneInlinedRegion) { - return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, - inlinedOperands, resultsToReplace, inlineLoc, - shouldCloneInlinedRegion); + +LogicalResult mlir::inlineRegion( + InlinerInterface &interface, + function_ref cloneCallback, Region *src, + Block *inlineBlock, Block::iterator inlinePoint, ValueRange inlinedOperands, + ValueRange resultsToReplace, std::optional inlineLoc, + bool shouldCloneInlinedRegion) { + return inlineRegionImpl(interface, cloneCallback, src, inlineBlock, + inlinePoint, inlinedOperands, resultsToReplace, + inlineLoc, shouldCloneInlinedRegion); } /// Utility function used to generate a cast operation from the given interface, @@ -475,10 +474,11 @@ static Value materializeConversion(const DialectInlinerInterface *interface, /// failure, no changes are made to the module. 'shouldCloneInlinedRegion' /// corresponds to whether the source region should be cloned into the 'call' or /// spliced directly. -LogicalResult mlir::inlineCall(InlinerInterface &interface, - CallOpInterface call, - CallableOpInterface callable, Region *src, - bool shouldCloneInlinedRegion) { +LogicalResult +mlir::inlineCall(InlinerInterface &interface, + function_ref cloneCallback, + CallOpInterface call, CallableOpInterface callable, + Region *src, bool shouldCloneInlinedRegion) { // We expect the region to have at least one block. if (src->empty()) return failure(); @@ -552,7 +552,7 @@ LogicalResult mlir::inlineCall(InlinerInterface &interface, return cleanupState(); // Attempt to inline the call. - if (failed(inlineRegionImpl(interface, src, call->getBlock(), + if (failed(inlineRegionImpl(interface, cloneCallback, src, call->getBlock(), ++call->getIterator(), mapper, callResults, callableResultTypes, call.getLoc(), shouldCloneInlinedRegion, call))) diff --git a/mlir/test/Transforms/test-inlining-callback.mlir b/mlir/test/Transforms/test-inlining-callback.mlir new file mode 100644 index 000000000000..c012c31e7e49 --- /dev/null +++ b/mlir/test/Transforms/test-inlining-callback.mlir @@ -0,0 +1,24 @@ +// RUN: mlir-opt -allow-unregistered-dialect %s -test-inline-callback | FileCheck %s + +// Test inlining with multiple blocks and scf.execute_region transformation +// CHECK-LABEL: func @test_inline_multiple_blocks +func.func @test_inline_multiple_blocks(%arg0: i32) -> i32 { + // CHECK: %[[RES:.*]] = scf.execute_region -> i32 + // CHECK-NEXT: %[[ADD1:.*]] = arith.addi %arg0, %arg0 + // CHECK-NEXT: cf.br ^bb1(%[[ADD1]] : i32) + // CHECK: ^bb1(%[[ARG:.*]]: i32): + // CHECK-NEXT: %[[ADD2:.*]] = arith.addi %[[ARG]], %[[ARG]] + // CHECK-NEXT: scf.yield %[[ADD2]] + // CHECK: return %[[RES]] + %fn = "test.functional_region_op"() ({ + ^bb0(%a : i32): + %b = arith.addi %a, %a : i32 + cf.br ^bb1(%b: i32) + ^bb1(%c: i32): + %d = arith.addi %c, %c : i32 + "test.return"(%d) : (i32) -> () + }) : () -> ((i32) -> i32) + + %0 = call_indirect %fn(%arg0) : (i32) -> i32 + return %0 : i32 +} diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt index c053fd4b2047..76041cd6cd79 100644 --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -29,6 +29,7 @@ add_mlir_library(MLIRTestTransforms TestConstantFold.cpp TestControlFlowSink.cpp TestInlining.cpp + TestInliningCallback.cpp TestMakeIsolatedFromAbove.cpp TestTransformsOps.cpp ${MLIRTestTransformsPDLSrc} diff --git a/mlir/test/lib/Transforms/TestInlining.cpp b/mlir/test/lib/Transforms/TestInlining.cpp index 223cc78dd1e2..ae904a92a5d6 100644 --- a/mlir/test/lib/Transforms/TestInlining.cpp +++ b/mlir/test/lib/Transforms/TestInlining.cpp @@ -18,6 +18,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/IRMapping.h" #include "mlir/Pass/Pass.h" +#include "mlir/Transforms/Inliner.h" #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/StringSet.h" @@ -25,8 +26,9 @@ using namespace mlir; using namespace test; namespace { -struct Inliner : public PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(Inliner) +struct InlinerTest + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InlinerTest) StringRef getArgument() const final { return "test-inline"; } StringRef getDescription() const final { @@ -34,6 +36,8 @@ struct Inliner : public PassWrapper> { } void runOnOperation() override { + InlinerConfig config; + auto function = getOperation(); // Collect each of the direct function calls within the module. @@ -54,8 +58,8 @@ struct Inliner : public PassWrapper> { // Inline the functional region operation, but only clone the internal // region if there is more than one use. if (failed(inlineRegion( - interface, &callee.getBody(), caller, caller.getArgOperands(), - caller.getResults(), caller.getLoc(), + interface, config.getCloneCallback(), &callee.getBody(), caller, + caller.getArgOperands(), caller.getResults(), caller.getLoc(), /*shouldCloneInlinedRegion=*/!callee.getResult().hasOneUse()))) continue; @@ -71,6 +75,6 @@ struct Inliner : public PassWrapper> { namespace mlir { namespace test { -void registerInliner() { PassRegistration(); } +void registerInliner() { PassRegistration(); } } // namespace test } // namespace mlir diff --git a/mlir/test/lib/Transforms/TestInliningCallback.cpp b/mlir/test/lib/Transforms/TestInliningCallback.cpp new file mode 100644 index 000000000000..012d62b7b1b4 --- /dev/null +++ b/mlir/test/lib/Transforms/TestInliningCallback.cpp @@ -0,0 +1,151 @@ +//===- TestInliningCallback.cpp - Pass to inline calls in the test dialect +//--------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// This file implements a pass to test inlining callbacks including +// canHandleMultipleBlocks and doClone. +//===----------------------------------------------------------------------===// + +#include "TestDialect.h" +#include "TestOps.h" +#include "mlir/Analysis/CallGraph.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/Inliner.h" +#include "mlir/Transforms/InliningUtils.h" +#include "llvm/ADT/StringSet.h" + +using namespace mlir; +using namespace test; + +namespace { +struct InlinerCallback + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InlinerCallback) + + StringRef getArgument() const final { return "test-inline-callback"; } + StringRef getDescription() const final { + return "Test inlining region calls with call back functions"; + } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + static LogicalResult runPipelineHelper(Pass &pass, OpPassManager &pipeline, + Operation *op) { + return mlir::cast(pass).runPipeline(pipeline, op); + } + + // Customize the implementation of Inliner::doClone + // Wrap the callee into scf.execute_region operation + static void testDoClone(OpBuilder &builder, Region *src, Block *inlineBlock, + Block *postInsertBlock, IRMapping &mapper, + bool shouldCloneInlinedRegion) { + // Create a new scf.execute_region operation + mlir::Operation &call = inlineBlock->back(); + builder.setInsertionPointAfter(&call); + + auto executeRegionOp = builder.create( + call.getLoc(), call.getResultTypes()); + mlir::Region ®ion = executeRegionOp.getRegion(); + + // Move the inlined blocks into the region + src->cloneInto(®ion, mapper); + + // Split block before scf operation. + Block *continueBlock = + inlineBlock->splitBlock(executeRegionOp.getOperation()); + + // Replace all test.return with scf.yield + for (mlir::Block &block : region) { + + for (mlir::Operation &op : llvm::make_early_inc_range(block)) { + if (test::TestReturnOp returnOp = + llvm::dyn_cast(&op)) { + mlir::OpBuilder returnBuilder(returnOp); + returnBuilder.create(returnOp.getLoc(), + returnOp.getOperands()); + returnOp.erase(); + } + } + } + + // Add test.return after scf.execute_region + builder.setInsertionPointAfter(executeRegionOp); + builder.create(executeRegionOp.getLoc(), + executeRegionOp.getResults()); + } + + void runOnOperation() override { + InlinerConfig config; + CallGraph &cg = getAnalysis(); + + func::FuncOp function = getOperation(); + + // By default, assume that any inlining is profitable. + auto profitabilityCb = [&](const mlir::Inliner::ResolvedCall &call) { + return true; + }; + + // Set the clone callback in the config + config.setCloneCallback([](OpBuilder &builder, Region *src, + Block *inlineBlock, Block *postInsertBlock, + IRMapping &mapper, + bool shouldCloneInlinedRegion) { + return testDoClone(builder, src, inlineBlock, postInsertBlock, mapper, + shouldCloneInlinedRegion); + }); + + // Set canHandleMultipleBlocks to true in the config + config.setCanHandleMultipleBlocks(); + + // Get an instance of the inliner. + Inliner inliner(function, cg, *this, getAnalysisManager(), + runPipelineHelper, config, profitabilityCb); + + // Collect each of the direct function calls within the module. + SmallVector callers; + function.walk( + [&](func::CallIndirectOp caller) { callers.push_back(caller); }); + + // Build the inliner interface. + InlinerInterface interface(&getContext()); + + // Try to inline each of the call operations. + for (auto caller : callers) { + auto callee = dyn_cast_or_null( + caller.getCallee().getDefiningOp()); + if (!callee) + continue; + + // Inline the functional region operation, but only clone the internal + // region if there is more than one use. + if (failed(inlineRegion( + interface, config.getCloneCallback(), &callee.getBody(), caller, + caller.getArgOperands(), caller.getResults(), caller.getLoc(), + /*shouldCloneInlinedRegion=*/!callee.getResult().hasOneUse()))) + continue; + + // If the inlining was successful then erase the call and callee if + // possible. + caller.erase(); + if (callee.use_empty()) + callee.erase(); + } + } +}; +} // namespace + +namespace mlir { +namespace test { +void registerInlinerCallback() { PassRegistration(); } +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index d06ff8070e7c..ca4706e96787 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -73,6 +73,7 @@ void registerCommutativityUtils(); void registerConvertCallOpPass(); void registerConvertFuncOpPass(); void registerInliner(); +void registerInlinerCallback(); void registerMemRefBoundCheck(); void registerPatternsTestPass(); void registerSimpleParametricTilingPass(); @@ -215,6 +216,7 @@ void registerTestPasses() { mlir::test::registerConvertCallOpPass(); mlir::test::registerConvertFuncOpPass(); mlir::test::registerInliner(); + mlir::test::registerInlinerCallback(); mlir::test::registerMemRefBoundCheck(); mlir::test::registerPatternsTestPass(); mlir::test::registerSimpleParametricTilingPass();