[mlir][capi] Add external pass creation to MLIR C-API

Adds the ability to create external passes using the C-API. This allows passes
to be written in C or languages that use the C-bindings.

Differential Revision: https://reviews.llvm.org/D121866
This commit is contained in:
Daniel Resnick
2022-03-16 16:31:08 -06:00
parent c69307e5ee
commit 2387fadea3
10 changed files with 557 additions and 30 deletions

View File

@@ -62,7 +62,6 @@ DEFINE_C_API_STRUCT(MlirIdentifier, const void);
DEFINE_C_API_STRUCT(MlirLocation, const void);
DEFINE_C_API_STRUCT(MlirModule, const void);
DEFINE_C_API_STRUCT(MlirType, const void);
DEFINE_C_API_STRUCT(MlirTypeID, const void);
DEFINE_C_API_STRUCT(MlirValue, const void);
#undef DEFINE_C_API_STRUCT
@@ -757,19 +756,6 @@ MLIR_CAPI_EXPORTED bool mlirIdentifierEqual(MlirIdentifier ident,
/// Gets the string value of the identifier.
MLIR_CAPI_EXPORTED MlirStringRef mlirIdentifierStr(MlirIdentifier ident);
//===----------------------------------------------------------------------===//
// TypeID API.
//===----------------------------------------------------------------------===//
/// Checks whether a type id is null.
static inline bool mlirTypeIDIsNull(MlirTypeID typeID) { return !typeID.ptr; }
/// Checks if two type ids are equal.
MLIR_CAPI_EXPORTED bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2);
/// Returns the hash value of the type id.
MLIR_CAPI_EXPORTED size_t mlirTypeIDHashValue(MlirTypeID typeID);
//===----------------------------------------------------------------------===//
// Symbol and SymbolTable API.
//===----------------------------------------------------------------------===//

View File

