[mlir] Add support for "promised" interfaces

Promised interfaces allow for a dialect to "promise" the implementation of an interface, i.e.
declare that it supports an interface, but have the interface defined in an extension in a library
separate from the dialect itself. A promised interface is powerful in that it alerts the user when
the interface is attempted to be used (e.g. via cast/dyn_cast/etc.) and the implementation has
not yet been provided. This makes the system much more robust against misconfiguration,
and ensures that we do not lose the benefit we currently have of defining the interface in
the dialect library.

Differential Revision: https://reviews.llvm.org/D120368
This commit is contained in:
River Riddle
2022-02-22 14:51:37 -08:00
parent fb19fa2f3d
commit a5ef51d786
40 changed files with 411 additions and 75 deletions

View File

@@ -18,6 +18,7 @@
#include "mlir/Conversion/Passes.h"
#include "mlir/Dialect/Affine/Passes.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Func/Extensions/InlinerExtension.h"
#include "mlir/InitAllDialects.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
@@ -34,20 +35,28 @@ namespace fir::support {
mlir::vector::VectorDialect, mlir::math::MathDialect, \
mlir::complex::ComplexDialect, mlir::DLTIDialect
#define FLANG_CODEGEN_DIALECT_LIST FIRCodeGenDialect, mlir::LLVM::LLVMDialect
// The definitive list of dialects used by flang.
#define FLANG_DIALECT_LIST \
FLANG_NONCODEGEN_DIALECT_LIST, FIRCodeGenDialect, mlir::LLVM::LLVMDialect
FLANG_NONCODEGEN_DIALECT_LIST, FLANG_CODEGEN_DIALECT_LIST
inline void registerNonCodegenDialects(mlir::DialectRegistry &registry) {
registry.insert<FLANG_NONCODEGEN_DIALECT_LIST>();
mlir::func::registerInlinerExtension(registry);
}
/// Register all the dialects used by flang.
inline void registerDialects(mlir::DialectRegistry &registry) {
registry.insert<FLANG_DIALECT_LIST>();
registerNonCodegenDialects(registry);
registry.insert<FLANG_CODEGEN_DIALECT_LIST>();
}
inline void loadNonCodegenDialects(mlir::MLIRContext &context) {
mlir::DialectRegistry registry;
registerNonCodegenDialects(registry);
context.appendDialectRegistry(registry);
context.loadDialect<FLANG_NONCODEGEN_DIALECT_LIST>();
}
@@ -55,6 +64,10 @@ inline void loadNonCodegenDialects(mlir::MLIRContext &context) {
/// pass, but a producer of FIR and MLIR. It is therefore a requirement that the
/// dialects be preloaded to be able to build the IR.
inline void loadDialects(mlir::MLIRContext &context) {
mlir::DialectRegistry registry;
registerDialects(registry);
context.appendDialectRegistry(registry);
context.loadDialect<FLANG_DIALECT_LIST>();
}

View File

@@ -1,4 +1,5 @@
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS)
add_flang_library(flangFrontend
CompilerInstance.cpp
@@ -18,6 +19,7 @@ add_flang_library(flangFrontend
HLFIRDialect
MLIRIR
${dialect_libs}
${extension_libs}
LINK_LIBS
FortranParser
@@ -39,6 +41,7 @@ add_flang_library(flangFrontend
MLIRSCFToControlFlow
MLIRTargetLLVMIRImport
${dialect_libs}
${extension_libs}
LINK_COMPONENTS
Passes

View File

@@ -1,4 +1,5 @@
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS)
add_flang_library(FortranLower
Allocatable.cpp
@@ -33,6 +34,7 @@ add_flang_library(FortranLower
FIRTransforms
HLFIRDialect
${dialect_libs}
${extension_libs}
LINK_LIBS
FIRDialect
@@ -42,6 +44,7 @@ add_flang_library(FortranLower
FIRTransforms
HLFIRDialect
${dialect_libs}
${extension_libs}
FortranCommon
FortranParser
FortranEvaluate

View File

@@ -1,4 +1,5 @@
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS)
add_flang_library(FIRBuilder
BoxValue.cpp
@@ -31,6 +32,7 @@ add_flang_library(FIRBuilder
FIRDialect
HLFIRDialect
${dialect_libs}
${extension_libs}
LINK_LIBS
FIRDialect
@@ -38,4 +40,5 @@ add_flang_library(FIRBuilder
FIRSupport
HLFIRDialect
${dialect_libs}
${extension_libs}
)

View File

@@ -1,4 +1,5 @@
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS)
add_flang_library(FIRSupport
InitFIR.cpp
@@ -9,9 +10,11 @@ add_flang_library(FIRSupport
HLFIROpsIncGen
MLIRIR
${dialect_libs}
${extension_libs}
LINK_LIBS
${dialect_libs}
${extension_libs}
MLIRBuiltinToLLVMIRTranslation
MLIROpenACCToLLVMIRTranslation
MLIROpenMPToLLVMIRTranslation

View File

@@ -10,6 +10,7 @@ FIROptCodeGenPassIncGen
llvm_update_compile_flags(bbc)
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS)
target_link_libraries(bbc PRIVATE
FIRDialect
FIRDialectSupport
@@ -19,6 +20,7 @@ FIRBuilder
HLFIRDialect
HLFIRTransforms
${dialect_libs}
${extension_libs}
MLIRAffineToStandard
MLIRSCFToControlFlow
FortranCommon

View File

@@ -1,6 +1,7 @@
add_flang_tool(fir-opt fir-opt.cpp)
llvm_update_compile_flags(fir-opt)
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS)
if(FLANG_INCLUDE_TESTS)
set(test_libs
@@ -18,6 +19,7 @@ target_link_libraries(fir-opt PRIVATE
FIRAnalysis
${test_libs}
${dialect_libs}
${extension_libs}
# TODO: these should be transitive dependencies from a target providing
# "registerFIRPasses()"

View File

@@ -5,6 +5,7 @@ set(LLVM_LINK_COMPONENTS
add_flang_tool(tco tco.cpp)
llvm_update_compile_flags(tco)
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS)
target_link_libraries(tco PRIVATE
FIRCodeGen
FIRDialect
@@ -15,13 +16,13 @@ target_link_libraries(tco PRIVATE
HLFIRDialect
HLFIRTransforms
${dialect_libs}
${extension_libs}
MLIRIR
MLIRLLVMDialect
MLIRBuiltinToLLVMIRTranslation
MLIRLLVMToLLVMIRTranslation
MLIRTargetLLVMIRExport
MLIRPass
MLIRFuncToLLVM
MLIRTransforms
MLIRAffineToStandard
MLIRAnalysis

View File

@@ -1,4 +1,5 @@
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS)
set(LIBS
FIRBuilder
@@ -8,6 +9,7 @@ set(LIBS
FIRSupport
HLFIRDialect
${dialect_libs}
${extension_libs}
LLVMTargetParser
)

View File

@@ -588,6 +588,12 @@ function(add_mlir_conversion_library name)
add_mlir_library(${ARGV} DEPENDS mlir-headers)
endfunction(add_mlir_conversion_library)
# Declare the library associated with an extension.
function(add_mlir_extension_library name)
set_property(GLOBAL APPEND PROPERTY MLIR_EXTENSION_LIBS ${name})
add_mlir_library(${ARGV} DEPENDS mlir-headers)
endfunction(add_mlir_extension_library)
# Declare the library associated with a translation.
function(add_mlir_translation_library name)
set_property(GLOBAL APPEND PROPERTY MLIR_TRANSLATION_LIBS ${name})

View File

@@ -24,6 +24,7 @@ export(TARGETS ${MLIR_EXPORTS} FILE ${mlir_cmake_builddir}/MLIRTargets.cmake)
get_property(MLIR_ALL_LIBS GLOBAL PROPERTY MLIR_ALL_LIBS)
get_property(MLIR_DIALECT_LIBS GLOBAL PROPERTY MLIR_DIALECT_LIBS)
get_property(MLIR_CONVERSION_LIBS GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
get_property(MLIR_EXTENSION_LIBS GLOBAL PROPERTY MLIR_EXTENSION_LIBS)
get_property(MLIR_TRANSLATION_LIBS GLOBAL PROPERTY MLIR_TRANSLATION_LIBS)
# Generate MlirConfig.cmake for the build tree.

View File

@@ -21,6 +21,7 @@ set(MLIR_MAIN_SRC_DIR "@MLIR_MAIN_SRC_DIR@")
set_property(GLOBAL PROPERTY MLIR_ALL_LIBS "@MLIR_ALL_LIBS@")
set_property(GLOBAL PROPERTY MLIR_DIALECT_LIBS "@MLIR_DIALECT_LIBS@")
set_property(GLOBAL PROPERTY MLIR_CONVERSION_LIBS "@MLIR_CONVERSION_LIBS@")
set_property(GLOBAL PROPERTY MLIR_EXTENSION_LIBS "@MLIR_EXTENSION_LIBS@")
set_property(GLOBAL PROPERTY MLIR_TRANSLATION_LIBS "@MLIR_TRANSLATION_LIBS@")
# Provide all our library targets to users.

View File

@@ -28,9 +28,11 @@ add_toy_chapter(toyc-ch5
include_directories(${CMAKE_CURRENT_BINARY_DIR})
include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/)
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS)
target_link_libraries(toyc-ch5
PRIVATE
${dialect_libs}
${extension_libs}
MLIRAnalysis
MLIRCallInterfaces
MLIRCastInterfaces

View File

@@ -10,6 +10,7 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
#include "toy/Dialect.h"
#include "toy/MLIRGen.h"
#include "toy/Parser.h"
@@ -107,7 +108,10 @@ int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context,
}
int dumpMLIR() {
mlir::MLIRContext context;
mlir::DialectRegistry registry;
mlir::func::registerAllExtensions(registry);
mlir::MLIRContext context(registry);
// Load our Dialect in this MLIR Context.
context.getOrLoadDialect<mlir::toy::ToyDialect>();

View File

@@ -39,10 +39,12 @@ include_directories(${CMAKE_CURRENT_BINARY_DIR})
include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/)
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS)
target_link_libraries(toyc-ch6
PRIVATE
${dialect_libs}
${conversion_libs}
${extension_libs}
MLIRAnalysis
MLIRBuiltinToLLVMIRTranslation
MLIRCallInterfaces

View File

@@ -10,6 +10,7 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
#include "toy/Dialect.h"
#include "toy/MLIRGen.h"
#include "toy/Parser.h"
@@ -289,8 +290,10 @@ int main(int argc, char **argv) {
return dumpAST();
// If we aren't dumping the AST, then we are compiling with/to MLIR.
mlir::DialectRegistry registry;
mlir::func::registerAllExtensions(registry);
mlir::MLIRContext context;
mlir::MLIRContext context(registry);
// Load our Dialect in this MLIR Context.
context.getOrLoadDialect<mlir::toy::ToyDialect>();

View File

@@ -38,10 +38,12 @@ include_directories(${CMAKE_CURRENT_BINARY_DIR})
include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/)
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS)
target_link_libraries(toyc-ch7
PRIVATE
${dialect_libs}
${conversion_libs}
${extension_libs}
MLIRAnalysis
MLIRBuiltinToLLVMIRTranslation
MLIRCallInterfaces

View File

@@ -10,6 +10,7 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
#include "toy/Dialect.h"
#include "toy/MLIRGen.h"
#include "toy/Parser.h"
@@ -290,8 +291,10 @@ int main(int argc, char **argv) {
return dumpAST();
// If we aren't dumping the AST, then we are compiling with/to MLIR.
mlir::DialectRegistry registry;
mlir::func::registerAllExtensions(registry);
mlir::MLIRContext context;
mlir::MLIRContext context(registry);
// Load our Dialect in this MLIR Context.
context.getOrLoadDialect<mlir::toy::ToyDialect>();

View File

@@ -0,0 +1,30 @@
//===- AllExtensions.h - All Func Extensions --------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines a common entry point for registering all extensions to the
// func dialect.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_FUNC_EXTENSIONS_ALLEXTENSIONS_H
#define MLIR_DIALECT_FUNC_EXTENSIONS_ALLEXTENSIONS_H
namespace mlir {
class DialectRegistry;
namespace func {
/// Register all extensions of the func dialect. This should generally only be
/// used by tools, or other use cases that really do want *all* extensions of
/// the dialect. All other cases should prefer to instead register the specific
/// extensions they intend to take advantage of.
void registerAllExtensions(DialectRegistry &registry);
} // namespace func
} // namespace mlir
#endif // MLIR_DIALECT_FUNC_EXTENSIONS_ALLEXTENSIONS_H

View File

@@ -0,0 +1,27 @@
//===- InlinerExtension.h - Func Inliner Extension 0000----------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines an extension for the func dialect that implements the
// interfaces necessary to support inlining.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_FUNC_EXTENSIONS_INLINEREXTENSION_H
#define MLIR_DIALECT_FUNC_EXTENSIONS_INLINEREXTENSION_H
namespace mlir {
class DialectRegistry;
namespace func {
/// Register the extension used to support inlining the func dialect.
void registerInlinerExtension(DialectRegistry &registry);
} // namespace func
} // namespace mlir
#endif // MLIR_DIALECT_FUNC_EXTENSIONS_INLINEREXTENSION_H

View File

@@ -21,7 +21,6 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
def Func_Dialect : Dialect {
let name = "func";
let cppNamespace = "::mlir::func";
let dependentDialects = ["cf::ControlFlowDialect"];
let hasConstantMaterializer = 1;
let usePropertiesForAttributes = 1;
}

View File

@@ -285,6 +285,14 @@ public:
private:
/// Returns the impl interface instance for the given type.
static typename InterfaceBase::Concept *getInterfaceFor(Attribute attr) {
#ifndef NDEBUG
// Check that the current interface isn't an unresolved promise for the
// given attribute.
dialect_extension_detail::handleUseOfUndefinedPromisedInterface(
attr.getDialect(), ConcreteType::getInterfaceID(),
llvm::getTypeName<ConcreteType>());
#endif
return attr.getAbstractAttribute().getInterface<ConcreteType>();
}

View File

@@ -159,11 +159,20 @@ public:
/// Lookup an interface for the given ID if one is registered, otherwise
/// nullptr.
DialectInterface *getRegisteredInterface(TypeID interfaceID) {
#ifndef NDEBUG
handleUseOfUndefinedPromisedInterface(interfaceID);
#endif
auto it = registeredInterfaces.find(interfaceID);
return it != registeredInterfaces.end() ? it->getSecond().get() : nullptr;
}
template <typename InterfaceT>
InterfaceT *getRegisteredInterface() {
#ifndef NDEBUG
handleUseOfUndefinedPromisedInterface(InterfaceT::getInterfaceID(),
llvm::getTypeName<InterfaceT>());
#endif
return static_cast<InterfaceT *>(
getRegisteredInterface(InterfaceT::getInterfaceID()));
}
@@ -196,6 +205,37 @@ public:
return *interface;
}
/// Declare that the given interface will be implemented, but has a delayed
/// registration. The promised interface type can be an interface of any type
/// not just a dialect interface, i.e. it may also be an
/// AttributeInterface/OpInterface/TypeInterface/etc.
template <typename InterfaceT>
void declarePromisedInterface() {
unresolvedPromisedInterfaces.insert(InterfaceT::getInterfaceID());
}
/// Checks if the given interface, which is attempting to be used, is a
/// promised interface of this dialect that has yet to be implemented. If so,
/// emits a fatal error. `interfaceName` is an optional string that contains a
/// more user readable name for the interface (such as the class name).
void handleUseOfUndefinedPromisedInterface(TypeID interfaceID,
StringRef interfaceName = "") {
if (unresolvedPromisedInterfaces.count(interfaceID)) {
llvm::report_fatal_error(
"checking for an interface (`" + interfaceName +
"`) that was promised by dialect '" + getNamespace() +
"' but never implemented. This is generally an indication "
"that the dialect extension implementing the interface was never "
"registered.");
}
}
/// Checks if the given interface, which is attempting to be attached to a
/// construct owned by this dialect, is a promised interface of this dialect
/// that has yet to be implemented. If so, it resolves the interface promise.
void handleAdditionOfUndefinedPromisedInterface(TypeID interfaceID) {
unresolvedPromisedInterfaces.erase(interfaceID);
}
protected:
/// The constructor takes a unique namespace for this dialect as well as the
/// context to bind to.
@@ -289,6 +329,11 @@ private:
/// A collection of registered dialect interfaces.
DenseMap<TypeID, std::unique_ptr<DialectInterface>> registeredInterfaces;
/// A set of interfaces that the dialect (or its constructs, i.e.
/// Attributes/Operations/Types/etc.) has promised to implement, but has yet
/// to provide an implementation for.
DenseSet<TypeID> unresolvedPromisedInterfaces;
friend class DialectRegistry;
friend void registerDialect();
friend class MLIRContext;

View File

@@ -97,6 +97,22 @@ protected:
}
};
namespace dialect_extension_detail {
/// Checks if the given interface, which is attempting to be used, is a
/// promised interface of this dialect that has yet to be implemented. If so,
/// emits a fatal error.
void handleUseOfUndefinedPromisedInterface(Dialect &dialect, TypeID interfaceID,
StringRef interfaceName);
/// Checks if the given interface, which is attempting to be attached, is a
/// promised interface of this dialect that has yet to be implemented. If so,
/// the promised interface is marked as resolved.
void handleAdditionOfUndefinedPromisedInterface(Dialect &dialect,
TypeID interfaceID);
} // namespace dialect_extension_detail
//===----------------------------------------------------------------------===//
// DialectRegistry
//===----------------------------------------------------------------------===//

View File

@@ -2083,6 +2083,16 @@ protected:
static typename InterfaceBase::Concept *getInterfaceFor(Operation *op) {
OperationName name = op->getName();
#ifndef NDEBUG
// Check that the current interface isn't an unresolved promise for the
// given operation.
if (Dialect *dialect = name.getDialect()) {
dialect_extension_detail::handleUseOfUndefinedPromisedInterface(
*dialect, ConcreteType::getInterfaceID(),
llvm::getTypeName<ConcreteType>());
}
#endif
// Access the raw interface from the operation info.
if (std::optional<RegisteredOperationName> rInfo =
name.getRegisteredInfo()) {

View File

@@ -18,6 +18,7 @@
#include "mlir/IR/BlockSupport.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/TypeRange.h"
#include "mlir/IR/Types.h"
@@ -348,6 +349,11 @@ public:
/// interfaces for the concrete operation.
template <typename... Models>
void attachInterface() {
// Handle the case where the models resolve a promised interface.
(dialect_extension_detail::handleAdditionOfUndefinedPromisedInterface(
*getDialect(), Models::Interface::getInterfaceID()),
...);
getImpl()->getInterfaceMap().insertModels<Models...>();
}

View File

@@ -14,6 +14,7 @@
#define MLIR_IR_STORAGEUNIQUERSUPPORT_H
#include "mlir/IR/AttrTypeSubElements.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/Support/InterfaceSupport.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Support/StorageUniquer.h"
@@ -160,6 +161,11 @@ public:
llvm::report_fatal_error("Registering an interface for an attribute/type "
"that is not itself registered.");
// Handle the case where the models resolve a promised interface.
(dialect_extension_detail::handleAdditionOfUndefinedPromisedInterface(
abstract->getDialect(), IfaceModels::Interface::getInterfaceID()),
...);
(checkInterfaceTarget<IfaceModels>(), ...);
abstract->interfaceMap.template insertModels<IfaceModels...>();
}

View File

@@ -269,6 +269,14 @@ public:
private:
/// Returns the impl interface instance for the given type.
static typename InterfaceBase::Concept *getInterfaceFor(Type type) {
#ifndef NDEBUG
// Check that the current interface isn't an unresolved promise for the
// given type.
dialect_extension_detail::handleUseOfUndefinedPromisedInterface(
type.getDialect(), ConcreteType::getInterfaceID(),
llvm::getTypeName<ConcreteType>());
#endif
return type.getAbstractType().getInterface<ConcreteType>();
}

View File

@@ -0,0 +1,34 @@
//===- InitAllExtensions.h - MLIR Extension Registration --------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines a helper to trigger the registration of all dialect
// extensions to the system.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_INITALLEXTENSIONS_H_
#define MLIR_INITALLEXTENSIONS_H_
#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
#include <cstdlib>
namespace mlir {
/// This function may be called to register all MLIR dialect extensions with the
/// provided registry.
/// If you're building a compiler, you generally shouldn't use this: you would
/// individually register the specific extensions that are useful for the
/// pipelines and transformations you are using.
inline void registerAllExtensions(DialectRegistry &registry) {
func::registerAllExtensions(registry);
}
} // namespace mlir
#endif // MLIR_INITALLEXTENSIONS_H_

View File

@@ -1,2 +1,3 @@
add_subdirectory(Extensions)
add_subdirectory(IR)
add_subdirectory(Transforms)

View File

@@ -0,0 +1,16 @@
//===- AllExtensions.cpp - All Func Dialect Extensions --------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
#include "mlir/Dialect/Func/Extensions/InlinerExtension.h"
using namespace mlir;
void mlir::func::registerAllExtensions(DialectRegistry &registry) {
registerInlinerExtension(registry);
}

View File

@@ -0,0 +1,27 @@
set(LLVM_OPTIONAL_SOURCES
AllExtensions.cpp
InlinerExtension.cpp
)
add_mlir_extension_library(MLIRFuncInlinerExtension
InlinerExtension.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Func/Extensions
LINK_LIBS PUBLIC
MLIRControlFlowDialect
MLIRInferTypeOpInterface
MLIRIR
MLIRFuncDialect
)
add_mlir_extension_library(MLIRFuncAllExtensions
AllExtensions.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Func/Extensions
LINK_LIBS PUBLIC
MLIRFuncInlinerExtension
)

View File

@@ -0,0 +1,90 @@
//===- InlinerExtension.cpp - Func Inliner Extension ----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Func/Extensions/InlinerExtension.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/DialectInterface.h"
#include "mlir/Transforms/InliningUtils.h"
using namespace mlir;
using namespace mlir::func;
//===----------------------------------------------------------------------===//
// FuncDialect Interfaces
//===----------------------------------------------------------------------===//
namespace {
/// This class defines the interface for handling inlining with func operations.
struct FuncInlinerInterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface;
//===--------------------------------------------------------------------===//
// Analysis Hooks
//===--------------------------------------------------------------------===//
/// All call operations can be inlined.
bool isLegalToInline(Operation *call, Operation *callable,
bool wouldBeCloned) const final {
return true;
}
/// All operations can be inlined.
bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
return true;
}
/// All functions can be inlined.
bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final {
return true;
}
//===--------------------------------------------------------------------===//
// Transformation Hooks
//===--------------------------------------------------------------------===//
/// Handle the given inlined terminator by replacing it with a new operation
/// as necessary.
void handleTerminator(Operation *op, Block *newDest) const final {
// Only return needs to be handled here.
auto returnOp = dyn_cast<ReturnOp>(op);
if (!returnOp)
return;
// Replace the return with a branch to the dest.
OpBuilder builder(op);
builder.create<cf::BranchOp>(op->getLoc(), newDest, returnOp.getOperands());
op->erase();
}
/// Handle the given inlined terminator by replacing it with a new operation
/// as necessary.
void handleTerminator(Operation *op,
ArrayRef<Value> valuesToRepl) const final {
// Only return needs to be handled here.
auto returnOp = cast<ReturnOp>(op);
// Replace the values directly with the return operands.
assert(returnOp.getNumOperands() == valuesToRepl.size());
for (const auto &it : llvm::enumerate(returnOp.getOperands()))
valuesToRepl[it.index()].replaceAllUsesWith(it.value());
}
};
} // namespace
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
void mlir::func::registerInlinerExtension(DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, func::FuncDialect *dialect) {
dialect->addInterfaces<FuncInlinerInterface>();
// The inliner extension relies on the ControlFlow dialect.
ctx->getOrLoadDialect<cf::ControlFlowDialect>();
});
}

View File

@@ -9,7 +9,6 @@ add_mlir_dialect_library(MLIRFuncDialect
LINK_LIBS PUBLIC
MLIRCallInterfaces
MLIRControlFlowDialect
MLIRControlFlowInterfaces
MLIRInferTypeOpInterface
MLIRIR

View File

@@ -8,8 +8,6 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/FunctionImplementation.h"
@@ -33,67 +31,6 @@
using namespace mlir;
using namespace mlir::func;
//===----------------------------------------------------------------------===//
// FuncDialect Interfaces
//===----------------------------------------------------------------------===//
namespace {
/// This class defines the interface for handling inlining with func operations.
struct FuncInlinerInterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface;
//===--------------------------------------------------------------------===//
// Analysis Hooks
//===--------------------------------------------------------------------===//
/// All call operations can be inlined.
bool isLegalToInline(Operation *call, Operation *callable,
bool wouldBeCloned) const final {
return true;
}
/// All operations can be inlined.
bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
return true;
}
/// All functions can be inlined.
bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final {
return true;
}
//===--------------------------------------------------------------------===//
// Transformation Hooks
//===--------------------------------------------------------------------===//
/// Handle the given inlined terminator by replacing it with a new operation
/// as necessary.
void handleTerminator(Operation *op, Block *newDest) const final {
// Only return needs to be handled here.
auto returnOp = dyn_cast<ReturnOp>(op);
if (!returnOp)
return;
// Replace the return with a branch to the dest.
OpBuilder builder(op);
builder.create<cf::BranchOp>(op->getLoc(), newDest, returnOp.getOperands());
op->erase();
}
/// Handle the given inlined terminator by replacing it with a new operation
/// as necessary.
void handleTerminator(Operation *op,
ArrayRef<Value> valuesToRepl) const final {
// Only return needs to be handled here.
auto returnOp = cast<ReturnOp>(op);
// Replace the values directly with the return operands.
assert(returnOp.getNumOperands() == valuesToRepl.size());
for (const auto &it : llvm::enumerate(returnOp.getOperands()))
valuesToRepl[it.index()].replaceAllUsesWith(it.value());
}
};
} // namespace
//===----------------------------------------------------------------------===//
// FuncDialect
//===----------------------------------------------------------------------===//
@@ -103,7 +40,7 @@ void FuncDialect::initialize() {
#define GET_OP_LIST
#include "mlir/Dialect/Func/IR/FuncOps.cpp.inc"
>();
addInterfaces<FuncInlinerInterface>();
declarePromisedInterface<DialectInlinerInterface>();
}
/// Materialize a single constant operation from a given attribute value with

View File

@@ -96,6 +96,9 @@ bool Dialect::isValidNamespace(StringRef str) {
/// Register a set of dialect interfaces with this dialect instance.
void Dialect::addInterface(std::unique_ptr<DialectInterface> interface) {
// Handle the case where the models resolve a promised interface.
handleAdditionOfUndefinedPromisedInterface(interface->getID());
auto it = registeredInterfaces.try_emplace(interface->getID(),
std::move(interface));
(void)it;
@@ -143,6 +146,16 @@ DialectInterfaceCollectionBase::getInterfaceFor(Operation *op) const {
DialectExtensionBase::~DialectExtensionBase() = default;
void dialect_extension_detail::handleUseOfUndefinedPromisedInterface(
Dialect &dialect, TypeID interfaceID, StringRef interfaceName) {
dialect.handleUseOfUndefinedPromisedInterface(interfaceID, interfaceName);
}
void dialect_extension_detail::handleAdditionOfUndefinedPromisedInterface(
Dialect &dialect, TypeID interfaceID) {
dialect.handleAdditionOfUndefinedPromisedInterface(interfaceID);
}
//===----------------------------------------------------------------------===//
// DialectRegistry
//===----------------------------------------------------------------------===//

View File

@@ -1725,10 +1725,10 @@ int registerOnlyStd(void) {
fprintf(stderr, "@registration\n");
// CHECK-LABEL: @registration
// CHECK: cf.cond_br is_registered: 1
fprintf(stderr, "cf.cond_br is_registered: %d\n",
// CHECK: func.call is_registered: 1
fprintf(stderr, "func.call is_registered: %d\n",
mlirContextIsRegisteredOperation(
ctx, mlirStringRefCreateFromCString("cf.cond_br")));
ctx, mlirStringRefCreateFromCString("func.call")));
// CHECK: func.not_existing_op is_registered: 0
fprintf(stderr, "func.not_existing_op is_registered: %d\n",
@@ -1942,6 +1942,7 @@ int testClone(void) {
registerAllUpstreamDialects(ctx);
mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("func"));
mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("arith"));
MlirLocation loc = mlirLocationUnknownGet(ctx);
MlirType indexType = mlirIndexTypeGet(ctx);
MlirStringRef valueStringRef = mlirStringRefCreateFromCString("value");

View File

@@ -4,6 +4,7 @@ set(LLVM_OPTIONAL_SOURCES
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS)
set(LLVM_LINK_COMPONENTS
Core
Support
@@ -50,7 +51,9 @@ endif()
set(LIBS
${dialect_libs}
${conversion_libs}
${extension_libs}
${test_libs}
MLIRAffineAnalysis
MLIRAnalysis
MLIRCastInterfaces

View File

@@ -14,6 +14,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/InitAllDialects.h"
#include "mlir/InitAllExtensions.h"
#include "mlir/InitAllPasses.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
@@ -259,6 +260,8 @@ int main(int argc, char **argv) {
#endif
DialectRegistry registry;
registerAllDialects(registry);
registerAllExtensions(registry);
#ifdef MLIR_INCLUDE_TESTS
::test::registerTestDialect(registry);
::test::registerTestTransformDialectExtension(registry);

View File

@@ -7,6 +7,7 @@ add_mlir_unittest(MLIRInterfacesTests
target_link_libraries(MLIRInterfacesTests
PRIVATE
MLIRArithDialect
MLIRControlFlowInterfaces
MLIRDataLayoutInterfaces
MLIRDLTIDialect