Summary: With the move towards dialect registration that does not depend only use static initialization, we are running into more cases where the dialects are registered by different methods. For example, TensorFlow still uses static initialization to register all MLIR core dialects, which prevents explicit registration of any of them when linking it in. We ran into this issue in https://github.com/google/iree/pull/982. To address potential issues with conflicts from non-standard allocators passed to registerDialectAllocator, made this method private. Now all dialects can only be registered with their constructor. Similarly deduplicates DialectHooks for consistency and makes their registration follow the same pattern. Differential Revision: https://reviews.llvm.org/D76329
159 lines
5.8 KiB
C++
159 lines
5.8 KiB
C++
//===- Dialect.cpp - Dialect implementation -------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/IR/Dialect.h"
|
|
#include "mlir/IR/Diagnostics.h"
|
|
#include "mlir/IR/DialectHooks.h"
|
|
#include "mlir/IR/DialectImplementation.h"
|
|
#include "mlir/IR/DialectInterface.h"
|
|
#include "mlir/IR/MLIRContext.h"
|
|
#include "mlir/IR/Operation.h"
|
|
#include "llvm/ADT/MapVector.h"
|
|
#include "llvm/ADT/Twine.h"
|
|
#include "llvm/Support/ManagedStatic.h"
|
|
#include "llvm/Support/Regex.h"
|
|
|
|
using namespace mlir;
|
|
using namespace detail;
|
|
|
|
DialectAsmParser::~DialectAsmParser() {}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Dialect Registration
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Registry for all dialect allocation functions.
|
|
static llvm::ManagedStatic<
|
|
llvm::MapVector<const ClassID *, DialectAllocatorFunction>>
|
|
dialectRegistry;
|
|
|
|
/// Registry for functions that set dialect hooks.
|
|
static llvm::ManagedStatic<llvm::MapVector<const ClassID *, DialectHooksSetter>>
|
|
dialectHooksRegistry;
|
|
|
|
void Dialect::registerDialectAllocator(
|
|
const ClassID *classId, const DialectAllocatorFunction &function) {
|
|
assert(function &&
|
|
"Attempting to register an empty dialect initialize function");
|
|
dialectRegistry->insert({classId, function});
|
|
}
|
|
|
|
/// Registers a function to set specific hooks for a specific dialect, typically
|
|
/// used through the DialectHooksRegistration template.
|
|
void DialectHooks::registerDialectHooksSetter(
|
|
const ClassID *classId, const DialectHooksSetter &function) {
|
|
assert(
|
|
function &&
|
|
"Attempting to register an empty dialect hooks initialization function");
|
|
|
|
dialectHooksRegistry->insert({classId, function});
|
|
}
|
|
|
|
/// Registers all dialects and hooks from the global registries with the
|
|
/// specified MLIRContext.
|
|
void mlir::registerAllDialects(MLIRContext *context) {
|
|
for (const auto &it : *dialectRegistry)
|
|
it.second(context);
|
|
for (const auto &it : *dialectHooksRegistry) {
|
|
it.second(context);
|
|
}
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Dialect
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
Dialect::Dialect(StringRef name, MLIRContext *context)
|
|
: name(name), context(context) {
|
|
assert(isValidNamespace(name) && "invalid dialect namespace");
|
|
registerDialect(context);
|
|
}
|
|
|
|
Dialect::~Dialect() {}
|
|
|
|
/// Verify an attribute from this dialect on the argument at 'argIndex' for
|
|
/// the region at 'regionIndex' on the given operation. Returns failure if
|
|
/// the verification failed, success otherwise. This hook may optionally be
|
|
/// invoked from any operation containing a region.
|
|
LogicalResult Dialect::verifyRegionArgAttribute(Operation *, unsigned, unsigned,
|
|
NamedAttribute) {
|
|
return success();
|
|
}
|
|
|
|
/// Verify an attribute from this dialect on the result at 'resultIndex' for
|
|
/// the region at 'regionIndex' on the given operation. Returns failure if
|
|
/// the verification failed, success otherwise. This hook may optionally be
|
|
/// invoked from any operation containing a region.
|
|
LogicalResult Dialect::verifyRegionResultAttribute(Operation *, unsigned,
|
|
unsigned, NamedAttribute) {
|
|
return success();
|
|
}
|
|
|
|
/// Parse an attribute registered to this dialect.
|
|
Attribute Dialect::parseAttribute(DialectAsmParser &parser, Type type) const {
|
|
parser.emitError(parser.getNameLoc())
|
|
<< "dialect '" << getNamespace()
|
|
<< "' provides no attribute parsing hook";
|
|
return Attribute();
|
|
}
|
|
|
|
/// Parse a type registered to this dialect.
|
|
Type Dialect::parseType(DialectAsmParser &parser) const {
|
|
// If this dialect allows unknown types, then represent this with OpaqueType.
|
|
if (allowsUnknownTypes()) {
|
|
auto ns = Identifier::get(getNamespace(), getContext());
|
|
return OpaqueType::get(ns, parser.getFullSymbolSpec(), getContext());
|
|
}
|
|
|
|
parser.emitError(parser.getNameLoc())
|
|
<< "dialect '" << getNamespace() << "' provides no type parsing hook";
|
|
return Type();
|
|
}
|
|
|
|
/// Utility function that returns if the given string is a valid dialect
|
|
/// namespace.
|
|
bool Dialect::isValidNamespace(StringRef str) {
|
|
if (str.empty())
|
|
return true;
|
|
llvm::Regex dialectNameRegex("^[a-zA-Z_][a-zA-Z_0-9\\$]*$");
|
|
return dialectNameRegex.match(str);
|
|
}
|
|
|
|
/// Register a set of dialect interfaces with this dialect instance.
|
|
void Dialect::addInterface(std::unique_ptr<DialectInterface> interface) {
|
|
auto it = registeredInterfaces.try_emplace(interface->getID(),
|
|
std::move(interface));
|
|
(void)it;
|
|
assert(it.second && "interface kind has already been registered");
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Dialect Interface
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
DialectInterface::~DialectInterface() {}
|
|
|
|
DialectInterfaceCollectionBase::DialectInterfaceCollectionBase(
|
|
MLIRContext *ctx, ClassID *interfaceKind) {
|
|
for (auto *dialect : ctx->getRegisteredDialects()) {
|
|
if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) {
|
|
interfaces.insert(interface);
|
|
orderedInterfaces.push_back(interface);
|
|
}
|
|
}
|
|
}
|
|
|
|
DialectInterfaceCollectionBase::~DialectInterfaceCollectionBase() {}
|
|
|
|
/// Get the interface for the dialect of given operation, or null if one
|
|
/// is not registered.
|
|
const DialectInterface *
|
|
DialectInterfaceCollectionBase::getInterfaceFor(Operation *op) const {
|
|
return getInterfaceFor(op->getDialect());
|
|
}
|