//===-- Shared memory RPC server instantiation ------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "rpc_server.h" #include "src/__support/RPC/rpc.h" #include "src/stdio/gpu/file.h" #include #include #include #include #include #include #include #include using namespace __llvm_libc; static_assert(sizeof(rpc_buffer_t) == sizeof(rpc::Buffer), "Buffer size mismatch"); static_assert(RPC_MAXIMUM_PORT_COUNT == rpc::MAX_PORT_COUNT, "Incorrect maximum port count"); // The client needs to support different lane sizes for the SIMT model. Because // of this we need to select between the possible sizes that the client can use. struct Server { template Server(std::unique_ptr> &&server) : server(std::move(server)) {} void reset(uint64_t port_count, void *buffer) { std::visit([&](auto &server) { server->reset(port_count, buffer); }, server); } uint64_t allocation_size(uint64_t port_count) { uint64_t ret = 0; std::visit([&](auto &server) { ret = server->allocation_size(port_count); }, server); return ret; } void *get_buffer_start() const { void *ret = nullptr; std::visit([&](auto &server) { ret = server->get_buffer_start(); }, server); return ret; } rpc_status_t handle_server( std::unordered_map &callbacks, std::unordered_map &callback_data) { rpc_status_t ret = RPC_STATUS_SUCCESS; std::visit( [&](auto &server) { ret = handle_server(*server, callbacks, callback_data); }, server); return ret; } private: template rpc_status_t handle_server( rpc::Server &server, std::unordered_map &callbacks, std::unordered_map &callback_data) { auto port = server.try_open(); if (!port) return RPC_STATUS_SUCCESS; switch (port->get_opcode()) { case RPC_WRITE_TO_STREAM: case RPC_WRITE_TO_STDERR: case RPC_WRITE_TO_STDOUT: { uint64_t sizes[lane_size] = {0}; void *strs[lane_size] = {nullptr}; FILE *files[lane_size] = {nullptr}; if (port->get_opcode() == RPC_WRITE_TO_STREAM) port->recv([&](rpc::Buffer *buffer, uint32_t id) { files[id] = reinterpret_cast(buffer->data[0]); }); port->recv_n(strs, sizes, [&](uint64_t size) { return new char[size]; }); port->send([&](rpc::Buffer *buffer, uint32_t id) { FILE *file = port->get_opcode() == RPC_WRITE_TO_STDOUT ? stdout : (port->get_opcode() == RPC_WRITE_TO_STDERR ? stderr : files[id]); uint64_t ret = fwrite(strs[id], 1, sizes[id], file); std::memcpy(buffer->data, &ret, sizeof(uint64_t)); delete[] reinterpret_cast(strs[id]); }); break; } case RPC_READ_FROM_STREAM: case RPC_READ_FROM_STDIN: { uint64_t sizes[lane_size] = {0}; void *data[lane_size] = {nullptr}; uint64_t rets[lane_size] = {0}; port->recv([&](rpc::Buffer *buffer, uint32_t id) { sizes[id] = buffer->data[0]; data[id] = new char[sizes[id]]; FILE *file = port->get_opcode() == RPC_READ_FROM_STREAM ? reinterpret_cast(buffer->data[1]) : stdin; rets[id] = fread(data[id], 1, sizes[id], file); }); port->send_n(data, sizes); port->send([&](rpc::Buffer *buffer, uint32_t id) { delete[] reinterpret_cast(data[id]); std::memcpy(buffer->data, &rets[id], sizeof(uint64_t)); }); break; } case RPC_OPEN_FILE: { uint64_t sizes[lane_size] = {0}; void *paths[lane_size] = {nullptr}; port->recv_n(paths, sizes, [&](uint64_t size) { return new char[size]; }); port->recv_and_send([&](rpc::Buffer *buffer, uint32_t id) { FILE *file = fopen(reinterpret_cast(paths[id]), reinterpret_cast(buffer->data)); buffer->data[0] = reinterpret_cast(file); }); break; } case RPC_CLOSE_FILE: { port->recv_and_send([&](rpc::Buffer *buffer, uint32_t id) { FILE *file = reinterpret_cast(buffer->data[0]); buffer->data[0] = fclose(file); }); break; } case RPC_EXIT: { // Send a response to the client to signal that we are ready to exit. port->recv_and_send([](rpc::Buffer *) {}); port->recv([](rpc::Buffer *buffer) { int status = 0; std::memcpy(&status, buffer->data, sizeof(int)); exit(status); }); break; } case RPC_ABORT: { // Send a response to the client to signal that we are ready to abort. port->recv_and_send([](rpc::Buffer *) {}); port->recv([](rpc::Buffer *) {}); abort(); break; } case RPC_HOST_CALL: { uint64_t sizes[lane_size] = {0}; void *args[lane_size] = {nullptr}; port->recv_n(args, sizes, [&](uint64_t size) { return new char[size]; }); port->recv([&](rpc::Buffer *buffer, uint32_t id) { reinterpret_cast(buffer->data[0])(args[id]); }); port->send([&](rpc::Buffer *, uint32_t id) { delete[] reinterpret_cast(args[id]); }); break; } case RPC_FEOF: { port->recv_and_send([](rpc::Buffer *buffer) { buffer->data[0] = feof(file::to_stream(buffer->data[0])); }); break; } case RPC_FERROR: { port->recv_and_send([](rpc::Buffer *buffer) { buffer->data[0] = ferror(file::to_stream(buffer->data[0])); }); break; } case RPC_CLEARERR: { port->recv_and_send([](rpc::Buffer *buffer) { clearerr(file::to_stream(buffer->data[0])); }); break; } case RPC_NOOP: { port->recv([](rpc::Buffer *) {}); break; } default: { auto handler = callbacks.find(static_cast(port->get_opcode())); // We error out on an unhandled opcode. if (handler == callbacks.end()) return RPC_STATUS_UNHANDLED_OPCODE; // Invoke the registered callback with a reference to the port. void *data = callback_data.at(static_cast(port->get_opcode())); rpc_port_t port_ref{reinterpret_cast(&*port), lane_size}; (handler->second)(port_ref, data); } } port->close(); return RPC_STATUS_CONTINUE; } std::variant>, std::unique_ptr>, std::unique_ptr>> server; }; struct Device { template Device(std::unique_ptr &&server) : server(std::move(server)) {} Server server; rpc::Client client; std::unordered_map callbacks; std::unordered_map callback_data; }; // A struct containing all the runtime state required to run the RPC server. struct State { State(uint32_t num_devices) : num_devices(num_devices), devices(num_devices), reference_count(0u) {} uint32_t num_devices; std::vector> devices; std::atomic_uint32_t reference_count; }; static std::mutex startup_mutex; static State *state; rpc_status_t rpc_init(uint32_t num_devices) { std::scoped_lock lock(startup_mutex); if (!state) state = new State(num_devices); if (state->reference_count == std::numeric_limits::max()) return RPC_STATUS_ERROR; state->reference_count++; return RPC_STATUS_SUCCESS; } rpc_status_t rpc_shutdown(void) { if (state && state->reference_count-- == 1) delete state; return RPC_STATUS_SUCCESS; } rpc_status_t rpc_server_init(uint32_t device_id, uint64_t num_ports, uint32_t lane_size, rpc_alloc_ty alloc, void *data) { if (!state) return RPC_STATUS_NOT_INITIALIZED; if (device_id >= state->num_devices) return RPC_STATUS_OUT_OF_RANGE; if (!state->devices[device_id]) { switch (lane_size) { case 1: state->devices[device_id] = std::make_unique(std::make_unique>()); break; case 32: state->devices[device_id] = std::make_unique(std::make_unique>()); break; case 64: state->devices[device_id] = std::make_unique(std::make_unique>()); break; default: return RPC_STATUS_INVALID_LANE_SIZE; } } uint64_t size = state->devices[device_id]->server.allocation_size(num_ports); void *buffer = alloc(size, data); if (!buffer) return RPC_STATUS_ERROR; state->devices[device_id]->server.reset(num_ports, buffer); state->devices[device_id]->client.reset(num_ports, buffer); return RPC_STATUS_SUCCESS; } rpc_status_t rpc_server_shutdown(uint32_t device_id, rpc_free_ty dealloc, void *data) { if (!state) return RPC_STATUS_NOT_INITIALIZED; if (device_id >= state->num_devices) return RPC_STATUS_OUT_OF_RANGE; if (!state->devices[device_id]) return RPC_STATUS_ERROR; dealloc(rpc_get_buffer(device_id), data); if (state->devices[device_id]) state->devices[device_id].release(); return RPC_STATUS_SUCCESS; } rpc_status_t rpc_handle_server(uint32_t device_id) { if (!state) return RPC_STATUS_NOT_INITIALIZED; if (device_id >= state->num_devices) return RPC_STATUS_OUT_OF_RANGE; if (!state->devices[device_id]) return RPC_STATUS_ERROR; for (;;) { auto &device = *state->devices[device_id]; rpc_status_t status = device.server.handle_server(device.callbacks, device.callback_data); if (status != RPC_STATUS_CONTINUE) return status; } } rpc_status_t rpc_register_callback(uint32_t device_id, rpc_opcode_t opcode, rpc_opcode_callback_ty callback, void *data) { if (!state) return RPC_STATUS_NOT_INITIALIZED; if (device_id >= state->num_devices) return RPC_STATUS_OUT_OF_RANGE; if (!state->devices[device_id]) return RPC_STATUS_ERROR; state->devices[device_id]->callbacks[opcode] = callback; state->devices[device_id]->callback_data[opcode] = data; return RPC_STATUS_SUCCESS; } void *rpc_get_buffer(uint32_t device_id) { if (!state || device_id >= state->num_devices || !state->devices[device_id]) return nullptr; return state->devices[device_id]->server.get_buffer_start(); } const void *rpc_get_client_buffer(uint32_t device_id) { if (!state || device_id >= state->num_devices || !state->devices[device_id]) return nullptr; return &state->devices[device_id]->client; } uint64_t rpc_get_client_size() { return sizeof(rpc::Client); } using ServerPort = std::variant::Port *, rpc::Server<32>::Port *, rpc::Server<64>::Port *>; ServerPort get_port(rpc_port_t ref) { if (ref.lane_size == 1) return reinterpret_cast::Port *>(ref.handle); else if (ref.lane_size == 32) return reinterpret_cast::Port *>(ref.handle); else if (ref.lane_size == 64) return reinterpret_cast::Port *>(ref.handle); else __builtin_unreachable(); } void rpc_send(rpc_port_t ref, rpc_port_callback_ty callback, void *data) { auto port = get_port(ref); std::visit( [=](auto &port) { port->send([=](rpc::Buffer *buffer) { callback(reinterpret_cast(buffer), data); }); }, port); } void rpc_send_n(rpc_port_t ref, const void *const *src, uint64_t *size) { auto port = get_port(ref); std::visit([=](auto &port) { port->send_n(src, size); }, port); } void rpc_recv(rpc_port_t ref, rpc_port_callback_ty callback, void *data) { auto port = get_port(ref); std::visit( [=](auto &port) { port->recv([=](rpc::Buffer *buffer) { callback(reinterpret_cast(buffer), data); }); }, port); } void rpc_recv_n(rpc_port_t ref, void **dst, uint64_t *size, rpc_alloc_ty alloc, void *data) { auto port = get_port(ref); auto alloc_fn = [=](uint64_t size) { return alloc(size, data); }; std::visit([=](auto &port) { port->recv_n(dst, size, alloc_fn); }, port); } void rpc_recv_and_send(rpc_port_t ref, rpc_port_callback_ty callback, void *data) { auto port = get_port(ref); std::visit( [=](auto &port) { port->recv_and_send([=](rpc::Buffer *buffer) { callback(reinterpret_cast(buffer), data); }); }, port); }