[lldb] Make MCP server instance global (#145616)

Rather than having one MCP server per debugger, make the MCP server
global and pass a debugger id along with tool invocations that require
one. This PR also adds a second tool to list the available debuggers
with their targets so the model can decide which debugger instance to
use.
This commit is contained in:
Jonas Devlieghere
2025-06-25 15:46:33 -05:00
committed by GitHub
parent 2db0289abe
commit e8abdfc88f
13 changed files with 181 additions and 137 deletions

View File

@@ -23,20 +23,6 @@ using namespace lldb_private;
#define LLDB_OPTIONS_mcp
#include "CommandOptions.inc"
static std::vector<llvm::StringRef> GetSupportedProtocols() {
std::vector<llvm::StringRef> supported_protocols;
size_t i = 0;
for (llvm::StringRef protocol_name =
PluginManager::GetProtocolServerPluginNameAtIndex(i++);
!protocol_name.empty();
protocol_name = PluginManager::GetProtocolServerPluginNameAtIndex(i++)) {
supported_protocols.push_back(protocol_name);
}
return supported_protocols;
}
class CommandObjectProtocolServerStart : public CommandObjectParsed {
public:
CommandObjectProtocolServerStart(CommandInterpreter &interpreter)
@@ -57,12 +43,11 @@ protected:
}
llvm::StringRef protocol = args.GetArgumentAtIndex(0);
std::vector<llvm::StringRef> supported_protocols = GetSupportedProtocols();
if (llvm::find(supported_protocols, protocol) ==
supported_protocols.end()) {
ProtocolServer *server = ProtocolServer::GetOrCreate(protocol);
if (!server) {
result.AppendErrorWithFormatv(
"unsupported protocol: {0}. Supported protocols are: {1}", protocol,
llvm::join(GetSupportedProtocols(), ", "));
llvm::join(ProtocolServer::GetSupportedProtocols(), ", "));
return;
}
@@ -72,10 +57,6 @@ protected:
}
llvm::StringRef connection_uri = args.GetArgumentAtIndex(1);
ProtocolServerSP server_sp = GetDebugger().GetProtocolServer(protocol);
if (!server_sp)
server_sp = ProtocolServer::Create(protocol, GetDebugger());
const char *connection_error =
"unsupported connection specifier, expected 'accept:///path' or "
"'listen://[host]:port', got '{0}'.";
@@ -98,14 +79,12 @@ protected:
formatv("[{0}]:{1}", uri->hostname.empty() ? "0.0.0.0" : uri->hostname,
uri->port.value_or(0));
if (llvm::Error error = server_sp->Start(connection)) {
if (llvm::Error error = server->Start(connection)) {
result.AppendErrorWithFormatv("{0}", llvm::fmt_consume(std::move(error)));
return;
}
GetDebugger().AddProtocolServer(server_sp);
if (Socket *socket = server_sp->GetSocket()) {
if (Socket *socket = server->GetSocket()) {
std::string address =
llvm::join(socket->GetListeningConnectionURI(), ", ");
result.AppendMessageWithFormatv(
@@ -134,30 +113,18 @@ protected:
}
llvm::StringRef protocol = args.GetArgumentAtIndex(0);
std::vector<llvm::StringRef> supported_protocols = GetSupportedProtocols();
if (llvm::find(supported_protocols, protocol) ==
supported_protocols.end()) {
ProtocolServer *server = ProtocolServer::GetOrCreate(protocol);
if (!server) {
result.AppendErrorWithFormatv(
"unsupported protocol: {0}. Supported protocols are: {1}", protocol,
llvm::join(GetSupportedProtocols(), ", "));
llvm::join(ProtocolServer::GetSupportedProtocols(), ", "));
return;
}
Debugger &debugger = GetDebugger();
ProtocolServerSP server_sp = debugger.GetProtocolServer(protocol);
if (!server_sp) {
result.AppendError(
llvm::formatv("no {0} protocol server running", protocol).str());
return;
}
if (llvm::Error error = server_sp->Stop()) {
if (llvm::Error error = server->Stop()) {
result.AppendErrorWithFormatv("{0}", llvm::fmt_consume(std::move(error)));
return;
}
debugger.RemoveProtocolServer(server_sp);
}
};

View File

@@ -2376,26 +2376,3 @@ llvm::ThreadPoolInterface &Debugger::GetThreadPool() {
"Debugger::GetThreadPool called before Debugger::Initialize");
return *g_thread_pool;
}
void Debugger::AddProtocolServer(lldb::ProtocolServerSP protocol_server_sp) {
assert(protocol_server_sp &&
GetProtocolServer(protocol_server_sp->GetPluginName()) == nullptr);
m_protocol_servers.push_back(protocol_server_sp);
}
void Debugger::RemoveProtocolServer(lldb::ProtocolServerSP protocol_server_sp) {
auto it = llvm::find(m_protocol_servers, protocol_server_sp);
if (it != m_protocol_servers.end())
m_protocol_servers.erase(it);
}
lldb::ProtocolServerSP
Debugger::GetProtocolServer(llvm::StringRef protocol) const {
for (ProtocolServerSP protocol_server_sp : m_protocol_servers) {
if (!protocol_server_sp)
continue;
if (protocol_server_sp->GetPluginName() == protocol)
return protocol_server_sp;
}
return nullptr;
}

View File

@@ -12,10 +12,36 @@
using namespace lldb_private;
using namespace lldb;
ProtocolServerSP ProtocolServer::Create(llvm::StringRef name,
Debugger &debugger) {
ProtocolServer *ProtocolServer::GetOrCreate(llvm::StringRef name) {
static std::mutex g_mutex;
static llvm::StringMap<ProtocolServerUP> g_protocol_server_instances;
std::lock_guard<std::mutex> guard(g_mutex);
auto it = g_protocol_server_instances.find(name);
if (it != g_protocol_server_instances.end())
return it->second.get();
if (ProtocolServerCreateInstance create_callback =
PluginManager::GetProtocolCreateCallbackForPluginName(name))
return create_callback(debugger);
PluginManager::GetProtocolCreateCallbackForPluginName(name)) {
auto pair =
g_protocol_server_instances.try_emplace(name, create_callback());
return pair.first->second.get();
}
return nullptr;
}
std::vector<llvm::StringRef> ProtocolServer::GetSupportedProtocols() {
std::vector<llvm::StringRef> supported_protocols;
size_t i = 0;
for (llvm::StringRef protocol_name =
PluginManager::GetProtocolServerPluginNameAtIndex(i++);
!protocol_name.empty();
protocol_name = PluginManager::GetProtocolServerPluginNameAtIndex(i++)) {
supported_protocols.push_back(protocol_name);
}
return supported_protocols;
}

View File

@@ -123,6 +123,8 @@ using Message = std::variant<Request, Response, Notification, Error>;
bool fromJSON(const llvm::json::Value &, Message &, llvm::json::Path);
llvm::json::Value toJSON(const Message &);
using ToolArguments = std::variant<std::monostate, llvm::json::Value>;
} // namespace lldb_private::mcp::protocol
#endif

