Files
clang-p2996/mlir/lib/IR/Dialect.cpp
Geoffrey Martin-Noble b72e13c242 [MLIR] Deduplicate dialect registration by ClassID
Summary:
With the move towards dialect registration that does not depend only use
static initialization, we are running into more cases where the dialects
are registered by different methods. For example, TensorFlow still uses
static initialization to register all MLIR core dialects, which prevents
explicit registration of any of them when linking it in. We ran into this
issue in https://github.com/google/iree/pull/982.

To address potential issues with conflicts from non-standard
allocators passed to registerDialectAllocator, made this method
private. Now all dialects can only be registered with their
constructor.

Similarly deduplicates DialectHooks for consistency and makes their
registration follow the same pattern.

Differential Revision: https://reviews.llvm.org/D76329
2020-03-18 19:52:27 -07:00

159 lines
5.8 KiB
C++

//===- Dialect.cpp - Dialect implementation -------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectHooks.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/DialectInterface.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/ManagedStatic.h"
#include "llvm/Support/Regex.h"
using namespace mlir;
using namespace detail;
DialectAsmParser::~DialectAsmParser() {}
//===----------------------------------------------------------------------===//
// Dialect Registration
//===----------------------------------------------------------------------===//
/// Registry for all dialect allocation functions.
static llvm::ManagedStatic<
llvm::MapVector<const ClassID *, DialectAllocatorFunction>>
dialectRegistry;
/// Registry for functions that set dialect hooks.
static llvm::ManagedStatic<llvm::MapVector<const ClassID *, DialectHooksSetter>>
dialectHooksRegistry;
void Dialect::registerDialectAllocator(
const ClassID *classId, const DialectAllocatorFunction &function) {
assert(function &&
"Attempting to register an empty dialect initialize function");
dialectRegistry->insert({classId, function});
}
/// Registers a function to set specific hooks for a specific dialect, typically
/// used through the DialectHooksRegistration template.
void DialectHooks::registerDialectHooksSetter(
const ClassID *classId, const DialectHooksSetter &function) {
assert(
function &&
"Attempting to register an empty dialect hooks initialization function");
dialectHooksRegistry->insert({classId, function});
}
/// Registers all dialects and hooks from the global registries with the
/// specified MLIRContext.
void mlir::registerAllDialects(MLIRContext *context) {
for (const auto &it : *dialectRegistry)
it.second(context);
for (const auto &it : *dialectHooksRegistry) {
it.second(context);
}
}
//===----------------------------------------------------------------------===//
// Dialect
//===----------------------------------------------------------------------===//
Dialect::Dialect(StringRef name, MLIRContext *context)
: name(name), context(context) {
assert(isValidNamespace(name) && "invalid dialect namespace");
registerDialect(context);
}
Dialect::~Dialect() {}
/// Verify an attribute from this dialect on the argument at 'argIndex' for
/// the region at 'regionIndex' on the given operation. Returns failure if
/// the verification failed, success otherwise. This hook may optionally be
/// invoked from any operation containing a region.
LogicalResult Dialect::verifyRegionArgAttribute(Operation *, unsigned, unsigned,
NamedAttribute) {
return success();
}
/// Verify an attribute from this dialect on the result at 'resultIndex' for
/// the region at 'regionIndex' on the given operation. Returns failure if
/// the verification failed, success otherwise. This hook may optionally be
/// invoked from any operation containing a region.
LogicalResult Dialect::verifyRegionResultAttribute(Operation *, unsigned,
unsigned, NamedAttribute) {
return success();
}
/// Parse an attribute registered to this dialect.
Attribute Dialect::parseAttribute(DialectAsmParser &parser, Type type) const {
parser.emitError(parser.getNameLoc())
<< "dialect '" << getNamespace()
<< "' provides no attribute parsing hook";
return Attribute();
}
/// Parse a type registered to this dialect.
Type Dialect::parseType(DialectAsmParser &parser) const {
// If this dialect allows unknown types, then represent this with OpaqueType.
if (allowsUnknownTypes()) {
auto ns = Identifier::get(getNamespace(), getContext());
return OpaqueType::get(ns, parser.getFullSymbolSpec(), getContext());
}
parser.emitError(parser.getNameLoc())
<< "dialect '" << getNamespace() << "' provides no type parsing hook";
return Type();
}
/// Utility function that returns if the given string is a valid dialect
/// namespace.
bool Dialect::isValidNamespace(StringRef str) {
if (str.empty())
return true;
llvm::Regex dialectNameRegex("^[a-zA-Z_][a-zA-Z_0-9\\$]*$");
return dialectNameRegex.match(str);
}
/// Register a set of dialect interfaces with this dialect instance.
void Dialect::addInterface(std::unique_ptr<DialectInterface> interface) {
auto it = registeredInterfaces.try_emplace(interface->getID(),
std::move(interface));
(void)it;
assert(it.second && "interface kind has already been registered");
}
//===----------------------------------------------------------------------===//
// Dialect Interface
//===----------------------------------------------------------------------===//
DialectInterface::~DialectInterface() {}
DialectInterfaceCollectionBase::DialectInterfaceCollectionBase(
MLIRContext *ctx, ClassID *interfaceKind) {
for (auto *dialect : ctx->getRegisteredDialects()) {
if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) {
interfaces.insert(interface);
orderedInterfaces.push_back(interface);
}
}
}
DialectInterfaceCollectionBase::~DialectInterfaceCollectionBase() {}
/// Get the interface for the dialect of given operation, or null if one
/// is not registered.
const DialectInterface *
DialectInterfaceCollectionBase::getInterfaceFor(Operation *op) const {
return getInterfaceFor(op->getDialect());
}