The utility functions takes a region and makes it isolated from above by appending to the entry block arguments that represent the captured values and replacing all uses of the captured values within the region with the newly added arguments. The captures values are returned. The utility function also takes an optional callback that allows cloning operations that define the captured values into the region during the process of making it isolated from above. The cloned value is no longer a captured values. The operands of the operation are then captured values. This is done transitively allow cloning of a DAG of operations into the region based on the callback. Reviewed By: jpienaar Differential Revision: https://reviews.llvm.org/D148684
157 lines
5.3 KiB
C++
157 lines
5.3 KiB
C++
//===- TestMakeIsolatedFromAbove.cpp - Test makeIsolatedFromAbove method -===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "TestDialect.h"
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
#include "mlir/Transforms/RegionUtils.h"
|
|
|
|
using namespace mlir;
|
|
|
|
/// Helper function to call the `makeRegionIsolatedFromAbove` to convert
|
|
/// `test.one_region_op` to `test.isolated_one_region_op`.
|
|
static LogicalResult
|
|
makeIsolatedFromAboveImpl(RewriterBase &rewriter,
|
|
test::OneRegionWithOperandsOp regionOp,
|
|
llvm::function_ref<bool(Operation *)> callBack) {
|
|
Region ®ion = regionOp.getRegion();
|
|
SmallVector<Value> capturedValues =
|
|
makeRegionIsolatedFromAbove(rewriter, region, callBack);
|
|
SmallVector<Value> operands = regionOp.getOperands();
|
|
operands.append(capturedValues);
|
|
auto isolatedRegionOp =
|
|
rewriter.create<test::IsolatedOneRegionOp>(regionOp.getLoc(), operands);
|
|
rewriter.inlineRegionBefore(region, isolatedRegionOp.getRegion(),
|
|
isolatedRegionOp.getRegion().begin());
|
|
rewriter.eraseOp(regionOp);
|
|
return success();
|
|
}
|
|
|
|
namespace {
|
|
|
|
/// Simple test for making region isolated from above without cloning any
|
|
/// operations.
|
|
struct SimpleMakeIsolatedFromAbove
|
|
: OpRewritePattern<test::OneRegionWithOperandsOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(test::OneRegionWithOperandsOp regionOp,
|
|
PatternRewriter &rewriter) const override {
|
|
return makeIsolatedFromAboveImpl(rewriter, regionOp,
|
|
[](Operation *) { return false; });
|
|
}
|
|
};
|
|
|
|
/// Test for making region isolated from above while clong operations
|
|
/// with no operands.
|
|
struct MakeIsolatedFromAboveAndCloneOpsWithNoOperands
|
|
: OpRewritePattern<test::OneRegionWithOperandsOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(test::OneRegionWithOperandsOp regionOp,
|
|
PatternRewriter &rewriter) const override {
|
|
return makeIsolatedFromAboveImpl(rewriter, regionOp, [](Operation *op) {
|
|
return op->getNumOperands() == 0;
|
|
});
|
|
}
|
|
};
|
|
|
|
/// Test for making region isolated from above while clong operations
|
|
/// with no operands.
|
|
struct MakeIsolatedFromAboveAndCloneOpsWithOperands
|
|
: OpRewritePattern<test::OneRegionWithOperandsOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(test::OneRegionWithOperandsOp regionOp,
|
|
PatternRewriter &rewriter) const override {
|
|
return makeIsolatedFromAboveImpl(rewriter, regionOp,
|
|
[](Operation *op) { return true; });
|
|
}
|
|
};
|
|
|
|
/// Test pass for testing the `makeIsolatedFromAbove` function.
|
|
struct TestMakeIsolatedFromAbovePass
|
|
: public PassWrapper<TestMakeIsolatedFromAbovePass,
|
|
OperationPass<func::FuncOp>> {
|
|
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMakeIsolatedFromAbovePass)
|
|
|
|
TestMakeIsolatedFromAbovePass() = default;
|
|
TestMakeIsolatedFromAbovePass(const TestMakeIsolatedFromAbovePass &pass)
|
|
: PassWrapper(pass) {}
|
|
|
|
StringRef getArgument() const final {
|
|
return "test-make-isolated-from-above";
|
|
}
|
|
|
|
StringRef getDescription() const final {
|
|
return "Test making a region isolated from above";
|
|
}
|
|
|
|
Option<bool> simple{
|
|
*this, "simple",
|
|
llvm::cl::desc("Test simple case with no cloning of operations"),
|
|
llvm::cl::init(false)};
|
|
|
|
Option<bool> cloneOpsWithNoOperands{
|
|
*this, "clone-ops-with-no-operands",
|
|
llvm::cl::desc("Test case with cloning of operations with no operands"),
|
|
llvm::cl::init(false)};
|
|
|
|
Option<bool> cloneOpsWithOperands{
|
|
*this, "clone-ops-with-operands",
|
|
llvm::cl::desc("Test case with cloning of operations with no operands"),
|
|
llvm::cl::init(false)};
|
|
|
|
void runOnOperation() override;
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void TestMakeIsolatedFromAbovePass::runOnOperation() {
|
|
MLIRContext *context = &getContext();
|
|
func::FuncOp funcOp = getOperation();
|
|
|
|
if (simple) {
|
|
RewritePatternSet patterns(context);
|
|
patterns.insert<SimpleMakeIsolatedFromAbove>(context);
|
|
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
|
|
return signalPassFailure();
|
|
}
|
|
return;
|
|
}
|
|
|
|
if (cloneOpsWithNoOperands) {
|
|
RewritePatternSet patterns(context);
|
|
patterns.insert<MakeIsolatedFromAboveAndCloneOpsWithNoOperands>(context);
|
|
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
|
|
return signalPassFailure();
|
|
}
|
|
return;
|
|
}
|
|
|
|
if (cloneOpsWithOperands) {
|
|
RewritePatternSet patterns(context);
|
|
patterns.insert<MakeIsolatedFromAboveAndCloneOpsWithOperands>(context);
|
|
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
|
|
return signalPassFailure();
|
|
}
|
|
return;
|
|
}
|
|
}
|
|
|
|
namespace mlir {
|
|
namespace test {
|
|
void registerTestMakeIsolatedFromAbovePass() {
|
|
PassRegistration<TestMakeIsolatedFromAbovePass>();
|
|
}
|
|
} // namespace test
|
|
} // namespace mlir
|