Transform interfaces are implemented, direction or via extensions, in libraries belonging to multiple other dialects. Those dialects don't need to depend on the non-interface part of the transform dialect, which includes the growing number of ops and transitive dependency footprint. Split out the interfaces into a separate library. This in turn requires flipping the dependency from the interface on the dialect that has crept in because both co-existed in one library. The interface shouldn't depend on the transform dialect either. As a consequence of splitting, the capability of the interpreter to automatically walk the payload IR to identify payload ops of a certain kind based on the type used for the entry point symbol argument is disabled. This is a good move by itself as it simplifies the interpreter logic. This functionality can be trivially replaced by a `transform.structured.match` operation.
161 lines
5.5 KiB
C++
161 lines
5.5 KiB
C++
//===- InterpreterPass.cpp - Transform dialect interpreter pass -----------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
|
|
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
|
|
#include "mlir/Dialect/Transform/Transforms/Passes.h"
|
|
#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace mlir {
|
|
namespace transform {
|
|
#define GEN_PASS_DEF_INTERPRETERPASS
|
|
#include "mlir/Dialect/Transform/Transforms/Passes.h.inc"
|
|
} // namespace transform
|
|
} // namespace mlir
|
|
|
|
/// Returns the payload operation to be used as payload root:
|
|
/// - the operation nested under `passRoot` that has the given tag attribute,
|
|
/// must be unique;
|
|
/// - the `passRoot` itself if the tag is empty.
|
|
static Operation *findPayloadRoot(Operation *passRoot, StringRef tag) {
|
|
// Fast return.
|
|
if (tag.empty())
|
|
return passRoot;
|
|
|
|
// Walk to do a lookup.
|
|
Operation *target = nullptr;
|
|
auto tagAttrName = StringAttr::get(
|
|
passRoot->getContext(), transform::TransformDialect::kTargetTagAttrName);
|
|
WalkResult walkResult = passRoot->walk([&](Operation *op) {
|
|
auto attr = op->getAttrOfType<StringAttr>(tagAttrName);
|
|
if (!attr || attr.getValue() != tag)
|
|
return WalkResult::advance();
|
|
|
|
if (!target) {
|
|
target = op;
|
|
return WalkResult::advance();
|
|
}
|
|
|
|
InFlightDiagnostic diag = op->emitError()
|
|
<< "repeated operation with the target tag '"
|
|
<< tag << "'";
|
|
diag.attachNote(target->getLoc()) << "previously seen operation";
|
|
return WalkResult::interrupt();
|
|
});
|
|
|
|
if (!target) {
|
|
passRoot->emitError()
|
|
<< "could not find the operation with transform.target_tag=\"" << tag
|
|
<< "\" attribute";
|
|
return nullptr;
|
|
}
|
|
|
|
return walkResult.wasInterrupted() ? nullptr : target;
|
|
}
|
|
|
|
namespace {
|
|
class InterpreterPass
|
|
: public transform::impl::InterpreterPassBase<InterpreterPass> {
|
|
// Parses the pass arguments to bind trailing arguments of the entry point.
|
|
std::optional<RaggedArray<transform::MappedValue>>
|
|
parseArguments(Operation *payloadRoot) {
|
|
MLIRContext *context = payloadRoot->getContext();
|
|
|
|
SmallVector<SmallVector<transform::MappedValue>, 2> trailingBindings;
|
|
trailingBindings.resize(debugBindTrailingArgs.size());
|
|
|
|
// Construct lists of op names to match.
|
|
SmallVector<std::optional<OperationName>> debugBindNames;
|
|
debugBindNames.reserve(debugBindTrailingArgs.size());
|
|
for (auto &&[position, nameString] :
|
|
llvm::enumerate(debugBindTrailingArgs)) {
|
|
StringRef name = nameString;
|
|
|
|
// Parse the integer literals.
|
|
if (name.starts_with("#")) {
|
|
debugBindNames.push_back(std::nullopt);
|
|
StringRef lhs = "";
|
|
StringRef rhs = name.drop_front();
|
|
do {
|
|
std::tie(lhs, rhs) = rhs.split(';');
|
|
int64_t value;
|
|
if (lhs.getAsInteger(10, value)) {
|
|
emitError(UnknownLoc::get(context))
|
|
<< "couldn't parse integer pass argument " << name;
|
|
return std::nullopt;
|
|
}
|
|
trailingBindings[position].push_back(
|
|
Builder(context).getI64IntegerAttr(value));
|
|
} while (!rhs.empty());
|
|
} else if (name.starts_with("^")) {
|
|
debugBindNames.emplace_back(OperationName(name.drop_front(), context));
|
|
} else {
|
|
debugBindNames.emplace_back(OperationName(name, context));
|
|
}
|
|
}
|
|
|
|
// Collect operations or results for extra bindings.
|
|
payloadRoot->walk([&](Operation *payload) {
|
|
for (auto &&[position, name] : llvm::enumerate(debugBindNames)) {
|
|
if (!name || payload->getName() != *name)
|
|
continue;
|
|
|
|
if (StringRef(*std::next(debugBindTrailingArgs.begin(), position))
|
|
.starts_with("^")) {
|
|
llvm::append_range(trailingBindings[position], payload->getResults());
|
|
} else {
|
|
trailingBindings[position].push_back(payload);
|
|
}
|
|
}
|
|
});
|
|
|
|
RaggedArray<transform::MappedValue> bindings;
|
|
bindings.push_back(ArrayRef<Operation *>{payloadRoot});
|
|
for (SmallVector<transform::MappedValue> &trailing : trailingBindings)
|
|
bindings.push_back(std::move(trailing));
|
|
return bindings;
|
|
}
|
|
|
|
public:
|
|
using Base::Base;
|
|
|
|
void runOnOperation() override {
|
|
MLIRContext *context = &getContext();
|
|
ModuleOp transformModule =
|
|
transform::detail::getPreloadedTransformModule(context);
|
|
Operation *payloadRoot =
|
|
findPayloadRoot(getOperation(), debugPayloadRootTag);
|
|
if (!payloadRoot)
|
|
return signalPassFailure();
|
|
|
|
Operation *transformEntryPoint = transform::detail::findTransformEntryPoint(
|
|
getOperation(), transformModule, entryPoint);
|
|
if (!transformEntryPoint)
|
|
return signalPassFailure();
|
|
|
|
std::optional<RaggedArray<transform::MappedValue>> bindings =
|
|
parseArguments(payloadRoot);
|
|
if (!bindings)
|
|
return signalPassFailure();
|
|
if (failed(transform::applyTransformNamedSequence(
|
|
*bindings,
|
|
cast<transform::TransformOpInterface>(transformEntryPoint),
|
|
transformModule,
|
|
options.enableExpensiveChecks(!disableExpensiveChecks)))) {
|
|
return signalPassFailure();
|
|
}
|
|
}
|
|
|
|
private:
|
|
/// Transform interpreter options.
|
|
transform::TransformOptions options;
|
|
};
|
|
} // namespace
|