Transform interfaces are implemented, direction or via extensions, in libraries belonging to multiple other dialects. Those dialects don't need to depend on the non-interface part of the transform dialect, which includes the growing number of ops and transitive dependency footprint. Split out the interfaces into a separate library. This in turn requires flipping the dependency from the interface on the dialect that has crept in because both co-existed in one library. The interface shouldn't depend on the transform dialect either. As a consequence of splitting, the capability of the interpreter to automatically walk the payload IR to identify payload ops of a certain kind based on the type used for the entry point symbol argument is disabled. This is a good move by itself as it simplifies the interpreter logic. This functionality can be trivially replaced by a `transform.structured.match` operation.
458 lines
19 KiB
C++
458 lines
19 KiB
C++
//===- TransformInterpreterPassBase.cpp -----------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// Base class with shared implementation for transform dialect interpreter
|
|
// passes.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h"
|
|
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
|
|
#include "mlir/Dialect/Transform/IR/TransformOps.h"
|
|
#include "mlir/Dialect/Transform/IR/Utils.h"
|
|
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
|
|
#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
|
|
#include "mlir/IR/BuiltinOps.h"
|
|
#include "mlir/IR/Verifier.h"
|
|
#include "mlir/IR/Visitors.h"
|
|
#include "mlir/Interfaces/FunctionInterfaces.h"
|
|
#include "mlir/Parser/Parser.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Support/FileUtilities.h"
|
|
#include "llvm/ADT/ScopeExit.h"
|
|
#include "llvm/ADT/StringRef.h"
|
|
#include "llvm/Support/Debug.h"
|
|
#include "llvm/Support/FileSystem.h"
|
|
#include "llvm/Support/FormatVariadic.h"
|
|
#include "llvm/Support/Mutex.h"
|
|
#include "llvm/Support/Path.h"
|
|
#include "llvm/Support/SourceMgr.h"
|
|
#include "llvm/Support/raw_ostream.h"
|
|
|
|
using namespace mlir;
|
|
|
|
#define DEBUG_TYPE "transform-dialect-interpreter"
|
|
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
|
|
#define DEBUG_TYPE_DUMP_STDERR "transform-dialect-dump-repro"
|
|
#define DEBUG_TYPE_DUMP_FILE "transform-dialect-save-repro"
|
|
|
|
/// Name of the attribute used for targeting the transform dialect interpreter
|
|
/// at specific operations.
|
|
constexpr static llvm::StringLiteral kTransformDialectTagAttrName =
|
|
"transform.target_tag";
|
|
/// Value of the attribute indicating the root payload operation.
|
|
constexpr static llvm::StringLiteral kTransformDialectTagPayloadRootValue =
|
|
"payload_root";
|
|
/// Value of the attribute indicating the container of transform operations
|
|
/// (containing the top-level transform operation).
|
|
constexpr static llvm::StringLiteral
|
|
kTransformDialectTagTransformContainerValue = "transform_container";
|
|
|
|
/// Finds the single top-level transform operation with `root` as ancestor.
|
|
/// Reports an error if there is more than one such operation and returns the
|
|
/// first one found. Reports an error returns nullptr if no such operation
|
|
/// found.
|
|
static Operation *
|
|
findTopLevelTransform(Operation *root, StringRef filenameOption,
|
|
mlir::transform::TransformOptions options) {
|
|
::mlir::transform::TransformOpInterface topLevelTransform = nullptr;
|
|
root->walk<WalkOrder::PreOrder>(
|
|
[&](::mlir::transform::TransformOpInterface transformOp) {
|
|
if (!transformOp
|
|
->hasTrait<transform::PossibleTopLevelTransformOpTrait>())
|
|
return WalkResult::skip();
|
|
if (!topLevelTransform) {
|
|
topLevelTransform = transformOp;
|
|
return WalkResult::skip();
|
|
}
|
|
if (options.getEnforceSingleToplevelTransformOp()) {
|
|
auto diag = transformOp.emitError()
|
|
<< "more than one top-level transform op";
|
|
diag.attachNote(topLevelTransform.getLoc())
|
|
<< "previous top-level transform op";
|
|
return WalkResult::interrupt();
|
|
}
|
|
return WalkResult::skip();
|
|
});
|
|
if (!topLevelTransform) {
|
|
auto diag = root->emitError()
|
|
<< "could not find a nested top-level transform op";
|
|
diag.attachNote() << "use the '" << filenameOption
|
|
<< "' option to provide transform as external file";
|
|
return nullptr;
|
|
}
|
|
return topLevelTransform;
|
|
}
|
|
|
|
/// Finds an operation nested in `root` that has the transform dialect tag
|
|
/// attribute with the value specified as `tag`. Assumes only one operation
|
|
/// may have the tag. Returns nullptr if there is no such operation.
|
|
static Operation *findOpWithTag(Operation *root, StringRef tagKey,
|
|
StringRef tagValue) {
|
|
Operation *found = nullptr;
|
|
WalkResult walkResult = root->walk<WalkOrder::PreOrder>(
|
|
[tagKey, tagValue, &found, root](Operation *op) {
|
|
auto attr = op->getAttrOfType<StringAttr>(tagKey);
|
|
if (!attr || attr.getValue() != tagValue)
|
|
return WalkResult::advance();
|
|
|
|
if (found) {
|
|
InFlightDiagnostic diag = root->emitError()
|
|
<< "more than one operation with " << tagKey
|
|
<< "=\"" << tagValue << "\" attribute";
|
|
diag.attachNote(found->getLoc()) << "first operation";
|
|
diag.attachNote(op->getLoc()) << "other operation";
|
|
return WalkResult::interrupt();
|
|
}
|
|
|
|
found = op;
|
|
return WalkResult::advance();
|
|
});
|
|
if (walkResult.wasInterrupted())
|
|
return nullptr;
|
|
|
|
if (!found) {
|
|
root->emitError() << "could not find the operation with " << tagKey << "=\""
|
|
<< tagValue << "\" attribute";
|
|
}
|
|
return found;
|
|
}
|
|
|
|
/// Returns the ancestor of `target` that doesn't have a parent.
|
|
static Operation *getRootOperation(Operation *target) {
|
|
Operation *root = target;
|
|
while (root->getParentOp())
|
|
root = root->getParentOp();
|
|
return root;
|
|
}
|
|
|
|
/// Prints the CLI command running the repro with the current path.
|
|
// TODO: make binary name optional by querying LLVM command line API for the
|
|
// name of the current binary.
|
|
static llvm::raw_ostream &
|
|
printReproCall(llvm::raw_ostream &os, StringRef rootOpName, StringRef passName,
|
|
const Pass::Option<std::string> &debugPayloadRootTag,
|
|
const Pass::Option<std::string> &debugTransformRootTag,
|
|
StringRef binaryName) {
|
|
os << llvm::formatv(
|
|
"{6} --pass-pipeline=\"{0}({1}{{{2}={3} {4}={5}})\"", rootOpName,
|
|
passName, debugPayloadRootTag.getArgStr(),
|
|
debugPayloadRootTag.empty()
|
|
? StringRef(kTransformDialectTagPayloadRootValue)
|
|
: debugPayloadRootTag,
|
|
debugTransformRootTag.getArgStr(),
|
|
debugTransformRootTag.empty()
|
|
? StringRef(kTransformDialectTagTransformContainerValue)
|
|
: debugTransformRootTag,
|
|
binaryName);
|
|
return os;
|
|
}
|
|
|
|
/// Prints the module rooted at `root` to `os` and appends
|
|
/// `transformContainer` if it is not nested in `root`.
|
|
static llvm::raw_ostream &printModuleForRepro(llvm::raw_ostream &os,
|
|
Operation *root,
|
|
Operation *transform) {
|
|
root->print(os);
|
|
if (!root->isAncestor(transform))
|
|
transform->print(os);
|
|
return os;
|
|
}
|
|
|
|
/// Saves the payload and the transform IR into a temporary file and reports
|
|
/// the file name to `os`.
|
|
[[maybe_unused]] static void
|
|
saveReproToTempFile(llvm::raw_ostream &os, Operation *target,
|
|
Operation *transform, StringRef passName,
|
|
const Pass::Option<std::string> &debugPayloadRootTag,
|
|
const Pass::Option<std::string> &debugTransformRootTag,
|
|
const Pass::ListOption<std::string> &transformLibraryPaths,
|
|
StringRef binaryName) {
|
|
using llvm::sys::fs::TempFile;
|
|
Operation *root = getRootOperation(target);
|
|
|
|
SmallVector<char, 128> tmpPath;
|
|
llvm::sys::path::system_temp_directory(/*erasedOnReboot=*/true, tmpPath);
|
|
llvm::sys::path::append(tmpPath, "transform_dialect_%%%%%%.mlir");
|
|
llvm::Expected<TempFile> tempFile = TempFile::create(tmpPath);
|
|
if (!tempFile) {
|
|
os << "could not open temporary file to save the repro\n";
|
|
return;
|
|
}
|
|
|
|
llvm::raw_fd_ostream fout(tempFile->FD, /*shouldClose=*/false);
|
|
printModuleForRepro(fout, root, transform);
|
|
fout.flush();
|
|
std::string filename = tempFile->TmpName;
|
|
|
|
if (tempFile->keep()) {
|
|
os << "could not preserve the temporary file with the repro\n";
|
|
return;
|
|
}
|
|
|
|
os << "=== Transform Interpreter Repro ===\n";
|
|
printReproCall(os, root->getName().getStringRef(), passName,
|
|
debugPayloadRootTag, debugTransformRootTag, binaryName)
|
|
<< " " << filename << "\n";
|
|
os << "===================================\n";
|
|
}
|
|
|
|
// Optionally perform debug actions requested by the user to dump IR and a
|
|
// repro to stderr and/or a file.
|
|
static void performOptionalDebugActions(
|
|
Operation *target, Operation *transform, StringRef passName,
|
|
const Pass::Option<std::string> &debugPayloadRootTag,
|
|
const Pass::Option<std::string> &debugTransformRootTag,
|
|
const Pass::ListOption<std::string> &transformLibraryPaths,
|
|
StringRef binaryName) {
|
|
MLIRContext *context = target->getContext();
|
|
|
|
// If we are not planning to print, bail early.
|
|
bool hasDebugFlags = false;
|
|
DEBUG_WITH_TYPE(DEBUG_TYPE_DUMP_STDERR, { hasDebugFlags = true; });
|
|
DEBUG_WITH_TYPE(DEBUG_TYPE_DUMP_FILE, { hasDebugFlags = true; });
|
|
if (!hasDebugFlags)
|
|
return;
|
|
|
|
// We will be mutating the IR to set attributes. If this is running
|
|
// concurrently on several parts of a container or using a shared transform
|
|
// script, this would create a race. Bail in multithreaded mode and require
|
|
// the user to disable threading to dump repros.
|
|
static llvm::sys::SmartMutex<true> dbgStreamMutex;
|
|
if (target->getContext()->isMultithreadingEnabled()) {
|
|
llvm::sys::SmartScopedLock<true> lock(dbgStreamMutex);
|
|
llvm::dbgs() << "=======================================================\n";
|
|
llvm::dbgs() << "| Transform reproducers cannot be produced |\n";
|
|
llvm::dbgs() << "| in multi-threaded mode! |\n";
|
|
llvm::dbgs() << "=======================================================\n";
|
|
return;
|
|
}
|
|
|
|
Operation *root = getRootOperation(target);
|
|
|
|
// Add temporary debug / repro attributes, these must never leak out.
|
|
if (debugPayloadRootTag.empty()) {
|
|
target->setAttr(
|
|
kTransformDialectTagAttrName,
|
|
StringAttr::get(context, kTransformDialectTagPayloadRootValue));
|
|
}
|
|
if (debugTransformRootTag.empty()) {
|
|
transform->setAttr(
|
|
kTransformDialectTagAttrName,
|
|
StringAttr::get(context, kTransformDialectTagTransformContainerValue));
|
|
}
|
|
|
|
DEBUG_WITH_TYPE(DEBUG_TYPE_DUMP_STDERR, {
|
|
llvm::dbgs() << "=== Transform Interpreter Repro ===\n";
|
|
printReproCall(llvm::dbgs() << "cat <<EOF | ",
|
|
root->getName().getStringRef(), passName,
|
|
debugPayloadRootTag, debugTransformRootTag, binaryName)
|
|
<< "\n";
|
|
printModuleForRepro(llvm::dbgs(), root, transform);
|
|
llvm::dbgs() << "\nEOF\n";
|
|
llvm::dbgs() << "===================================\n";
|
|
});
|
|
(void)root;
|
|
DEBUG_WITH_TYPE(DEBUG_TYPE_DUMP_FILE, {
|
|
saveReproToTempFile(llvm::dbgs(), target, transform, passName,
|
|
debugPayloadRootTag, debugTransformRootTag,
|
|
transformLibraryPaths, binaryName);
|
|
});
|
|
|
|
// Remove temporary attributes if they were set.
|
|
if (debugPayloadRootTag.empty())
|
|
target->removeAttr(kTransformDialectTagAttrName);
|
|
if (debugTransformRootTag.empty())
|
|
transform->removeAttr(kTransformDialectTagAttrName);
|
|
}
|
|
|
|
LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
|
|
Operation *target, StringRef passName,
|
|
const std::shared_ptr<OwningOpRef<ModuleOp>> &sharedTransformModule,
|
|
const std::shared_ptr<OwningOpRef<ModuleOp>> &transformLibraryModule,
|
|
const RaggedArray<MappedValue> &extraMappings,
|
|
const TransformOptions &options,
|
|
const Pass::Option<std::string> &transformFileName,
|
|
const Pass::ListOption<std::string> &transformLibraryPaths,
|
|
const Pass::Option<std::string> &debugPayloadRootTag,
|
|
const Pass::Option<std::string> &debugTransformRootTag,
|
|
StringRef binaryName) {
|
|
bool hasSharedTransformModule =
|
|
sharedTransformModule && *sharedTransformModule;
|
|
bool hasTransformLibraryModule =
|
|
transformLibraryModule && *transformLibraryModule;
|
|
assert((!hasSharedTransformModule || !hasTransformLibraryModule) &&
|
|
"at most one of shared or library transform module can be set");
|
|
|
|
// Step 1
|
|
// ------
|
|
// If debugPayloadRootTag was passed, then we are in user-specified selection
|
|
// of the transformed IR. This corresponds to REPL debug mode. Otherwise, just
|
|
// apply to `target`.
|
|
Operation *payloadRoot = target;
|
|
if (!debugPayloadRootTag.empty()) {
|
|
payloadRoot = findOpWithTag(target, kTransformDialectTagAttrName,
|
|
debugPayloadRootTag);
|
|
if (!payloadRoot)
|
|
return failure();
|
|
}
|
|
|
|
// Step 2
|
|
// ------
|
|
// If a shared transform was specified separately, use it. Otherwise, the
|
|
// transform is embedded in the payload IR. If debugTransformRootTag was
|
|
// passed, then we are in user-specified selection of the transforming IR.
|
|
// This corresponds to REPL debug mode.
|
|
Operation *transformContainer =
|
|
hasSharedTransformModule ? sharedTransformModule->get() : target;
|
|
Operation *transformRoot =
|
|
debugTransformRootTag.empty()
|
|
? findTopLevelTransform(transformContainer,
|
|
transformFileName.getArgStr(), options)
|
|
: findOpWithTag(transformContainer, kTransformDialectTagAttrName,
|
|
debugTransformRootTag);
|
|
if (!transformRoot)
|
|
return failure();
|
|
|
|
if (!transformRoot->hasTrait<PossibleTopLevelTransformOpTrait>()) {
|
|
return emitError(transformRoot->getLoc())
|
|
<< "expected the transform entry point to be a top-level transform "
|
|
"op";
|
|
}
|
|
|
|
// Step 3
|
|
// ------
|
|
// Copy external defintions for symbols if provided. Be aware of potential
|
|
// concurrent execution (normally, the error shouldn't be triggered unless the
|
|
// transform IR modifies itself in a pass, which is also forbidden elsewhere).
|
|
if (hasTransformLibraryModule) {
|
|
if (!target->isProperAncestor(transformRoot)) {
|
|
InFlightDiagnostic diag =
|
|
transformRoot->emitError()
|
|
<< "cannot inject transform definitions next to pass anchor op";
|
|
diag.attachNote(target->getLoc()) << "pass anchor op";
|
|
return diag;
|
|
}
|
|
InFlightDiagnostic diag = detail::mergeSymbolsInto(
|
|
SymbolTable::getNearestSymbolTable(transformRoot),
|
|
transformLibraryModule->get()->clone());
|
|
if (failed(diag)) {
|
|
diag.attachNote(transformRoot->getLoc())
|
|
<< "failed to merge library symbols into transform root";
|
|
return diag;
|
|
}
|
|
}
|
|
|
|
// Step 4
|
|
// ------
|
|
// Optionally perform debug actions requested by the user to dump IR and a
|
|
// repro to stderr and/or a file.
|
|
performOptionalDebugActions(target, transformRoot, passName,
|
|
debugPayloadRootTag, debugTransformRootTag,
|
|
transformLibraryPaths, binaryName);
|
|
|
|
// Step 5
|
|
// ------
|
|
// Apply the transform to the IR
|
|
return applyTransforms(payloadRoot, cast<TransformOpInterface>(transformRoot),
|
|
extraMappings, options);
|
|
}
|
|
|
|
LogicalResult transform::detail::interpreterBaseInitializeImpl(
|
|
MLIRContext *context, StringRef transformFileName,
|
|
ArrayRef<std::string> transformLibraryPaths,
|
|
std::shared_ptr<OwningOpRef<ModuleOp>> &sharedTransformModule,
|
|
std::shared_ptr<OwningOpRef<ModuleOp>> &transformLibraryModule,
|
|
function_ref<std::optional<LogicalResult>(OpBuilder &, Location)>
|
|
moduleBuilder) {
|
|
auto unknownLoc = UnknownLoc::get(context);
|
|
|
|
// Parse module from file.
|
|
OwningOpRef<ModuleOp> moduleFromFile;
|
|
{
|
|
auto loc = FileLineColLoc::get(context, transformFileName, 0, 0);
|
|
if (failed(detail::parseTransformModuleFromFile(context, transformFileName,
|
|
moduleFromFile)))
|
|
return emitError(loc) << "failed to parse transform module";
|
|
if (moduleFromFile && failed(mlir::verify(*moduleFromFile)))
|
|
return emitError(loc) << "failed to verify transform module";
|
|
}
|
|
|
|
// Assemble list of library files.
|
|
SmallVector<std::string> libraryFileNames;
|
|
if (failed(expandPathsToMLIRFiles(transformLibraryPaths, context,
|
|
libraryFileNames)))
|
|
return failure();
|
|
|
|
// Parse modules from library files.
|
|
SmallVector<OwningOpRef<ModuleOp>> parsedLibraries;
|
|
for (const std::string &libraryFileName : libraryFileNames) {
|
|
OwningOpRef<ModuleOp> parsedLibrary;
|
|
auto loc = FileLineColLoc::get(context, libraryFileName, 0, 0);
|
|
if (failed(detail::parseTransformModuleFromFile(context, libraryFileName,
|
|
parsedLibrary)))
|
|
return emitError(loc) << "failed to parse transform library module";
|
|
if (parsedLibrary && failed(mlir::verify(*parsedLibrary)))
|
|
return emitError(loc) << "failed to verify transform library module";
|
|
parsedLibraries.push_back(std::move(parsedLibrary));
|
|
}
|
|
|
|
// Build shared transform module.
|
|
if (moduleFromFile) {
|
|
sharedTransformModule =
|
|
std::make_shared<OwningOpRef<ModuleOp>>(std::move(moduleFromFile));
|
|
} else if (moduleBuilder) {
|
|
auto loc = FileLineColLoc::get(context, "<shared-transform-module>", 0, 0);
|
|
auto localModule = std::make_shared<OwningOpRef<ModuleOp>>(
|
|
ModuleOp::create(unknownLoc, "__transform"));
|
|
|
|
OpBuilder b(context);
|
|
b.setInsertionPointToEnd(localModule->get().getBody());
|
|
if (std::optional<LogicalResult> result = moduleBuilder(b, loc)) {
|
|
if (failed(*result))
|
|
return (*localModule)->emitError()
|
|
<< "failed to create shared transform module";
|
|
sharedTransformModule = std::move(localModule);
|
|
}
|
|
}
|
|
|
|
if (parsedLibraries.empty())
|
|
return success();
|
|
|
|
// Merge parsed libraries into one module.
|
|
auto loc = FileLineColLoc::get(context, "<shared-library-module>", 0, 0);
|
|
OwningOpRef<ModuleOp> mergedParsedLibraries =
|
|
ModuleOp::create(loc, "__transform");
|
|
{
|
|
mergedParsedLibraries.get()->setAttr("transform.with_named_sequence",
|
|
UnitAttr::get(context));
|
|
IRRewriter rewriter(context);
|
|
// TODO: extend `mergeSymbolsInto` to support multiple `other` modules.
|
|
for (OwningOpRef<ModuleOp> &parsedLibrary : parsedLibraries) {
|
|
if (failed(detail::mergeSymbolsInto(mergedParsedLibraries.get(),
|
|
std::move(parsedLibrary))))
|
|
return mergedParsedLibraries->emitError()
|
|
<< "failed to verify merged transform module";
|
|
}
|
|
}
|
|
|
|
// Use parsed libaries to resolve symbols in shared transform module or return
|
|
// as separate library module.
|
|
if (sharedTransformModule && *sharedTransformModule) {
|
|
if (failed(detail::mergeSymbolsInto(sharedTransformModule->get(),
|
|
std::move(mergedParsedLibraries))))
|
|
return (*sharedTransformModule)->emitError()
|
|
<< "failed to merge symbols from library files "
|
|
"into shared transform module";
|
|
} else {
|
|
transformLibraryModule = std::make_shared<OwningOpRef<ModuleOp>>(
|
|
std::move(mergedParsedLibraries));
|
|
}
|
|
return success();
|
|
}
|