Files
clang-p2996/mlir/test/lib/Transforms/TestTransformsOps.cpp
MaheshRavishankar 205c5325b3 [mlir] Add a utility method to move operation dependencies. (#129975)
The added utility method moves all SSA values that an operation depends
upon before an insertion point. This is useful during transformations
where such movements might make transformations (like fusion) more
powerful.

To test the operation add a transform dialect op that calls the move
operation. To be able to capture the `notifyMatchFailure` messages from
the transformation and to report/check these in the test modify the
`ErrorCheckingTrackingListener` to capture the last match failure
notification.

---------

Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
2025-03-10 20:23:08 -07:00

67 lines
2.1 KiB
C++

//===- TestTransformsOps.cpp - Test Transforms ----------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines transform dialect operations for testing MLIR
// transformations
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/Transforms/RegionUtils.h"
#define GET_OP_CLASSES
#include "TestTransformsOps.h.inc"
using namespace mlir;
using namespace mlir::transform;
#define GET_OP_CLASSES
#include "TestTransformsOps.cpp.inc"
DiagnosedSilenceableFailure
transform::TestMoveOperandDeps::apply(TransformRewriter &rewriter,
TransformResults &TransformResults,
TransformState &state) {
Operation *op = *state.getPayloadOps(getOp()).begin();
Operation *moveBefore = *state.getPayloadOps(getInsertionPoint()).begin();
if (failed(moveOperationDependencies(rewriter, op, moveBefore))) {
auto listener = cast<ErrorCheckingTrackingListener>(rewriter.getListener());
std::string errorMsg = listener->getLatestMatchFailureMessage();
(void)emitRemark(errorMsg);
}
return DiagnosedSilenceableFailure::success();
}
namespace {
class TestTransformsDialectExtension
: public transform::TransformDialectExtension<
TestTransformsDialectExtension> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformsDialectExtension)
using Base::Base;
void init() {
registerTransformOps<
#define GET_OP_LIST
#include "TestTransformsOps.cpp.inc"
>();
}
};
} // namespace
namespace test {
void registerTestTransformsTransformDialectExtension(
DialectRegistry &registry) {
registry.addExtensions<TestTransformsDialectExtension>();
}
} // namespace test