Refactor the mlir-opt command line options related to debugging in a helper

This makes it reusable across various tooling and reduces the amount of
boilerplate needed.

Differential Revision: https://reviews.llvm.org/D144818
This commit is contained in:
Mehdi Amini
2023-02-26 01:01:18 -05:00
parent b0528a53ea
commit cca510640b
5 changed files with 231 additions and 122 deletions

View File

@@ -0,0 +1,93 @@
//===- CLOptionsSetup.h - Helpers to setup debug CL options -----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DEBUG_CLOPTIONSSETUP_H
#define MLIR_DEBUG_CLOPTIONSSETUP_H
#include "mlir/Debug/BreakpointManagers/FileLineColLocBreakpointManager.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/StringRef.h"
#include <memory>
namespace mlir {
class MLIRContext;
namespace tracing {
class BreakpointManager;
class DebugConfig {
public:
/// Register the options as global LLVM command line options.
static void registerCLOptions();
/// Create a new config with the default set from the CL options.
static DebugConfig createFromCLOptions();
///
/// Options.
///
/// Enable the Debugger action hook: it makes a debugger (like gdb or lldb)
/// able to intercept MLIR Actions.
void enableDebuggerActionHook(bool enabled = true) {
enableDebuggerActionHookFlag = enabled;
}
/// Return true if the debugger action hook is enabled.
bool isDebuggerActionHookEnabled() const {
return enableDebuggerActionHookFlag;
}
/// Set the filename to use for logging actions, use "-" for stdout.
DebugConfig &logActionsTo(StringRef filename) {
logActionsToFlag = filename;
return *this;
}
/// Get the filename to use for logging actions.
StringRef getLogActionsTo() const { return logActionsToFlag; }
/// Set a location breakpoint manager to filter out action logging based on
/// the attached IR location in the Action context. Ownership stays with the
/// caller.
void addLogActionLocFilter(tracing::BreakpointManager *breakpointManager) {
logActionLocationFilter.push_back(breakpointManager);
}
/// Get the location breakpoint managers to use to filter out action logging.
ArrayRef<tracing::BreakpointManager *> getLogActionsLocFilters() const {
return logActionLocationFilter;
}
protected:
/// Enable the Debugger action hook: a debugger (like gdb or lldb) can
/// intercept MLIR Actions.
bool enableDebuggerActionHookFlag = false;
/// Log action execution to the given file (or "-" for stdout)
std::string logActionsToFlag;
/// Location Breakpoints to filter the action logging.
std::vector<tracing::BreakpointManager *> logActionLocationFilter;
};
/// This is a RAII class that installs the debug handlers on the context
/// based on the provided configuration.
class InstallDebugHandler {
public:
InstallDebugHandler(MLIRContext &context, const DebugConfig &config);
~InstallDebugHandler();
private:
class Impl;
std::unique_ptr<Impl> impl;
};
} // namespace tracing
} // namespace mlir
#endif // MLIR_DEBUG_CLOPTIONSSETUP_H

View File

