Introduce a transform dialect op that allows one to attempt different transformation sequences on the same piece of payload IR until one of them succeeds. This op fundamentally expands the scope of possibilities in the transform dialect that, until now, could only propagate transformation failure, at least using in-tree operations. This requires a more detailed specification of the execution model for the transform dialect that now indicates how failure is handled and propagated. Transformations described by transform operations now have tri-state results, with some errors being fundamentally irrecoverable (e.g., generating malformed IR) and some others being recoverable by containing ops. Existing transform ops directly implementing the `apply` interface method are updated to produce this directly. Transform ops with the `TransformEachTransformOpTrait` are currently considered to produce only irrecoverable failures and will be updated separately. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D127724
70 lines
2.3 KiB
C++
70 lines
2.3 KiB
C++
//===- TestTransformDialectInterpreter.cpp --------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file defines a test pass that interprets Transform dialect operations in
|
|
// the module.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
|
|
#include "mlir/IR/BuiltinOps.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace {
|
|
/// Simple pass that applies transform dialect ops directly contained in a
|
|
/// module.
|
|
class TestTransformDialectInterpreterPass
|
|
: public PassWrapper<TestTransformDialectInterpreterPass,
|
|
OperationPass<ModuleOp>> {
|
|
public:
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
|
|
TestTransformDialectInterpreterPass)
|
|
|
|
TestTransformDialectInterpreterPass() = default;
|
|
TestTransformDialectInterpreterPass(
|
|
const TestTransformDialectInterpreterPass &) {}
|
|
|
|
StringRef getArgument() const override {
|
|
return "test-transform-dialect-interpreter";
|
|
}
|
|
|
|
StringRef getDescription() const override {
|
|
return "apply transform dialect operations one by one";
|
|
}
|
|
|
|
void runOnOperation() override {
|
|
ModuleOp module = getOperation();
|
|
transform::TransformState state(
|
|
module.getBodyRegion(), module,
|
|
transform::TransformOptions().enableExpensiveChecks(
|
|
enableExpensiveChecks));
|
|
for (auto op :
|
|
module.getBody()->getOps<transform::TransformOpInterface>()) {
|
|
if (failed(state.applyTransform(op).checkAndReport()))
|
|
return signalPassFailure();
|
|
}
|
|
}
|
|
|
|
Option<bool> enableExpensiveChecks{
|
|
*this, "enable-expensive-checks", llvm::cl::init(false),
|
|
llvm::cl::desc("perform expensive checks to better report errors in the "
|
|
"transform IR")};
|
|
};
|
|
} // namespace
|
|
|
|
namespace mlir {
|
|
namespace test {
|
|
/// Registers the test pass for applying transform dialect ops.
|
|
void registerTestTransformDialectInterpreterPass() {
|
|
PassRegistration<TestTransformDialectInterpreterPass> reg;
|
|
}
|
|
} // namespace test
|
|
} // namespace mlir
|