Files
clang-p2996/mlir/test/lib/Transforms/TestMakeIsolatedFromAbove.cpp
Mahesh Ravishankar da784e77da [mlir] Add a utility function to make a region isolated from above.
The utility functions takes a region and makes it isolated from above
by appending to the entry block arguments that represent the captured
values and replacing all uses of the captured values within the region
with the newly added arguments. The captures values are returned.

The utility function also takes an optional callback that allows
cloning operations that define the captured values into the region
during the process of making it isolated from above. The cloned value
is no longer a captured values. The operands of the operation are then
captured values. This is done transitively allow cloning of a DAG of
operations into the region based on the callback.

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D148684
2023-04-20 16:40:25 +00:00

157 lines
5.3 KiB
C++

//===- TestMakeIsolatedFromAbove.cpp - Test makeIsolatedFromAbove method -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/RegionUtils.h"
using namespace mlir;
/// Helper function to call the `makeRegionIsolatedFromAbove` to convert
/// `test.one_region_op` to `test.isolated_one_region_op`.
static LogicalResult
makeIsolatedFromAboveImpl(RewriterBase &rewriter,
test::OneRegionWithOperandsOp regionOp,
llvm::function_ref<bool(Operation *)> callBack) {
Region &region = regionOp.getRegion();
SmallVector<Value> capturedValues =
makeRegionIsolatedFromAbove(rewriter, region, callBack);
SmallVector<Value> operands = regionOp.getOperands();
operands.append(capturedValues);
auto isolatedRegionOp =
rewriter.create<test::IsolatedOneRegionOp>(regionOp.getLoc(), operands);
rewriter.inlineRegionBefore(region, isolatedRegionOp.getRegion(),
isolatedRegionOp.getRegion().begin());
rewriter.eraseOp(regionOp);
return success();
}
namespace {
/// Simple test for making region isolated from above without cloning any
/// operations.
struct SimpleMakeIsolatedFromAbove
: OpRewritePattern<test::OneRegionWithOperandsOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(test::OneRegionWithOperandsOp regionOp,
PatternRewriter &rewriter) const override {
return makeIsolatedFromAboveImpl(rewriter, regionOp,
[](Operation *) { return false; });
}
};
/// Test for making region isolated from above while clong operations
/// with no operands.
struct MakeIsolatedFromAboveAndCloneOpsWithNoOperands
: OpRewritePattern<test::OneRegionWithOperandsOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(test::OneRegionWithOperandsOp regionOp,
PatternRewriter &rewriter) const override {
return makeIsolatedFromAboveImpl(rewriter, regionOp, [](Operation *op) {
return op->getNumOperands() == 0;
});
}
};
/// Test for making region isolated from above while clong operations
/// with no operands.
struct MakeIsolatedFromAboveAndCloneOpsWithOperands
: OpRewritePattern<test::OneRegionWithOperandsOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(test::OneRegionWithOperandsOp regionOp,
PatternRewriter &rewriter) const override {
return makeIsolatedFromAboveImpl(rewriter, regionOp,
[](Operation *op) { return true; });
}
};
/// Test pass for testing the `makeIsolatedFromAbove` function.
struct TestMakeIsolatedFromAbovePass
: public PassWrapper<TestMakeIsolatedFromAbovePass,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMakeIsolatedFromAbovePass)
TestMakeIsolatedFromAbovePass() = default;
TestMakeIsolatedFromAbovePass(const TestMakeIsolatedFromAbovePass &pass)
: PassWrapper(pass) {}
StringRef getArgument() const final {
return "test-make-isolated-from-above";
}
StringRef getDescription() const final {
return "Test making a region isolated from above";
}
Option<bool> simple{
*this, "simple",
llvm::cl::desc("Test simple case with no cloning of operations"),
llvm::cl::init(false)};
Option<bool> cloneOpsWithNoOperands{
*this, "clone-ops-with-no-operands",
llvm::cl::desc("Test case with cloning of operations with no operands"),
llvm::cl::init(false)};
Option<bool> cloneOpsWithOperands{
*this, "clone-ops-with-operands",
llvm::cl::desc("Test case with cloning of operations with no operands"),
llvm::cl::init(false)};
void runOnOperation() override;
};
} // namespace
void TestMakeIsolatedFromAbovePass::runOnOperation() {
MLIRContext *context = &getContext();
func::FuncOp funcOp = getOperation();
if (simple) {
RewritePatternSet patterns(context);
patterns.insert<SimpleMakeIsolatedFromAbove>(context);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
return;
}
if (cloneOpsWithNoOperands) {
RewritePatternSet patterns(context);
patterns.insert<MakeIsolatedFromAboveAndCloneOpsWithNoOperands>(context);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
return;
}
if (cloneOpsWithOperands) {
RewritePatternSet patterns(context);
patterns.insert<MakeIsolatedFromAboveAndCloneOpsWithOperands>(context);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
return;
}
}
namespace mlir {
namespace test {
void registerTestMakeIsolatedFromAbovePass() {
PassRegistration<TestMakeIsolatedFromAbovePass>();
}
} // namespace test
} // namespace mlir