@@ -15,6 +15,7 @@
#define MLIR_C_PASS_H
#include "mlir-c/IR.h"
#include "mlir-c/Registration.h"
#include "mlir-c/Support.h"
#ifdef __cplusplus
@@ -41,11 +42,16 @@ extern "C" {
typedef struct name name
DEFINE_C_API_STRUCT(MlirPass, void);
DEFINE_C_API_STRUCT(MlirExternalPass, void);
DEFINE_C_API_STRUCT(MlirPassManager, void);
DEFINE_C_API_STRUCT(MlirOpPassManager, void);
#undef DEFINE_C_API_STRUCT
//===----------------------------------------------------------------------===//
// PassManager/OpPassManager APIs.
//===----------------------------------------------------------------------===//
/// Create a new top-level PassManager.
MLIR_CAPI_EXPORTED MlirPassManager mlirPassManagerCreate(MlirContext ctx);
@@ -112,6 +118,55 @@ MLIR_CAPI_EXPORTED void mlirPrintPassPipeline(MlirOpPassManager passManager,
MLIR_CAPI_EXPORTED MlirLogicalResult
mlirParsePassPipeline(MlirOpPassManager passManager, MlirStringRef pipeline);
//===----------------------------------------------------------------------===//
// External Pass API.
//
// This API allows to define passes outside of MLIR, not necessarily in
// C++, and register them with the MLIR pass management infrastructure.
//
//===----------------------------------------------------------------------===//
/// Structure of external `MlirPass` callbacks.
/// All callbacks are required to be set unless otherwise specified.
struct MlirExternalPassCallbacks {
/// This callback is called from the pass is created.
/// This is analogous to a C++ pass constructor.
void (*construct)(void *userData);
/// This callback is called when the pass is destroyed
/// This is analogous to a C++ pass destructor.
void (*destruct)(void *userData);
/// This callback is optional.
/// The callback is called before the pass is run, allowing a chance to
/// initialize any complex state necessary for running the pass.
/// See Pass::initialize(MLIRContext *).
MlirLogicalResult (*initialize)(MlirContext ctx, void *userData);
/// This callback is called when the pass is cloned.
/// See Pass::clonePass().
void *(*clone)(void *userData);
/// This callback is called when the pass is run.
/// See Pass::runOnOperation().
void (*run)(MlirOperation op, MlirExternalPass pass, void *userData);
};
typedef struct MlirExternalPassCallbacks MlirExternalPassCallbacks;
/// Creates an external `MlirPass` that calls the supplied `callbacks` using the
/// supplied `userData`. If `opName` is empty, the pass is a generic operation
/// pass. Otherwise it is an operation pass specific to the specified pass name.
MLIR_CAPI_EXPORTED MlirPass mlirCreateExternalPass(
MlirTypeID passID, MlirStringRef name, MlirStringRef argument,
MlirStringRef description, MlirStringRef opName,
intptr_t nDependentDialects, MlirDialectHandle *dependentDialects,
MlirExternalPassCallbacks callbacks, void *userData);
/// This signals that the pass has failed. This is only valid to call during
/// the `run` callback of `MlirExternalPassCallbacks`.
/// See Pass::signalPassFailure().
MLIR_CAPI_EXPORTED void mlirExternalPassSignalFailure(MlirExternalPass pass);
#ifdef __cplusplus
}
#endif

View File

@@ -50,6 +50,17 @@
extern "C" {
#endif
#define DEFINE_C_API_STRUCT(name, storage) \
struct name { \
storage *ptr; \
}; \
typedef struct name name
DEFINE_C_API_STRUCT(MlirTypeID, const void);
DEFINE_C_API_STRUCT(MlirTypeIDAllocator, void);
#undef DEFINE_C_API_STRUCT
//===----------------------------------------------------------------------===//
// MlirStringRef.
//===----------------------------------------------------------------------===//
@@ -127,6 +138,38 @@ inline static MlirLogicalResult mlirLogicalResultFailure() {
return res;
}
//===----------------------------------------------------------------------===//
// TypeID API.
//===----------------------------------------------------------------------===//
/// `ptr` must be 8 byte aligned and unique to a type valid for the duration of
/// the returned type id's usage
MLIR_CAPI_EXPORTED MlirTypeID mlirTypeIDCreate(const void *ptr);
/// Checks whether a type id is null.
static inline bool mlirTypeIDIsNull(MlirTypeID typeID) { return !typeID.ptr; }
/// Checks if two type ids are equal.
MLIR_CAPI_EXPORTED bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2);
/// Returns the hash value of the type id.
MLIR_CAPI_EXPORTED size_t mlirTypeIDHashValue(MlirTypeID typeID);
//===----------------------------------------------------------------------===//
// TypeIDAllocator API.
//===----------------------------------------------------------------------===//
/// Creates a type id allocator for dynamic type id creation
MLIR_CAPI_EXPORTED MlirTypeIDAllocator mlirTypeIDAllocatorCreate();
/// Deallocates the allocator and all allocated type ids
MLIR_CAPI_EXPORTED void
mlirTypeIDAllocatorDestroy(MlirTypeIDAllocator allocator);
/// Allocates a type id that is valid for the lifetime of the allocator
MLIR_CAPI_EXPORTED MlirTypeID
mlirTypeIDAllocatorAllocateTypeID(MlirTypeIDAllocator allocator);
#ifdef __cplusplus
}
#endif

View File

@@ -34,7 +34,6 @@ DEFINE_C_API_METHODS(MlirIdentifier, mlir::StringAttr)
DEFINE_C_API_METHODS(MlirLocation, mlir::Location)
DEFINE_C_API_METHODS(MlirModule, mlir::ModuleOp)
DEFINE_C_API_METHODS(MlirType, mlir::Type)
DEFINE_C_API_METHODS(MlirTypeID, mlir::TypeID)
DEFINE_C_API_METHODS(MlirValue, mlir::Value)
#endif // MLIR_CAPI_IR_H

