The added utility method moves all SSA values that an operation depends upon before an insertion point. This is useful during transformations where such movements might make transformations (like fusion) more powerful. To test the operation add a transform dialect op that calls the move operation. To be able to capture the `notifyMatchFailure` messages from the transformation and to report/check these in the test modify the `ErrorCheckingTrackingListener` to capture the last match failure notification. --------- Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
67 lines
2.1 KiB
C++
67 lines
2.1 KiB
C++
//===- TestTransformsOps.cpp - Test Transforms ----------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file defines transform dialect operations for testing MLIR
|
|
// transformations
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
|
|
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
|
|
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
|
|
#include "mlir/Transforms/RegionUtils.h"
|
|
|
|
#define GET_OP_CLASSES
|
|
#include "TestTransformsOps.h.inc"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::transform;
|
|
|
|
#define GET_OP_CLASSES
|
|
#include "TestTransformsOps.cpp.inc"
|
|
|
|
DiagnosedSilenceableFailure
|
|
transform::TestMoveOperandDeps::apply(TransformRewriter &rewriter,
|
|
TransformResults &TransformResults,
|
|
TransformState &state) {
|
|
Operation *op = *state.getPayloadOps(getOp()).begin();
|
|
Operation *moveBefore = *state.getPayloadOps(getInsertionPoint()).begin();
|
|
if (failed(moveOperationDependencies(rewriter, op, moveBefore))) {
|
|
auto listener = cast<ErrorCheckingTrackingListener>(rewriter.getListener());
|
|
std::string errorMsg = listener->getLatestMatchFailureMessage();
|
|
(void)emitRemark(errorMsg);
|
|
}
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
namespace {
|
|
|
|
class TestTransformsDialectExtension
|
|
: public transform::TransformDialectExtension<
|
|
TestTransformsDialectExtension> {
|
|
public:
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformsDialectExtension)
|
|
|
|
using Base::Base;
|
|
|
|
void init() {
|
|
registerTransformOps<
|
|
#define GET_OP_LIST
|
|
#include "TestTransformsOps.cpp.inc"
|
|
>();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
namespace test {
|
|
void registerTestTransformsTransformDialectExtension(
|
|
DialectRegistry ®istry) {
|
|
registry.addExtensions<TestTransformsDialectExtension>();
|
|
}
|
|
} // namespace test
|