@@ -13,7 +13,7 @@
#ifndef MLIR_TOOLS_MLIROPT_MLIROPTMAIN_H
#define MLIR_TOOLS_MLIROPT_MLIROPTMAIN_H
#include "mlir/Debug/BreakpointManagers/FileLineColLocBreakpointManager.h"
#include "mlir/Debug/CLOptionsSetup.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/StringRef.h"
@@ -30,9 +30,6 @@ namespace mlir {
class DialectRegistry;
class PassPipelineCLParser;
class PassManager;
namespace tracing {
class FileLineColLocBreakpointManager;
}
/// Configuration options for the mlir-opt tool.
/// This is intended to help building tools like mlir-opt by collecting the
@@ -64,6 +61,14 @@ public:
return allowUnregisteredDialectsFlag;
}
/// Set the debug configuration to use.
MlirOptMainConfig &setDebugConfig(tracing::DebugConfig config) {
debugConfig = std::move(config);
return *this;
}
tracing::DebugConfig &getDebugConfig() { return debugConfig; }
const tracing::DebugConfig &getDebugConfig() const { return debugConfig; }
/// Print the pass-pipeline as text before executing.
MlirOptMainConfig &dumpPassPipeline(bool dump) {
dumpPassPipelineFlag = dump;
@@ -78,17 +83,6 @@ public:
}
bool shouldEmitBytecode() const { return emitBytecodeFlag; }
/// Enable the debugger action hook: it makes the debugger able to intercept
/// MLIR Actions.
void enableDebuggerActionHook(bool enabled = true) {
enableDebuggerActionHookFlag = enabled;
}
/// Return true if the Debugger action hook is enabled.
bool isDebuggerActionHookEnabled() const {
return enableDebuggerActionHookFlag;
}
/// Set the IRDL file to load before processing the input.
MlirOptMainConfig &setIrdlFile(StringRef file) {
irdlFileFlag = file;
@@ -96,26 +90,6 @@ public:
}
StringRef getIrdlFile() const { return irdlFileFlag; }
/// Set the filename to use for logging actions, use "-" for stdout.
MlirOptMainConfig &logActionsTo(StringRef filename) {
logActionsToFlag = filename;
return *this;
}
/// Get the filename to use for logging actions.
StringRef getLogActionsTo() const { return logActionsToFlag; }
/// Set a location breakpoint manager to filter out action logging based on
/// the attached IR location in the Action context. Ownership stays with the
/// caller.
void addLogActionLocFilter(tracing::BreakpointManager *breakpointManager) {
logActionLocationFilter.push_back(breakpointManager);
}
/// Get the location breakpoint managers to use to filter out action logging.
ArrayRef<tracing::BreakpointManager *> getLogActionsLocFilters() const {
return logActionLocationFilter;
}
/// Set the callback to populate the pass manager.
MlirOptMainConfig &
setPassPipelineSetupFn(std::function<LogicalResult(PassManager &)> callback) {
@@ -185,6 +159,9 @@ protected:
/// general.
bool allowUnregisteredDialectsFlag = false;
/// Configuration for the debugging hooks.
tracing::DebugConfig debugConfig;
/// Print the pipeline that will be run.
bool dumpPassPipelineFlag = false;
@@ -197,9 +174,6 @@ protected:
/// IRDL file to register before processing the input.
std::string irdlFileFlag = "";
/// Log action execution to the given file (or "-" for stdout)
std::string logActionsToFlag;
/// Location Breakpoints to filter the action logging.
std::vector<tracing::BreakpointManager *> logActionLocationFilter;

View File

@@ -0,0 +1,120 @@
//===- CLOptionsSetup.cpp - Helpers to setup debug CL options ---*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Debug/CLOptionsSetup.h"
#include "mlir/Debug/Counter.h"
#include "mlir/Debug/DebuggerDebugExecutionContextHook.h"
#include "mlir/Debug/ExecutionContext.h"
#include "mlir/Debug/Observers/ActionLogging.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Support/FileUtilities.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/ToolOutputFile.h"
using namespace mlir;
using namespace mlir::tracing;
using namespace llvm;
namespace {
struct DebugConfigCLOptions : public DebugConfig {
DebugConfigCLOptions() {
static cl::opt<std::string, /*ExternalStorage=*/true> logActionsTo{
"log-actions-to",
cl::desc("Log action execution to a file, or stderr if "
" '-' is passed"),
cl::location(logActionsToFlag)};
static cl::list<std::string> logActionLocationFilter(
"log-mlir-actions-filter",
cl::desc(
"Comma separated list of locations to filter actions from logging"),
cl::CommaSeparated,
cl::cb<void, std::string>([&](const std::string &location) {
static bool register_once = [&] {
addLogActionLocFilter(&locBreakpointManager);
return true;
}();
(void)register_once;
static std::vector<std::string> locations;
locations.push_back(location);
StringRef locStr = locations.back();
// Parse the individual location filters and set the breakpoints.
auto diag = [](Twine msg) { llvm::errs() << msg << "\n"; };
auto locBreakpoint =
tracing::FileLineColLocBreakpoint::parseFromString(locStr, diag);
if (failed(locBreakpoint)) {
llvm::errs() << "Invalid location filter: " << locStr << "\n";
exit(1);
}
auto [file, line, col] = *locBreakpoint;
locBreakpointManager.addBreakpoint(file, line, col);
}));
}
tracing::FileLineColLocBreakpointManager locBreakpointManager;
};
} // namespace
static ManagedStatic<DebugConfigCLOptions> clOptionsConfig;
void DebugConfig::registerCLOptions() { *clOptionsConfig; }
DebugConfig DebugConfig::createFromCLOptions() { return *clOptionsConfig; }
class InstallDebugHandler::Impl {
public:
Impl(MLIRContext &context, const DebugConfig &config) {
if (config.getLogActionsTo().empty() &&
!config.isDebuggerActionHookEnabled()) {
if (tracing::DebugCounter::isActivated())
context.registerActionHandler(tracing::DebugCounter());
return;
}
errs() << "ExecutionContext registered on the context";
if (tracing::DebugCounter::isActivated())
emitError(UnknownLoc::get(&context),
"Debug counters are incompatible with --log-actions-to and "
"--mlir-enable-debugger-hook options and are disabled");
if (!config.getLogActionsTo().empty()) {
std::string errorMessage;
logActionsFile = openOutputFile(config.getLogActionsTo(), &errorMessage);
if (!logActionsFile) {
emitError(UnknownLoc::get(&context),
"Opening file for --log-actions-to failed: ")
<< errorMessage << "\n";
return;
}
logActionsFile->keep();
raw_fd_ostream &logActionsStream = logActionsFile->os();
actionLogger = std::make_unique<tracing::ActionLogger>(logActionsStream);
for (const auto *locationBreakpoint : config.getLogActionsLocFilters())
actionLogger->addBreakpointManager(locationBreakpoint);
executionContext.registerObserver(actionLogger.get());
}
if (config.isDebuggerActionHookEnabled()) {
errs() << " (with Debugger hook)";
setupDebuggerDebugExecutionContextHook(executionContext);
}
errs() << "\n";
context.registerActionHandler(executionContext);
}
private:
std::unique_ptr<ToolOutputFile> logActionsFile;
tracing::ExecutionContext executionContext;
std::unique_ptr<tracing::ActionLogger> actionLogger;
std::vector<std::unique_ptr<tracing::FileLineColLocBreakpoint>>
locationBreakpoints;
};
InstallDebugHandler::InstallDebugHandler(MLIRContext &context,
const DebugConfig &config)
: impl(std::make_unique<Impl>(context, config)) {}
InstallDebugHandler::~InstallDebugHandler() = default;

View File

@@ -1,6 +1,7 @@
add_subdirectory(Observers)
add_mlir_library(MLIRDebug
CLOptionsSetup.cpp
DebugCounter.cpp
ExecutionContext.cpp
BreakpointManagers/FileLineColLocBreakpointManager.cpp

View File

@@ -13,6 +13,7 @@
#include "mlir/Tools/mlir-opt/MlirOptMain.h"
#include "mlir/Bytecode/BytecodeWriter.h"
#include "mlir/Debug/CLOptionsSetup.h"
#include "mlir/Debug/Counter.h"
#include "mlir/Debug/DebuggerExecutionContextHook.h"
#include "mlir/Debug/ExecutionContext.h"
@@ -89,39 +90,6 @@ struct MlirOptMainConfigCLOptions : public MlirOptMainConfig {
"parsing"),
cl::location(useExplicitModuleFlag), cl::init(false));
static cl::opt<std::string, /*ExternalStorage=*/true> logActionsTo{
"log-actions-to",
cl::desc("Log action execution to a file, or stderr if "
" '-' is passed"),
cl::location(logActionsToFlag)};
static cl::list<std::string> logActionLocationFilter(
"log-mlir-actions-filter",
cl::desc(
"Comma separated list of locations to filter actions from logging"),
cl::CommaSeparated,
cl::cb<void, std::string>([&](const std::string &location) {
static bool register_once = [&] {
addLogActionLocFilter(&locBreakpointManager);
return true;
}();
(void)register_once;
static std::vector<std::string> locations;
locations.push_back(location);
StringRef locStr = locations.back();
// Parse the individual location filters and set the breakpoints.
auto diag = [](Twine msg) { llvm::errs() << msg << "\n"; };
auto locBreakpoint =
tracing::FileLineColLocBreakpoint::parseFromString(locStr, diag);
if (failed(locBreakpoint)) {
llvm::errs() << "Invalid location filter: " << locStr << "\n";
exit(1);
}
auto [file, line, col] = *locBreakpoint;
locBreakpointManager.addBreakpoint(file, line, col);
}));
static cl::opt<bool, /*ExternalStorage=*/true> showDialects(
"show-dialects",
cl::desc("Print the list of registered dialects and exit"),
@@ -171,9 +139,6 @@ struct MlirOptMainConfigCLOptions : public MlirOptMainConfig {
/// Pointer to static dialectPlugins variable in constructor, needed by
/// setDialectPluginsCallback(DialectRegistry&).
cl::list<std::string> *dialectPlugins = nullptr;
/// The breakpoint manager for the log action location filter.
tracing::FileLineColLocBreakpointManager locBreakpointManager;
};
} // namespace
@@ -181,9 +146,11 @@ ManagedStatic<MlirOptMainConfigCLOptions> clOptionsConfig;
void MlirOptMainConfig::registerCLOptions(DialectRegistry &registry) {
clOptionsConfig->setDialectPluginsCallback(registry);
tracing::DebugConfig::registerCLOptions();
}
MlirOptMainConfig MlirOptMainConfig::createFromCLOptions() {
clOptionsConfig->setDebugConfig(tracing::DebugConfig::createFromCLOptions());
return *clOptionsConfig;
}
@@ -219,53 +186,6 @@ void MlirOptMainConfigCLOptions::setDialectPluginsCallback(
});
}
/// Set the ExecutionContext on the context and handle the observers.
class InstallDebugHandler {
public:
InstallDebugHandler(MLIRContext &context, const MlirOptMainConfig &config) {
if (config.getLogActionsTo().empty() &&
!config.isDebuggerActionHookEnabled()) {
if (tracing::DebugCounter::isActivated())
context.registerActionHandler(tracing::DebugCounter());
return;
}
llvm::errs() << "ExecutionContext registered on the context";
if (tracing::DebugCounter::isActivated())
emitError(UnknownLoc::get(&context),
"Debug counters are incompatible with --log-actions-to and "
"--mlir-enable-debugger-hook options and are disabled");
if (!config.getLogActionsTo().empty()) {
std::string errorMessage;
logActionsFile = openOutputFile(config.getLogActionsTo(), &errorMessage);
if (!logActionsFile) {
emitError(UnknownLoc::get(&context),
"Opening file for --log-actions-to failed: ")
<< errorMessage << "\n";
return;
}
logActionsFile->keep();
raw_fd_ostream &logActionsStream = logActionsFile->os();
actionLogger = std::make_unique<tracing::ActionLogger>(logActionsStream);
for (const auto *locationBreakpoint : config.getLogActionsLocFilters())
actionLogger->addBreakpointManager(locationBreakpoint);
executionContext.registerObserver(actionLogger.get());
}
if (config.isDebuggerActionHookEnabled()) {
llvm::errs() << " (with Debugger hook)";
setupDebuggerExecutionContextHook(executionContext);
}
llvm::errs() << "\n";
context.registerActionHandler(executionContext);
}
private:
std::unique_ptr<llvm::ToolOutputFile> logActionsFile;
std::unique_ptr<tracing::ActionLogger> actionLogger;
std::vector<std::unique_ptr<tracing::FileLineColLocBreakpoint>>
locationBreakpoints;
tracing::ExecutionContext executionContext;
};
/// Perform the actions on the input file indicated by the command line flags
/// within the specified context.
///
@@ -386,7 +306,8 @@ static LogicalResult processBuffer(raw_ostream &os,
if (config.shouldVerifyDiagnostics())
context.printOpOnDiagnostic(false);
InstallDebugHandler installDebugHandler(context, config);
tracing::InstallDebugHandler installDebugHandler(context,
config.getDebugConfig());
// If we are in verify diagnostics mode then we have a lot of work to do,
// otherwise just perform the actions without worrying about it.