Files
clang-p2996/mlir/lib/Dialect/Transform/Transforms/InterpreterPass.cpp
Oleksandr "Alex" Zinenko e4384149b5 [mlir] use transform-interpreter in test passes (#70040)
Update most test passes to use the transform-interpreter pass instead of
the test-transform-dialect-interpreter-pass. The new "main" interpreter
pass has a named entry point instead of looking up the top-level op with
`PossibleTopLevelOpTrait`, which is arguably a more understandable
interface. The change is mechanical, rewriting an unnamed sequence into
a named one and wrapping the transform IR in to a module when necessary.

Add an option to the transform-interpreter pass to target a tagged
payload op instead of the root anchor op, which is also useful for repro
generation.

Only the test in the transform dialect proper and the examples have not
been updated yet. These will be updated separately after a more careful
consideration of testing coverage of the transform interpreter logic.
2023-10-24 16:12:34 +02:00

88 lines
2.9 KiB
C++

//===- InterpreterPass.cpp - Transform dialect interpreter pass -----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/Transforms/Passes.h"
#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
using namespace mlir;
namespace mlir {
namespace transform {
#define GEN_PASS_DEF_INTERPRETERPASS
#include "mlir/Dialect/Transform/Transforms/Passes.h.inc"
} // namespace transform
} // namespace mlir
/// Returns the payload operation to be used as payload root:
/// - the operation nested under `passRoot` that has the given tag attribute,
/// must be unique;
/// - the `passRoot` itself if the tag is empty.
static Operation *findPayloadRoot(Operation *passRoot, StringRef tag) {
// Fast return.
if (tag.empty())
return passRoot;
// Walk to do a lookup.
Operation *target = nullptr;
auto tagAttrName = StringAttr::get(
passRoot->getContext(), transform::TransformDialect::kTargetTagAttrName);
WalkResult walkResult = passRoot->walk([&](Operation *op) {
auto attr = op->getAttrOfType<StringAttr>(tagAttrName);
if (!attr || attr.getValue() != tag)
return WalkResult::advance();
if (!target) {
target = op;
return WalkResult::advance();
}
InFlightDiagnostic diag = op->emitError()
<< "repeated operation with the target tag '"
<< tag << "'";
diag.attachNote(target->getLoc()) << "previously seen operation";
return WalkResult::interrupt();
});
return walkResult.wasInterrupted() ? nullptr : target;
}
namespace {
class InterpreterPass
: public transform::impl::InterpreterPassBase<InterpreterPass> {
public:
using Base::Base;
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp transformModule =
transform::detail::getPreloadedTransformModule(context);
Operation *payloadRoot =
findPayloadRoot(getOperation(), debugPayloadRootTag);
Operation *transformEntryPoint = transform::detail::findTransformEntryPoint(
getOperation(), transformModule, entryPoint);
if (!transformEntryPoint) {
getOperation()->emitError()
<< "could not find transform entry point: " << entryPoint
<< " in either payload or transform module";
return signalPassFailure();
}
if (failed(transform::applyTransformNamedSequence(
payloadRoot, transformEntryPoint, transformModule,
options.enableExpensiveChecks(!disableExpensiveChecks)))) {
return signalPassFailure();
}
}
private:
/// Transform interpreter options.
transform::TransformOptions options;
};
} // namespace