Files
clang-p2996/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
Matthias Springer c63d2b2c71 [mlir][transform] Add TransformRewriter
All `apply` functions now have a `TransformRewriter &` parameter. This rewriter should be used to modify the IR. It has a `TrackingListener` attached and updates the internal handle-payload mappings based on rewrites.

Implementations no longer need to create their own `TrackingListener` and `IRRewriter`. Error checking is integrated into `applyTransform`. Tracking listener errors are reported only for ops with the `ReportTrackingListenerFailuresOpTrait` trait attached, allowing for a gradual migration. Furthermore, errors can be silenced with an op attribute.

Additional API will be added to `TransformRewriter` in subsequent revisions. This revision just adds an "empty" `TransformRewriter` class and updates all `apply` implementations.

Differential Revision: https://reviews.llvm.org/D152427
2023-06-20 10:49:59 +02:00

862 lines
32 KiB
C++

//===- TestTransformDialectExtension.cpp ----------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines an extension of the MLIR Transform dialect for testing
// purposes.
//
//===----------------------------------------------------------------------===//
#include "TestTransformDialectExtension.h"
#include "TestTransformStateExtension.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Compiler.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
namespace {
/// Simple transform op defined outside of the dialect. Just emits a remark when
/// applied. This op is defined in C++ to test that C++ definitions also work
/// for op injection into the Transform dialect.
class TestTransformOp
: public Op<TestTransformOp, transform::TransformOpInterface::Trait,
MemoryEffectOpInterface::Trait> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformOp)
using Op::Op;
static ArrayRef<StringRef> getAttributeNames() { return {}; }
static constexpr llvm::StringLiteral getOperationName() {
return llvm::StringLiteral("transform.test_transform_op");
}
DiagnosedSilenceableFailure apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
InFlightDiagnostic remark = emitRemark() << "applying transformation";
if (Attribute message = getMessage())
remark << " " << message;
return DiagnosedSilenceableFailure::success();
}
Attribute getMessage() { return getOperation()->getAttr("message"); }
static ParseResult parse(OpAsmParser &parser, OperationState &state) {
StringAttr message;
OptionalParseResult result = parser.parseOptionalAttribute(message);
if (!result.has_value())
return success();
if (result.value().succeeded())
state.addAttribute("message", message);
return result.value();
}
void print(OpAsmPrinter &printer) {
if (getMessage())
printer << " " << getMessage();
}
// No side effects.
void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
};
/// A test op to exercise the verifier of the PossibleTopLevelTransformOpTrait
/// in cases where it is attached to ops that do not comply with the trait
/// requirements. This op cannot be defined in ODS because ODS generates strict
/// verifiers that overalp with those in the trait and run earlier.
class TestTransformUnrestrictedOpNoInterface
: public Op<TestTransformUnrestrictedOpNoInterface,
transform::PossibleTopLevelTransformOpTrait,
transform::TransformOpInterface::Trait,
MemoryEffectOpInterface::Trait> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
TestTransformUnrestrictedOpNoInterface)
using Op::Op;
static ArrayRef<StringRef> getAttributeNames() { return {}; }
static constexpr llvm::StringLiteral getOperationName() {
return llvm::StringLiteral(
"transform.test_transform_unrestricted_op_no_interface");
}
DiagnosedSilenceableFailure apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
return DiagnosedSilenceableFailure::success();
}
// No side effects.
void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
};
} // namespace
DiagnosedSilenceableFailure
mlir::test::TestProduceSelfHandleOrForwardOperandOp::apply(
transform::TransformRewriter &rewriter,
transform::TransformResults &results, transform::TransformState &state) {
if (getOperation()->getNumOperands() != 0) {
results.set(cast<OpResult>(getResult()),
{getOperation()->getOperand(0).getDefiningOp()});
} else {
results.set(cast<OpResult>(getResult()), {getOperation()});
}
return DiagnosedSilenceableFailure::success();
}
void mlir::test::TestProduceSelfHandleOrForwardOperandOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
if (getOperand())
transform::onlyReadsHandle(getOperand(), effects);
transform::producesHandle(getRes(), effects);
}
DiagnosedSilenceableFailure
mlir::test::TestProduceValueHandleToSelfOperand::apply(
transform::TransformRewriter &rewriter,
transform::TransformResults &results, transform::TransformState &state) {
results.setValues(llvm::cast<OpResult>(getOut()), getIn());
return DiagnosedSilenceableFailure::success();
}
void mlir::test::TestProduceValueHandleToSelfOperand::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::onlyReadsHandle(getIn(), effects);
transform::producesHandle(getOut(), effects);
transform::onlyReadsPayload(effects);
}
DiagnosedSilenceableFailure
mlir::test::TestProduceValueHandleToResult::applyToOne(
transform::TransformRewriter &rewriter, Operation *target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
if (target->getNumResults() <= getNumber())
return emitSilenceableError() << "payload has no result #" << getNumber();
results.push_back(target->getResult(getNumber()));
return DiagnosedSilenceableFailure::success();
}
void mlir::test::TestProduceValueHandleToResult::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::onlyReadsHandle(getIn(), effects);
transform::producesHandle(getOut(), effects);
transform::onlyReadsPayload(effects);
}
DiagnosedSilenceableFailure
mlir::test::TestProduceValueHandleToArgumentOfParentBlock::applyToOne(
transform::TransformRewriter &rewriter, Operation *target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
if (!target->getBlock())
return emitSilenceableError() << "payload has no parent block";
if (target->getBlock()->getNumArguments() <= getNumber())
return emitSilenceableError()
<< "parent of the payload has no argument #" << getNumber();
results.push_back(target->getBlock()->getArgument(getNumber()));
return DiagnosedSilenceableFailure::success();
}
void mlir::test::TestProduceValueHandleToArgumentOfParentBlock::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::onlyReadsHandle(getIn(), effects);
transform::producesHandle(getOut(), effects);
transform::onlyReadsPayload(effects);
}
bool mlir::test::TestConsumeOperand::allowsRepeatedHandleOperands() {
return getAllowRepeatedHandles();
}
DiagnosedSilenceableFailure
mlir::test::TestConsumeOperand::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
return DiagnosedSilenceableFailure::success();
}
void mlir::test::TestConsumeOperand::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::consumesHandle(getOperand(), effects);
if (getSecondOperand())
transform::consumesHandle(getSecondOperand(), effects);
transform::modifiesPayload(effects);
}
DiagnosedSilenceableFailure mlir::test::TestConsumeOperandOfOpKindOrFail::apply(
transform::TransformRewriter &rewriter,
transform::TransformResults &results, transform::TransformState &state) {
auto payload = state.getPayloadOps(getOperand());
assert(llvm::hasSingleElement(payload) && "expected a single target op");
if ((*payload.begin())->getName().getStringRef() != getOpKind()) {
return emitSilenceableError()
<< "op expected the operand to be associated a payload op of kind "
<< getOpKind() << " got "
<< (*payload.begin())->getName().getStringRef();
}
emitRemark() << "succeeded";
return DiagnosedSilenceableFailure::success();
}
void mlir::test::TestConsumeOperandOfOpKindOrFail::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::consumesHandle(getOperand(), effects);
transform::modifiesPayload(effects);
}
DiagnosedSilenceableFailure
mlir::test::TestSucceedIfOperandOfOpKind::matchOperation(
Operation *op, transform::TransformResults &results,
transform::TransformState &state) {
if (op->getName().getStringRef() != getOpKind()) {
return emitSilenceableError()
<< "op expected the operand to be associated with a payload op of "
"kind "
<< getOpKind() << " got " << op->getName().getStringRef();
}
return DiagnosedSilenceableFailure::success();
}
void mlir::test::TestSucceedIfOperandOfOpKind::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::onlyReadsHandle(getOperand(), effects);
transform::onlyReadsPayload(effects);
}
DiagnosedSilenceableFailure mlir::test::TestPrintRemarkAtOperandOp::apply(
transform::TransformRewriter &rewriter,
transform::TransformResults &results, transform::TransformState &state) {
auto payload = state.getPayloadOps(getOperand());
for (Operation *op : payload)
op->emitRemark() << getMessage();
return DiagnosedSilenceableFailure::success();
}
void mlir::test::TestPrintRemarkAtOperandOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::onlyReadsHandle(getOperand(), effects);
transform::onlyReadsPayload(effects);
}
DiagnosedSilenceableFailure mlir::test::TestPrintRemarkAtOperandValue::apply(
transform::TransformRewriter &rewriter,
transform::TransformResults &results, transform::TransformState &state) {
ArrayRef<Value> values = state.getPayloadValues(getIn());
for (Value value : values) {
std::string note;
llvm::raw_string_ostream os(note);
if (auto arg = llvm::dyn_cast<BlockArgument>(value)) {
os << "a block argument #" << arg.getArgNumber() << " in block #"
<< std::distance(arg.getOwner()->getParent()->begin(),
arg.getOwner()->getIterator())
<< " in region #" << arg.getOwner()->getParent()->getRegionNumber();
} else {
os << "an op result #" << llvm::cast<OpResult>(value).getResultNumber();
}
InFlightDiagnostic diag = ::emitRemark(value.getLoc()) << getMessage();
diag.attachNote() << "value handle points to " << os.str();
}
return DiagnosedSilenceableFailure::success();
}
void mlir::test::TestPrintRemarkAtOperandValue::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::onlyReadsHandle(getIn(), effects);
transform::onlyReadsPayload(effects);
}
DiagnosedSilenceableFailure mlir::test::TestAddTestExtensionOp::apply(
transform::TransformRewriter &rewriter,
transform::TransformResults &results, transform::TransformState &state) {
state.addExtension<TestTransformStateExtension>(getMessageAttr());
return DiagnosedSilenceableFailure::success();
}
DiagnosedSilenceableFailure
mlir::test::TestCheckIfTestExtensionPresentOp::apply(
transform::TransformRewriter &rewriter,
transform::TransformResults &results, transform::TransformState &state) {
auto *extension = state.getExtension<TestTransformStateExtension>();
if (!extension) {
emitRemark() << "extension absent";
return DiagnosedSilenceableFailure::success();
}
InFlightDiagnostic diag = emitRemark()
<< "extension present, " << extension->getMessage();
for (Operation *payload : state.getPayloadOps(getOperand())) {
diag.attachNote(payload->getLoc()) << "associated payload op";
#ifndef NDEBUG
SmallVector<Value> handles;
assert(succeeded(state.getHandlesForPayloadOp(payload, handles)));
assert(llvm::is_contained(handles, getOperand()) &&
"inconsistent mapping between transform IR handles and payload IR "
"operations");
#endif // NDEBUG
}
return DiagnosedSilenceableFailure::success();
}
void mlir::test::TestCheckIfTestExtensionPresentOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::onlyReadsHandle(getOperand(), effects);
transform::onlyReadsPayload(effects);
}
DiagnosedSilenceableFailure mlir::test::TestRemapOperandPayloadToSelfOp::apply(
transform::TransformRewriter &rewriter,
transform::TransformResults &results, transform::TransformState &state) {
auto *extension = state.getExtension<TestTransformStateExtension>();
if (!extension)
return emitDefiniteFailure("TestTransformStateExtension missing");
if (failed(extension->updateMapping(
*state.getPayloadOps(getOperand()).begin(), getOperation())))
return DiagnosedSilenceableFailure::definiteFailure();
if (getNumResults() > 0)
results.set(cast<OpResult>(getResult(0)), {getOperation()});
return DiagnosedSilenceableFailure::success();
}
void mlir::test::TestRemapOperandPayloadToSelfOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::onlyReadsHandle(getOperand(), effects);
transform::producesHandle(getOut(), effects);
transform::onlyReadsPayload(effects);
}
DiagnosedSilenceableFailure mlir::test::TestRemoveTestExtensionOp::apply(
transform::TransformRewriter &rewriter,
transform::TransformResults &results, transform::TransformState &state) {
state.removeExtension<TestTransformStateExtension>();
return DiagnosedSilenceableFailure::success();
}
DiagnosedSilenceableFailure mlir::test::TestReversePayloadOpsOp::apply(
transform::TransformRewriter &rewriter,
transform::TransformResults &results, transform::TransformState &state) {
auto payloadOps = state.getPayloadOps(getTarget());
auto reversedOps = llvm::to_vector(llvm::reverse(payloadOps));
results.set(llvm::cast<OpResult>(getResult()), reversedOps);
return DiagnosedSilenceableFailure::success();
}
DiagnosedSilenceableFailure mlir::test::TestTransformOpWithRegions::apply(
transform::TransformRewriter &rewriter,
transform::TransformResults &results, transform::TransformState &state) {
return DiagnosedSilenceableFailure::success();
}
void mlir::test::TestTransformOpWithRegions::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
DiagnosedSilenceableFailure
mlir::test::TestBranchingTransformOpTerminator::apply(
transform::TransformRewriter &rewriter,
transform::TransformResults &results, transform::TransformState &state) {
return DiagnosedSilenceableFailure::success();
}
void mlir::test::TestBranchingTransformOpTerminator::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
DiagnosedSilenceableFailure mlir::test::TestEmitRemarkAndEraseOperandOp::apply(
transform::TransformRewriter &rewriter,
transform::TransformResults &results, transform::TransformState &state) {
emitRemark() << getRemark();
for (Operation *op : state.getPayloadOps(getTarget()))
op->erase();
if (getFailAfterErase())
return emitSilenceableError() << "silenceable error";
return DiagnosedSilenceableFailure::success();
}
void mlir::test::TestEmitRemarkAndEraseOperandOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::consumesHandle(getTarget(), effects);
transform::modifiesPayload(effects);
}
DiagnosedSilenceableFailure mlir::test::TestWrongNumberOfResultsOp::applyToOne(
transform::TransformRewriter &rewriter, Operation *target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
OperationState opState(target->getLoc(), "foo");
results.push_back(OpBuilder(target).create(opState));
return DiagnosedSilenceableFailure::success();
}
DiagnosedSilenceableFailure
mlir::test::TestWrongNumberOfMultiResultsOp::applyToOne(
transform::TransformRewriter &rewriter, Operation *target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
static int count = 0;
if (count++ == 0) {
OperationState opState(target->getLoc(), "foo");
results.push_back(OpBuilder(target).create(opState));
}
return DiagnosedSilenceableFailure::success();
}
DiagnosedSilenceableFailure
mlir::test::TestCorrectNumberOfMultiResultsOp::applyToOne(
transform::TransformRewriter &rewriter, Operation *target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
OperationState opState(target->getLoc(), "foo");
results.push_back(OpBuilder(target).create(opState));
results.push_back(OpBuilder(target).create(opState));
return DiagnosedSilenceableFailure::success();
}
DiagnosedSilenceableFailure
mlir::test::TestMixedNullAndNonNullResultsOp::applyToOne(
transform::TransformRewriter &rewriter, Operation *target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
OperationState opState(target->getLoc(), "foo");
results.push_back(nullptr);
results.push_back(OpBuilder(target).create(opState));
return DiagnosedSilenceableFailure::success();
}
DiagnosedSilenceableFailure
mlir::test::TestMixedSuccessAndSilenceableOp::applyToOne(
transform::TransformRewriter &rewriter, Operation *target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
if (target->hasAttr("target_me"))
return DiagnosedSilenceableFailure::success();
return emitDefaultSilenceableFailure(target);
}
DiagnosedSilenceableFailure
mlir::test::TestPrintNumberOfAssociatedPayloadIROps::apply(
transform::TransformRewriter &rewriter,
transform::TransformResults &results, transform::TransformState &state) {
if (!getHandle())
emitRemark() << 0;
emitRemark() << llvm::range_size(state.getPayloadOps(getHandle()));
return DiagnosedSilenceableFailure::success();
}
void mlir::test::TestPrintNumberOfAssociatedPayloadIROps::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::onlyReadsHandle(getHandle(), effects);
}
DiagnosedSilenceableFailure
mlir::test::TestCopyPayloadOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
results.set(llvm::cast<OpResult>(getCopy()),
state.getPayloadOps(getHandle()));
return DiagnosedSilenceableFailure::success();
}
void mlir::test::TestCopyPayloadOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::onlyReadsHandle(getHandle(), effects);
transform::producesHandle(getCopy(), effects);
transform::onlyReadsPayload(effects);
}
DiagnosedSilenceableFailure mlir::transform::TestDialectOpType::checkPayload(
Location loc, ArrayRef<Operation *> payload) const {
if (payload.empty())
return DiagnosedSilenceableFailure::success();
for (Operation *op : payload) {
if (op->getName().getDialectNamespace() != "test") {
return emitSilenceableError(loc) << "expected the payload operation to "
"belong to the 'test' dialect";
}
}
return DiagnosedSilenceableFailure::success();
}
DiagnosedSilenceableFailure mlir::transform::TestDialectParamType::checkPayload(
Location loc, ArrayRef<Attribute> payload) const {
for (Attribute attr : payload) {
auto integerAttr = llvm::dyn_cast<IntegerAttr>(attr);
if (integerAttr && integerAttr.getType().isSignlessInteger(32))
continue;
return emitSilenceableError(loc)
<< "expected the parameter to be a i32 integer attribute";
}
return DiagnosedSilenceableFailure::success();
}
void mlir::test::TestReportNumberOfTrackedHandlesNestedUnder::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::onlyReadsHandle(getTarget(), effects);
}
DiagnosedSilenceableFailure
mlir::test::TestReportNumberOfTrackedHandlesNestedUnder::apply(
transform::TransformRewriter &rewriter,
transform::TransformResults &results, transform::TransformState &state) {
int64_t count = 0;
for (Operation *op : state.getPayloadOps(getTarget())) {
op->walk([&](Operation *nested) {
SmallVector<Value> handles;
(void)state.getHandlesForPayloadOp(nested, handles);
count += handles.size();
});
}
emitRemark() << count << " handles nested under";
return DiagnosedSilenceableFailure::success();
}
void mlir::test::TestPrintParamOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::onlyReadsHandle(getParam(), effects);
if (getAnchor())
transform::onlyReadsHandle(getAnchor(), effects);
transform::onlyReadsPayload(effects);
}
DiagnosedSilenceableFailure
mlir::test::TestPrintParamOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
std::string str;
llvm::raw_string_ostream os(str);
if (getMessage())
os << *getMessage() << " ";
llvm::interleaveComma(state.getParams(getParam()), os);
if (!getAnchor()) {
emitRemark() << os.str();
return DiagnosedSilenceableFailure::success();
}
for (Operation *payload : state.getPayloadOps(getAnchor()))
::mlir::emitRemark(payload->getLoc()) << os.str();
return DiagnosedSilenceableFailure::success();
}
DiagnosedSilenceableFailure
mlir::test::TestAddToParamOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
SmallVector<uint32_t> values(/*Size=*/1, /*Value=*/0);
if (Value param = getParam()) {
values = llvm::to_vector(
llvm::map_range(state.getParams(param), [](Attribute attr) -> uint32_t {
return llvm::cast<IntegerAttr>(attr).getValue().getLimitedValue(
UINT32_MAX);
}));
}
Builder builder(getContext());
SmallVector<Attribute> result = llvm::to_vector(
llvm::map_range(values, [this, &builder](uint32_t value) -> Attribute {
return builder.getI32IntegerAttr(value + getAddendum());
}));
results.setParams(llvm::cast<OpResult>(getResult()), result);
return DiagnosedSilenceableFailure::success();
}
DiagnosedSilenceableFailure
mlir::test::TestProduceParamWithNumberOfTestOps::apply(
transform::TransformRewriter &rewriter,
transform::TransformResults &results, transform::TransformState &state) {
Builder builder(getContext());
SmallVector<Attribute> result = llvm::to_vector(
llvm::map_range(state.getPayloadOps(getHandle()),
[&builder](Operation *payload) -> Attribute {
int32_t count = 0;
payload->walk([&count](Operation *op) {
if (op->getName().getDialectNamespace() == "test")
++count;
});
return builder.getI32IntegerAttr(count);
}));
results.setParams(llvm::cast<OpResult>(getResult()), result);
return DiagnosedSilenceableFailure::success();
}
DiagnosedSilenceableFailure
mlir::test::TestProduceIntegerParamWithTypeOp::apply(
transform::TransformRewriter &rewriter,
transform::TransformResults &results, transform::TransformState &state) {
Attribute zero = IntegerAttr::get(getType(), 0);
results.setParams(llvm::cast<OpResult>(getResult()), zero);
return DiagnosedSilenceableFailure::success();
}
LogicalResult mlir::test::TestProduceIntegerParamWithTypeOp::verify() {
if (!llvm::isa<IntegerType>(getType())) {
return emitOpError() << "expects an integer type";
}
return success();
}
void mlir::test::TestProduceTransformParamOrForwardOperandOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::onlyReadsHandle(getIn(), effects);
transform::producesHandle(getOut(), effects);
transform::producesHandle(getParam(), effects);
}
DiagnosedSilenceableFailure
mlir::test::TestProduceTransformParamOrForwardOperandOp::applyToOne(
transform::TransformRewriter &rewriter, Operation *target,
::transform::ApplyToEachResultList &results,
::transform::TransformState &state) {
Builder builder(getContext());
if (getFirstResultIsParam()) {
results.push_back(builder.getI64IntegerAttr(0));
} else if (getFirstResultIsNull()) {
results.push_back(nullptr);
} else {
results.push_back(*state.getPayloadOps(getIn()).begin());
}
if (getSecondResultIsHandle()) {
results.push_back(*state.getPayloadOps(getIn()).begin());
} else {
results.push_back(builder.getI64IntegerAttr(42));
}
return DiagnosedSilenceableFailure::success();
}
void mlir::test::TestProduceNullPayloadOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::producesHandle(getOut(), effects);
}
DiagnosedSilenceableFailure mlir::test::TestProduceNullPayloadOp::apply(
transform::TransformRewriter &rewriter,
transform::TransformResults &results, transform::TransformState &state) {
SmallVector<Operation *, 1> null({nullptr});
results.set(llvm::cast<OpResult>(getOut()), null);
return DiagnosedSilenceableFailure::success();
}
DiagnosedSilenceableFailure mlir::test::TestProduceEmptyPayloadOp::apply(
transform::TransformRewriter &rewriter,
transform::TransformResults &results, transform::TransformState &state) {
results.set(cast<OpResult>(getOut()), {});
return DiagnosedSilenceableFailure::success();
}
void mlir::test::TestProduceNullParamOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::producesHandle(getOut(), effects);
}
DiagnosedSilenceableFailure mlir::test::TestProduceNullParamOp::apply(
transform::TransformRewriter &rewriter,
transform::TransformResults &results, transform::TransformState &state) {
results.setParams(llvm::cast<OpResult>(getOut()), Attribute());
return DiagnosedSilenceableFailure::success();
}
void mlir::test::TestProduceNullValueOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::producesHandle(getOut(), effects);
}
DiagnosedSilenceableFailure mlir::test::TestProduceNullValueOp::apply(
transform::TransformRewriter &rewriter,
transform::TransformResults &results, transform::TransformState &state) {
results.setValues(llvm::cast<OpResult>(getOut()), Value());
return DiagnosedSilenceableFailure::success();
}
void mlir::test::TestRequiredMemoryEffectsOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
if (getHasOperandEffect())
transform::consumesHandle(getIn(), effects);
if (getHasResultEffect())
transform::producesHandle(getOut(), effects);
else
transform::onlyReadsHandle(getOut(), effects);
if (getModifiesPayload())
transform::modifiesPayload(effects);
}
DiagnosedSilenceableFailure mlir::test::TestRequiredMemoryEffectsOp::apply(
transform::TransformRewriter &rewriter,
transform::TransformResults &results, transform::TransformState &state) {
results.set(llvm::cast<OpResult>(getOut()), state.getPayloadOps(getIn()));
return DiagnosedSilenceableFailure::success();
}
void mlir::test::TestTrackedRewriteOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::onlyReadsHandle(getIn(), effects);
transform::modifiesPayload(effects);
}
void mlir::test::TestDummyPayloadOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
for (OpResult result : getResults())
transform::producesHandle(result, effects);
}
DiagnosedSilenceableFailure
mlir::test::TestTrackedRewriteOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
int64_t numIterations = 0;
// `getPayloadOps` returns an iterator that skips ops that are erased in the
// loop body. Replacement ops are not enumerated.
for (Operation *op : state.getPayloadOps(getIn())) {
++numIterations;
(void)op;
// Erase all payload ops. The outer loop should have only one iteration.
for (Operation *op : state.getPayloadOps(getIn())) {
rewriter.setInsertionPoint(op);
if (op->hasAttr("erase_me")) {
rewriter.eraseOp(op);
continue;
}
if (!op->hasAttr("replace_me")) {
continue;
}
SmallVector<NamedAttribute> attributes;
attributes.emplace_back(rewriter.getStringAttr("new_op"),
rewriter.getUnitAttr());
OperationState opState(op->getLoc(), op->getName().getIdentifier(),
/*operands=*/ValueRange(),
/*types=*/op->getResultTypes(), attributes);
Operation *newOp = rewriter.create(opState);
rewriter.replaceOp(op, newOp->getResults());
}
}
emitRemark() << numIterations << " iterations";
return DiagnosedSilenceableFailure::success();
}
namespace {
// Test pattern to replace an operation with a new op.
class ReplaceWithNewOp : public RewritePattern {
public:
ReplaceWithNewOp(MLIRContext *context)
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
auto newName = op->getAttrOfType<StringAttr>("replace_with_new_op");
if (!newName)
return failure();
Operation *newOp = rewriter.create(
op->getLoc(), OperationName(newName, op->getContext()).getIdentifier(),
op->getOperands(), op->getResultTypes());
rewriter.replaceOp(op, newOp->getResults());
return success();
}
};
// Test pattern to erase an operation.
class EraseOp : public RewritePattern {
public:
EraseOp(MLIRContext *context)
: RewritePattern("test.erase_op", /*benefit=*/1, context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
rewriter.eraseOp(op);
return success();
}
};
} // namespace
void mlir::test::ApplyTestPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
patterns.insert<ReplaceWithNewOp, EraseOp>(patterns.getContext());
}
namespace {
/// Test extension of the Transform dialect. Registers additional ops and
/// declares PDL as dependent dialect since the additional ops are using PDL
/// types for operands and results.
class TestTransformDialectExtension
: public transform::TransformDialectExtension<
TestTransformDialectExtension> {
public:
using Base::Base;
void init() {
declareDependentDialect<pdl::PDLDialect>();
registerTransformOps<TestTransformOp,
TestTransformUnrestrictedOpNoInterface,
#define GET_OP_LIST
#include "TestTransformDialectExtension.cpp.inc"
>();
registerTypes<
#define GET_TYPEDEF_LIST
#include "TestTransformDialectExtensionTypes.cpp.inc"
>();
auto verboseConstraint = [](PatternRewriter &rewriter,
ArrayRef<PDLValue> pdlValues) {
for (const PDLValue &pdlValue : pdlValues) {
if (Operation *op = pdlValue.dyn_cast<Operation *>()) {
op->emitWarning() << "from PDL constraint";
}
}
return success();
};
addDialectDataInitializer<transform::PDLMatchHooks>(
[&](transform::PDLMatchHooks &hooks) {
llvm::StringMap<PDLConstraintFunction> constraints;
constraints.try_emplace("verbose_constraint", verboseConstraint);
hooks.mergeInPDLMatchHooks(std::move(constraints));
});
}
};
} // namespace
// These are automatically generated by ODS but are not used as the Transform
// dialect uses a different dispatch mechanism to support dialect extensions.
LLVM_ATTRIBUTE_UNUSED static OptionalParseResult
generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value);
LLVM_ATTRIBUTE_UNUSED static LogicalResult
generatedTypePrinter(Type def, AsmPrinter &printer);
#define GET_TYPEDEF_CLASSES
#include "TestTransformDialectExtensionTypes.cpp.inc"
#define GET_OP_CLASSES
#include "TestTransformDialectExtension.cpp.inc"
void ::test::registerTestTransformDialectExtension(DialectRegistry &registry) {
registry.addExtensions<TestTransformDialectExtension>();
}