//===- ProtocolServerMCP.cpp ----------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "ProtocolServerMCP.h" #include "MCPError.h" #include "lldb/Core/PluginManager.h" #include "lldb/Utility/LLDBLog.h" #include "lldb/Utility/Log.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/Threading.h" #include #include using namespace lldb_private; using namespace lldb_private::mcp; using namespace llvm; LLDB_PLUGIN_DEFINE(ProtocolServerMCP) static constexpr size_t kChunkSize = 1024; ProtocolServerMCP::ProtocolServerMCP(Debugger &debugger) : ProtocolServer(), m_debugger(debugger) { AddRequestHandler("initialize", std::bind(&ProtocolServerMCP::InitializeHandler, this, std::placeholders::_1)); AddRequestHandler("tools/list", std::bind(&ProtocolServerMCP::ToolsListHandler, this, std::placeholders::_1)); AddRequestHandler("tools/call", std::bind(&ProtocolServerMCP::ToolsCallHandler, this, std::placeholders::_1)); AddNotificationHandler( "notifications/initialized", [](const protocol::Notification &) { LLDB_LOG(GetLog(LLDBLog::Host), "MCP initialization complete"); }); AddTool(std::make_unique( "lldb_command", "Run an lldb command.", m_debugger)); } ProtocolServerMCP::~ProtocolServerMCP() { llvm::consumeError(Stop()); } void ProtocolServerMCP::Initialize() { PluginManager::RegisterPlugin(GetPluginNameStatic(), GetPluginDescriptionStatic(), CreateInstance); } void ProtocolServerMCP::Terminate() { PluginManager::UnregisterPlugin(CreateInstance); } lldb::ProtocolServerSP ProtocolServerMCP::CreateInstance(Debugger &debugger) { return std::make_shared(debugger); } llvm::StringRef ProtocolServerMCP::GetPluginDescriptionStatic() { return "MCP Server."; } llvm::Expected ProtocolServerMCP::Handle(protocol::Request request) { auto it = m_request_handlers.find(request.method); if (it != m_request_handlers.end()) { llvm::Expected response = it->second(request); if (!response) return response; response->id = request.id; return *response; } return make_error( llvm::formatv("no handler for request: {0}", request.method).str(), 1); } void ProtocolServerMCP::Handle(protocol::Notification notification) { auto it = m_notification_handlers.find(notification.method); if (it != m_notification_handlers.end()) { it->second(notification); return; } LLDB_LOG(GetLog(LLDBLog::Host), "MPC notification: {0} ({1})", notification.method, notification.params); } void ProtocolServerMCP::AcceptCallback(std::unique_ptr socket) { LLDB_LOG(GetLog(LLDBLog::Host), "New MCP client ({0}) connected", m_clients.size() + 1); lldb::IOObjectSP io_sp = std::move(socket); auto client_up = std::make_unique(); client_up->io_sp = io_sp; Client *client = client_up.get(); Status status; auto read_handle_up = m_loop.RegisterReadObject( io_sp, [this, client](MainLoopBase &loop) { if (Error error = ReadCallback(*client)) { LLDB_LOG_ERROR(GetLog(LLDBLog::Host), std::move(error), "{0}"); client->read_handle_up.reset(); } }, status); if (status.Fail()) return; client_up->read_handle_up = std::move(read_handle_up); m_clients.emplace_back(std::move(client_up)); } llvm::Error ProtocolServerMCP::ReadCallback(Client &client) { char chunk[kChunkSize]; size_t bytes_read = sizeof(chunk); if (Status status = client.io_sp->Read(chunk, bytes_read); status.Fail()) return status.takeError(); client.buffer.append(chunk, bytes_read); for (std::string::size_type pos; (pos = client.buffer.find('\n')) != std::string::npos;) { llvm::Expected> message = HandleData(StringRef(client.buffer.data(), pos)); client.buffer = client.buffer.erase(0, pos + 1); if (!message) return message.takeError(); if (*message) { std::string Output; llvm::raw_string_ostream OS(Output); OS << llvm::formatv("{0}", toJSON(**message)) << '\n'; size_t num_bytes = Output.size(); return client.io_sp->Write(Output.data(), num_bytes).takeError(); } } return llvm::Error::success(); } llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) { std::lock_guard guard(m_server_mutex); if (m_running) return llvm::createStringError("server already running"); Status status; m_listener = Socket::Create(connection.protocol, status); if (status.Fail()) return status.takeError(); status = m_listener->Listen(connection.name, /*backlog=*/5); if (status.Fail()) return status.takeError(); auto handles = m_listener->Accept(m_loop, std::bind(&ProtocolServerMCP::AcceptCallback, this, std::placeholders::_1)); if (llvm::Error error = handles.takeError()) return error; m_listen_handlers = std::move(*handles); m_loop_thread = std::thread([=] { llvm::set_thread_name( llvm::formatv("debugger-{0}.mcp.runloop", m_debugger.GetID())); m_loop.Run(); }); return llvm::Error::success(); } llvm::Error ProtocolServerMCP::Stop() { { std::lock_guard guard(m_server_mutex); m_running = false; } // Stop the main loop. m_loop.AddPendingCallback( [](MainLoopBase &loop) { loop.RequestTermination(); }); // Wait for the main loop to exit. if (m_loop_thread.joinable()) m_loop_thread.join(); { std::lock_guard guard(m_server_mutex); m_listener.reset(); m_listen_handlers.clear(); m_clients.clear(); } return llvm::Error::success(); } llvm::Expected> ProtocolServerMCP::HandleData(llvm::StringRef data) { auto message = llvm::json::parse(/*JSON=*/data); if (!message) return message.takeError(); if (const protocol::Request *request = std::get_if(&(*message))) { llvm::Expected response = Handle(*request); // Handle failures by converting them into an Error message. if (!response) { protocol::Error protocol_error; llvm::handleAllErrors( response.takeError(), [&](const MCPError &err) { protocol_error = err.toProtcolError(); }, [&](const llvm::ErrorInfoBase &err) { protocol_error.error.code = -1; protocol_error.error.message = err.message(); }); protocol_error.id = request->id; return protocol_error; } return *response; } if (const protocol::Notification *notification = std::get_if(&(*message))) { Handle(*notification); return std::nullopt; } if (std::get_if(&(*message))) return llvm::createStringError("unexpected MCP message: error"); if (std::get_if(&(*message))) return llvm::createStringError("unexpected MCP message: response"); llvm_unreachable("all message types handled"); } protocol::Capabilities ProtocolServerMCP::GetCapabilities() { protocol::Capabilities capabilities; capabilities.tools.listChanged = true; return capabilities; } void ProtocolServerMCP::AddTool(std::unique_ptr tool) { std::lock_guard guard(m_server_mutex); if (!tool) return; m_tools[tool->GetName()] = std::move(tool); } void ProtocolServerMCP::AddRequestHandler(llvm::StringRef method, RequestHandler handler) { std::lock_guard guard(m_server_mutex); m_request_handlers[method] = std::move(handler); } void ProtocolServerMCP::AddNotificationHandler(llvm::StringRef method, NotificationHandler handler) { std::lock_guard guard(m_server_mutex); m_notification_handlers[method] = std::move(handler); } llvm::Expected ProtocolServerMCP::InitializeHandler(const protocol::Request &request) { protocol::Response response; response.result.emplace(llvm::json::Object{ {"protocolVersion", protocol::kVersion}, {"capabilities", GetCapabilities()}, {"serverInfo", llvm::json::Object{{"name", kName}, {"version", kVersion}}}}); return response; } llvm::Expected ProtocolServerMCP::ToolsListHandler(const protocol::Request &request) { protocol::Response response; llvm::json::Array tools; for (const auto &tool : m_tools) tools.emplace_back(toJSON(tool.second->GetDefinition())); response.result.emplace(llvm::json::Object{{"tools", std::move(tools)}}); return response; } llvm::Expected ProtocolServerMCP::ToolsCallHandler(const protocol::Request &request) { protocol::Response response; if (!request.params) return llvm::createStringError("no tool parameters"); const json::Object *param_obj = request.params->getAsObject(); if (!param_obj) return llvm::createStringError("no tool parameters"); const json::Value *name = param_obj->get("name"); if (!name) return llvm::createStringError("no tool name"); llvm::StringRef tool_name = name->getAsString().value_or(""); if (tool_name.empty()) return llvm::createStringError("no tool name"); auto it = m_tools.find(tool_name); if (it == m_tools.end()) return llvm::createStringError(llvm::formatv("no tool \"{0}\"", tool_name)); const json::Value *args = param_obj->get("arguments"); if (!args) return llvm::createStringError("no tool arguments"); llvm::Expected text_result = it->second->Call(*args); if (!text_result) return text_result.takeError(); response.result.emplace(toJSON(*text_result)); return response; }