//===-- Shared memory RPC server instantiation ------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "Server.h" #include "src/__support/RPC/rpc.h" #include #include #include #include #include using namespace __llvm_libc; static_assert(sizeof(rpc_buffer_t) == sizeof(rpc::Buffer), "Buffer size mismatch"); static_assert(RPC_MAXIMUM_PORT_COUNT == rpc::DEFAULT_PORT_COUNT, "Incorrect maximum port count"); struct Device { rpc::Server server; std::unordered_map callbacks; std::unordered_map callback_data; }; // A struct containing all the runtime state required to run the RPC server. struct State { State(uint32_t num_devices) : num_devices(num_devices), devices(std::unique_ptr(new Device[num_devices])), reference_count(0u) {} uint32_t num_devices; std::unique_ptr devices; std::atomic_uint32_t reference_count; }; static std::mutex startup_mutex; static State *state; rpc_status_t rpc_init(uint32_t num_devices) { std::scoped_lock lock(startup_mutex); if (!state) state = new State(num_devices); if (state->reference_count == std::numeric_limits::max()) return RPC_STATUS_ERROR; state->reference_count++; return RPC_STATUS_SUCCESS; } rpc_status_t rpc_shutdown(void) { if (state->reference_count-- == 1) delete state; return RPC_STATUS_SUCCESS; } rpc_status_t rpc_server_init(uint32_t device_id, uint64_t num_ports, uint32_t lane_size, rpc_alloc_ty alloc, void *data) { if (device_id >= state->num_devices) return RPC_STATUS_OUT_OF_RANGE; uint64_t buffer_size = __llvm_libc::rpc::Server::allocation_size(num_ports, lane_size); void *buffer = alloc(buffer_size, data); if (!buffer) return RPC_STATUS_ERROR; state->devices[device_id].server.reset(num_ports, lane_size, buffer); return RPC_STATUS_SUCCESS; } rpc_status_t rpc_server_shutdown(uint32_t device_id, rpc_free_ty dealloc, void *data) { if (device_id >= state->num_devices) return RPC_STATUS_OUT_OF_RANGE; dealloc(rpc_get_buffer(device_id), data); return RPC_STATUS_SUCCESS; } rpc_status_t rpc_handle_server(uint32_t device_id) { if (device_id >= state->num_devices) return RPC_STATUS_OUT_OF_RANGE; for (;;) { auto port = state->devices[device_id].server.try_open(); if (!port) return RPC_STATUS_SUCCESS; switch (port->get_opcode()) { case RPC_WRITE_TO_STREAM: case RPC_WRITE_TO_STDERR: case RPC_WRITE_TO_STDOUT: { uint64_t sizes[rpc::MAX_LANE_SIZE] = {0}; void *strs[rpc::MAX_LANE_SIZE] = {nullptr}; FILE *files[rpc::MAX_LANE_SIZE] = {nullptr}; if (port->get_opcode() == RPC_WRITE_TO_STREAM) port->recv([&](rpc::Buffer *buffer, uint32_t id) { files[id] = reinterpret_cast(buffer->data[0]); }); port->recv_n(strs, sizes, [&](uint64_t size) { return new char[size]; }); port->send([&](rpc::Buffer *buffer, uint32_t id) { FILE *file = port->get_opcode() == RPC_WRITE_TO_STDOUT ? stdout : (port->get_opcode() == RPC_WRITE_TO_STDERR ? stderr : files[id]); int ret = fwrite(strs[id], sizes[id], 1, file); reinterpret_cast(buffer->data)[0] = ret >= 0 ? sizes[id] : ret; }); for (uint64_t i = 0; i < rpc::MAX_LANE_SIZE; ++i) { if (strs[i]) delete[] reinterpret_cast(strs[i]); } break; } case RPC_EXIT: { port->recv([](rpc::Buffer *buffer) { exit(reinterpret_cast(buffer->data)[0]); }); break; } // TODO: Move handling of these test cases to the loader implementation. case RPC_TEST_INCREMENT: { port->recv_and_send([](rpc::Buffer *buffer) { reinterpret_cast(buffer->data)[0] += 1; }); break; } case RPC_TEST_INTERFACE: { uint64_t cnt = 0; bool end_with_recv; port->recv([&](rpc::Buffer *buffer) { end_with_recv = buffer->data[0]; }); port->recv([&](rpc::Buffer *buffer) { cnt = buffer->data[0]; }); port->send([&](rpc::Buffer *buffer) { buffer->data[0] = cnt = cnt + 1; }); port->recv([&](rpc::Buffer *buffer) { cnt = buffer->data[0]; }); port->send([&](rpc::Buffer *buffer) { buffer->data[0] = cnt = cnt + 1; }); port->recv([&](rpc::Buffer *buffer) { cnt = buffer->data[0]; }); port->recv([&](rpc::Buffer *buffer) { cnt = buffer->data[0]; }); port->send([&](rpc::Buffer *buffer) { buffer->data[0] = cnt = cnt + 1; }); port->send([&](rpc::Buffer *buffer) { buffer->data[0] = cnt = cnt + 1; }); if (end_with_recv) port->recv([&](rpc::Buffer *buffer) { cnt = buffer->data[0]; }); else port->send( [&](rpc::Buffer *buffer) { buffer->data[0] = cnt = cnt + 1; }); break; } case RPC_TEST_STREAM: { uint64_t sizes[rpc::MAX_LANE_SIZE] = {0}; void *dst[rpc::MAX_LANE_SIZE] = {nullptr}; port->recv_n(dst, sizes, [](uint64_t size) { return new char[size]; }); port->send_n(dst, sizes); for (uint64_t i = 0; i < rpc::MAX_LANE_SIZE; ++i) { if (dst[i]) delete[] reinterpret_cast(dst[i]); } break; } case RPC_NOOP: { port->recv([](rpc::Buffer *buffer) {}); break; } default: { auto handler = state->devices[device_id].callbacks.find( static_cast(port->get_opcode())); // We error out on an unhandled opcode. if (handler == state->devices[device_id].callbacks.end()) return RPC_STATUS_UNHANDLED_OPCODE; // Invoke the registered callback with a reference to the port. void *data = state->devices[device_id].callback_data.at( static_cast(port->get_opcode())); rpc_port_t port_ref{reinterpret_cast(&*port)}; (handler->second)(port_ref, data); } } port->close(); } } rpc_status_t rpc_register_callback(uint32_t device_id, rpc_opcode_t opcode, rpc_opcode_callback_ty callback, void *data) { if (device_id >= state->num_devices) return RPC_STATUS_OUT_OF_RANGE; state->devices[device_id].callbacks[opcode] = callback; state->devices[device_id].callback_data[opcode] = data; return RPC_STATUS_SUCCESS; } void *rpc_get_buffer(uint32_t device_id) { if (device_id >= state->num_devices) return nullptr; return state->devices[device_id].server.get_buffer_start(); } void rpc_recv_and_send(rpc_port_t ref, rpc_port_callback_ty callback, void *data) { rpc::Server::Port *port = reinterpret_cast(ref.handle); port->recv_and_send([=](rpc::Buffer *buffer) { callback(reinterpret_cast(buffer), data); }); }