When using the `enable_ir_printing` API from Python, it invokes IR printing with default args, printing the IR before each pass and printing IR after pass only if there have been changes. This PR attempts to align the `enable_ir_printing` API with the documentation
204 lines
7.7 KiB
C++
204 lines
7.7 KiB
C++
//===- Pass.cpp - C Interface for General Pass Management APIs ------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir-c/Pass.h"
|
|
|
|
#include "mlir/CAPI/IR.h"
|
|
#include "mlir/CAPI/Pass.h"
|
|
#include "mlir/CAPI/Support.h"
|
|
#include "mlir/CAPI/Utils.h"
|
|
#include "mlir/Pass/PassManager.h"
|
|
#include <optional>
|
|
|
|
using namespace mlir;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// PassManager/OpPassManager APIs.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
MlirPassManager mlirPassManagerCreate(MlirContext ctx) {
|
|
return wrap(new PassManager(unwrap(ctx)));
|
|
}
|
|
|
|
MlirPassManager mlirPassManagerCreateOnOperation(MlirContext ctx,
|
|
MlirStringRef anchorOp) {
|
|
return wrap(new PassManager(unwrap(ctx), unwrap(anchorOp)));
|
|
}
|
|
|
|
void mlirPassManagerDestroy(MlirPassManager passManager) {
|
|
delete unwrap(passManager);
|
|
}
|
|
|
|
MlirOpPassManager
|
|
mlirPassManagerGetAsOpPassManager(MlirPassManager passManager) {
|
|
return wrap(static_cast<OpPassManager *>(unwrap(passManager)));
|
|
}
|
|
|
|
MlirLogicalResult mlirPassManagerRunOnOp(MlirPassManager passManager,
|
|
MlirOperation op) {
|
|
return wrap(unwrap(passManager)->run(unwrap(op)));
|
|
}
|
|
|
|
void mlirPassManagerEnableIRPrinting(MlirPassManager passManager,
|
|
bool printBeforeAll, bool printAfterAll,
|
|
bool printModuleScope,
|
|
bool printAfterOnlyOnChange,
|
|
bool printAfterOnlyOnFailure) {
|
|
auto shouldPrintBeforePass = [printBeforeAll](Pass *, Operation *) {
|
|
return printBeforeAll;
|
|
};
|
|
auto shouldPrintAfterPass = [printAfterAll](Pass *, Operation *) {
|
|
return printAfterAll;
|
|
};
|
|
return unwrap(passManager)
|
|
->enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass,
|
|
printModuleScope, printAfterOnlyOnChange,
|
|
printAfterOnlyOnFailure);
|
|
}
|
|
|
|
void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable) {
|
|
unwrap(passManager)->enableVerifier(enable);
|
|
}
|
|
|
|
MlirOpPassManager mlirPassManagerGetNestedUnder(MlirPassManager passManager,
|
|
MlirStringRef operationName) {
|
|
return wrap(&unwrap(passManager)->nest(unwrap(operationName)));
|
|
}
|
|
|
|
MlirOpPassManager mlirOpPassManagerGetNestedUnder(MlirOpPassManager passManager,
|
|
MlirStringRef operationName) {
|
|
return wrap(&unwrap(passManager)->nest(unwrap(operationName)));
|
|
}
|
|
|
|
void mlirPassManagerAddOwnedPass(MlirPassManager passManager, MlirPass pass) {
|
|
unwrap(passManager)->addPass(std::unique_ptr<Pass>(unwrap(pass)));
|
|
}
|
|
|
|
void mlirOpPassManagerAddOwnedPass(MlirOpPassManager passManager,
|
|
MlirPass pass) {
|
|
unwrap(passManager)->addPass(std::unique_ptr<Pass>(unwrap(pass)));
|
|
}
|
|
|
|
MlirLogicalResult mlirOpPassManagerAddPipeline(MlirOpPassManager passManager,
|
|
MlirStringRef pipelineElements,
|
|
MlirStringCallback callback,
|
|
void *userData) {
|
|
detail::CallbackOstream stream(callback, userData);
|
|
return wrap(parsePassPipeline(unwrap(pipelineElements), *unwrap(passManager),
|
|
stream));
|
|
}
|
|
|
|
void mlirPrintPassPipeline(MlirOpPassManager passManager,
|
|
MlirStringCallback callback, void *userData) {
|
|
detail::CallbackOstream stream(callback, userData);
|
|
unwrap(passManager)->printAsTextualPipeline(stream);
|
|
}
|
|
|
|
MlirLogicalResult mlirParsePassPipeline(MlirOpPassManager passManager,
|
|
MlirStringRef pipeline,
|
|
MlirStringCallback callback,
|
|
void *userData) {
|
|
detail::CallbackOstream stream(callback, userData);
|
|
FailureOr<OpPassManager> pm = parsePassPipeline(unwrap(pipeline), stream);
|
|
if (succeeded(pm))
|
|
*unwrap(passManager) = std::move(*pm);
|
|
return wrap(pm);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// External Pass API.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace mlir {
|
|
class ExternalPass;
|
|
} // namespace mlir
|
|
DEFINE_C_API_PTR_METHODS(MlirExternalPass, mlir::ExternalPass)
|
|
|
|
namespace mlir {
|
|
/// This pass class wraps external passes defined in other languages using the
|
|
/// MLIR C-interface
|
|
class ExternalPass : public Pass {
|
|
public:
|
|
ExternalPass(TypeID passID, StringRef name, StringRef argument,
|
|
StringRef description, std::optional<StringRef> opName,
|
|
ArrayRef<MlirDialectHandle> dependentDialects,
|
|
MlirExternalPassCallbacks callbacks, void *userData)
|
|
: Pass(passID, opName), id(passID), name(name), argument(argument),
|
|
description(description), dependentDialects(dependentDialects),
|
|
callbacks(callbacks), userData(userData) {
|
|
callbacks.construct(userData);
|
|
}
|
|
|
|
~ExternalPass() override { callbacks.destruct(userData); }
|
|
|
|
StringRef getName() const override { return name; }
|
|
StringRef getArgument() const override { return argument; }
|
|
StringRef getDescription() const override { return description; }
|
|
|
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
|
MlirDialectRegistry cRegistry = wrap(®istry);
|
|
for (MlirDialectHandle dialect : dependentDialects)
|
|
mlirDialectHandleInsertDialect(dialect, cRegistry);
|
|
}
|
|
|
|
void signalPassFailure() { Pass::signalPassFailure(); }
|
|
|
|
protected:
|
|
LogicalResult initialize(MLIRContext *ctx) override {
|
|
if (callbacks.initialize)
|
|
return unwrap(callbacks.initialize(wrap(ctx), userData));
|
|
return success();
|
|
}
|
|
|
|
bool canScheduleOn(RegisteredOperationName opName) const override {
|
|
if (std::optional<StringRef> specifiedOpName = getOpName())
|
|
return opName.getStringRef() == specifiedOpName;
|
|
return true;
|
|
}
|
|
|
|
void runOnOperation() override {
|
|
callbacks.run(wrap(getOperation()), wrap(this), userData);
|
|
}
|
|
|
|
std::unique_ptr<Pass> clonePass() const override {
|
|
void *clonedUserData = callbacks.clone(userData);
|
|
return std::make_unique<ExternalPass>(id, name, argument, description,
|
|
getOpName(), dependentDialects,
|
|
callbacks, clonedUserData);
|
|
}
|
|
|
|
private:
|
|
TypeID id;
|
|
std::string name;
|
|
std::string argument;
|
|
std::string description;
|
|
std::vector<MlirDialectHandle> dependentDialects;
|
|
MlirExternalPassCallbacks callbacks;
|
|
void *userData;
|
|
};
|
|
} // namespace mlir
|
|
|
|
MlirPass mlirCreateExternalPass(MlirTypeID passID, MlirStringRef name,
|
|
MlirStringRef argument,
|
|
MlirStringRef description, MlirStringRef opName,
|
|
intptr_t nDependentDialects,
|
|
MlirDialectHandle *dependentDialects,
|
|
MlirExternalPassCallbacks callbacks,
|
|
void *userData) {
|
|
return wrap(static_cast<mlir::Pass *>(new mlir::ExternalPass(
|
|
unwrap(passID), unwrap(name), unwrap(argument), unwrap(description),
|
|
opName.length > 0 ? std::optional<StringRef>(unwrap(opName))
|
|
: std::nullopt,
|
|
{dependentDialects, static_cast<size_t>(nDependentDialects)}, callbacks,
|
|
userData)));
|
|
}
|
|
|
|
void mlirExternalPassSignalFailure(MlirExternalPass pass) {
|
|
unwrap(pass)->signalPassFailure();
|
|
}
|