[lldb] Make MCP server instance global (#145616)
Rather than having one MCP server per debugger, make the MCP server global and pass a debugger id along with tool invocations that require one. This PR also adds a second tool to list the available debuggers with their targets so the model can decide which debugger instance to use.
This commit is contained in:
committed by
GitHub
parent
2db0289abe
commit
e8abdfc88f
@@ -23,20 +23,6 @@ using namespace lldb_private;
|
||||
#define LLDB_OPTIONS_mcp
|
||||
#include "CommandOptions.inc"
|
||||
|
||||
static std::vector<llvm::StringRef> GetSupportedProtocols() {
|
||||
std::vector<llvm::StringRef> supported_protocols;
|
||||
size_t i = 0;
|
||||
|
||||
for (llvm::StringRef protocol_name =
|
||||
PluginManager::GetProtocolServerPluginNameAtIndex(i++);
|
||||
!protocol_name.empty();
|
||||
protocol_name = PluginManager::GetProtocolServerPluginNameAtIndex(i++)) {
|
||||
supported_protocols.push_back(protocol_name);
|
||||
}
|
||||
|
||||
return supported_protocols;
|
||||
}
|
||||
|
||||
class CommandObjectProtocolServerStart : public CommandObjectParsed {
|
||||
public:
|
||||
CommandObjectProtocolServerStart(CommandInterpreter &interpreter)
|
||||
@@ -57,12 +43,11 @@ protected:
|
||||
}
|
||||
|
||||
llvm::StringRef protocol = args.GetArgumentAtIndex(0);
|
||||
std::vector<llvm::StringRef> supported_protocols = GetSupportedProtocols();
|
||||
if (llvm::find(supported_protocols, protocol) ==
|
||||
supported_protocols.end()) {
|
||||
ProtocolServer *server = ProtocolServer::GetOrCreate(protocol);
|
||||
if (!server) {
|
||||
result.AppendErrorWithFormatv(
|
||||
"unsupported protocol: {0}. Supported protocols are: {1}", protocol,
|
||||
llvm::join(GetSupportedProtocols(), ", "));
|
||||
llvm::join(ProtocolServer::GetSupportedProtocols(), ", "));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -72,10 +57,6 @@ protected:
|
||||
}
|
||||
llvm::StringRef connection_uri = args.GetArgumentAtIndex(1);
|
||||
|
||||
ProtocolServerSP server_sp = GetDebugger().GetProtocolServer(protocol);
|
||||
if (!server_sp)
|
||||
server_sp = ProtocolServer::Create(protocol, GetDebugger());
|
||||
|
||||
const char *connection_error =
|
||||
"unsupported connection specifier, expected 'accept:///path' or "
|
||||
"'listen://[host]:port', got '{0}'.";
|
||||
@@ -98,14 +79,12 @@ protected:
|
||||
formatv("[{0}]:{1}", uri->hostname.empty() ? "0.0.0.0" : uri->hostname,
|
||||
uri->port.value_or(0));
|
||||
|
||||
if (llvm::Error error = server_sp->Start(connection)) {
|
||||
if (llvm::Error error = server->Start(connection)) {
|
||||
result.AppendErrorWithFormatv("{0}", llvm::fmt_consume(std::move(error)));
|
||||
return;
|
||||
}
|
||||
|
||||
GetDebugger().AddProtocolServer(server_sp);
|
||||
|
||||
if (Socket *socket = server_sp->GetSocket()) {
|
||||
if (Socket *socket = server->GetSocket()) {
|
||||
std::string address =
|
||||
llvm::join(socket->GetListeningConnectionURI(), ", ");
|
||||
result.AppendMessageWithFormatv(
|
||||
@@ -134,30 +113,18 @@ protected:
|
||||
}
|
||||
|
||||
llvm::StringRef protocol = args.GetArgumentAtIndex(0);
|
||||
std::vector<llvm::StringRef> supported_protocols = GetSupportedProtocols();
|
||||
if (llvm::find(supported_protocols, protocol) ==
|
||||
supported_protocols.end()) {
|
||||
ProtocolServer *server = ProtocolServer::GetOrCreate(protocol);
|
||||
if (!server) {
|
||||
result.AppendErrorWithFormatv(
|
||||
"unsupported protocol: {0}. Supported protocols are: {1}", protocol,
|
||||
llvm::join(GetSupportedProtocols(), ", "));
|
||||
llvm::join(ProtocolServer::GetSupportedProtocols(), ", "));
|
||||
return;
|
||||
}
|
||||
|
||||
Debugger &debugger = GetDebugger();
|
||||
|
||||
ProtocolServerSP server_sp = debugger.GetProtocolServer(protocol);
|
||||
if (!server_sp) {
|
||||
result.AppendError(
|
||||
llvm::formatv("no {0} protocol server running", protocol).str());
|
||||
return;
|
||||
}
|
||||
|
||||
if (llvm::Error error = server_sp->Stop()) {
|
||||
if (llvm::Error error = server->Stop()) {
|
||||
result.AppendErrorWithFormatv("{0}", llvm::fmt_consume(std::move(error)));
|
||||
return;
|
||||
}
|
||||
|
||||
debugger.RemoveProtocolServer(server_sp);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -2376,26 +2376,3 @@ llvm::ThreadPoolInterface &Debugger::GetThreadPool() {
|
||||
"Debugger::GetThreadPool called before Debugger::Initialize");
|
||||
return *g_thread_pool;
|
||||
}
|
||||
|
||||
void Debugger::AddProtocolServer(lldb::ProtocolServerSP protocol_server_sp) {
|
||||
assert(protocol_server_sp &&
|
||||
GetProtocolServer(protocol_server_sp->GetPluginName()) == nullptr);
|
||||
m_protocol_servers.push_back(protocol_server_sp);
|
||||
}
|
||||
|
||||
void Debugger::RemoveProtocolServer(lldb::ProtocolServerSP protocol_server_sp) {
|
||||
auto it = llvm::find(m_protocol_servers, protocol_server_sp);
|
||||
if (it != m_protocol_servers.end())
|
||||
m_protocol_servers.erase(it);
|
||||
}
|
||||
|
||||
lldb::ProtocolServerSP
|
||||
Debugger::GetProtocolServer(llvm::StringRef protocol) const {
|
||||
for (ProtocolServerSP protocol_server_sp : m_protocol_servers) {
|
||||
if (!protocol_server_sp)
|
||||
continue;
|
||||
if (protocol_server_sp->GetPluginName() == protocol)
|
||||
return protocol_server_sp;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@@ -12,10 +12,36 @@
|
||||
using namespace lldb_private;
|
||||
using namespace lldb;
|
||||
|
||||
ProtocolServerSP ProtocolServer::Create(llvm::StringRef name,
|
||||
Debugger &debugger) {
|
||||
ProtocolServer *ProtocolServer::GetOrCreate(llvm::StringRef name) {
|
||||
static std::mutex g_mutex;
|
||||
static llvm::StringMap<ProtocolServerUP> g_protocol_server_instances;
|
||||
|
||||
std::lock_guard<std::mutex> guard(g_mutex);
|
||||
|
||||
auto it = g_protocol_server_instances.find(name);
|
||||
if (it != g_protocol_server_instances.end())
|
||||
return it->second.get();
|
||||
|
||||
if (ProtocolServerCreateInstance create_callback =
|
||||
PluginManager::GetProtocolCreateCallbackForPluginName(name))
|
||||
return create_callback(debugger);
|
||||
PluginManager::GetProtocolCreateCallbackForPluginName(name)) {
|
||||
auto pair =
|
||||
g_protocol_server_instances.try_emplace(name, create_callback());
|
||||
return pair.first->second.get();
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::vector<llvm::StringRef> ProtocolServer::GetSupportedProtocols() {
|
||||
std::vector<llvm::StringRef> supported_protocols;
|
||||
size_t i = 0;
|
||||
|
||||
for (llvm::StringRef protocol_name =
|
||||
PluginManager::GetProtocolServerPluginNameAtIndex(i++);
|
||||
!protocol_name.empty();
|
||||
protocol_name = PluginManager::GetProtocolServerPluginNameAtIndex(i++)) {
|
||||
supported_protocols.push_back(protocol_name);
|
||||
}
|
||||
|
||||
return supported_protocols;
|
||||
}
|
||||
|
||||
@@ -123,6 +123,8 @@ using Message = std::variant<Request, Response, Notification, Error>;
|
||||
bool fromJSON(const llvm::json::Value &, Message &, llvm::json::Path);
|
||||
llvm::json::Value toJSON(const Message &);
|
||||
|
||||
using ToolArguments = std::variant<std::monostate, llvm::json::Value>;
|
||||
|
||||
} // namespace lldb_private::mcp::protocol
|
||||
|
||||
#endif
|
||||
|
||||
@@ -24,8 +24,7 @@ LLDB_PLUGIN_DEFINE(ProtocolServerMCP)
|
||||
|
||||
static constexpr size_t kChunkSize = 1024;
|
||||
|
||||
ProtocolServerMCP::ProtocolServerMCP(Debugger &debugger)
|
||||
: ProtocolServer(), m_debugger(debugger) {
|
||||
ProtocolServerMCP::ProtocolServerMCP() : ProtocolServer() {
|
||||
AddRequestHandler("initialize",
|
||||
std::bind(&ProtocolServerMCP::InitializeHandler, this,
|
||||
std::placeholders::_1));
|
||||
@@ -39,8 +38,10 @@ ProtocolServerMCP::ProtocolServerMCP(Debugger &debugger)
|
||||
"notifications/initialized", [](const protocol::Notification &) {
|
||||
LLDB_LOG(GetLog(LLDBLog::Host), "MCP initialization complete");
|
||||
});
|
||||
AddTool(std::make_unique<LLDBCommandTool>(
|
||||
"lldb_command", "Run an lldb command.", m_debugger));
|
||||
AddTool(
|
||||
std::make_unique<CommandTool>("lldb_command", "Run an lldb command."));
|
||||
AddTool(std::make_unique<DebuggerListTool>(
|
||||
"lldb_debugger_list", "List debugger instances with their debugger_id."));
|
||||
}
|
||||
|
||||
ProtocolServerMCP::~ProtocolServerMCP() { llvm::consumeError(Stop()); }
|
||||
@@ -54,8 +55,8 @@ void ProtocolServerMCP::Terminate() {
|
||||
PluginManager::UnregisterPlugin(CreateInstance);
|
||||
}
|
||||
|
||||
lldb::ProtocolServerSP ProtocolServerMCP::CreateInstance(Debugger &debugger) {
|
||||
return std::make_shared<ProtocolServerMCP>(debugger);
|
||||
lldb::ProtocolServerUP ProtocolServerMCP::CreateInstance() {
|
||||
return std::make_unique<ProtocolServerMCP>();
|
||||
}
|
||||
|
||||
llvm::StringRef ProtocolServerMCP::GetPluginDescriptionStatic() {
|
||||
@@ -145,7 +146,7 @@ llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) {
|
||||
std::lock_guard<std::mutex> guard(m_server_mutex);
|
||||
|
||||
if (m_running)
|
||||
return llvm::createStringError("server already running");
|
||||
return llvm::createStringError("the MCP server is already running");
|
||||
|
||||
Status status;
|
||||
m_listener = Socket::Create(connection.protocol, status);
|
||||
@@ -162,10 +163,10 @@ llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) {
|
||||
if (llvm::Error error = handles.takeError())
|
||||
return error;
|
||||
|
||||
m_running = true;
|
||||
m_listen_handlers = std::move(*handles);
|
||||
m_loop_thread = std::thread([=] {
|
||||
llvm::set_thread_name(
|
||||
llvm::formatv("debugger-{0}.mcp.runloop", m_debugger.GetID()));
|
||||
llvm::set_thread_name("protocol-server.mcp");
|
||||
m_loop.Run();
|
||||
});
|
||||
|
||||
@@ -175,6 +176,8 @@ llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) {
|
||||
llvm::Error ProtocolServerMCP::Stop() {
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(m_server_mutex);
|
||||
if (!m_running)
|
||||
return createStringError("the MCP sever is not running");
|
||||
m_running = false;
|
||||
}
|
||||
|
||||
@@ -311,11 +314,12 @@ ProtocolServerMCP::ToolsCallHandler(const protocol::Request &request) {
|
||||
if (it == m_tools.end())
|
||||
return llvm::createStringError(llvm::formatv("no tool \"{0}\"", tool_name));
|
||||
|
||||
const json::Value *args = param_obj->get("arguments");
|
||||
if (!args)
|
||||
return llvm::createStringError("no tool arguments");
|
||||
protocol::ToolArguments tool_args;
|
||||
if (const json::Value *args = param_obj->get("arguments"))
|
||||
tool_args = *args;
|
||||
|
||||
llvm::Expected<protocol::TextResult> text_result = it->second->Call(*args);
|
||||
llvm::Expected<protocol::TextResult> text_result =
|
||||
it->second->Call(tool_args);
|
||||
if (!text_result)
|
||||
return text_result.takeError();
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ namespace lldb_private::mcp {
|
||||
|
||||
class ProtocolServerMCP : public ProtocolServer {
|
||||
public:
|
||||
ProtocolServerMCP(Debugger &debugger);
|
||||
ProtocolServerMCP();
|
||||
virtual ~ProtocolServerMCP() override;
|
||||
|
||||
virtual llvm::Error Start(ProtocolServer::Connection connection) override;
|
||||
@@ -33,7 +33,7 @@ public:
|
||||
static llvm::StringRef GetPluginNameStatic() { return "MCP"; }
|
||||
static llvm::StringRef GetPluginDescriptionStatic();
|
||||
|
||||
static lldb::ProtocolServerSP CreateInstance(Debugger &debugger);
|
||||
static lldb::ProtocolServerUP CreateInstance();
|
||||
|
||||
llvm::StringRef GetPluginName() override { return GetPluginNameStatic(); }
|
||||
|
||||
@@ -71,8 +71,6 @@ private:
|
||||
llvm::StringLiteral kName = "lldb-mcp";
|
||||
llvm::StringLiteral kVersion = "0.1.0";
|
||||
|
||||
Debugger &m_debugger;
|
||||
|
||||
bool m_running = false;
|
||||
|
||||
MainLoop m_loop;
|
||||
|
||||
@@ -7,22 +7,38 @@
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "Tool.h"
|
||||
#include "lldb/Core/Module.h"
|
||||
#include "lldb/Interpreter/CommandInterpreter.h"
|
||||
#include "lldb/Interpreter/CommandReturnObject.h"
|
||||
|
||||
using namespace lldb_private::mcp;
|
||||
using namespace llvm;
|
||||
|
||||
struct LLDBCommandToolArguments {
|
||||
namespace {
|
||||
struct CommandToolArguments {
|
||||
uint64_t debugger_id;
|
||||
std::string arguments;
|
||||
};
|
||||
|
||||
bool fromJSON(const llvm::json::Value &V, LLDBCommandToolArguments &A,
|
||||
bool fromJSON(const llvm::json::Value &V, CommandToolArguments &A,
|
||||
llvm::json::Path P) {
|
||||
llvm::json::ObjectMapper O(V, P);
|
||||
return O && O.map("arguments", A.arguments);
|
||||
return O && O.map("debugger_id", A.debugger_id) &&
|
||||
O.mapOptional("arguments", A.arguments);
|
||||
}
|
||||
|
||||
/// Helper function to create a TextResult from a string output.
|
||||
static lldb_private::mcp::protocol::TextResult
|
||||
createTextResult(std::string output, bool is_error = false) {
|
||||
lldb_private::mcp::protocol::TextResult text_result;
|
||||
text_result.content.emplace_back(
|
||||
lldb_private::mcp::protocol::TextContent{{std::move(output)}});
|
||||
text_result.isError = is_error;
|
||||
return text_result;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Tool::Tool(std::string name, std::string description)
|
||||
: m_name(std::move(name)), m_description(std::move(description)) {}
|
||||
|
||||
@@ -37,22 +53,27 @@ protocol::ToolDefinition Tool::GetDefinition() const {
|
||||
return definition;
|
||||
}
|
||||
|
||||
LLDBCommandTool::LLDBCommandTool(std::string name, std::string description,
|
||||
Debugger &debugger)
|
||||
: Tool(std::move(name), std::move(description)), m_debugger(debugger) {}
|
||||
|
||||
llvm::Expected<protocol::TextResult>
|
||||
LLDBCommandTool::Call(const llvm::json::Value &args) {
|
||||
llvm::json::Path::Root root;
|
||||
CommandTool::Call(const protocol::ToolArguments &args) {
|
||||
if (!std::holds_alternative<json::Value>(args))
|
||||
return createStringError("CommandTool requires arguments");
|
||||
|
||||
LLDBCommandToolArguments arguments;
|
||||
if (!fromJSON(args, arguments, root))
|
||||
json::Path::Root root;
|
||||
|
||||
CommandToolArguments arguments;
|
||||
if (!fromJSON(std::get<json::Value>(args), arguments, root))
|
||||
return root.getError();
|
||||
|
||||
lldb::DebuggerSP debugger_sp =
|
||||
Debugger::GetDebuggerAtIndex(arguments.debugger_id);
|
||||
if (!debugger_sp)
|
||||
return createStringError(
|
||||
llvm::formatv("no debugger with id {0}", arguments.debugger_id));
|
||||
|
||||
// FIXME: Disallow certain commands and their aliases.
|
||||
CommandReturnObject result(/*colors=*/false);
|
||||
m_debugger.GetCommandInterpreter().HandleCommand(arguments.arguments.c_str(),
|
||||
eLazyBoolYes, result);
|
||||
debugger_sp->GetCommandInterpreter().HandleCommand(
|
||||
arguments.arguments.c_str(), eLazyBoolYes, result);
|
||||
|
||||
std::string output;
|
||||
llvm::StringRef output_str = result.GetOutputString();
|
||||
@@ -66,16 +87,64 @@ LLDBCommandTool::Call(const llvm::json::Value &args) {
|
||||
output += err_str;
|
||||
}
|
||||
|
||||
mcp::protocol::TextResult text_result;
|
||||
text_result.content.emplace_back(mcp::protocol::TextContent{{output}});
|
||||
text_result.isError = !result.Succeeded();
|
||||
return text_result;
|
||||
return createTextResult(output, !result.Succeeded());
|
||||
}
|
||||
|
||||
std::optional<llvm::json::Value> LLDBCommandTool::GetSchema() const {
|
||||
std::optional<llvm::json::Value> CommandTool::GetSchema() const {
|
||||
llvm::json::Object id_type{{"type", "number"}};
|
||||
llvm::json::Object str_type{{"type", "string"}};
|
||||
llvm::json::Object properties{{"arguments", std::move(str_type)}};
|
||||
llvm::json::Object properties{{"debugger_id", std::move(id_type)},
|
||||
{"arguments", std::move(str_type)}};
|
||||
llvm::json::Array required{"debugger_id"};
|
||||
llvm::json::Object schema{{"type", "object"},
|
||||
{"properties", std::move(properties)}};
|
||||
{"properties", std::move(properties)},
|
||||
{"required", std::move(required)}};
|
||||
return schema;
|
||||
}
|
||||
|
||||
llvm::Expected<protocol::TextResult>
|
||||
DebuggerListTool::Call(const protocol::ToolArguments &args) {
|
||||
if (!std::holds_alternative<std::monostate>(args))
|
||||
return createStringError("DebuggerListTool takes no arguments");
|
||||
|
||||
llvm::json::Path::Root root;
|
||||
|
||||
// Return a nested Markdown list with debuggers and target.
|
||||
// Example output:
|
||||
//
|
||||
// - debugger 0
|
||||
// - target 0 /path/to/foo
|
||||
// - target 1
|
||||
// - debugger 1
|
||||
// - target 0 /path/to/bar
|
||||
//
|
||||
// FIXME: Use Structured Content when we adopt protocol version 2025-06-18.
|
||||
std::string output;
|
||||
llvm::raw_string_ostream os(output);
|
||||
|
||||
const size_t num_debuggers = Debugger::GetNumDebuggers();
|
||||
for (size_t i = 0; i < num_debuggers; ++i) {
|
||||
lldb::DebuggerSP debugger_sp = Debugger::GetDebuggerAtIndex(i);
|
||||
if (!debugger_sp)
|
||||
continue;
|
||||
|
||||
os << "- debugger " << i << '\n';
|
||||
|
||||
TargetList &target_list = debugger_sp->GetTargetList();
|
||||
const size_t num_targets = target_list.GetNumTargets();
|
||||
for (size_t j = 0; j < num_targets; ++j) {
|
||||
lldb::TargetSP target_sp = target_list.GetTargetAtIndex(j);
|
||||
if (!target_sp)
|
||||
continue;
|
||||
os << " - target " << j;
|
||||
if (target_sp == target_list.GetSelectedTarget())
|
||||
os << " (selected)";
|
||||
// Append the module path if we have one.
|
||||
if (Module *exe_module = target_sp->GetExecutableModulePointer())
|
||||
os << " " << exe_module->GetFileSpec().GetPath();
|
||||
os << '\n';
|
||||
}
|
||||
}
|
||||
|
||||
return createTextResult(output);
|
||||
}
|
||||
|
||||
@@ -22,10 +22,10 @@ public:
|
||||
virtual ~Tool() = default;
|
||||
|
||||
virtual llvm::Expected<protocol::TextResult>
|
||||
Call(const llvm::json::Value &args) = 0;
|
||||
Call(const protocol::ToolArguments &args) = 0;
|
||||
|
||||
virtual std::optional<llvm::json::Value> GetSchema() const {
|
||||
return std::nullopt;
|
||||
return llvm::json::Object{{"type", "object"}};
|
||||
}
|
||||
|
||||
protocol::ToolDefinition GetDefinition() const;
|
||||
@@ -37,20 +37,26 @@ private:
|
||||
std::string m_description;
|
||||
};
|
||||
|
||||
class LLDBCommandTool : public mcp::Tool {
|
||||
class CommandTool : public mcp::Tool {
|
||||
public:
|
||||
LLDBCommandTool(std::string name, std::string description,
|
||||
Debugger &debugger);
|
||||
~LLDBCommandTool() = default;
|
||||
using mcp::Tool::Tool;
|
||||
~CommandTool() = default;
|
||||
|
||||
virtual llvm::Expected<protocol::TextResult>
|
||||
Call(const llvm::json::Value &args) override;
|
||||
Call(const protocol::ToolArguments &args) override;
|
||||
|
||||
virtual std::optional<llvm::json::Value> GetSchema() const override;
|
||||
|
||||
private:
|
||||
Debugger &m_debugger;
|
||||
};
|
||||
|
||||
class DebuggerListTool : public mcp::Tool {
|
||||
public:
|
||||
using mcp::Tool::Tool;
|
||||
~DebuggerListTool() = default;
|
||||
|
||||
virtual llvm::Expected<protocol::TextResult>
|
||||
Call(const protocol::ToolArguments &args) override;
|
||||
};
|
||||
|
||||
} // namespace lldb_private::mcp
|
||||
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user