Sequence is an important transform combination primitive that just indicates transform ops being applied in a row. The simplest version requires fails immediately if any transformation in the sequence fails. Introducing this operation allows one to start placing transform IR within other IR. Depends On D123135 Reviewed By: Mogball, rriddle Differential Revision: https://reviews.llvm.org/D123664
58 lines
1.8 KiB
C++
58 lines
1.8 KiB
C++
//===- TestTransformDialectInterpreter.cpp --------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file defines a test pass that interprets Transform dialect operations in
|
|
// the module.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
|
|
#include "mlir/IR/BuiltinOps.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace {
|
|
/// Simple pass that applies transform dialect ops directly contained in a
|
|
/// module.
|
|
class TestTransformDialectInterpreterPass
|
|
: public PassWrapper<TestTransformDialectInterpreterPass,
|
|
OperationPass<ModuleOp>> {
|
|
public:
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
|
|
TestTransformDialectInterpreterPass)
|
|
|
|
StringRef getArgument() const override {
|
|
return "test-transform-dialect-interpreter";
|
|
}
|
|
|
|
StringRef getDescription() const override {
|
|
return "apply transform dialect operations one by one";
|
|
}
|
|
|
|
void runOnOperation() override {
|
|
ModuleOp module = getOperation();
|
|
transform::TransformState state(module.getBodyRegion(), module);
|
|
for (auto op :
|
|
module.getBody()->getOps<transform::TransformOpInterface>()) {
|
|
if (failed(state.applyTransform(op)))
|
|
return signalPassFailure();
|
|
}
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
namespace mlir {
|
|
namespace test {
|
|
/// Registers the test pass for applying transform dialect ops.
|
|
void registerTestTransformDialectInterpreterPass() {
|
|
PassRegistration<TestTransformDialectInterpreterPass> reg;
|
|
}
|
|
} // namespace test
|
|
} // namespace mlir
|