View File

@@ -16,7 +16,9 @@
#define MLIR_CAPI_SUPPORT_H
#include "mlir-c/Support.h"
#include "mlir/CAPI/Wrap.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Support/TypeID.h"
#include "llvm/ADT/StringRef.h"
/// Converts a StringRef into its MLIR C API equivalent.
@@ -39,4 +41,7 @@ inline mlir::LogicalResult unwrap(MlirLogicalResult res) {
return mlir::success(mlirLogicalResultIsSuccess(res));
}
DEFINE_C_API_METHODS(MlirTypeID, mlir::TypeID)
DEFINE_C_API_PTR_METHODS(MlirTypeIDAllocator, mlir::TypeIDAllocator)
#endif // MLIR_CAPI_SUPPORT_H

View File

@@ -787,18 +787,6 @@ MlirStringRef mlirIdentifierStr(MlirIdentifier ident) {
return wrap(unwrap(ident).strref());
}
//===----------------------------------------------------------------------===//
// TypeID API.
//===----------------------------------------------------------------------===//
bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2) {
return unwrap(typeID1) == unwrap(typeID2);
}
size_t mlirTypeIDHashValue(MlirTypeID typeID) {
return hash_value(unwrap(typeID));
}
//===----------------------------------------------------------------------===//
// Symbol and SymbolTable API.
//===----------------------------------------------------------------------===//

View File

@@ -77,3 +77,94 @@ MlirLogicalResult mlirParsePassPipeline(MlirOpPassManager passManager,
// stream and redirect to a diagnostic.
return wrap(mlir::parsePassPipeline(unwrap(pipeline), *unwrap(passManager)));
}
//===----------------------------------------------------------------------===//
// External Pass API.
//===----------------------------------------------------------------------===//
namespace mlir {
class ExternalPass;
} // namespace mlir
DEFINE_C_API_PTR_METHODS(MlirExternalPass, mlir::ExternalPass)
namespace mlir {
/// This pass class wraps external passes defined in other languages using the
/// MLIR C-interface
class ExternalPass : public Pass {
public:
ExternalPass(TypeID passID, StringRef name, StringRef argument,
StringRef description, Optional<StringRef> opName,
ArrayRef<MlirDialectHandle> dependentDialects,
MlirExternalPassCallbacks callbacks, void *userData)
: Pass(passID, opName), id(passID), name(name), argument(argument),
description(description), dependentDialects(dependentDialects),
callbacks(callbacks), userData(userData) {
callbacks.construct(userData);
}
~ExternalPass() override { callbacks.destruct(userData); }
StringRef getName() const override { return name; }
StringRef getArgument() const override { return argument; }
StringRef getDescription() const override { return description; }
void getDependentDialects(DialectRegistry &registry) const override {
MlirDialectRegistry cRegistry = wrap(&registry);
for (MlirDialectHandle dialect : dependentDialects)
mlirDialectHandleInsertDialect(dialect, cRegistry);
}
void signalPassFailure() { Pass::signalPassFailure(); }
protected:
LogicalResult initialize(MLIRContext *ctx) override {
if (callbacks.initialize)
return unwrap(callbacks.initialize(wrap(ctx), userData));
return success();
}
bool canScheduleOn(RegisteredOperationName opName) const override {
if (Optional<StringRef> specifiedOpName = getOpName())
return opName.getStringRef() == specifiedOpName;
return true;
}
void runOnOperation() override {
callbacks.run(wrap(getOperation()), wrap(this), userData);
}
std::unique_ptr<Pass> clonePass() const override {
void *clonedUserData = callbacks.clone(userData);
return std::make_unique<ExternalPass>(id, name, argument, description,
getOpName(), dependentDialects,
callbacks, clonedUserData);
}
private:
TypeID id;
std::string name;
std::string argument;
std::string description;
std::vector<MlirDialectHandle> dependentDialects;
MlirExternalPassCallbacks callbacks;
void *userData;
};
} // namespace mlir
MlirPass mlirCreateExternalPass(MlirTypeID passID, MlirStringRef name,
MlirStringRef argument,
MlirStringRef description, MlirStringRef opName,
intptr_t nDependentDialects,
MlirDialectHandle *dependentDialects,
MlirExternalPassCallbacks callbacks,
void *userData) {
return wrap(static_cast<mlir::Pass *>(new mlir::ExternalPass(
unwrap(passID), unwrap(name), unwrap(argument), unwrap(description),
opName.length > 0 ? Optional<StringRef>(unwrap(opName)) : None,
{dependentDialects, static_cast<size_t>(nDependentDialects)}, callbacks,
userData)));
}
void mlirExternalPassSignalFailure(MlirExternalPass pass) {
unwrap(pass)->signalPassFailure();
}

