Transform interfaces are implemented, direction or via extensions, in libraries belonging to multiple other dialects. Those dialects don't need to depend on the non-interface part of the transform dialect, which includes the growing number of ops and transitive dependency footprint. Split out the interfaces into a separate library. This in turn requires flipping the dependency from the interface on the dialect that has crept in because both co-existed in one library. The interface shouldn't depend on the transform dialect either. As a consequence of splitting, the capability of the interpreter to automatically walk the payload IR to identify payload ops of a certain kind based on the type used for the entry point symbol argument is disabled. This is a good move by itself as it simplifies the interpreter logic. This functionality can be trivially replaced by a `transform.structured.match` operation.
192 lines
7.6 KiB
C++
192 lines
7.6 KiB
C++
//===- TransformDialect.cpp - Transform Dialect Definition ----------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
|
|
#include "mlir/Analysis/CallGraph.h"
|
|
#include "mlir/Dialect/Transform/IR/TransformOps.h"
|
|
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
|
|
#include "mlir/Dialect/Transform/IR/Utils.h"
|
|
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
|
|
#include "mlir/IR/DialectImplementation.h"
|
|
#include "llvm/ADT/SCCIterator.h"
|
|
|
|
using namespace mlir;
|
|
|
|
#include "mlir/Dialect/Transform/IR/TransformDialect.cpp.inc"
|
|
|
|
#ifndef NDEBUG
|
|
void transform::detail::checkImplementsTransformOpInterface(
|
|
StringRef name, MLIRContext *context) {
|
|
// Since the operation is being inserted into the Transform dialect and the
|
|
// dialect does not implement the interface fallback, only check for the op
|
|
// itself having the interface implementation.
|
|
RegisteredOperationName opName =
|
|
*RegisteredOperationName::lookup(name, context);
|
|
assert((opName.hasInterface<TransformOpInterface>() ||
|
|
opName.hasInterface<PatternDescriptorOpInterface>() ||
|
|
opName.hasInterface<ConversionPatternDescriptorOpInterface>() ||
|
|
opName.hasInterface<TypeConverterBuilderOpInterface>() ||
|
|
opName.hasTrait<OpTrait::IsTerminator>()) &&
|
|
"non-terminator ops injected into the transform dialect must "
|
|
"implement TransformOpInterface or PatternDescriptorOpInterface or "
|
|
"ConversionPatternDescriptorOpInterface");
|
|
if (!opName.hasInterface<PatternDescriptorOpInterface>() &&
|
|
!opName.hasInterface<ConversionPatternDescriptorOpInterface>() &&
|
|
!opName.hasInterface<TypeConverterBuilderOpInterface>()) {
|
|
assert(opName.hasInterface<MemoryEffectOpInterface>() &&
|
|
"ops injected into the transform dialect must implement "
|
|
"MemoryEffectsOpInterface");
|
|
}
|
|
}
|
|
|
|
void transform::detail::checkImplementsTransformHandleTypeInterface(
|
|
TypeID typeID, MLIRContext *context) {
|
|
const auto &abstractType = AbstractType::lookup(typeID, context);
|
|
assert((abstractType.hasInterface(
|
|
TransformHandleTypeInterface::getInterfaceID()) ||
|
|
abstractType.hasInterface(
|
|
TransformParamTypeInterface::getInterfaceID()) ||
|
|
abstractType.hasInterface(
|
|
TransformValueHandleTypeInterface::getInterfaceID())) &&
|
|
"expected Transform dialect type to implement one of the three "
|
|
"interfaces");
|
|
}
|
|
#endif // NDEBUG
|
|
|
|
void transform::TransformDialect::initialize() {
|
|
// Using the checked versions to enable the same assertions as for the ops
|
|
// from extensions.
|
|
addOperationsChecked<
|
|
#define GET_OP_LIST
|
|
#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
|
|
>();
|
|
initializeTypes();
|
|
initializeLibraryModule();
|
|
}
|
|
|
|
Type transform::TransformDialect::parseType(DialectAsmParser &parser) const {
|
|
StringRef keyword;
|
|
SMLoc loc = parser.getCurrentLocation();
|
|
if (failed(parser.parseKeyword(&keyword)))
|
|
return nullptr;
|
|
|
|
auto it = typeParsingHooks.find(keyword);
|
|
if (it == typeParsingHooks.end()) {
|
|
parser.emitError(loc) << "unknown type mnemonic: " << keyword;
|
|
return nullptr;
|
|
}
|
|
|
|
return it->getValue()(parser);
|
|
}
|
|
|
|
void transform::TransformDialect::printType(Type type,
|
|
DialectAsmPrinter &printer) const {
|
|
auto it = typePrintingHooks.find(type.getTypeID());
|
|
assert(it != typePrintingHooks.end() && "printing unknown type");
|
|
it->getSecond()(type, printer);
|
|
}
|
|
|
|
LogicalResult transform::TransformDialect::loadIntoLibraryModule(
|
|
::mlir::OwningOpRef<::mlir::ModuleOp> &&library) {
|
|
return detail::mergeSymbolsInto(getLibraryModule(), std::move(library));
|
|
}
|
|
|
|
void transform::TransformDialect::initializeLibraryModule() {
|
|
MLIRContext *context = getContext();
|
|
auto loc =
|
|
FileLineColLoc::get(context, "<transform-dialect-library-module>", 0, 0);
|
|
libraryModule = ModuleOp::create(loc, "__transform_library");
|
|
libraryModule.get()->setAttr(TransformDialect::kWithNamedSequenceAttrName,
|
|
UnitAttr::get(context));
|
|
}
|
|
|
|
void transform::TransformDialect::reportDuplicateTypeRegistration(
|
|
StringRef mnemonic) {
|
|
std::string buffer;
|
|
llvm::raw_string_ostream msg(buffer);
|
|
msg << "extensible dialect type '" << mnemonic
|
|
<< "' is already registered with a different implementation";
|
|
msg.flush();
|
|
llvm::report_fatal_error(StringRef(buffer));
|
|
}
|
|
|
|
void transform::TransformDialect::reportDuplicateOpRegistration(
|
|
StringRef opName) {
|
|
std::string buffer;
|
|
llvm::raw_string_ostream msg(buffer);
|
|
msg << "extensible dialect operation '" << opName
|
|
<< "' is already registered with a mismatching TypeID";
|
|
msg.flush();
|
|
llvm::report_fatal_error(StringRef(buffer));
|
|
}
|
|
|
|
LogicalResult transform::TransformDialect::verifyOperationAttribute(
|
|
Operation *op, NamedAttribute attribute) {
|
|
if (attribute.getName().getValue() == kWithNamedSequenceAttrName) {
|
|
if (!op->hasTrait<OpTrait::SymbolTable>()) {
|
|
return emitError(op->getLoc()) << attribute.getName()
|
|
<< " attribute can only be attached to "
|
|
"operations with symbol tables";
|
|
}
|
|
|
|
const mlir::CallGraph callgraph(op);
|
|
for (auto scc = llvm::scc_begin(&callgraph); !scc.isAtEnd(); ++scc) {
|
|
if (!scc.hasCycle())
|
|
continue;
|
|
|
|
// Need to check this here additionally because this verification may run
|
|
// before we check the nested operations.
|
|
if ((*scc->begin())->isExternal())
|
|
return op->emitOpError() << "contains a call to an external operation, "
|
|
"which is not allowed";
|
|
|
|
Operation *first = (*scc->begin())->getCallableRegion()->getParentOp();
|
|
InFlightDiagnostic diag = emitError(first->getLoc())
|
|
<< "recursion not allowed in named sequences";
|
|
for (auto it = std::next(scc->begin()); it != scc->end(); ++it) {
|
|
// Need to check this here additionally because this verification may
|
|
// run before we check the nested operations.
|
|
if ((*it)->isExternal()) {
|
|
return op->emitOpError() << "contains a call to an external "
|
|
"operation, which is not allowed";
|
|
}
|
|
|
|
Operation *current = (*it)->getCallableRegion()->getParentOp();
|
|
diag.attachNote(current->getLoc()) << "operation on recursion stack";
|
|
}
|
|
return diag;
|
|
}
|
|
return success();
|
|
}
|
|
if (attribute.getName().getValue() == kTargetTagAttrName) {
|
|
if (!llvm::isa<StringAttr>(attribute.getValue())) {
|
|
return op->emitError()
|
|
<< attribute.getName() << " attribute must be a string";
|
|
}
|
|
return success();
|
|
}
|
|
if (attribute.getName().getValue() == kArgConsumedAttrName ||
|
|
attribute.getName().getValue() == kArgReadOnlyAttrName) {
|
|
if (!llvm::isa<UnitAttr>(attribute.getValue())) {
|
|
return op->emitError()
|
|
<< attribute.getName() << " must be a unit attribute";
|
|
}
|
|
return success();
|
|
}
|
|
if (attribute.getName().getValue() ==
|
|
FindPayloadReplacementOpInterface::kSilenceTrackingFailuresAttrName) {
|
|
if (!llvm::isa<UnitAttr>(attribute.getValue())) {
|
|
return op->emitError()
|
|
<< attribute.getName() << " must be a unit attribute";
|
|
}
|
|
return success();
|
|
}
|
|
return emitError(op->getLoc())
|
|
<< "unknown attribute: " << attribute.getName();
|
|
}
|