//===- TransformInterpreterPassBase.cpp -----------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // Base class with shared implementation for transform dialect interpreter // passes. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformOps.h" #include "mlir/Dialect/Transform/IR/Utils.h" #include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Verifier.h" #include "mlir/IR/Visitors.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Parser/Parser.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/FileUtilities.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Debug.h" #include "llvm/Support/FileSystem.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/Mutex.h" #include "llvm/Support/Path.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/raw_ostream.h" using namespace mlir; #define DEBUG_TYPE "transform-dialect-interpreter" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") #define DEBUG_TYPE_DUMP_STDERR "transform-dialect-dump-repro" #define DEBUG_TYPE_DUMP_FILE "transform-dialect-save-repro" /// Name of the attribute used for targeting the transform dialect interpreter /// at specific operations. constexpr static llvm::StringLiteral kTransformDialectTagAttrName = "transform.target_tag"; /// Value of the attribute indicating the root payload operation. constexpr static llvm::StringLiteral kTransformDialectTagPayloadRootValue = "payload_root"; /// Value of the attribute indicating the container of transform operations /// (containing the top-level transform operation). constexpr static llvm::StringLiteral kTransformDialectTagTransformContainerValue = "transform_container"; /// Finds the single top-level transform operation with `root` as ancestor. /// Reports an error if there is more than one such operation and returns the /// first one found. Reports an error returns nullptr if no such operation /// found. static Operation * findTopLevelTransform(Operation *root, StringRef filenameOption, mlir::transform::TransformOptions options) { ::mlir::transform::TransformOpInterface topLevelTransform = nullptr; root->walk( [&](::mlir::transform::TransformOpInterface transformOp) { if (!transformOp ->hasTrait()) return WalkResult::skip(); if (!topLevelTransform) { topLevelTransform = transformOp; return WalkResult::skip(); } if (options.getEnforceSingleToplevelTransformOp()) { auto diag = transformOp.emitError() << "more than one top-level transform op"; diag.attachNote(topLevelTransform.getLoc()) << "previous top-level transform op"; return WalkResult::interrupt(); } return WalkResult::skip(); }); if (!topLevelTransform) { auto diag = root->emitError() << "could not find a nested top-level transform op"; diag.attachNote() << "use the '" << filenameOption << "' option to provide transform as external file"; return nullptr; } return topLevelTransform; } /// Finds an operation nested in `root` that has the transform dialect tag /// attribute with the value specified as `tag`. Assumes only one operation /// may have the tag. Returns nullptr if there is no such operation. static Operation *findOpWithTag(Operation *root, StringRef tagKey, StringRef tagValue) { Operation *found = nullptr; WalkResult walkResult = root->walk( [tagKey, tagValue, &found, root](Operation *op) { auto attr = op->getAttrOfType(tagKey); if (!attr || attr.getValue() != tagValue) return WalkResult::advance(); if (found) { InFlightDiagnostic diag = root->emitError() << "more than one operation with " << tagKey << "=\"" << tagValue << "\" attribute"; diag.attachNote(found->getLoc()) << "first operation"; diag.attachNote(op->getLoc()) << "other operation"; return WalkResult::interrupt(); } found = op; return WalkResult::advance(); }); if (walkResult.wasInterrupted()) return nullptr; if (!found) { root->emitError() << "could not find the operation with " << tagKey << "=\"" << tagValue << "\" attribute"; } return found; } /// Returns the ancestor of `target` that doesn't have a parent. static Operation *getRootOperation(Operation *target) { Operation *root = target; while (root->getParentOp()) root = root->getParentOp(); return root; } /// Prints the CLI command running the repro with the current path. // TODO: make binary name optional by querying LLVM command line API for the // name of the current binary. static llvm::raw_ostream & printReproCall(llvm::raw_ostream &os, StringRef rootOpName, StringRef passName, const Pass::Option &debugPayloadRootTag, const Pass::Option &debugTransformRootTag, StringRef binaryName) { os << llvm::formatv( "{6} --pass-pipeline=\"{0}({1}{{{2}={3} {4}={5}})\"", rootOpName, passName, debugPayloadRootTag.getArgStr(), debugPayloadRootTag.empty() ? StringRef(kTransformDialectTagPayloadRootValue) : debugPayloadRootTag, debugTransformRootTag.getArgStr(), debugTransformRootTag.empty() ? StringRef(kTransformDialectTagTransformContainerValue) : debugTransformRootTag, binaryName); return os; } /// Prints the module rooted at `root` to `os` and appends /// `transformContainer` if it is not nested in `root`. static llvm::raw_ostream &printModuleForRepro(llvm::raw_ostream &os, Operation *root, Operation *transform) { root->print(os); if (!root->isAncestor(transform)) transform->print(os); return os; } /// Saves the payload and the transform IR into a temporary file and reports /// the file name to `os`. [[maybe_unused]] static void saveReproToTempFile(llvm::raw_ostream &os, Operation *target, Operation *transform, StringRef passName, const Pass::Option &debugPayloadRootTag, const Pass::Option &debugTransformRootTag, const Pass::ListOption &transformLibraryPaths, StringRef binaryName) { using llvm::sys::fs::TempFile; Operation *root = getRootOperation(target); SmallVector tmpPath; llvm::sys::path::system_temp_directory(/*erasedOnReboot=*/true, tmpPath); llvm::sys::path::append(tmpPath, "transform_dialect_%%%%%%.mlir"); llvm::Expected tempFile = TempFile::create(tmpPath); if (!tempFile) { os << "could not open temporary file to save the repro\n"; return; } llvm::raw_fd_ostream fout(tempFile->FD, /*shouldClose=*/false); printModuleForRepro(fout, root, transform); fout.flush(); std::string filename = tempFile->TmpName; if (tempFile->keep()) { os << "could not preserve the temporary file with the repro\n"; return; } os << "=== Transform Interpreter Repro ===\n"; printReproCall(os, root->getName().getStringRef(), passName, debugPayloadRootTag, debugTransformRootTag, binaryName) << " " << filename << "\n"; os << "===================================\n"; } // Optionally perform debug actions requested by the user to dump IR and a // repro to stderr and/or a file. static void performOptionalDebugActions( Operation *target, Operation *transform, StringRef passName, const Pass::Option &debugPayloadRootTag, const Pass::Option &debugTransformRootTag, const Pass::ListOption &transformLibraryPaths, StringRef binaryName) { MLIRContext *context = target->getContext(); // If we are not planning to print, bail early. bool hasDebugFlags = false; DEBUG_WITH_TYPE(DEBUG_TYPE_DUMP_STDERR, { hasDebugFlags = true; }); DEBUG_WITH_TYPE(DEBUG_TYPE_DUMP_FILE, { hasDebugFlags = true; }); if (!hasDebugFlags) return; // We will be mutating the IR to set attributes. If this is running // concurrently on several parts of a container or using a shared transform // script, this would create a race. Bail in multithreaded mode and require // the user to disable threading to dump repros. static llvm::sys::SmartMutex dbgStreamMutex; if (target->getContext()->isMultithreadingEnabled()) { llvm::sys::SmartScopedLock lock(dbgStreamMutex); llvm::dbgs() << "=======================================================\n"; llvm::dbgs() << "| Transform reproducers cannot be produced |\n"; llvm::dbgs() << "| in multi-threaded mode! |\n"; llvm::dbgs() << "=======================================================\n"; return; } Operation *root = getRootOperation(target); // Add temporary debug / repro attributes, these must never leak out. if (debugPayloadRootTag.empty()) { target->setAttr( kTransformDialectTagAttrName, StringAttr::get(context, kTransformDialectTagPayloadRootValue)); } if (debugTransformRootTag.empty()) { transform->setAttr( kTransformDialectTagAttrName, StringAttr::get(context, kTransformDialectTagTransformContainerValue)); } DEBUG_WITH_TYPE(DEBUG_TYPE_DUMP_STDERR, { llvm::dbgs() << "=== Transform Interpreter Repro ===\n"; printReproCall(llvm::dbgs() << "cat <getName().getStringRef(), passName, debugPayloadRootTag, debugTransformRootTag, binaryName) << "\n"; printModuleForRepro(llvm::dbgs(), root, transform); llvm::dbgs() << "\nEOF\n"; llvm::dbgs() << "===================================\n"; }); (void)root; DEBUG_WITH_TYPE(DEBUG_TYPE_DUMP_FILE, { saveReproToTempFile(llvm::dbgs(), target, transform, passName, debugPayloadRootTag, debugTransformRootTag, transformLibraryPaths, binaryName); }); // Remove temporary attributes if they were set. if (debugPayloadRootTag.empty()) target->removeAttr(kTransformDialectTagAttrName); if (debugTransformRootTag.empty()) transform->removeAttr(kTransformDialectTagAttrName); } LogicalResult transform::detail::interpreterBaseRunOnOperationImpl( Operation *target, StringRef passName, const std::shared_ptr> &sharedTransformModule, const std::shared_ptr> &transformLibraryModule, const RaggedArray &extraMappings, const TransformOptions &options, const Pass::Option &transformFileName, const Pass::ListOption &transformLibraryPaths, const Pass::Option &debugPayloadRootTag, const Pass::Option &debugTransformRootTag, StringRef binaryName) { bool hasSharedTransformModule = sharedTransformModule && *sharedTransformModule; bool hasTransformLibraryModule = transformLibraryModule && *transformLibraryModule; assert((!hasSharedTransformModule || !hasTransformLibraryModule) && "at most one of shared or library transform module can be set"); // Step 1 // ------ // If debugPayloadRootTag was passed, then we are in user-specified selection // of the transformed IR. This corresponds to REPL debug mode. Otherwise, just // apply to `target`. Operation *payloadRoot = target; if (!debugPayloadRootTag.empty()) { payloadRoot = findOpWithTag(target, kTransformDialectTagAttrName, debugPayloadRootTag); if (!payloadRoot) return failure(); } // Step 2 // ------ // If a shared transform was specified separately, use it. Otherwise, the // transform is embedded in the payload IR. If debugTransformRootTag was // passed, then we are in user-specified selection of the transforming IR. // This corresponds to REPL debug mode. Operation *transformContainer = hasSharedTransformModule ? sharedTransformModule->get() : target; Operation *transformRoot = debugTransformRootTag.empty() ? findTopLevelTransform(transformContainer, transformFileName.getArgStr(), options) : findOpWithTag(transformContainer, kTransformDialectTagAttrName, debugTransformRootTag); if (!transformRoot) return failure(); if (!transformRoot->hasTrait()) { return emitError(transformRoot->getLoc()) << "expected the transform entry point to be a top-level transform " "op"; } // Step 3 // ------ // Copy external defintions for symbols if provided. Be aware of potential // concurrent execution (normally, the error shouldn't be triggered unless the // transform IR modifies itself in a pass, which is also forbidden elsewhere). if (hasTransformLibraryModule) { if (!target->isProperAncestor(transformRoot)) { InFlightDiagnostic diag = transformRoot->emitError() << "cannot inject transform definitions next to pass anchor op"; diag.attachNote(target->getLoc()) << "pass anchor op"; return diag; } InFlightDiagnostic diag = detail::mergeSymbolsInto( SymbolTable::getNearestSymbolTable(transformRoot), transformLibraryModule->get()->clone()); if (failed(diag)) { diag.attachNote(transformRoot->getLoc()) << "failed to merge library symbols into transform root"; return diag; } } // Step 4 // ------ // Optionally perform debug actions requested by the user to dump IR and a // repro to stderr and/or a file. performOptionalDebugActions(target, transformRoot, passName, debugPayloadRootTag, debugTransformRootTag, transformLibraryPaths, binaryName); // Step 5 // ------ // Apply the transform to the IR return applyTransforms(payloadRoot, cast(transformRoot), extraMappings, options); } LogicalResult transform::detail::interpreterBaseInitializeImpl( MLIRContext *context, StringRef transformFileName, ArrayRef transformLibraryPaths, std::shared_ptr> &sharedTransformModule, std::shared_ptr> &transformLibraryModule, function_ref(OpBuilder &, Location)> moduleBuilder) { auto unknownLoc = UnknownLoc::get(context); // Parse module from file. OwningOpRef moduleFromFile; { auto loc = FileLineColLoc::get(context, transformFileName, 0, 0); if (failed(detail::parseTransformModuleFromFile(context, transformFileName, moduleFromFile))) return emitError(loc) << "failed to parse transform module"; if (moduleFromFile && failed(mlir::verify(*moduleFromFile))) return emitError(loc) << "failed to verify transform module"; } // Assemble list of library files. SmallVector libraryFileNames; if (failed(expandPathsToMLIRFiles(transformLibraryPaths, context, libraryFileNames))) return failure(); // Parse modules from library files. SmallVector> parsedLibraries; for (const std::string &libraryFileName : libraryFileNames) { OwningOpRef parsedLibrary; auto loc = FileLineColLoc::get(context, libraryFileName, 0, 0); if (failed(detail::parseTransformModuleFromFile(context, libraryFileName, parsedLibrary))) return emitError(loc) << "failed to parse transform library module"; if (parsedLibrary && failed(mlir::verify(*parsedLibrary))) return emitError(loc) << "failed to verify transform library module"; parsedLibraries.push_back(std::move(parsedLibrary)); } // Build shared transform module. if (moduleFromFile) { sharedTransformModule = std::make_shared>(std::move(moduleFromFile)); } else if (moduleBuilder) { auto loc = FileLineColLoc::get(context, "", 0, 0); auto localModule = std::make_shared>( ModuleOp::create(unknownLoc, "__transform")); OpBuilder b(context); b.setInsertionPointToEnd(localModule->get().getBody()); if (std::optional result = moduleBuilder(b, loc)) { if (failed(*result)) return (*localModule)->emitError() << "failed to create shared transform module"; sharedTransformModule = std::move(localModule); } } if (parsedLibraries.empty()) return success(); // Merge parsed libraries into one module. auto loc = FileLineColLoc::get(context, "", 0, 0); OwningOpRef mergedParsedLibraries = ModuleOp::create(loc, "__transform"); { mergedParsedLibraries.get()->setAttr("transform.with_named_sequence", UnitAttr::get(context)); IRRewriter rewriter(context); // TODO: extend `mergeSymbolsInto` to support multiple `other` modules. for (OwningOpRef &parsedLibrary : parsedLibraries) { if (failed(detail::mergeSymbolsInto(mergedParsedLibraries.get(), std::move(parsedLibrary)))) return mergedParsedLibraries->emitError() << "failed to verify merged transform module"; } } // Use parsed libaries to resolve symbols in shared transform module or return // as separate library module. if (sharedTransformModule && *sharedTransformModule) { if (failed(detail::mergeSymbolsInto(sharedTransformModule->get(), std::move(mergedParsedLibraries)))) return (*sharedTransformModule)->emitError() << "failed to merge symbols from library files " "into shared transform module"; } else { transformLibraryModule = std::make_shared>( std::move(mergedParsedLibraries)); } return success(); }