//===- InterpreterPass.cpp - Transform dialect interpreter pass -----------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/Transforms/Passes.h" #include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h" using namespace mlir; namespace mlir { namespace transform { #define GEN_PASS_DEF_INTERPRETERPASS #include "mlir/Dialect/Transform/Transforms/Passes.h.inc" } // namespace transform } // namespace mlir /// Returns the payload operation to be used as payload root: /// - the operation nested under `passRoot` that has the given tag attribute, /// must be unique; /// - the `passRoot` itself if the tag is empty. static Operation *findPayloadRoot(Operation *passRoot, StringRef tag) { // Fast return. if (tag.empty()) return passRoot; // Walk to do a lookup. Operation *target = nullptr; auto tagAttrName = StringAttr::get( passRoot->getContext(), transform::TransformDialect::kTargetTagAttrName); WalkResult walkResult = passRoot->walk([&](Operation *op) { auto attr = op->getAttrOfType(tagAttrName); if (!attr || attr.getValue() != tag) return WalkResult::advance(); if (!target) { target = op; return WalkResult::advance(); } InFlightDiagnostic diag = op->emitError() << "repeated operation with the target tag '" << tag << "'"; diag.attachNote(target->getLoc()) << "previously seen operation"; return WalkResult::interrupt(); }); return walkResult.wasInterrupted() ? nullptr : target; } namespace { class InterpreterPass : public transform::impl::InterpreterPassBase { public: using Base::Base; void runOnOperation() override { MLIRContext *context = &getContext(); ModuleOp transformModule = transform::detail::getPreloadedTransformModule(context); Operation *payloadRoot = findPayloadRoot(getOperation(), debugPayloadRootTag); Operation *transformEntryPoint = transform::detail::findTransformEntryPoint( getOperation(), transformModule, entryPoint); if (!transformEntryPoint) { getOperation()->emitError() << "could not find transform entry point: " << entryPoint << " in either payload or transform module"; return signalPassFailure(); } if (failed(transform::applyTransformNamedSequence( payloadRoot, transformEntryPoint, transformModule, options.enableExpensiveChecks(!disableExpensiveChecks)))) { return signalPassFailure(); } } private: /// Transform interpreter options. transform::TransformOptions options; }; } // namespace