View File

@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
#include "mlir-c/Support.h"
#include "mlir/CAPI/Support.h"
#include "llvm/ADT/StringRef.h"
#include <cstring>
@@ -19,3 +19,40 @@ bool mlirStringRefEqual(MlirStringRef string, MlirStringRef other) {
return llvm::StringRef(string.data, string.length) ==
llvm::StringRef(other.data, other.length);
}
//===----------------------------------------------------------------------===//
// TypeID API.
//===----------------------------------------------------------------------===//
MlirTypeID mlirTypeIDCreate(const void *ptr) {
assert(reinterpret_cast<uintptr_t>(ptr) % 8 == 0 &&
"ptr must be 8 byte aligned");
// This is essentially a no-op that returns back `ptr`, but by going through
// the `TypeID` functions we can get compiler errors in case the `TypeID`
// api/representation changes
return wrap(mlir::TypeID::getFromOpaquePointer(ptr));
}
bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2) {
return unwrap(typeID1) == unwrap(typeID2);
}
size_t mlirTypeIDHashValue(MlirTypeID typeID) {
return hash_value(unwrap(typeID));
}
//===----------------------------------------------------------------------===//
// TypeIDAllocator API.
//===----------------------------------------------------------------------===//
MlirTypeIDAllocator mlirTypeIDAllocatorCreate() {
return wrap(new mlir::TypeIDAllocator());
}
void mlirTypeIDAllocatorDestroy(MlirTypeIDAllocator allocator) {
delete unwrap(allocator);
}
MlirTypeID mlirTypeIDAllocatorAllocateTypeID(MlirTypeIDAllocator allocator) {
return wrap(unwrap(allocator)->allocate());
}

View File

