This patch adds more precise side effects to the current ops with memory effects, allowing us to determine which OpOperand/OpResult/BlockArgument the operation reads or writes, rather than just recording the reading and writing of values. This allows for convenient use of precise side effects to achieve analysis and optimization. Related discussions: https://discourse.llvm.org/t/rfc-add-operandindex-to-sideeffect-instance/79243
927 lines
35 KiB
C++
927 lines
35 KiB
C++
//===- TestTransformDialectExtension.cpp ----------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file defines an extension of the MLIR Transform dialect for testing
|
|
// purposes.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "TestTransformDialectExtension.h"
|
|
#include "TestTransformStateExtension.h"
|
|
#include "mlir/Dialect/PDL/IR/PDL.h"
|
|
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
|
|
#include "mlir/Dialect/Transform/IR/TransformOps.h"
|
|
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
|
|
#include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h"
|
|
#include "mlir/IR/OpImplementation.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/ADT/TypeSwitch.h"
|
|
#include "llvm/Support/Compiler.h"
|
|
#include "llvm/Support/raw_ostream.h"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace {
|
|
/// Simple transform op defined outside of the dialect. Just emits a remark when
|
|
/// applied. This op is defined in C++ to test that C++ definitions also work
|
|
/// for op injection into the Transform dialect.
|
|
class TestTransformOp
|
|
: public Op<TestTransformOp, transform::TransformOpInterface::Trait,
|
|
MemoryEffectOpInterface::Trait> {
|
|
public:
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformOp)
|
|
|
|
using Op::Op;
|
|
|
|
static ArrayRef<StringRef> getAttributeNames() { return {}; }
|
|
|
|
static constexpr llvm::StringLiteral getOperationName() {
|
|
return llvm::StringLiteral("transform.test_transform_op");
|
|
}
|
|
|
|
DiagnosedSilenceableFailure apply(transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results,
|
|
transform::TransformState &state) {
|
|
InFlightDiagnostic remark = emitRemark() << "applying transformation";
|
|
if (Attribute message = getMessage())
|
|
remark << " " << message;
|
|
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
Attribute getMessage() {
|
|
return getOperation()->getDiscardableAttr("message");
|
|
}
|
|
|
|
static ParseResult parse(OpAsmParser &parser, OperationState &state) {
|
|
StringAttr message;
|
|
OptionalParseResult result = parser.parseOptionalAttribute(message);
|
|
if (!result.has_value())
|
|
return success();
|
|
|
|
if (result.value().succeeded())
|
|
state.addAttribute("message", message);
|
|
return result.value();
|
|
}
|
|
|
|
void print(OpAsmPrinter &printer) {
|
|
if (getMessage())
|
|
printer << " " << getMessage();
|
|
}
|
|
|
|
// No side effects.
|
|
void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
|
|
};
|
|
|
|
/// A test op to exercise the verifier of the PossibleTopLevelTransformOpTrait
|
|
/// in cases where it is attached to ops that do not comply with the trait
|
|
/// requirements. This op cannot be defined in ODS because ODS generates strict
|
|
/// verifiers that overalp with those in the trait and run earlier.
|
|
class TestTransformUnrestrictedOpNoInterface
|
|
: public Op<TestTransformUnrestrictedOpNoInterface,
|
|
transform::PossibleTopLevelTransformOpTrait,
|
|
transform::TransformOpInterface::Trait,
|
|
MemoryEffectOpInterface::Trait> {
|
|
public:
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
|
|
TestTransformUnrestrictedOpNoInterface)
|
|
|
|
using Op::Op;
|
|
|
|
static ArrayRef<StringRef> getAttributeNames() { return {}; }
|
|
|
|
static constexpr llvm::StringLiteral getOperationName() {
|
|
return llvm::StringLiteral(
|
|
"transform.test_transform_unrestricted_op_no_interface");
|
|
}
|
|
|
|
DiagnosedSilenceableFailure apply(transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results,
|
|
transform::TransformState &state) {
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
// No side effects.
|
|
void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
|
|
};
|
|
} // namespace
|
|
|
|
DiagnosedSilenceableFailure
|
|
mlir::test::TestProduceSelfHandleOrForwardOperandOp::apply(
|
|
transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results, transform::TransformState &state) {
|
|
if (getOperation()->getNumOperands() != 0) {
|
|
results.set(cast<OpResult>(getResult()),
|
|
{getOperation()->getOperand(0).getDefiningOp()});
|
|
} else {
|
|
results.set(cast<OpResult>(getResult()), {getOperation()});
|
|
}
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void mlir::test::TestProduceSelfHandleOrForwardOperandOp::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
if (getOperand())
|
|
transform::onlyReadsHandle(getOperandMutable(), effects);
|
|
transform::producesHandle(getOperation()->getOpResults(), effects);
|
|
}
|
|
|
|
DiagnosedSilenceableFailure
|
|
mlir::test::TestProduceValueHandleToSelfOperand::apply(
|
|
transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results, transform::TransformState &state) {
|
|
results.setValues(llvm::cast<OpResult>(getOut()), {getIn()});
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void mlir::test::TestProduceValueHandleToSelfOperand::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
transform::onlyReadsHandle(getInMutable(), effects);
|
|
transform::producesHandle(getOperation()->getOpResults(), effects);
|
|
transform::onlyReadsPayload(effects);
|
|
}
|
|
|
|
DiagnosedSilenceableFailure
|
|
mlir::test::TestProduceValueHandleToResult::applyToOne(
|
|
transform::TransformRewriter &rewriter, Operation *target,
|
|
transform::ApplyToEachResultList &results,
|
|
transform::TransformState &state) {
|
|
if (target->getNumResults() <= getNumber())
|
|
return emitSilenceableError() << "payload has no result #" << getNumber();
|
|
results.push_back(target->getResult(getNumber()));
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void mlir::test::TestProduceValueHandleToResult::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
transform::onlyReadsHandle(getInMutable(), effects);
|
|
transform::producesHandle(getOperation()->getOpResults(), effects);
|
|
transform::onlyReadsPayload(effects);
|
|
}
|
|
|
|
DiagnosedSilenceableFailure
|
|
mlir::test::TestProduceValueHandleToArgumentOfParentBlock::applyToOne(
|
|
transform::TransformRewriter &rewriter, Operation *target,
|
|
transform::ApplyToEachResultList &results,
|
|
transform::TransformState &state) {
|
|
if (!target->getBlock())
|
|
return emitSilenceableError() << "payload has no parent block";
|
|
if (target->getBlock()->getNumArguments() <= getNumber())
|
|
return emitSilenceableError()
|
|
<< "parent of the payload has no argument #" << getNumber();
|
|
results.push_back(target->getBlock()->getArgument(getNumber()));
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void mlir::test::TestProduceValueHandleToArgumentOfParentBlock::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
transform::onlyReadsHandle(getInMutable(), effects);
|
|
transform::producesHandle(getOperation()->getOpResults(), effects);
|
|
transform::onlyReadsPayload(effects);
|
|
}
|
|
|
|
bool mlir::test::TestConsumeOperand::allowsRepeatedHandleOperands() {
|
|
return getAllowRepeatedHandles();
|
|
}
|
|
|
|
DiagnosedSilenceableFailure
|
|
mlir::test::TestConsumeOperand::apply(transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results,
|
|
transform::TransformState &state) {
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void mlir::test::TestConsumeOperand::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
transform::consumesHandle(getOperation()->getOpOperands(), effects);
|
|
if (getSecondOperand())
|
|
transform::consumesHandle(getSecondOperandMutable(), effects);
|
|
transform::modifiesPayload(effects);
|
|
}
|
|
|
|
DiagnosedSilenceableFailure mlir::test::TestConsumeOperandOfOpKindOrFail::apply(
|
|
transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results, transform::TransformState &state) {
|
|
auto payload = state.getPayloadOps(getOperand());
|
|
assert(llvm::hasSingleElement(payload) && "expected a single target op");
|
|
if ((*payload.begin())->getName().getStringRef() != getOpKind()) {
|
|
return emitSilenceableError()
|
|
<< "op expected the operand to be associated a payload op of kind "
|
|
<< getOpKind() << " got "
|
|
<< (*payload.begin())->getName().getStringRef();
|
|
}
|
|
|
|
emitRemark() << "succeeded";
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void mlir::test::TestConsumeOperandOfOpKindOrFail::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
transform::consumesHandle(getOperation()->getOpOperands(), effects);
|
|
transform::modifiesPayload(effects);
|
|
}
|
|
|
|
DiagnosedSilenceableFailure
|
|
mlir::test::TestSucceedIfOperandOfOpKind::matchOperation(
|
|
Operation *op, transform::TransformResults &results,
|
|
transform::TransformState &state) {
|
|
if (op->getName().getStringRef() != getOpKind()) {
|
|
return emitSilenceableError()
|
|
<< "op expected the operand to be associated with a payload op of "
|
|
"kind "
|
|
<< getOpKind() << " got " << op->getName().getStringRef();
|
|
}
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void mlir::test::TestSucceedIfOperandOfOpKind::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
transform::onlyReadsHandle(getOperation()->getOpOperands(), effects);
|
|
transform::onlyReadsPayload(effects);
|
|
}
|
|
|
|
DiagnosedSilenceableFailure mlir::test::TestAddTestExtensionOp::apply(
|
|
transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results, transform::TransformState &state) {
|
|
state.addExtension<TestTransformStateExtension>(getMessageAttr());
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
DiagnosedSilenceableFailure
|
|
mlir::test::TestCheckIfTestExtensionPresentOp::apply(
|
|
transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results, transform::TransformState &state) {
|
|
auto *extension = state.getExtension<TestTransformStateExtension>();
|
|
if (!extension) {
|
|
emitRemark() << "extension absent";
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
InFlightDiagnostic diag = emitRemark()
|
|
<< "extension present, " << extension->getMessage();
|
|
for (Operation *payload : state.getPayloadOps(getOperand())) {
|
|
diag.attachNote(payload->getLoc()) << "associated payload op";
|
|
#ifndef NDEBUG
|
|
SmallVector<Value> handles;
|
|
assert(succeeded(state.getHandlesForPayloadOp(payload, handles)));
|
|
assert(llvm::is_contained(handles, getOperand()) &&
|
|
"inconsistent mapping between transform IR handles and payload IR "
|
|
"operations");
|
|
#endif // NDEBUG
|
|
}
|
|
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void mlir::test::TestCheckIfTestExtensionPresentOp::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
transform::onlyReadsHandle(getOperation()->getOpOperands(), effects);
|
|
transform::onlyReadsPayload(effects);
|
|
}
|
|
|
|
DiagnosedSilenceableFailure mlir::test::TestRemapOperandPayloadToSelfOp::apply(
|
|
transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results, transform::TransformState &state) {
|
|
auto *extension = state.getExtension<TestTransformStateExtension>();
|
|
if (!extension)
|
|
return emitDefiniteFailure("TestTransformStateExtension missing");
|
|
|
|
if (failed(extension->updateMapping(
|
|
*state.getPayloadOps(getOperand()).begin(), getOperation())))
|
|
return DiagnosedSilenceableFailure::definiteFailure();
|
|
if (getNumResults() > 0)
|
|
results.set(cast<OpResult>(getResult(0)), {getOperation()});
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void mlir::test::TestRemapOperandPayloadToSelfOp::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
transform::onlyReadsHandle(getOperation()->getOpOperands(), effects);
|
|
transform::producesHandle(getOperation()->getOpResults(), effects);
|
|
transform::onlyReadsPayload(effects);
|
|
}
|
|
|
|
DiagnosedSilenceableFailure mlir::test::TestRemoveTestExtensionOp::apply(
|
|
transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results, transform::TransformState &state) {
|
|
state.removeExtension<TestTransformStateExtension>();
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
DiagnosedSilenceableFailure mlir::test::TestReversePayloadOpsOp::apply(
|
|
transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results, transform::TransformState &state) {
|
|
auto payloadOps = state.getPayloadOps(getTarget());
|
|
auto reversedOps = llvm::to_vector(llvm::reverse(payloadOps));
|
|
results.set(llvm::cast<OpResult>(getResult()), reversedOps);
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
DiagnosedSilenceableFailure mlir::test::TestTransformOpWithRegions::apply(
|
|
transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results, transform::TransformState &state) {
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void mlir::test::TestTransformOpWithRegions::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
|
|
|
|
DiagnosedSilenceableFailure
|
|
mlir::test::TestBranchingTransformOpTerminator::apply(
|
|
transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results, transform::TransformState &state) {
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void mlir::test::TestBranchingTransformOpTerminator::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
|
|
|
|
DiagnosedSilenceableFailure mlir::test::TestEmitRemarkAndEraseOperandOp::apply(
|
|
transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results, transform::TransformState &state) {
|
|
emitRemark() << getRemark();
|
|
for (Operation *op : state.getPayloadOps(getTarget()))
|
|
rewriter.eraseOp(op);
|
|
|
|
if (getFailAfterErase())
|
|
return emitSilenceableError() << "silenceable error";
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void mlir::test::TestEmitRemarkAndEraseOperandOp::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
transform::consumesHandle(getTargetMutable(), effects);
|
|
transform::modifiesPayload(effects);
|
|
}
|
|
|
|
DiagnosedSilenceableFailure mlir::test::TestWrongNumberOfResultsOp::applyToOne(
|
|
transform::TransformRewriter &rewriter, Operation *target,
|
|
transform::ApplyToEachResultList &results,
|
|
transform::TransformState &state) {
|
|
OperationState opState(target->getLoc(), "foo");
|
|
results.push_back(OpBuilder(target).create(opState));
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
DiagnosedSilenceableFailure
|
|
mlir::test::TestWrongNumberOfMultiResultsOp::applyToOne(
|
|
transform::TransformRewriter &rewriter, Operation *target,
|
|
transform::ApplyToEachResultList &results,
|
|
transform::TransformState &state) {
|
|
static int count = 0;
|
|
if (count++ == 0) {
|
|
OperationState opState(target->getLoc(), "foo");
|
|
results.push_back(OpBuilder(target).create(opState));
|
|
}
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
DiagnosedSilenceableFailure
|
|
mlir::test::TestCorrectNumberOfMultiResultsOp::applyToOne(
|
|
transform::TransformRewriter &rewriter, Operation *target,
|
|
transform::ApplyToEachResultList &results,
|
|
transform::TransformState &state) {
|
|
OperationState opState(target->getLoc(), "foo");
|
|
results.push_back(OpBuilder(target).create(opState));
|
|
results.push_back(OpBuilder(target).create(opState));
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
DiagnosedSilenceableFailure
|
|
mlir::test::TestMixedNullAndNonNullResultsOp::applyToOne(
|
|
transform::TransformRewriter &rewriter, Operation *target,
|
|
transform::ApplyToEachResultList &results,
|
|
transform::TransformState &state) {
|
|
OperationState opState(target->getLoc(), "foo");
|
|
results.push_back(nullptr);
|
|
results.push_back(OpBuilder(target).create(opState));
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
DiagnosedSilenceableFailure
|
|
mlir::test::TestMixedSuccessAndSilenceableOp::applyToOne(
|
|
transform::TransformRewriter &rewriter, Operation *target,
|
|
transform::ApplyToEachResultList &results,
|
|
transform::TransformState &state) {
|
|
if (target->hasAttr("target_me"))
|
|
return DiagnosedSilenceableFailure::success();
|
|
return emitDefaultSilenceableFailure(target);
|
|
}
|
|
|
|
DiagnosedSilenceableFailure
|
|
mlir::test::TestCopyPayloadOp::apply(transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results,
|
|
transform::TransformState &state) {
|
|
results.set(llvm::cast<OpResult>(getCopy()),
|
|
state.getPayloadOps(getHandle()));
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void mlir::test::TestCopyPayloadOp::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
transform::onlyReadsHandle(getHandleMutable(), effects);
|
|
transform::producesHandle(getOperation()->getOpResults(), effects);
|
|
transform::onlyReadsPayload(effects);
|
|
}
|
|
|
|
DiagnosedSilenceableFailure mlir::transform::TestDialectOpType::checkPayload(
|
|
Location loc, ArrayRef<Operation *> payload) const {
|
|
if (payload.empty())
|
|
return DiagnosedSilenceableFailure::success();
|
|
|
|
for (Operation *op : payload) {
|
|
if (op->getName().getDialectNamespace() != "test") {
|
|
return emitSilenceableError(loc) << "expected the payload operation to "
|
|
"belong to the 'test' dialect";
|
|
}
|
|
}
|
|
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
DiagnosedSilenceableFailure mlir::transform::TestDialectParamType::checkPayload(
|
|
Location loc, ArrayRef<Attribute> payload) const {
|
|
for (Attribute attr : payload) {
|
|
auto integerAttr = llvm::dyn_cast<IntegerAttr>(attr);
|
|
if (integerAttr && integerAttr.getType().isSignlessInteger(32))
|
|
continue;
|
|
return emitSilenceableError(loc)
|
|
<< "expected the parameter to be a i32 integer attribute";
|
|
}
|
|
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void mlir::test::TestReportNumberOfTrackedHandlesNestedUnder::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
transform::onlyReadsHandle(getTargetMutable(), effects);
|
|
}
|
|
|
|
DiagnosedSilenceableFailure
|
|
mlir::test::TestReportNumberOfTrackedHandlesNestedUnder::apply(
|
|
transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results, transform::TransformState &state) {
|
|
int64_t count = 0;
|
|
for (Operation *op : state.getPayloadOps(getTarget())) {
|
|
op->walk([&](Operation *nested) {
|
|
SmallVector<Value> handles;
|
|
(void)state.getHandlesForPayloadOp(nested, handles);
|
|
count += handles.size();
|
|
});
|
|
}
|
|
emitRemark() << count << " handles nested under";
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
DiagnosedSilenceableFailure
|
|
mlir::test::TestAddToParamOp::apply(transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results,
|
|
transform::TransformState &state) {
|
|
SmallVector<uint32_t> values(/*Size=*/1, /*Value=*/0);
|
|
if (Value param = getParam()) {
|
|
values = llvm::to_vector(
|
|
llvm::map_range(state.getParams(param), [](Attribute attr) -> uint32_t {
|
|
return llvm::cast<IntegerAttr>(attr).getValue().getLimitedValue(
|
|
UINT32_MAX);
|
|
}));
|
|
}
|
|
|
|
Builder builder(getContext());
|
|
SmallVector<Attribute> result = llvm::to_vector(
|
|
llvm::map_range(values, [this, &builder](uint32_t value) -> Attribute {
|
|
return builder.getI32IntegerAttr(value + getAddendum());
|
|
}));
|
|
results.setParams(llvm::cast<OpResult>(getResult()), result);
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
DiagnosedSilenceableFailure
|
|
mlir::test::TestProduceParamWithNumberOfTestOps::apply(
|
|
transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results, transform::TransformState &state) {
|
|
Builder builder(getContext());
|
|
SmallVector<Attribute> result = llvm::to_vector(
|
|
llvm::map_range(state.getPayloadOps(getHandle()),
|
|
[&builder](Operation *payload) -> Attribute {
|
|
int32_t count = 0;
|
|
payload->walk([&count](Operation *op) {
|
|
if (op->getName().getDialectNamespace() == "test")
|
|
++count;
|
|
});
|
|
return builder.getI32IntegerAttr(count);
|
|
}));
|
|
results.setParams(llvm::cast<OpResult>(getResult()), result);
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
DiagnosedSilenceableFailure
|
|
mlir::test::TestProduceParamOp::apply(transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results,
|
|
transform::TransformState &state) {
|
|
results.setParams(llvm::cast<OpResult>(getResult()), getAttr());
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void mlir::test::TestProduceTransformParamOrForwardOperandOp::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
transform::onlyReadsHandle(getInMutable(), effects);
|
|
transform::producesHandle(getOperation()->getOpResults(), effects);
|
|
}
|
|
|
|
DiagnosedSilenceableFailure
|
|
mlir::test::TestProduceTransformParamOrForwardOperandOp::applyToOne(
|
|
transform::TransformRewriter &rewriter, Operation *target,
|
|
::transform::ApplyToEachResultList &results,
|
|
::transform::TransformState &state) {
|
|
Builder builder(getContext());
|
|
if (getFirstResultIsParam()) {
|
|
results.push_back(builder.getI64IntegerAttr(0));
|
|
} else if (getFirstResultIsNull()) {
|
|
results.push_back(nullptr);
|
|
} else {
|
|
results.push_back(*state.getPayloadOps(getIn()).begin());
|
|
}
|
|
|
|
if (getSecondResultIsHandle()) {
|
|
results.push_back(*state.getPayloadOps(getIn()).begin());
|
|
} else {
|
|
results.push_back(builder.getI64IntegerAttr(42));
|
|
}
|
|
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void mlir::test::TestProduceNullPayloadOp::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
transform::producesHandle(getOperation()->getOpResults(), effects);
|
|
}
|
|
|
|
DiagnosedSilenceableFailure mlir::test::TestProduceNullPayloadOp::apply(
|
|
transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results, transform::TransformState &state) {
|
|
SmallVector<Operation *, 1> null({nullptr});
|
|
results.set(llvm::cast<OpResult>(getOut()), null);
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
DiagnosedSilenceableFailure mlir::test::TestProduceEmptyPayloadOp::apply(
|
|
transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results, transform::TransformState &state) {
|
|
results.set(cast<OpResult>(getOut()), {});
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void mlir::test::TestProduceNullParamOp::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
transform::producesHandle(getOperation()->getOpResults(), effects);
|
|
}
|
|
|
|
DiagnosedSilenceableFailure mlir::test::TestProduceNullParamOp::apply(
|
|
transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results, transform::TransformState &state) {
|
|
results.setParams(llvm::cast<OpResult>(getOut()), Attribute());
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void mlir::test::TestProduceNullValueOp::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
transform::producesHandle(getOperation()->getOpResults(), effects);
|
|
}
|
|
|
|
DiagnosedSilenceableFailure mlir::test::TestProduceNullValueOp::apply(
|
|
transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results, transform::TransformState &state) {
|
|
results.setValues(llvm::cast<OpResult>(getOut()), {Value()});
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void mlir::test::TestRequiredMemoryEffectsOp::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
if (getHasOperandEffect())
|
|
transform::consumesHandle(getInMutable(), effects);
|
|
|
|
if (getHasResultEffect()) {
|
|
transform::producesHandle(getOperation()->getOpResults(), effects);
|
|
} else {
|
|
effects.emplace_back(MemoryEffects::Read::get(),
|
|
llvm::cast<OpResult>(getOut()),
|
|
transform::TransformMappingResource::get());
|
|
}
|
|
|
|
if (getModifiesPayload())
|
|
transform::modifiesPayload(effects);
|
|
}
|
|
|
|
DiagnosedSilenceableFailure mlir::test::TestRequiredMemoryEffectsOp::apply(
|
|
transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results, transform::TransformState &state) {
|
|
results.set(llvm::cast<OpResult>(getOut()), state.getPayloadOps(getIn()));
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void mlir::test::TestTrackedRewriteOp::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
transform::onlyReadsHandle(getInMutable(), effects);
|
|
transform::modifiesPayload(effects);
|
|
}
|
|
|
|
void mlir::test::TestDummyPayloadOp::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
transform::producesHandle(getOperation()->getOpResults(), effects);
|
|
}
|
|
|
|
LogicalResult mlir::test::TestDummyPayloadOp::verify() {
|
|
if (getFailToVerify())
|
|
return emitOpError() << "fail_to_verify is set";
|
|
return success();
|
|
}
|
|
|
|
DiagnosedSilenceableFailure
|
|
mlir::test::TestTrackedRewriteOp::apply(transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results,
|
|
transform::TransformState &state) {
|
|
int64_t numIterations = 0;
|
|
|
|
// `getPayloadOps` returns an iterator that skips ops that are erased in the
|
|
// loop body. Replacement ops are not enumerated.
|
|
for (Operation *op : state.getPayloadOps(getIn())) {
|
|
++numIterations;
|
|
(void)op;
|
|
|
|
// Erase all payload ops. The outer loop should have only one iteration.
|
|
for (Operation *op : state.getPayloadOps(getIn())) {
|
|
rewriter.setInsertionPoint(op);
|
|
if (op->hasAttr("erase_me")) {
|
|
rewriter.eraseOp(op);
|
|
continue;
|
|
}
|
|
if (!op->hasAttr("replace_me")) {
|
|
continue;
|
|
}
|
|
|
|
SmallVector<NamedAttribute> attributes;
|
|
attributes.emplace_back(rewriter.getStringAttr("new_op"),
|
|
rewriter.getUnitAttr());
|
|
OperationState opState(op->getLoc(), op->getName().getIdentifier(),
|
|
/*operands=*/ValueRange(),
|
|
/*types=*/op->getResultTypes(), attributes);
|
|
Operation *newOp = rewriter.create(opState);
|
|
rewriter.replaceOp(op, newOp->getResults());
|
|
}
|
|
}
|
|
|
|
emitRemark() << numIterations << " iterations";
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
namespace {
|
|
// Test pattern to replace an operation with a new op.
|
|
class ReplaceWithNewOp : public RewritePattern {
|
|
public:
|
|
ReplaceWithNewOp(MLIRContext *context)
|
|
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
|
|
|
|
LogicalResult matchAndRewrite(Operation *op,
|
|
PatternRewriter &rewriter) const override {
|
|
auto newName = op->getAttrOfType<StringAttr>("replace_with_new_op");
|
|
if (!newName)
|
|
return failure();
|
|
Operation *newOp = rewriter.create(
|
|
op->getLoc(), OperationName(newName, op->getContext()).getIdentifier(),
|
|
op->getOperands(), op->getResultTypes());
|
|
rewriter.replaceOp(op, newOp->getResults());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
// Test pattern to erase an operation.
|
|
class EraseOp : public RewritePattern {
|
|
public:
|
|
EraseOp(MLIRContext *context)
|
|
: RewritePattern("test.erase_op", /*benefit=*/1, context) {}
|
|
LogicalResult matchAndRewrite(Operation *op,
|
|
PatternRewriter &rewriter) const override {
|
|
rewriter.eraseOp(op);
|
|
return success();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
void mlir::test::ApplyTestPatternsOp::populatePatterns(
|
|
RewritePatternSet &patterns) {
|
|
patterns.insert<ReplaceWithNewOp, EraseOp>(patterns.getContext());
|
|
}
|
|
|
|
void mlir::test::TestReEnterRegionOp::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
transform::consumesHandle(getOperation()->getOpOperands(), effects);
|
|
transform::modifiesPayload(effects);
|
|
}
|
|
|
|
DiagnosedSilenceableFailure
|
|
mlir::test::TestReEnterRegionOp::apply(transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results,
|
|
transform::TransformState &state) {
|
|
|
|
SmallVector<SmallVector<transform::MappedValue>> mappings;
|
|
for (BlockArgument arg : getBody().front().getArguments()) {
|
|
mappings.emplace_back(llvm::to_vector(llvm::map_range(
|
|
state.getPayloadOps(getOperand(arg.getArgNumber())),
|
|
[](Operation *op) -> transform::MappedValue { return op; })));
|
|
}
|
|
|
|
for (int i = 0; i < 4; ++i) {
|
|
auto scope = state.make_region_scope(getBody());
|
|
for (BlockArgument arg : getBody().front().getArguments()) {
|
|
if (failed(state.mapBlockArgument(arg, mappings[arg.getArgNumber()])))
|
|
return DiagnosedSilenceableFailure::definiteFailure();
|
|
}
|
|
for (Operation &op : getBody().front().without_terminator()) {
|
|
DiagnosedSilenceableFailure diag =
|
|
state.applyTransform(cast<transform::TransformOpInterface>(op));
|
|
if (!diag.succeeded())
|
|
return diag;
|
|
}
|
|
}
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
LogicalResult mlir::test::TestReEnterRegionOp::verify() {
|
|
if (getNumOperands() != getBody().front().getNumArguments()) {
|
|
return emitOpError() << "expects as many operands as block arguments";
|
|
}
|
|
return success();
|
|
}
|
|
|
|
DiagnosedSilenceableFailure mlir::test::TestNotifyPayloadOpReplacedOp::apply(
|
|
transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results, transform::TransformState &state) {
|
|
auto originalOps = state.getPayloadOps(getOriginal());
|
|
auto replacementOps = state.getPayloadOps(getReplacement());
|
|
if (llvm::range_size(originalOps) != llvm::range_size(replacementOps))
|
|
return emitSilenceableError() << "expected same number of original and "
|
|
"replacement payload operations";
|
|
for (const auto &[original, replacement] :
|
|
llvm::zip(originalOps, replacementOps)) {
|
|
if (failed(
|
|
rewriter.notifyPayloadOperationReplaced(original, replacement))) {
|
|
auto diag = emitSilenceableError()
|
|
<< "unable to replace payload op in transform mapping";
|
|
diag.attachNote(original->getLoc()) << "original payload op";
|
|
diag.attachNote(replacement->getLoc()) << "replacement payload op";
|
|
return diag;
|
|
}
|
|
}
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void mlir::test::TestNotifyPayloadOpReplacedOp::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
transform::onlyReadsHandle(getOriginalMutable(), effects);
|
|
transform::onlyReadsHandle(getReplacementMutable(), effects);
|
|
}
|
|
|
|
DiagnosedSilenceableFailure mlir::test::TestProduceInvalidIR::applyToOne(
|
|
transform::TransformRewriter &rewriter, Operation *target,
|
|
transform::ApplyToEachResultList &results,
|
|
transform::TransformState &state) {
|
|
// Provide some IR that does not verify.
|
|
rewriter.setInsertionPointToStart(&target->getRegion(0).front());
|
|
rewriter.create<TestDummyPayloadOp>(target->getLoc(), TypeRange(),
|
|
ValueRange(), /*failToVerify=*/true);
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void mlir::test::TestProduceInvalidIR::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
transform::onlyReadsHandle(getTargetMutable(), effects);
|
|
transform::modifiesPayload(effects);
|
|
}
|
|
|
|
namespace {
|
|
/// Test conversion pattern that replaces ops with the "replace_with_new_op"
|
|
/// attribute with "test.new_op".
|
|
class ReplaceWithNewOpConversion : public ConversionPattern {
|
|
public:
|
|
ReplaceWithNewOpConversion(TypeConverter &typeConverter, MLIRContext *context)
|
|
: ConversionPattern(typeConverter, RewritePattern::MatchAnyOpTypeTag(),
|
|
/*benefit=*/1, context) {}
|
|
|
|
LogicalResult
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
if (!op->hasAttr("replace_with_new_op"))
|
|
return failure();
|
|
SmallVector<Type> newResultTypes;
|
|
if (failed(getTypeConverter()->convertTypes(op->getResultTypes(),
|
|
newResultTypes)))
|
|
return failure();
|
|
Operation *newOp = rewriter.create(
|
|
op->getLoc(),
|
|
OperationName("test.new_op", op->getContext()).getIdentifier(),
|
|
operands, newResultTypes);
|
|
rewriter.replaceOp(op, newOp->getResults());
|
|
return success();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
void mlir::test::ApplyTestConversionPatternsOp::populatePatterns(
|
|
TypeConverter &typeConverter, RewritePatternSet &patterns) {
|
|
patterns.insert<ReplaceWithNewOpConversion>(typeConverter,
|
|
patterns.getContext());
|
|
}
|
|
|
|
namespace {
|
|
/// Test type converter that converts tensor types to memref types.
|
|
class TestTypeConverter : public TypeConverter {
|
|
public:
|
|
TestTypeConverter() {
|
|
addConversion([](Type t) { return t; });
|
|
addConversion([](RankedTensorType type) -> Type {
|
|
return MemRefType::get(type.getShape(), type.getElementType());
|
|
});
|
|
auto unrealizedCastConverter = [&](OpBuilder &builder, Type resultType,
|
|
ValueRange inputs,
|
|
Location loc) -> std::optional<Value> {
|
|
if (inputs.size() != 1)
|
|
return std::nullopt;
|
|
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
|
|
.getResult(0);
|
|
};
|
|
addSourceMaterialization(unrealizedCastConverter);
|
|
addTargetMaterialization(unrealizedCastConverter);
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
std::unique_ptr<::mlir::TypeConverter>
|
|
mlir::test::TestTypeConverterOp::getTypeConverter() {
|
|
return std::make_unique<TestTypeConverter>();
|
|
}
|
|
|
|
namespace {
|
|
/// Test extension of the Transform dialect. Registers additional ops and
|
|
/// declares PDL as dependent dialect since the additional ops are using PDL
|
|
/// types for operands and results.
|
|
class TestTransformDialectExtension
|
|
: public transform::TransformDialectExtension<
|
|
TestTransformDialectExtension> {
|
|
public:
|
|
using Base::Base;
|
|
|
|
void init() {
|
|
declareDependentDialect<pdl::PDLDialect>();
|
|
registerTransformOps<TestTransformOp,
|
|
TestTransformUnrestrictedOpNoInterface,
|
|
#define GET_OP_LIST
|
|
#include "TestTransformDialectExtension.cpp.inc"
|
|
>();
|
|
registerTypes<
|
|
#define GET_TYPEDEF_LIST
|
|
#include "TestTransformDialectExtensionTypes.cpp.inc"
|
|
>();
|
|
|
|
auto verboseConstraint = [](PatternRewriter &rewriter, PDLResultList &,
|
|
ArrayRef<PDLValue> pdlValues) {
|
|
for (const PDLValue &pdlValue : pdlValues) {
|
|
if (Operation *op = pdlValue.dyn_cast<Operation *>()) {
|
|
op->emitWarning() << "from PDL constraint";
|
|
}
|
|
}
|
|
return success();
|
|
};
|
|
|
|
addDialectDataInitializer<transform::PDLMatchHooks>(
|
|
[&](transform::PDLMatchHooks &hooks) {
|
|
llvm::StringMap<PDLConstraintFunction> constraints;
|
|
constraints.try_emplace("verbose_constraint", verboseConstraint);
|
|
hooks.mergeInPDLMatchHooks(std::move(constraints));
|
|
});
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
// These are automatically generated by ODS but are not used as the Transform
|
|
// dialect uses a different dispatch mechanism to support dialect extensions.
|
|
LLVM_ATTRIBUTE_UNUSED static OptionalParseResult
|
|
generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value);
|
|
LLVM_ATTRIBUTE_UNUSED static LogicalResult
|
|
generatedTypePrinter(Type def, AsmPrinter &printer);
|
|
|
|
#define GET_TYPEDEF_CLASSES
|
|
#include "TestTransformDialectExtensionTypes.cpp.inc"
|
|
|
|
#define GET_OP_CLASSES
|
|
#include "TestTransformDialectExtension.cpp.inc"
|
|
|
|
void ::test::registerTestTransformDialectExtension(DialectRegistry ®istry) {
|
|
registry.addExtensions<TestTransformDialectExtension>();
|
|
}
|