Files
clang-p2996/mlir/test/lib/IR/TestClone.cpp
Matthias Springer 5cc0f76d34 [mlir][IR] Add rewriter API for moving operations (#78988)
The pattern rewriter documentation states that "*all* IR mutations [...]
are required to be performed via the `PatternRewriter`." This commit
adds two functions that were missing from the rewriter API:
`moveOpBefore` and `moveOpAfter`.

After an operation was moved, the `notifyOperationInserted` callback is
triggered. This allows listeners such as the greedy pattern rewrite
driver to react to IR changes.

This commit narrows the discrepancy between the kind of IR modification
that can be performed and the kind of IR modifications that can be
listened to.
2024-01-25 11:01:28 +01:00

76 lines
2.5 KiB
C++

//===- TestClone.cpp - Pass to test operation cloning --------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
using namespace mlir;
namespace {
struct DumpNotifications : public OpBuilder::Listener {
void notifyOperationInserted(Operation *op,
OpBuilder::InsertPoint previous) override {
llvm::outs() << "notifyOperationInserted: " << op->getName() << "\n";
}
};
/// This is a test pass which clones the body of a function. Specifically
/// this pass replaces f(x) to instead return f(f(x)) in which the cloned body
/// takes the result of the first operation return as an input.
struct ClonePass
: public PassWrapper<ClonePass, InterfacePass<FunctionOpInterface>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ClonePass)
StringRef getArgument() const final { return "test-clone"; }
StringRef getDescription() const final { return "Test clone of op"; }
void runOnOperation() override {
FunctionOpInterface op = getOperation();
// Limit testing to ops with only one region.
if (op->getNumRegions() != 1)
return;
Region &region = op->getRegion(0);
if (!region.hasOneBlock())
return;
Block &regionEntry = region.front();
Operation *terminator = regionEntry.getTerminator();
// Only handle functions whose returns match the inputs.
if (terminator->getNumOperands() != regionEntry.getNumArguments())
return;
IRMapping map;
for (auto tup :
llvm::zip(terminator->getOperands(), regionEntry.getArguments())) {
if (std::get<0>(tup).getType() != std::get<1>(tup).getType())
return;
map.map(std::get<1>(tup), std::get<0>(tup));
}
OpBuilder builder(op->getContext());
DumpNotifications dumpNotifications;
builder.setListener(&dumpNotifications);
builder.setInsertionPointToEnd(&regionEntry);
SmallVector<Operation *> toClone;
for (Operation &inst : regionEntry)
toClone.push_back(&inst);
for (Operation *inst : toClone)
builder.clone(*inst, map);
terminator->erase();
}
};
} // namespace
namespace mlir {
void registerCloneTestPasses() { PassRegistration<ClonePass>(); }
} // namespace mlir