//===- TestTransformDialectExtension.cpp ----------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file defines an extension of the MLIR Transform dialect for testing // purposes. // //===----------------------------------------------------------------------===// #include "TestTransformDialectExtension.h" #include "TestTransformStateExtension.h" #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/IR/OpImplementation.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Compiler.h" #include "llvm/Support/raw_ostream.h" using namespace mlir; namespace { /// Simple transform op defined outside of the dialect. Just emits a remark when /// applied. This op is defined in C++ to test that C++ definitions also work /// for op injection into the Transform dialect. class TestTransformOp : public Op { public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformOp) using Op::Op; static ArrayRef getAttributeNames() { return {}; } static constexpr llvm::StringLiteral getOperationName() { return llvm::StringLiteral("transform.test_transform_op"); } DiagnosedSilenceableFailure apply(transform::TransformResults &results, transform::TransformState &state) { InFlightDiagnostic remark = emitRemark() << "applying transformation"; if (Attribute message = getMessage()) remark << " " << message; return DiagnosedSilenceableFailure::success(); } Attribute getMessage() { return getOperation()->getAttr("message"); } static ParseResult parse(OpAsmParser &parser, OperationState &state) { StringAttr message; OptionalParseResult result = parser.parseOptionalAttribute(message); if (!result.has_value()) return success(); if (result.value().succeeded()) state.addAttribute("message", message); return result.value(); } void print(OpAsmPrinter &printer) { if (getMessage()) printer << " " << getMessage(); } // No side effects. void getEffects(SmallVectorImpl &effects) {} }; /// A test op to exercise the verifier of the PossibleTopLevelTransformOpTrait /// in cases where it is attached to ops that do not comply with the trait /// requirements. This op cannot be defined in ODS because ODS generates strict /// verifiers that overalp with those in the trait and run earlier. class TestTransformUnrestrictedOpNoInterface : public Op { public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( TestTransformUnrestrictedOpNoInterface) using Op::Op; static ArrayRef getAttributeNames() { return {}; } static constexpr llvm::StringLiteral getOperationName() { return llvm::StringLiteral( "transform.test_transform_unrestricted_op_no_interface"); } DiagnosedSilenceableFailure apply(transform::TransformResults &results, transform::TransformState &state) { return DiagnosedSilenceableFailure::success(); } // No side effects. void getEffects(SmallVectorImpl &effects) {} }; } // namespace DiagnosedSilenceableFailure mlir::test::TestProduceParamOrForwardOperandOp::apply( transform::TransformResults &results, transform::TransformState &state) { if (getOperation()->getNumOperands() != 0) { results.set(getResult().cast(), getOperation()->getOperand(0).getDefiningOp()); } else { results.set(getResult().cast(), reinterpret_cast(*getParameter())); } return DiagnosedSilenceableFailure::success(); } LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::verify() { if (getParameter().has_value() ^ (getNumOperands() != 1)) return emitOpError() << "expects either a parameter or an operand"; return success(); } DiagnosedSilenceableFailure mlir::test::TestConsumeOperand::apply(transform::TransformResults &results, transform::TransformState &state) { return DiagnosedSilenceableFailure::success(); } DiagnosedSilenceableFailure mlir::test::TestConsumeOperandIfMatchesParamOrFail::apply( transform::TransformResults &results, transform::TransformState &state) { ArrayRef payload = state.getPayloadOps(getOperand()); assert(payload.size() == 1 && "expected a single target op"); auto value = reinterpret_cast(payload[0]); if (static_cast(value) != getParameter()) { return emitSilenceableError() << "op expected the operand to be associated with " << getParameter() << " got " << value; } emitRemark() << "succeeded"; return DiagnosedSilenceableFailure::success(); } DiagnosedSilenceableFailure mlir::test::TestPrintRemarkAtOperandOp::apply( transform::TransformResults &results, transform::TransformState &state) { ArrayRef payload = state.getPayloadOps(getOperand()); for (Operation *op : payload) op->emitRemark() << getMessage(); return DiagnosedSilenceableFailure::success(); } DiagnosedSilenceableFailure mlir::test::TestAddTestExtensionOp::apply(transform::TransformResults &results, transform::TransformState &state) { state.addExtension(getMessageAttr()); return DiagnosedSilenceableFailure::success(); } DiagnosedSilenceableFailure mlir::test::TestCheckIfTestExtensionPresentOp::apply( transform::TransformResults &results, transform::TransformState &state) { auto *extension = state.getExtension(); if (!extension) { emitRemark() << "extension absent"; return DiagnosedSilenceableFailure::success(); } InFlightDiagnostic diag = emitRemark() << "extension present, " << extension->getMessage(); for (Operation *payload : state.getPayloadOps(getOperand())) { diag.attachNote(payload->getLoc()) << "associated payload op"; #ifndef NDEBUG SmallVector handles; assert(succeeded(state.getHandlesForPayloadOp(payload, handles))); assert(llvm::is_contained(handles, getOperand()) && "inconsistent mapping between transform IR handles and payload IR " "operations"); #endif // NDEBUG } return DiagnosedSilenceableFailure::success(); } DiagnosedSilenceableFailure mlir::test::TestRemapOperandPayloadToSelfOp::apply( transform::TransformResults &results, transform::TransformState &state) { auto *extension = state.getExtension(); if (!extension) return emitDefiniteFailure("TestTransformStateExtension missing"); if (failed(extension->updateMapping(state.getPayloadOps(getOperand()).front(), getOperation()))) return DiagnosedSilenceableFailure::definiteFailure(); return DiagnosedSilenceableFailure::success(); } DiagnosedSilenceableFailure mlir::test::TestRemoveTestExtensionOp::apply( transform::TransformResults &results, transform::TransformState &state) { state.removeExtension(); return DiagnosedSilenceableFailure::success(); } DiagnosedSilenceableFailure mlir::test::TestReversePayloadOpsOp::apply(transform::TransformResults &results, transform::TransformState &state) { ArrayRef payloadOps = state.getPayloadOps(getTarget()); auto reversedOps = llvm::to_vector(llvm::reverse(payloadOps)); results.set(getResult().cast(), reversedOps); return DiagnosedSilenceableFailure::success(); } DiagnosedSilenceableFailure mlir::test::TestTransformOpWithRegions::apply( transform::TransformResults &results, transform::TransformState &state) { return DiagnosedSilenceableFailure::success(); } void mlir::test::TestTransformOpWithRegions::getEffects( SmallVectorImpl &effects) {} DiagnosedSilenceableFailure mlir::test::TestBranchingTransformOpTerminator::apply( transform::TransformResults &results, transform::TransformState &state) { return DiagnosedSilenceableFailure::success(); } void mlir::test::TestBranchingTransformOpTerminator::getEffects( SmallVectorImpl &effects) {} DiagnosedSilenceableFailure mlir::test::TestEmitRemarkAndEraseOperandOp::apply( transform::TransformResults &results, transform::TransformState &state) { emitRemark() << getRemark(); for (Operation *op : state.getPayloadOps(getTarget())) op->erase(); if (getFailAfterErase()) return emitSilenceableError() << "silenceable error"; return DiagnosedSilenceableFailure::success(); } DiagnosedSilenceableFailure mlir::test::TestWrongNumberOfResultsOp::applyToOne( Operation *target, transform::ApplyToEachResultList &results, transform::TransformState &state) { OperationState opState(target->getLoc(), "foo"); results.push_back(OpBuilder(target).create(opState)); return DiagnosedSilenceableFailure::success(); } DiagnosedSilenceableFailure mlir::test::TestWrongNumberOfMultiResultsOp::applyToOne( Operation *target, transform::ApplyToEachResultList &results, transform::TransformState &state) { static int count = 0; if (count++ == 0) { OperationState opState(target->getLoc(), "foo"); results.push_back(OpBuilder(target).create(opState)); } return DiagnosedSilenceableFailure::success(); } DiagnosedSilenceableFailure mlir::test::TestCorrectNumberOfMultiResultsOp::applyToOne( Operation *target, transform::ApplyToEachResultList &results, transform::TransformState &state) { OperationState opState(target->getLoc(), "foo"); results.push_back(OpBuilder(target).create(opState)); results.push_back(OpBuilder(target).create(opState)); return DiagnosedSilenceableFailure::success(); } DiagnosedSilenceableFailure mlir::test::TestMixedNullAndNonNullResultsOp::applyToOne( Operation *target, transform::ApplyToEachResultList &results, transform::TransformState &state) { OperationState opState(target->getLoc(), "foo"); results.push_back(nullptr); results.push_back(OpBuilder(target).create(opState)); return DiagnosedSilenceableFailure::success(); } DiagnosedSilenceableFailure mlir::test::TestMixedSuccessAndSilenceableOp::applyToOne( Operation *target, transform::ApplyToEachResultList &results, transform::TransformState &state) { if (target->hasAttr("target_me")) return DiagnosedSilenceableFailure::success(); return emitDefaultSilenceableFailure(target); } DiagnosedSilenceableFailure mlir::test::TestPrintNumberOfAssociatedPayloadIROps::apply( transform::TransformResults &results, transform::TransformState &state) { if (!getHandle()) emitRemark() << 0; emitRemark() << state.getPayloadOps(getHandle()).size(); return DiagnosedSilenceableFailure::success(); } void mlir::test::TestPrintNumberOfAssociatedPayloadIROps::getEffects( SmallVectorImpl &effects) { transform::onlyReadsHandle(getHandle(), effects); } DiagnosedSilenceableFailure mlir::test::TestCopyPayloadOp::apply(transform::TransformResults &results, transform::TransformState &state) { results.set(getCopy().cast(), state.getPayloadOps(getHandle())); return DiagnosedSilenceableFailure::success(); } DiagnosedSilenceableFailure mlir::transform::TestDialectOpType::checkPayload( Location loc, ArrayRef payload) const { if (payload.empty()) return DiagnosedSilenceableFailure::success(); for (Operation *op : payload) { if (op->getName().getDialectNamespace() != "test") { return emitSilenceableError(loc) << "expected the payload operation to " "belong to the 'test' dialect"; } } return DiagnosedSilenceableFailure::success(); } DiagnosedSilenceableFailure mlir::transform::TestDialectParamType::checkPayload( Location loc, ArrayRef payload) const { for (Attribute attr : payload) { auto integerAttr = attr.dyn_cast(); if (integerAttr && integerAttr.getType().isSignlessInteger(32)) continue; return emitSilenceableError(loc) << "expected the parameter to be a i32 integer attribute"; } return DiagnosedSilenceableFailure::success(); } void mlir::test::TestReportNumberOfTrackedHandlesNestedUnder::getEffects( SmallVectorImpl &effects) { transform::onlyReadsHandle(getTarget(), effects); } DiagnosedSilenceableFailure mlir::test::TestReportNumberOfTrackedHandlesNestedUnder::apply( transform::TransformResults &results, transform::TransformState &state) { int64_t count = 0; for (Operation *op : state.getPayloadOps(getTarget())) { op->walk([&](Operation *nested) { SmallVector handles; (void)state.getHandlesForPayloadOp(nested, handles); count += handles.size(); }); } emitRemark() << count << " handles nested under"; return DiagnosedSilenceableFailure::success(); } void mlir::test::TestPrintParamOp::getEffects( SmallVectorImpl &effects) { transform::onlyReadsHandle(getParam(), effects); } DiagnosedSilenceableFailure mlir::test::TestPrintParamOp::apply(transform::TransformResults &results, transform::TransformState &state) { std::string str; llvm::raw_string_ostream os(str); llvm::interleaveComma(state.getParams(getParam()), os); auto diag = emitRemark() << os.str(); return DiagnosedSilenceableFailure::success(); } DiagnosedSilenceableFailure mlir::test::TestAddToParamOp::apply(transform::TransformResults &results, transform::TransformState &state) { SmallVector values(/*Size=*/1, /*Value=*/0); if (Value param = getParam()) { values = llvm::to_vector( llvm::map_range(state.getParams(param), [](Attribute attr) -> uint32_t { return attr.cast().getValue().getLimitedValue( UINT32_MAX); })); } Builder builder(getContext()); SmallVector result = llvm::to_vector( llvm::map_range(values, [this, &builder](uint32_t value) -> Attribute { return builder.getI32IntegerAttr(value + getAddendum()); })); results.setParams(getResult().cast(), result); return DiagnosedSilenceableFailure::success(); } DiagnosedSilenceableFailure mlir::test::TestProduceParamWithNumberOfTestOps::apply( transform::TransformResults &results, transform::TransformState &state) { Builder builder(getContext()); SmallVector result = llvm::to_vector( llvm::map_range(state.getPayloadOps(getHandle()), [&builder](Operation *payload) -> Attribute { int32_t count = 0; payload->walk([&count](Operation *op) { if (op->getName().getDialectNamespace() == "test") ++count; }); return builder.getI32IntegerAttr(count); })); results.setParams(getResult().cast(), result); return DiagnosedSilenceableFailure::success(); } DiagnosedSilenceableFailure mlir::test::TestProduceIntegerParamWithTypeOp::apply( transform::TransformResults &results, transform::TransformState &state) { Attribute zero = IntegerAttr::get(getType(), 0); results.setParams(getResult().cast(), zero); return DiagnosedSilenceableFailure::success(); } LogicalResult mlir::test::TestProduceIntegerParamWithTypeOp::verify() { if (!getType().isa()) { return emitOpError() << "expects an integer type"; } return success(); } void mlir::test::TestProduceTransformParamOrForwardOperandOp::getEffects( SmallVectorImpl &effects) { transform::onlyReadsHandle(getIn(), effects); transform::producesHandle(getOut(), effects); transform::producesHandle(getParam(), effects); } DiagnosedSilenceableFailure mlir::test::TestProduceTransformParamOrForwardOperandOp::applyToOne( Operation *target, ::transform::ApplyToEachResultList &results, ::transform::TransformState &state) { Builder builder(getContext()); if (getFirstResultIsParam()) { results.push_back(builder.getI64IntegerAttr(0)); } else if (getFirstResultIsNull()) { results.push_back(nullptr); } else { results.push_back(state.getPayloadOps(getIn()).front()); } if (getSecondResultIsHandle()) { results.push_back(state.getPayloadOps(getIn()).front()); } else { results.push_back(builder.getI64IntegerAttr(42)); } return DiagnosedSilenceableFailure::success(); } void mlir::test::TestProduceNullPayloadOp::getEffects( SmallVectorImpl &effects) { transform::producesHandle(getOut(), effects); } DiagnosedSilenceableFailure mlir::test::TestProduceNullPayloadOp::apply( transform::TransformResults &results, transform::TransformState &state) { SmallVector null({nullptr}); results.set(getOut().cast(), null); return DiagnosedSilenceableFailure::success(); } void mlir::test::TestProduceNullParamOp::getEffects( SmallVectorImpl &effects) {} DiagnosedSilenceableFailure mlir::test::TestProduceNullParamOp::apply(transform::TransformResults &results, transform::TransformState &state) { results.setParams(getOut().cast(), Attribute()); return DiagnosedSilenceableFailure::success(); } namespace { /// Test extension of the Transform dialect. Registers additional ops and /// declares PDL as dependent dialect since the additional ops are using PDL /// types for operands and results. class TestTransformDialectExtension : public transform::TransformDialectExtension< TestTransformDialectExtension> { public: using Base::Base; void init() { declareDependentDialect(); registerTransformOps(); registerTypes< #define GET_TYPEDEF_LIST #include "TestTransformDialectExtensionTypes.cpp.inc" >(); } }; } // namespace // These are automatically generated by ODS but are not used as the Transform // dialect uses a different dispatch mechanism to support dialect extensions. LLVM_ATTRIBUTE_UNUSED static OptionalParseResult generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value); LLVM_ATTRIBUTE_UNUSED static LogicalResult generatedTypePrinter(Type def, AsmPrinter &printer); #define GET_TYPEDEF_CLASSES #include "TestTransformDialectExtensionTypes.cpp.inc" #define GET_OP_CLASSES #include "TestTransformDialectExtension.cpp.inc" void ::test::registerTestTransformDialectExtension(DialectRegistry ®istry) { registry.addExtensions(); }