@@ -49,6 +49,7 @@ _add_capi_test_executable(mlir-capi-llvm-test
_add_capi_test_executable(mlir-capi-pass-test
pass.c
LINK_LIBS PRIVATE
MLIRCAPIFunc
MLIRCAPIIR
MLIRCAPIRegistration
MLIRCAPITransforms

View File

@@ -11,6 +11,7 @@
*/
#include "mlir-c/Pass.h"
#include "mlir-c/Dialect/Func.h"
#include "mlir-c/IR.h"
#include "mlir-c/Registration.h"
#include "mlir-c/Transforms.h"
@@ -169,7 +170,9 @@ void testParsePassPipeline() {
" func.func(print-op-stats))"));
// Expect a failure, we haven't registered the print-op-stats pass yet.
if (mlirLogicalResultIsSuccess(status)) {
fprintf(stderr, "Unexpected success parsing pipeline without registering the pass\n");
fprintf(
stderr,
"Unexpected success parsing pipeline without registering the pass\n");
exit(EXIT_FAILURE);
}
// Try again after registrating the pass.
@@ -180,7 +183,8 @@ void testParsePassPipeline() {
" func.func(print-op-stats))"));
// Expect a failure, we haven't registered the print-op-stats pass yet.
if (mlirLogicalResultIsFailure(status)) {
fprintf(stderr, "Unexpected failure parsing pipeline after registering the pass\n");
fprintf(stderr,
"Unexpected failure parsing pipeline after registering the pass\n");
exit(EXIT_FAILURE);
}
@@ -194,10 +198,328 @@ void testParsePassPipeline() {
mlirContextDestroy(ctx);
}
struct TestExternalPassUserData {
int constructCallCount;
int destructCallCount;
int initializeCallCount;
int cloneCallCount;
int runCallCount;
};
typedef struct TestExternalPassUserData TestExternalPassUserData;
void testConstructExternalPass(void *userData) {
++((TestExternalPassUserData *)userData)->constructCallCount;
}
void testDestructExternalPass(void *userData) {
++((TestExternalPassUserData *)userData)->destructCallCount;
}
MlirLogicalResult testInitializeExternalPass(MlirContext ctx, void *userData) {
++((TestExternalPassUserData *)userData)->initializeCallCount;
return mlirLogicalResultSuccess();
}
MlirLogicalResult testInitializeFailingExternalPass(MlirContext ctx,
void *userData) {
++((TestExternalPassUserData *)userData)->initializeCallCount;
return mlirLogicalResultFailure();
}
void *testCloneExternalPass(void *userData) {
++((TestExternalPassUserData *)userData)->cloneCallCount;
return userData;
}
void testRunExternalPass(MlirOperation op, MlirExternalPass pass,
void *userData) {
++((TestExternalPassUserData *)userData)->runCallCount;
}
void testRunExternalFuncPass(MlirOperation op, MlirExternalPass pass,
void *userData) {
++((TestExternalPassUserData *)userData)->runCallCount;
MlirStringRef opName = mlirIdentifierStr(mlirOperationGetName(op));
if (!mlirStringRefEqual(opName,
mlirStringRefCreateFromCString("func.func"))) {
mlirExternalPassSignalFailure(pass);
}
}
void testRunFailingExternalPass(MlirOperation op, MlirExternalPass pass,
void *userData) {
++((TestExternalPassUserData *)userData)->runCallCount;
mlirExternalPassSignalFailure(pass);
}
MlirExternalPassCallbacks makeTestExternalPassCallbacks(
MlirLogicalResult (*initializePass)(MlirContext ctx, void *userData),
void (*runPass)(MlirOperation op, MlirExternalPass, void *userData)) {
return (MlirExternalPassCallbacks){testConstructExternalPass,
testDestructExternalPass, initializePass,
testCloneExternalPass, runPass};
}
void testExternalPass() {
MlirContext ctx = mlirContextCreate();
mlirRegisterAllDialects(ctx);
MlirModule module = mlirModuleCreateParse(
ctx,
// clang-format off
mlirStringRefCreateFromCString(
"func @foo(%arg0 : i32) -> i32 { \n"
" %res = arith.addi %arg0, %arg0 : i32 \n"
" return %res : i32 \n"
"}"));
// clang-format on
if (mlirModuleIsNull(module)) {
fprintf(stderr, "Unexpected failure parsing module.\n");
exit(EXIT_FAILURE);
}
MlirStringRef description = mlirStringRefCreateFromCString("");
MlirStringRef emptyOpName = mlirStringRefCreateFromCString("");
MlirTypeIDAllocator typeIDAllocator = mlirTypeIDAllocatorCreate();
// Run a generic pass
{
MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
MlirStringRef name = mlirStringRefCreateFromCString("TestExternalPass");
MlirStringRef argument =
mlirStringRefCreateFromCString("test-external-pass");
TestExternalPassUserData userData = {0};
MlirPass externalPass = mlirCreateExternalPass(
passID, name, argument, description, emptyOpName, 0, NULL,
makeTestExternalPassCallbacks(NULL, testRunExternalPass), &userData);
if (userData.constructCallCount != 1) {
fprintf(stderr, "Expected constructCallCount to be 1\n");
exit(EXIT_FAILURE);
}
MlirPassManager pm = mlirPassManagerCreate(ctx);
mlirPassManagerAddOwnedPass(pm, externalPass);
MlirLogicalResult success = mlirPassManagerRun(pm, module);
if (mlirLogicalResultIsFailure(success)) {
fprintf(stderr, "Unexpected failure running external pass.\n");
exit(EXIT_FAILURE);
}
if (userData.runCallCount != 1) {
fprintf(stderr, "Expected runCallCount to be 1\n");
exit(EXIT_FAILURE);
}
mlirPassManagerDestroy(pm);
if (userData.destructCallCount != userData.constructCallCount) {
fprintf(stderr, "Expected destructCallCount to be equal to "
"constructCallCount\n");
exit(EXIT_FAILURE);
}
}
// Run a func operation pass
{
MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
MlirStringRef name = mlirStringRefCreateFromCString("TestExternalFuncPass");
MlirStringRef argument =
mlirStringRefCreateFromCString("test-external-func-pass");
TestExternalPassUserData userData = {0};
MlirDialectHandle funcHandle = mlirGetDialectHandle__func__();
MlirStringRef funcOpName = mlirStringRefCreateFromCString("func.func");
MlirPass externalPass = mlirCreateExternalPass(
passID, name, argument, description, funcOpName, 1, &funcHandle,
makeTestExternalPassCallbacks(NULL, testRunExternalFuncPass),
&userData);
if (userData.constructCallCount != 1) {
fprintf(stderr, "Expected constructCallCount to be 1\n");
exit(EXIT_FAILURE);
}
MlirPassManager pm = mlirPassManagerCreate(ctx);
MlirOpPassManager nestedFuncPm =
mlirPassManagerGetNestedUnder(pm, funcOpName);
mlirOpPassManagerAddOwnedPass(nestedFuncPm, externalPass);
MlirLogicalResult success = mlirPassManagerRun(pm, module);
if (mlirLogicalResultIsFailure(success)) {
fprintf(stderr, "Unexpected failure running external operation pass.\n");
exit(EXIT_FAILURE);
}
// Since this is a nested pass, it can be cloned and run in parallel
if (userData.cloneCallCount != userData.constructCallCount - 1) {
fprintf(stderr, "Expected constructCallCount to be 1\n");
exit(EXIT_FAILURE);
}
// The pass should only be run once this there is only one func op
if (userData.runCallCount != 1) {
fprintf(stderr, "Expected runCallCount to be 1\n");
exit(EXIT_FAILURE);
}
mlirPassManagerDestroy(pm);
if (userData.destructCallCount != userData.constructCallCount) {
fprintf(stderr, "Expected destructCallCount to be equal to "
"constructCallCount\n");
exit(EXIT_FAILURE);
}
}
// Run a pass with `initialize` set
{
MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
MlirStringRef name = mlirStringRefCreateFromCString("TestExternalPass");
MlirStringRef argument =
mlirStringRefCreateFromCString("test-external-pass");
TestExternalPassUserData userData = {0};
MlirPass externalPass = mlirCreateExternalPass(
passID, name, argument, description, emptyOpName, 0, NULL,
makeTestExternalPassCallbacks(testInitializeExternalPass,
testRunExternalPass),
&userData);
if (userData.constructCallCount != 1) {
fprintf(stderr, "Expected constructCallCount to be 1\n");
exit(EXIT_FAILURE);
}
MlirPassManager pm = mlirPassManagerCreate(ctx);
mlirPassManagerAddOwnedPass(pm, externalPass);
MlirLogicalResult success = mlirPassManagerRun(pm, module);
if (mlirLogicalResultIsFailure(success)) {
fprintf(stderr, "Unexpected failure running external pass.\n");
exit(EXIT_FAILURE);
}
if (userData.initializeCallCount != 1) {
fprintf(stderr, "Expected initializeCallCount to be 1\n");
exit(EXIT_FAILURE);
}
if (userData.runCallCount != 1) {
fprintf(stderr, "Expected runCallCount to be 1\n");
exit(EXIT_FAILURE);
}
mlirPassManagerDestroy(pm);
if (userData.destructCallCount != userData.constructCallCount) {
fprintf(stderr, "Expected destructCallCount to be equal to "
"constructCallCount\n");
exit(EXIT_FAILURE);
}
}
// Run a pass that fails during `initialize`
{
MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
MlirStringRef name =
mlirStringRefCreateFromCString("TestExternalFailingPass");
MlirStringRef argument =
mlirStringRefCreateFromCString("test-external-failing-pass");
TestExternalPassUserData userData = {0};
MlirPass externalPass = mlirCreateExternalPass(
passID, name, argument, description, emptyOpName, 0, NULL,
makeTestExternalPassCallbacks(testInitializeFailingExternalPass,
testRunExternalPass),
&userData);
if (userData.constructCallCount != 1) {
fprintf(stderr, "Expected constructCallCount to be 1\n");
exit(EXIT_FAILURE);
}
MlirPassManager pm = mlirPassManagerCreate(ctx);
mlirPassManagerAddOwnedPass(pm, externalPass);
MlirLogicalResult success = mlirPassManagerRun(pm, module);
if (mlirLogicalResultIsSuccess(success)) {
fprintf(
stderr,
"Expected failure running pass manager on failing external pass.\n");
exit(EXIT_FAILURE);
}
if (userData.initializeCallCount != 1) {
fprintf(stderr, "Expected initializeCallCount to be 1\n");
exit(EXIT_FAILURE);
}
if (userData.runCallCount != 0) {
fprintf(stderr, "Expected runCallCount to be 0\n");
exit(EXIT_FAILURE);
}
mlirPassManagerDestroy(pm);
if (userData.destructCallCount != userData.constructCallCount) {
fprintf(stderr, "Expected destructCallCount to be equal to "
"constructCallCount\n");
exit(EXIT_FAILURE);
}
}
// Run a pass that fails during `run`
{
MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
MlirStringRef name =
mlirStringRefCreateFromCString("TestExternalFailingPass");
MlirStringRef argument =
mlirStringRefCreateFromCString("test-external-failing-pass");
TestExternalPassUserData userData = {0};
MlirPass externalPass = mlirCreateExternalPass(
passID, name, argument, description, emptyOpName, 0, NULL,
makeTestExternalPassCallbacks(NULL, testRunFailingExternalPass),
&userData);
if (userData.constructCallCount != 1) {
fprintf(stderr, "Expected constructCallCount to be 1\n");
exit(EXIT_FAILURE);
}
MlirPassManager pm = mlirPassManagerCreate(ctx);
mlirPassManagerAddOwnedPass(pm, externalPass);
MlirLogicalResult success = mlirPassManagerRun(pm, module);
if (mlirLogicalResultIsSuccess(success)) {
fprintf(
stderr,
"Expected failure running pass manager on failing external pass.\n");
exit(EXIT_FAILURE);
}
if (userData.runCallCount != 1) {
fprintf(stderr, "Expected runCallCount to be 1\n");
exit(EXIT_FAILURE);
}
mlirPassManagerDestroy(pm);
if (userData.destructCallCount != userData.constructCallCount) {
fprintf(stderr, "Expected destructCallCount to be equal to "
"constructCallCount\n");
exit(EXIT_FAILURE);
}
}
mlirTypeIDAllocatorDestroy(typeIDAllocator);
mlirContextDestroy(ctx);
}
int main() {
testRunPassOnModule();
testRunPassOnNestedModule();
testPrintPassPipeline();
testParsePassPipeline();
testExternalPass();
return 0;
}