[mlir][capi] Add external pass creation to MLIR C-API
Adds the ability to create external passes using the C-API. This allows passes to be written in C or languages that use the C-bindings. Differential Revision: https://reviews.llvm.org/D121866
This commit is contained in:
@@ -62,7 +62,6 @@ DEFINE_C_API_STRUCT(MlirIdentifier, const void);
|
||||
DEFINE_C_API_STRUCT(MlirLocation, const void);
|
||||
DEFINE_C_API_STRUCT(MlirModule, const void);
|
||||
DEFINE_C_API_STRUCT(MlirType, const void);
|
||||
DEFINE_C_API_STRUCT(MlirTypeID, const void);
|
||||
DEFINE_C_API_STRUCT(MlirValue, const void);
|
||||
|
||||
#undef DEFINE_C_API_STRUCT
|
||||
@@ -757,19 +756,6 @@ MLIR_CAPI_EXPORTED bool mlirIdentifierEqual(MlirIdentifier ident,
|
||||
/// Gets the string value of the identifier.
|
||||
MLIR_CAPI_EXPORTED MlirStringRef mlirIdentifierStr(MlirIdentifier ident);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TypeID API.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Checks whether a type id is null.
|
||||
static inline bool mlirTypeIDIsNull(MlirTypeID typeID) { return !typeID.ptr; }
|
||||
|
||||
/// Checks if two type ids are equal.
|
||||
MLIR_CAPI_EXPORTED bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2);
|
||||
|
||||
/// Returns the hash value of the type id.
|
||||
MLIR_CAPI_EXPORTED size_t mlirTypeIDHashValue(MlirTypeID typeID);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Symbol and SymbolTable API.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
#define MLIR_C_PASS_H
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir-c/Registration.h"
|
||||
#include "mlir-c/Support.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
@@ -41,11 +42,16 @@ extern "C" {
|
||||
typedef struct name name
|
||||
|
||||
DEFINE_C_API_STRUCT(MlirPass, void);
|
||||
DEFINE_C_API_STRUCT(MlirExternalPass, void);
|
||||
DEFINE_C_API_STRUCT(MlirPassManager, void);
|
||||
DEFINE_C_API_STRUCT(MlirOpPassManager, void);
|
||||
|
||||
#undef DEFINE_C_API_STRUCT
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PassManager/OpPassManager APIs.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Create a new top-level PassManager.
|
||||
MLIR_CAPI_EXPORTED MlirPassManager mlirPassManagerCreate(MlirContext ctx);
|
||||
|
||||
@@ -112,6 +118,55 @@ MLIR_CAPI_EXPORTED void mlirPrintPassPipeline(MlirOpPassManager passManager,
|
||||
MLIR_CAPI_EXPORTED MlirLogicalResult
|
||||
mlirParsePassPipeline(MlirOpPassManager passManager, MlirStringRef pipeline);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// External Pass API.
|
||||
//
|
||||
// This API allows to define passes outside of MLIR, not necessarily in
|
||||
// C++, and register them with the MLIR pass management infrastructure.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Structure of external `MlirPass` callbacks.
|
||||
/// All callbacks are required to be set unless otherwise specified.
|
||||
struct MlirExternalPassCallbacks {
|
||||
/// This callback is called from the pass is created.
|
||||
/// This is analogous to a C++ pass constructor.
|
||||
void (*construct)(void *userData);
|
||||
|
||||
/// This callback is called when the pass is destroyed
|
||||
/// This is analogous to a C++ pass destructor.
|
||||
void (*destruct)(void *userData);
|
||||
|
||||
/// This callback is optional.
|
||||
/// The callback is called before the pass is run, allowing a chance to
|
||||
/// initialize any complex state necessary for running the pass.
|
||||
/// See Pass::initialize(MLIRContext *).
|
||||
MlirLogicalResult (*initialize)(MlirContext ctx, void *userData);
|
||||
|
||||
/// This callback is called when the pass is cloned.
|
||||
/// See Pass::clonePass().
|
||||
void *(*clone)(void *userData);
|
||||
|
||||
/// This callback is called when the pass is run.
|
||||
/// See Pass::runOnOperation().
|
||||
void (*run)(MlirOperation op, MlirExternalPass pass, void *userData);
|
||||
};
|
||||
typedef struct MlirExternalPassCallbacks MlirExternalPassCallbacks;
|
||||
|
||||
/// Creates an external `MlirPass` that calls the supplied `callbacks` using the
|
||||
/// supplied `userData`. If `opName` is empty, the pass is a generic operation
|
||||
/// pass. Otherwise it is an operation pass specific to the specified pass name.
|
||||
MLIR_CAPI_EXPORTED MlirPass mlirCreateExternalPass(
|
||||
MlirTypeID passID, MlirStringRef name, MlirStringRef argument,
|
||||
MlirStringRef description, MlirStringRef opName,
|
||||
intptr_t nDependentDialects, MlirDialectHandle *dependentDialects,
|
||||
MlirExternalPassCallbacks callbacks, void *userData);
|
||||
|
||||
/// This signals that the pass has failed. This is only valid to call during
|
||||
/// the `run` callback of `MlirExternalPassCallbacks`.
|
||||
/// See Pass::signalPassFailure().
|
||||
MLIR_CAPI_EXPORTED void mlirExternalPassSignalFailure(MlirExternalPass pass);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -50,6 +50,17 @@
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#define DEFINE_C_API_STRUCT(name, storage) \
|
||||
struct name { \
|
||||
storage *ptr; \
|
||||
}; \
|
||||
typedef struct name name
|
||||
|
||||
DEFINE_C_API_STRUCT(MlirTypeID, const void);
|
||||
DEFINE_C_API_STRUCT(MlirTypeIDAllocator, void);
|
||||
|
||||
#undef DEFINE_C_API_STRUCT
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MlirStringRef.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -127,6 +138,38 @@ inline static MlirLogicalResult mlirLogicalResultFailure() {
|
||||
return res;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TypeID API.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// `ptr` must be 8 byte aligned and unique to a type valid for the duration of
|
||||
/// the returned type id's usage
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirTypeIDCreate(const void *ptr);
|
||||
|
||||
/// Checks whether a type id is null.
|
||||
static inline bool mlirTypeIDIsNull(MlirTypeID typeID) { return !typeID.ptr; }
|
||||
|
||||
/// Checks if two type ids are equal.
|
||||
MLIR_CAPI_EXPORTED bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2);
|
||||
|
||||
/// Returns the hash value of the type id.
|
||||
MLIR_CAPI_EXPORTED size_t mlirTypeIDHashValue(MlirTypeID typeID);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TypeIDAllocator API.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Creates a type id allocator for dynamic type id creation
|
||||
MLIR_CAPI_EXPORTED MlirTypeIDAllocator mlirTypeIDAllocatorCreate();
|
||||
|
||||
/// Deallocates the allocator and all allocated type ids
|
||||
MLIR_CAPI_EXPORTED void
|
||||
mlirTypeIDAllocatorDestroy(MlirTypeIDAllocator allocator);
|
||||
|
||||
/// Allocates a type id that is valid for the lifetime of the allocator
|
||||
MLIR_CAPI_EXPORTED MlirTypeID
|
||||
mlirTypeIDAllocatorAllocateTypeID(MlirTypeIDAllocator allocator);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -34,7 +34,6 @@ DEFINE_C_API_METHODS(MlirIdentifier, mlir::StringAttr)
|
||||
DEFINE_C_API_METHODS(MlirLocation, mlir::Location)
|
||||
DEFINE_C_API_METHODS(MlirModule, mlir::ModuleOp)
|
||||
DEFINE_C_API_METHODS(MlirType, mlir::Type)
|
||||
DEFINE_C_API_METHODS(MlirTypeID, mlir::TypeID)
|
||||
DEFINE_C_API_METHODS(MlirValue, mlir::Value)
|
||||
|
||||
#endif // MLIR_CAPI_IR_H
|
||||
|
||||
@@ -16,7 +16,9 @@
|
||||
#define MLIR_CAPI_SUPPORT_H
|
||||
|
||||
#include "mlir-c/Support.h"
|
||||
#include "mlir/CAPI/Wrap.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "mlir/Support/TypeID.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
/// Converts a StringRef into its MLIR C API equivalent.
|
||||
@@ -39,4 +41,7 @@ inline mlir::LogicalResult unwrap(MlirLogicalResult res) {
|
||||
return mlir::success(mlirLogicalResultIsSuccess(res));
|
||||
}
|
||||
|
||||
DEFINE_C_API_METHODS(MlirTypeID, mlir::TypeID)
|
||||
DEFINE_C_API_PTR_METHODS(MlirTypeIDAllocator, mlir::TypeIDAllocator)
|
||||
|
||||
#endif // MLIR_CAPI_SUPPORT_H
|
||||
|
||||
@@ -787,18 +787,6 @@ MlirStringRef mlirIdentifierStr(MlirIdentifier ident) {
|
||||
return wrap(unwrap(ident).strref());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TypeID API.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2) {
|
||||
return unwrap(typeID1) == unwrap(typeID2);
|
||||
}
|
||||
|
||||
size_t mlirTypeIDHashValue(MlirTypeID typeID) {
|
||||
return hash_value(unwrap(typeID));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Symbol and SymbolTable API.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -77,3 +77,94 @@ MlirLogicalResult mlirParsePassPipeline(MlirOpPassManager passManager,
|
||||
// stream and redirect to a diagnostic.
|
||||
return wrap(mlir::parsePassPipeline(unwrap(pipeline), *unwrap(passManager)));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// External Pass API.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace mlir {
|
||||
class ExternalPass;
|
||||
} // namespace mlir
|
||||
DEFINE_C_API_PTR_METHODS(MlirExternalPass, mlir::ExternalPass)
|
||||
|
||||
namespace mlir {
|
||||
/// This pass class wraps external passes defined in other languages using the
|
||||
/// MLIR C-interface
|
||||
class ExternalPass : public Pass {
|
||||
public:
|
||||
ExternalPass(TypeID passID, StringRef name, StringRef argument,
|
||||
StringRef description, Optional<StringRef> opName,
|
||||
ArrayRef<MlirDialectHandle> dependentDialects,
|
||||
MlirExternalPassCallbacks callbacks, void *userData)
|
||||
: Pass(passID, opName), id(passID), name(name), argument(argument),
|
||||
description(description), dependentDialects(dependentDialects),
|
||||
callbacks(callbacks), userData(userData) {
|
||||
callbacks.construct(userData);
|
||||
}
|
||||
|
||||
~ExternalPass() override { callbacks.destruct(userData); }
|
||||
|
||||
StringRef getName() const override { return name; }
|
||||
StringRef getArgument() const override { return argument; }
|
||||
StringRef getDescription() const override { return description; }
|
||||
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
MlirDialectRegistry cRegistry = wrap(®istry);
|
||||
for (MlirDialectHandle dialect : dependentDialects)
|
||||
mlirDialectHandleInsertDialect(dialect, cRegistry);
|
||||
}
|
||||
|
||||
void signalPassFailure() { Pass::signalPassFailure(); }
|
||||
|
||||
protected:
|
||||
LogicalResult initialize(MLIRContext *ctx) override {
|
||||
if (callbacks.initialize)
|
||||
return unwrap(callbacks.initialize(wrap(ctx), userData));
|
||||
return success();
|
||||
}
|
||||
|
||||
bool canScheduleOn(RegisteredOperationName opName) const override {
|
||||
if (Optional<StringRef> specifiedOpName = getOpName())
|
||||
return opName.getStringRef() == specifiedOpName;
|
||||
return true;
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
callbacks.run(wrap(getOperation()), wrap(this), userData);
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> clonePass() const override {
|
||||
void *clonedUserData = callbacks.clone(userData);
|
||||
return std::make_unique<ExternalPass>(id, name, argument, description,
|
||||
getOpName(), dependentDialects,
|
||||
callbacks, clonedUserData);
|
||||
}
|
||||
|
||||
private:
|
||||
TypeID id;
|
||||
std::string name;
|
||||
std::string argument;
|
||||
std::string description;
|
||||
std::vector<MlirDialectHandle> dependentDialects;
|
||||
MlirExternalPassCallbacks callbacks;
|
||||
void *userData;
|
||||
};
|
||||
} // namespace mlir
|
||||
|
||||
MlirPass mlirCreateExternalPass(MlirTypeID passID, MlirStringRef name,
|
||||
MlirStringRef argument,
|
||||
MlirStringRef description, MlirStringRef opName,
|
||||
intptr_t nDependentDialects,
|
||||
MlirDialectHandle *dependentDialects,
|
||||
MlirExternalPassCallbacks callbacks,
|
||||
void *userData) {
|
||||
return wrap(static_cast<mlir::Pass *>(new mlir::ExternalPass(
|
||||
unwrap(passID), unwrap(name), unwrap(argument), unwrap(description),
|
||||
opName.length > 0 ? Optional<StringRef>(unwrap(opName)) : None,
|
||||
{dependentDialects, static_cast<size_t>(nDependentDialects)}, callbacks,
|
||||
userData)));
|
||||
}
|
||||
|
||||
void mlirExternalPassSignalFailure(MlirExternalPass pass) {
|
||||
unwrap(pass)->signalPassFailure();
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir-c/Support.h"
|
||||
#include "mlir/CAPI/Support.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include <cstring>
|
||||
@@ -19,3 +19,40 @@ bool mlirStringRefEqual(MlirStringRef string, MlirStringRef other) {
|
||||
return llvm::StringRef(string.data, string.length) ==
|
||||
llvm::StringRef(other.data, other.length);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TypeID API.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
MlirTypeID mlirTypeIDCreate(const void *ptr) {
|
||||
assert(reinterpret_cast<uintptr_t>(ptr) % 8 == 0 &&
|
||||
"ptr must be 8 byte aligned");
|
||||
// This is essentially a no-op that returns back `ptr`, but by going through
|
||||
// the `TypeID` functions we can get compiler errors in case the `TypeID`
|
||||
// api/representation changes
|
||||
return wrap(mlir::TypeID::getFromOpaquePointer(ptr));
|
||||
}
|
||||
|
||||
bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2) {
|
||||
return unwrap(typeID1) == unwrap(typeID2);
|
||||
}
|
||||
|
||||
size_t mlirTypeIDHashValue(MlirTypeID typeID) {
|
||||
return hash_value(unwrap(typeID));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TypeIDAllocator API.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
MlirTypeIDAllocator mlirTypeIDAllocatorCreate() {
|
||||
return wrap(new mlir::TypeIDAllocator());
|
||||
}
|
||||
|
||||
void mlirTypeIDAllocatorDestroy(MlirTypeIDAllocator allocator) {
|
||||
delete unwrap(allocator);
|
||||
}
|
||||
|
||||
MlirTypeID mlirTypeIDAllocatorAllocateTypeID(MlirTypeIDAllocator allocator) {
|
||||
return wrap(unwrap(allocator)->allocate());
|
||||
}
|
||||
|
||||
@@ -49,6 +49,7 @@ _add_capi_test_executable(mlir-capi-llvm-test
|
||||
_add_capi_test_executable(mlir-capi-pass-test
|
||||
pass.c
|
||||
LINK_LIBS PRIVATE
|
||||
MLIRCAPIFunc
|
||||
MLIRCAPIIR
|
||||
MLIRCAPIRegistration
|
||||
MLIRCAPITransforms
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
*/
|
||||
|
||||
#include "mlir-c/Pass.h"
|
||||
#include "mlir-c/Dialect/Func.h"
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir-c/Registration.h"
|
||||
#include "mlir-c/Transforms.h"
|
||||
@@ -169,7 +170,9 @@ void testParsePassPipeline() {
|
||||
" func.func(print-op-stats))"));
|
||||
// Expect a failure, we haven't registered the print-op-stats pass yet.
|
||||
if (mlirLogicalResultIsSuccess(status)) {
|
||||
fprintf(stderr, "Unexpected success parsing pipeline without registering the pass\n");
|
||||
fprintf(
|
||||
stderr,
|
||||
"Unexpected success parsing pipeline without registering the pass\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
// Try again after registrating the pass.
|
||||
@@ -180,7 +183,8 @@ void testParsePassPipeline() {
|
||||
" func.func(print-op-stats))"));
|
||||
// Expect a failure, we haven't registered the print-op-stats pass yet.
|
||||
if (mlirLogicalResultIsFailure(status)) {
|
||||
fprintf(stderr, "Unexpected failure parsing pipeline after registering the pass\n");
|
||||
fprintf(stderr,
|
||||
"Unexpected failure parsing pipeline after registering the pass\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
@@ -194,10 +198,328 @@ void testParsePassPipeline() {
|
||||
mlirContextDestroy(ctx);
|
||||
}
|
||||
|
||||
struct TestExternalPassUserData {
|
||||
int constructCallCount;
|
||||
int destructCallCount;
|
||||
int initializeCallCount;
|
||||
int cloneCallCount;
|
||||
int runCallCount;
|
||||
};
|
||||
typedef struct TestExternalPassUserData TestExternalPassUserData;
|
||||
|
||||
void testConstructExternalPass(void *userData) {
|
||||
++((TestExternalPassUserData *)userData)->constructCallCount;
|
||||
}
|
||||
|
||||
void testDestructExternalPass(void *userData) {
|
||||
++((TestExternalPassUserData *)userData)->destructCallCount;
|
||||
}
|
||||
|
||||
MlirLogicalResult testInitializeExternalPass(MlirContext ctx, void *userData) {
|
||||
++((TestExternalPassUserData *)userData)->initializeCallCount;
|
||||
return mlirLogicalResultSuccess();
|
||||
}
|
||||
|
||||
MlirLogicalResult testInitializeFailingExternalPass(MlirContext ctx,
|
||||
void *userData) {
|
||||
++((TestExternalPassUserData *)userData)->initializeCallCount;
|
||||
return mlirLogicalResultFailure();
|
||||
}
|
||||
|
||||
void *testCloneExternalPass(void *userData) {
|
||||
++((TestExternalPassUserData *)userData)->cloneCallCount;
|
||||
return userData;
|
||||
}
|
||||
|
||||
void testRunExternalPass(MlirOperation op, MlirExternalPass pass,
|
||||
void *userData) {
|
||||
++((TestExternalPassUserData *)userData)->runCallCount;
|
||||
}
|
||||
|
||||
void testRunExternalFuncPass(MlirOperation op, MlirExternalPass pass,
|
||||
void *userData) {
|
||||
++((TestExternalPassUserData *)userData)->runCallCount;
|
||||
MlirStringRef opName = mlirIdentifierStr(mlirOperationGetName(op));
|
||||
if (!mlirStringRefEqual(opName,
|
||||
mlirStringRefCreateFromCString("func.func"))) {
|
||||
mlirExternalPassSignalFailure(pass);
|
||||
}
|
||||
}
|
||||
|
||||
void testRunFailingExternalPass(MlirOperation op, MlirExternalPass pass,
|
||||
void *userData) {
|
||||
++((TestExternalPassUserData *)userData)->runCallCount;
|
||||
mlirExternalPassSignalFailure(pass);
|
||||
}
|
||||
|
||||
MlirExternalPassCallbacks makeTestExternalPassCallbacks(
|
||||
MlirLogicalResult (*initializePass)(MlirContext ctx, void *userData),
|
||||
void (*runPass)(MlirOperation op, MlirExternalPass, void *userData)) {
|
||||
return (MlirExternalPassCallbacks){testConstructExternalPass,
|
||||
testDestructExternalPass, initializePass,
|
||||
testCloneExternalPass, runPass};
|
||||
}
|
||||
|
||||
void testExternalPass() {
|
||||
MlirContext ctx = mlirContextCreate();
|
||||
mlirRegisterAllDialects(ctx);
|
||||
|
||||
MlirModule module = mlirModuleCreateParse(
|
||||
ctx,
|
||||
// clang-format off
|
||||
mlirStringRefCreateFromCString(
|
||||
"func @foo(%arg0 : i32) -> i32 { \n"
|
||||
" %res = arith.addi %arg0, %arg0 : i32 \n"
|
||||
" return %res : i32 \n"
|
||||
"}"));
|
||||
// clang-format on
|
||||
if (mlirModuleIsNull(module)) {
|
||||
fprintf(stderr, "Unexpected failure parsing module.\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
MlirStringRef description = mlirStringRefCreateFromCString("");
|
||||
MlirStringRef emptyOpName = mlirStringRefCreateFromCString("");
|
||||
|
||||
MlirTypeIDAllocator typeIDAllocator = mlirTypeIDAllocatorCreate();
|
||||
|
||||
// Run a generic pass
|
||||
{
|
||||
MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
|
||||
MlirStringRef name = mlirStringRefCreateFromCString("TestExternalPass");
|
||||
MlirStringRef argument =
|
||||
mlirStringRefCreateFromCString("test-external-pass");
|
||||
TestExternalPassUserData userData = {0};
|
||||
|
||||
MlirPass externalPass = mlirCreateExternalPass(
|
||||
passID, name, argument, description, emptyOpName, 0, NULL,
|
||||
makeTestExternalPassCallbacks(NULL, testRunExternalPass), &userData);
|
||||
|
||||
if (userData.constructCallCount != 1) {
|
||||
fprintf(stderr, "Expected constructCallCount to be 1\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
MlirPassManager pm = mlirPassManagerCreate(ctx);
|
||||
mlirPassManagerAddOwnedPass(pm, externalPass);
|
||||
MlirLogicalResult success = mlirPassManagerRun(pm, module);
|
||||
if (mlirLogicalResultIsFailure(success)) {
|
||||
fprintf(stderr, "Unexpected failure running external pass.\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
if (userData.runCallCount != 1) {
|
||||
fprintf(stderr, "Expected runCallCount to be 1\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
mlirPassManagerDestroy(pm);
|
||||
|
||||
if (userData.destructCallCount != userData.constructCallCount) {
|
||||
fprintf(stderr, "Expected destructCallCount to be equal to "
|
||||
"constructCallCount\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
}
|
||||
|
||||
// Run a func operation pass
|
||||
{
|
||||
MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
|
||||
MlirStringRef name = mlirStringRefCreateFromCString("TestExternalFuncPass");
|
||||
MlirStringRef argument =
|
||||
mlirStringRefCreateFromCString("test-external-func-pass");
|
||||
TestExternalPassUserData userData = {0};
|
||||
MlirDialectHandle funcHandle = mlirGetDialectHandle__func__();
|
||||
MlirStringRef funcOpName = mlirStringRefCreateFromCString("func.func");
|
||||
|
||||
MlirPass externalPass = mlirCreateExternalPass(
|
||||
passID, name, argument, description, funcOpName, 1, &funcHandle,
|
||||
makeTestExternalPassCallbacks(NULL, testRunExternalFuncPass),
|
||||
&userData);
|
||||
|
||||
if (userData.constructCallCount != 1) {
|
||||
fprintf(stderr, "Expected constructCallCount to be 1\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
MlirPassManager pm = mlirPassManagerCreate(ctx);
|
||||
MlirOpPassManager nestedFuncPm =
|
||||
mlirPassManagerGetNestedUnder(pm, funcOpName);
|
||||
mlirOpPassManagerAddOwnedPass(nestedFuncPm, externalPass);
|
||||
MlirLogicalResult success = mlirPassManagerRun(pm, module);
|
||||
if (mlirLogicalResultIsFailure(success)) {
|
||||
fprintf(stderr, "Unexpected failure running external operation pass.\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
// Since this is a nested pass, it can be cloned and run in parallel
|
||||
if (userData.cloneCallCount != userData.constructCallCount - 1) {
|
||||
fprintf(stderr, "Expected constructCallCount to be 1\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
// The pass should only be run once this there is only one func op
|
||||
if (userData.runCallCount != 1) {
|
||||
fprintf(stderr, "Expected runCallCount to be 1\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
mlirPassManagerDestroy(pm);
|
||||
|
||||
if (userData.destructCallCount != userData.constructCallCount) {
|
||||
fprintf(stderr, "Expected destructCallCount to be equal to "
|
||||
"constructCallCount\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
}
|
||||
|
||||
// Run a pass with `initialize` set
|
||||
{
|
||||
MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
|
||||
MlirStringRef name = mlirStringRefCreateFromCString("TestExternalPass");
|
||||
MlirStringRef argument =
|
||||
mlirStringRefCreateFromCString("test-external-pass");
|
||||
TestExternalPassUserData userData = {0};
|
||||
|
||||
MlirPass externalPass = mlirCreateExternalPass(
|
||||
passID, name, argument, description, emptyOpName, 0, NULL,
|
||||
makeTestExternalPassCallbacks(testInitializeExternalPass,
|
||||
testRunExternalPass),
|
||||
&userData);
|
||||
|
||||
if (userData.constructCallCount != 1) {
|
||||
fprintf(stderr, "Expected constructCallCount to be 1\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
MlirPassManager pm = mlirPassManagerCreate(ctx);
|
||||
mlirPassManagerAddOwnedPass(pm, externalPass);
|
||||
MlirLogicalResult success = mlirPassManagerRun(pm, module);
|
||||
if (mlirLogicalResultIsFailure(success)) {
|
||||
fprintf(stderr, "Unexpected failure running external pass.\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
if (userData.initializeCallCount != 1) {
|
||||
fprintf(stderr, "Expected initializeCallCount to be 1\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
if (userData.runCallCount != 1) {
|
||||
fprintf(stderr, "Expected runCallCount to be 1\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
mlirPassManagerDestroy(pm);
|
||||
|
||||
if (userData.destructCallCount != userData.constructCallCount) {
|
||||
fprintf(stderr, "Expected destructCallCount to be equal to "
|
||||
"constructCallCount\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
}
|
||||
|
||||
// Run a pass that fails during `initialize`
|
||||
{
|
||||
MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
|
||||
MlirStringRef name =
|
||||
mlirStringRefCreateFromCString("TestExternalFailingPass");
|
||||
MlirStringRef argument =
|
||||
mlirStringRefCreateFromCString("test-external-failing-pass");
|
||||
TestExternalPassUserData userData = {0};
|
||||
|
||||
MlirPass externalPass = mlirCreateExternalPass(
|
||||
passID, name, argument, description, emptyOpName, 0, NULL,
|
||||
makeTestExternalPassCallbacks(testInitializeFailingExternalPass,
|
||||
testRunExternalPass),
|
||||
&userData);
|
||||
|
||||
if (userData.constructCallCount != 1) {
|
||||
fprintf(stderr, "Expected constructCallCount to be 1\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
MlirPassManager pm = mlirPassManagerCreate(ctx);
|
||||
mlirPassManagerAddOwnedPass(pm, externalPass);
|
||||
MlirLogicalResult success = mlirPassManagerRun(pm, module);
|
||||
if (mlirLogicalResultIsSuccess(success)) {
|
||||
fprintf(
|
||||
stderr,
|
||||
"Expected failure running pass manager on failing external pass.\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
if (userData.initializeCallCount != 1) {
|
||||
fprintf(stderr, "Expected initializeCallCount to be 1\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
if (userData.runCallCount != 0) {
|
||||
fprintf(stderr, "Expected runCallCount to be 0\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
mlirPassManagerDestroy(pm);
|
||||
|
||||
if (userData.destructCallCount != userData.constructCallCount) {
|
||||
fprintf(stderr, "Expected destructCallCount to be equal to "
|
||||
"constructCallCount\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
}
|
||||
|
||||
// Run a pass that fails during `run`
|
||||
{
|
||||
MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
|
||||
MlirStringRef name =
|
||||
mlirStringRefCreateFromCString("TestExternalFailingPass");
|
||||
MlirStringRef argument =
|
||||
mlirStringRefCreateFromCString("test-external-failing-pass");
|
||||
TestExternalPassUserData userData = {0};
|
||||
|
||||
MlirPass externalPass = mlirCreateExternalPass(
|
||||
passID, name, argument, description, emptyOpName, 0, NULL,
|
||||
makeTestExternalPassCallbacks(NULL, testRunFailingExternalPass),
|
||||
&userData);
|
||||
|
||||
if (userData.constructCallCount != 1) {
|
||||
fprintf(stderr, "Expected constructCallCount to be 1\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
MlirPassManager pm = mlirPassManagerCreate(ctx);
|
||||
mlirPassManagerAddOwnedPass(pm, externalPass);
|
||||
MlirLogicalResult success = mlirPassManagerRun(pm, module);
|
||||
if (mlirLogicalResultIsSuccess(success)) {
|
||||
fprintf(
|
||||
stderr,
|
||||
"Expected failure running pass manager on failing external pass.\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
if (userData.runCallCount != 1) {
|
||||
fprintf(stderr, "Expected runCallCount to be 1\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
mlirPassManagerDestroy(pm);
|
||||
|
||||
if (userData.destructCallCount != userData.constructCallCount) {
|
||||
fprintf(stderr, "Expected destructCallCount to be equal to "
|
||||
"constructCallCount\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
}
|
||||
|
||||
mlirTypeIDAllocatorDestroy(typeIDAllocator);
|
||||
mlirContextDestroy(ctx);
|
||||
}
|
||||
|
||||
int main() {
|
||||
testRunPassOnModule();
|
||||
testRunPassOnNestedModule();
|
||||
testPrintPassPipeline();
|
||||
testParsePassPipeline();
|
||||
testExternalPass();
|
||||
return 0;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user