View File

@@ -24,8 +24,7 @@ LLDB_PLUGIN_DEFINE(ProtocolServerMCP)
static constexpr size_t kChunkSize = 1024;
ProtocolServerMCP::ProtocolServerMCP(Debugger &debugger)
: ProtocolServer(), m_debugger(debugger) {
ProtocolServerMCP::ProtocolServerMCP() : ProtocolServer() {
AddRequestHandler("initialize",
std::bind(&ProtocolServerMCP::InitializeHandler, this,
std::placeholders::_1));
@@ -39,8 +38,10 @@ ProtocolServerMCP::ProtocolServerMCP(Debugger &debugger)
"notifications/initialized", [](const protocol::Notification &) {
LLDB_LOG(GetLog(LLDBLog::Host), "MCP initialization complete");
});
AddTool(std::make_unique<LLDBCommandTool>(
"lldb_command", "Run an lldb command.", m_debugger));
AddTool(
std::make_unique<CommandTool>("lldb_command", "Run an lldb command."));
AddTool(std::make_unique<DebuggerListTool>(
"lldb_debugger_list", "List debugger instances with their debugger_id."));
}
ProtocolServerMCP::~ProtocolServerMCP() { llvm::consumeError(Stop()); }
@@ -54,8 +55,8 @@ void ProtocolServerMCP::Terminate() {
PluginManager::UnregisterPlugin(CreateInstance);
}
lldb::ProtocolServerSP ProtocolServerMCP::CreateInstance(Debugger &debugger) {
return std::make_shared<ProtocolServerMCP>(debugger);
lldb::ProtocolServerUP ProtocolServerMCP::CreateInstance() {
return std::make_unique<ProtocolServerMCP>();
}
llvm::StringRef ProtocolServerMCP::GetPluginDescriptionStatic() {
@@ -145,7 +146,7 @@ llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) {
std::lock_guard<std::mutex> guard(m_server_mutex);
if (m_running)
return llvm::createStringError("server already running");
return llvm::createStringError("the MCP server is already running");
Status status;
m_listener = Socket::Create(connection.protocol, status);
@@ -162,10 +163,10 @@ llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) {
if (llvm::Error error = handles.takeError())
return error;
m_running = true;
m_listen_handlers = std::move(*handles);
m_loop_thread = std::thread([=] {
llvm::set_thread_name(
llvm::formatv("debugger-{0}.mcp.runloop", m_debugger.GetID()));
llvm::set_thread_name("protocol-server.mcp");
m_loop.Run();
});
@@ -175,6 +176,8 @@ llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) {
llvm::Error ProtocolServerMCP::Stop() {
{
std::lock_guard<std::mutex> guard(m_server_mutex);
if (!m_running)
return createStringError("the MCP sever is not running");
m_running = false;
}
@@ -311,11 +314,12 @@ ProtocolServerMCP::ToolsCallHandler(const protocol::Request &request) {
if (it == m_tools.end())
return llvm::createStringError(llvm::formatv("no tool \"{0}\"", tool_name));
const json::Value *args = param_obj->get("arguments");
if (!args)
return llvm::createStringError("no tool arguments");
protocol::ToolArguments tool_args;
if (const json::Value *args = param_obj->get("arguments"))
tool_args = *args;
llvm::Expected<protocol::TextResult> text_result = it->second->Call(*args);
llvm::Expected<protocol::TextResult> text_result =
it->second->Call(tool_args);
if (!text_result)
return text_result.takeError();

View File

@@ -21,7 +21,7 @@ namespace lldb_private::mcp {
class ProtocolServerMCP : public ProtocolServer {
public:
ProtocolServerMCP(Debugger &debugger);
ProtocolServerMCP();
virtual ~ProtocolServerMCP() override;
virtual llvm::Error Start(ProtocolServer::Connection connection) override;
@@ -33,7 +33,7 @@ public:
static llvm::StringRef GetPluginNameStatic() { return "MCP"; }
static llvm::StringRef GetPluginDescriptionStatic();
static lldb::ProtocolServerSP CreateInstance(Debugger &debugger);
static lldb::ProtocolServerUP CreateInstance();
llvm::StringRef GetPluginName() override { return GetPluginNameStatic(); }
@@ -71,8 +71,6 @@ private:
llvm::StringLiteral kName = "lldb-mcp";
llvm::StringLiteral kVersion = "0.1.0";
Debugger &m_debugger;
bool m_running = false;
MainLoop m_loop;

View File

@@ -7,22 +7,38 @@
//===----------------------------------------------------------------------===//
#include "Tool.h"
#include "lldb/Core/Module.h"
#include "lldb/Interpreter/CommandInterpreter.h"
#include "lldb/Interpreter/CommandReturnObject.h"
using namespace lldb_private::mcp;
using namespace llvm;
struct LLDBCommandToolArguments {
namespace {
struct CommandToolArguments {
uint64_t debugger_id;
std::string arguments;
};
bool fromJSON(const llvm::json::Value &V, LLDBCommandToolArguments &A,
bool fromJSON(const llvm::json::Value &V, CommandToolArguments &A,
llvm::json::Path P) {
llvm::json::ObjectMapper O(V, P);
return O && O.map("arguments", A.arguments);
return O && O.map("debugger_id", A.debugger_id) &&
O.mapOptional("arguments", A.arguments);
}
/// Helper function to create a TextResult from a string output.
static lldb_private::mcp::protocol::TextResult
createTextResult(std::string output, bool is_error = false) {
lldb_private::mcp::protocol::TextResult text_result;
text_result.content.emplace_back(
lldb_private::mcp::protocol::TextContent{{std::move(output)}});
text_result.isError = is_error;
return text_result;
}
} // namespace
Tool::Tool(std::string name, std::string description)
: m_name(std::move(name)), m_description(std::move(description)) {}
@@ -37,22 +53,27 @@ protocol::ToolDefinition Tool::GetDefinition() const {
return definition;
}
LLDBCommandTool::LLDBCommandTool(std::string name, std::string description,
Debugger &debugger)
: Tool(std::move(name), std::move(description)), m_debugger(debugger) {}
llvm::Expected<protocol::TextResult>
LLDBCommandTool::Call(const llvm::json::Value &args) {
llvm::json::Path::Root root;
CommandTool::Call(const protocol::ToolArguments &args) {
if (!std::holds_alternative<json::Value>(args))
return createStringError("CommandTool requires arguments");
LLDBCommandToolArguments arguments;
if (!fromJSON(args, arguments, root))
json::Path::Root root;
CommandToolArguments arguments;
if (!fromJSON(std::get<json::Value>(args), arguments, root))
return root.getError();
lldb::DebuggerSP debugger_sp =
Debugger::GetDebuggerAtIndex(arguments.debugger_id);
if (!debugger_sp)
return createStringError(
llvm::formatv("no debugger with id {0}", arguments.debugger_id));
// FIXME: Disallow certain commands and their aliases.
CommandReturnObject result(/*colors=*/false);
m_debugger.GetCommandInterpreter().HandleCommand(arguments.arguments.c_str(),
eLazyBoolYes, result);
debugger_sp->GetCommandInterpreter().HandleCommand(
arguments.arguments.c_str(), eLazyBoolYes, result);
std::string output;
llvm::StringRef output_str = result.GetOutputString();
@@ -66,16 +87,64 @@ LLDBCommandTool::Call(const llvm::json::Value &args) {
output += err_str;
}
mcp::protocol::TextResult text_result;
text_result.content.emplace_back(mcp::protocol::TextContent{{output}});
text_result.isError = !result.Succeeded();
return text_result;
return createTextResult(output, !result.Succeeded());
}
std::optional<llvm::json::Value> LLDBCommandTool::GetSchema() const {
std::optional<llvm::json::Value> CommandTool::GetSchema() const {
llvm::json::Object id_type{{"type", "number"}};
llvm::json::Object str_type{{"type", "string"}};
llvm::json::Object properties{{"arguments", std::move(str_type)}};
llvm::json::Object properties{{"debugger_id", std::move(id_type)},
{"arguments", std::move(str_type)}};
llvm::json::Array required{"debugger_id"};
llvm::json::Object schema{{"type", "object"},
{"properties", std::move(properties)}};
{"properties", std::move(properties)},
{"required", std::move(required)}};
return schema;
}
llvm::Expected<protocol::TextResult>
DebuggerListTool::Call(const protocol::ToolArguments &args) {
if (!std::holds_alternative<std::monostate>(args))
return createStringError("DebuggerListTool takes no arguments");
llvm::json::Path::Root root;
// Return a nested Markdown list with debuggers and target.
// Example output:
//
// - debugger 0
// - target 0 /path/to/foo
// - target 1
// - debugger 1
// - target 0 /path/to/bar
//
// FIXME: Use Structured Content when we adopt protocol version 2025-06-18.
std::string output;
llvm::raw_string_ostream os(output);
const size_t num_debuggers = Debugger::GetNumDebuggers();
for (size_t i = 0; i < num_debuggers; ++i) {
lldb::DebuggerSP debugger_sp = Debugger::GetDebuggerAtIndex(i);
if (!debugger_sp)
continue;
os << "- debugger " << i << '\n';
TargetList &target_list = debugger_sp->GetTargetList();
const size_t num_targets = target_list.GetNumTargets();
for (size_t j = 0; j < num_targets; ++j) {
lldb::TargetSP target_sp = target_list.GetTargetAtIndex(j);
if (!target_sp)
continue;
os << " - target " << j;
if (target_sp == target_list.GetSelectedTarget())
os << " (selected)";
// Append the module path if we have one.
if (Module *exe_module = target_sp->GetExecutableModulePointer())
os << " " << exe_module->GetFileSpec().GetPath();
os << '\n';
}
}
return createTextResult(output);
}

View File

@@ -22,10 +22,10 @@ public:
virtual ~Tool() = default;
virtual llvm::Expected<protocol::TextResult>
Call(const llvm::json::Value &args) = 0;
Call(const protocol::ToolArguments &args) = 0;
virtual std::optional<llvm::json::Value> GetSchema() const {
return std::nullopt;
return llvm::json::Object{{"type", "object"}};
}
protocol::ToolDefinition GetDefinition() const;
@@ -37,20 +37,26 @@ private:
std::string m_description;
};
class LLDBCommandTool : public mcp::Tool {
class CommandTool : public mcp::Tool {
public:
LLDBCommandTool(std::string name, std::string description,
Debugger &debugger);
~LLDBCommandTool() = default;
using mcp::Tool::Tool;
~CommandTool() = default;
virtual llvm::Expected<protocol::TextResult>
Call(const llvm::json::Value &args) override;
Call(const protocol::ToolArguments &args) override;
virtual std::optional<llvm::json::Value> GetSchema() const override;
private:
Debugger &m_debugger;
};
class DebuggerListTool : public mcp::Tool {
public:
using mcp::Tool::Tool;
~DebuggerListTool() = default;
virtual llvm::Expected<protocol::TextResult>
Call(const protocol::ToolArguments &args) override;
};
} // namespace lldb_private::mcp
#endif