The GPU port of the LLVM C library needs to export a few extensions to the interface such that users can interface with it. This patch adds the necessary logic to define a GPU extension. Currently, this only exports a `rpc_reset_client` function. This allows us to use the server in D147054 to set up the RPC interface outside of `libc`. Depends on https://reviews.llvm.org/D147054 Reviewed By: sivachandra Differential Revision: https://reviews.llvm.org/D152283
220 lines
7.2 KiB
C++
220 lines
7.2 KiB
C++
//===-- Shared memory RPC server instantiation ------------------*- C++ -*-===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "Server.h"
|
|
|
|
#include "src/__support/RPC/rpc.h"
|
|
#include <atomic>
|
|
#include <cstdio>
|
|
#include <memory>
|
|
#include <mutex>
|
|
#include <unordered_map>
|
|
|
|
using namespace __llvm_libc;
|
|
|
|
static_assert(sizeof(rpc_buffer_t) == sizeof(rpc::Buffer),
|
|
"Buffer size mismatch");
|
|
|
|
static_assert(RPC_MAXIMUM_PORT_COUNT == rpc::DEFAULT_PORT_COUNT,
|
|
"Incorrect maximum port count");
|
|
struct Device {
|
|
rpc::Server server;
|
|
std::unordered_map<rpc_opcode_t, rpc_opcode_callback_ty> callbacks;
|
|
std::unordered_map<rpc_opcode_t, void *> callback_data;
|
|
};
|
|
|
|
// A struct containing all the runtime state required to run the RPC server.
|
|
struct State {
|
|
State(uint32_t num_devices)
|
|
: num_devices(num_devices),
|
|
devices(std::unique_ptr<Device[]>(new Device[num_devices])),
|
|
reference_count(0u) {}
|
|
uint32_t num_devices;
|
|
std::unique_ptr<Device[]> devices;
|
|
std::atomic_uint32_t reference_count;
|
|
};
|
|
|
|
static std::mutex startup_mutex;
|
|
|
|
static State *state;
|
|
|
|
rpc_status_t rpc_init(uint32_t num_devices) {
|
|
std::scoped_lock<decltype(startup_mutex)> lock(startup_mutex);
|
|
if (!state)
|
|
state = new State(num_devices);
|
|
|
|
if (state->reference_count == std::numeric_limits<uint32_t>::max())
|
|
return RPC_STATUS_ERROR;
|
|
|
|
state->reference_count++;
|
|
|
|
return RPC_STATUS_SUCCESS;
|
|
}
|
|
|
|
rpc_status_t rpc_shutdown(void) {
|
|
if (state->reference_count-- == 1)
|
|
delete state;
|
|
|
|
return RPC_STATUS_SUCCESS;
|
|
}
|
|
|
|
rpc_status_t rpc_server_init(uint32_t device_id, uint64_t num_ports,
|
|
uint32_t lane_size, rpc_alloc_ty alloc,
|
|
void *data) {
|
|
if (device_id >= state->num_devices)
|
|
return RPC_STATUS_OUT_OF_RANGE;
|
|
|
|
uint64_t buffer_size =
|
|
__llvm_libc::rpc::Server::allocation_size(num_ports, lane_size);
|
|
void *buffer = alloc(buffer_size, data);
|
|
|
|
if (!buffer)
|
|
return RPC_STATUS_ERROR;
|
|
|
|
state->devices[device_id].server.reset(num_ports, lane_size, buffer);
|
|
|
|
return RPC_STATUS_SUCCESS;
|
|
}
|
|
|
|
rpc_status_t rpc_server_shutdown(uint32_t device_id, rpc_free_ty dealloc,
|
|
void *data) {
|
|
if (device_id >= state->num_devices)
|
|
return RPC_STATUS_OUT_OF_RANGE;
|
|
|
|
dealloc(rpc_get_buffer(device_id), data);
|
|
|
|
return RPC_STATUS_SUCCESS;
|
|
}
|
|
|
|
rpc_status_t rpc_handle_server(uint32_t device_id) {
|
|
if (device_id >= state->num_devices)
|
|
return RPC_STATUS_OUT_OF_RANGE;
|
|
|
|
for (;;) {
|
|
auto port = state->devices[device_id].server.try_open();
|
|
if (!port)
|
|
return RPC_STATUS_SUCCESS;
|
|
|
|
switch (port->get_opcode()) {
|
|
case RPC_WRITE_TO_STREAM:
|
|
case RPC_WRITE_TO_STDERR:
|
|
case RPC_WRITE_TO_STDOUT: {
|
|
uint64_t sizes[rpc::MAX_LANE_SIZE] = {0};
|
|
void *strs[rpc::MAX_LANE_SIZE] = {nullptr};
|
|
FILE *files[rpc::MAX_LANE_SIZE] = {nullptr};
|
|
if (port->get_opcode() == RPC_WRITE_TO_STREAM)
|
|
port->recv([&](rpc::Buffer *buffer, uint32_t id) {
|
|
files[id] = reinterpret_cast<FILE *>(buffer->data[0]);
|
|
});
|
|
port->recv_n(strs, sizes, [&](uint64_t size) { return new char[size]; });
|
|
port->send([&](rpc::Buffer *buffer, uint32_t id) {
|
|
FILE *file =
|
|
port->get_opcode() == RPC_WRITE_TO_STDOUT
|
|
? stdout
|
|
: (port->get_opcode() == RPC_WRITE_TO_STDERR ? stderr
|
|
: files[id]);
|
|
int ret = fwrite(strs[id], sizes[id], 1, file);
|
|
reinterpret_cast<int *>(buffer->data)[0] = ret >= 0 ? sizes[id] : ret;
|
|
});
|
|
for (uint64_t i = 0; i < rpc::MAX_LANE_SIZE; ++i) {
|
|
if (strs[i])
|
|
delete[] reinterpret_cast<uint8_t *>(strs[i]);
|
|
}
|
|
break;
|
|
}
|
|
case RPC_EXIT: {
|
|
port->recv([](rpc::Buffer *buffer) {
|
|
exit(reinterpret_cast<uint32_t *>(buffer->data)[0]);
|
|
});
|
|
break;
|
|
}
|
|
// TODO: Move handling of these test cases to the loader implementation.
|
|
case RPC_TEST_INCREMENT: {
|
|
port->recv_and_send([](rpc::Buffer *buffer) {
|
|
reinterpret_cast<uint64_t *>(buffer->data)[0] += 1;
|
|
});
|
|
break;
|
|
}
|
|
case RPC_TEST_INTERFACE: {
|
|
uint64_t cnt = 0;
|
|
bool end_with_recv;
|
|
port->recv([&](rpc::Buffer *buffer) { end_with_recv = buffer->data[0]; });
|
|
port->recv([&](rpc::Buffer *buffer) { cnt = buffer->data[0]; });
|
|
port->send([&](rpc::Buffer *buffer) { buffer->data[0] = cnt = cnt + 1; });
|
|
port->recv([&](rpc::Buffer *buffer) { cnt = buffer->data[0]; });
|
|
port->send([&](rpc::Buffer *buffer) { buffer->data[0] = cnt = cnt + 1; });
|
|
port->recv([&](rpc::Buffer *buffer) { cnt = buffer->data[0]; });
|
|
port->recv([&](rpc::Buffer *buffer) { cnt = buffer->data[0]; });
|
|
port->send([&](rpc::Buffer *buffer) { buffer->data[0] = cnt = cnt + 1; });
|
|
port->send([&](rpc::Buffer *buffer) { buffer->data[0] = cnt = cnt + 1; });
|
|
if (end_with_recv)
|
|
port->recv([&](rpc::Buffer *buffer) { cnt = buffer->data[0]; });
|
|
else
|
|
port->send(
|
|
[&](rpc::Buffer *buffer) { buffer->data[0] = cnt = cnt + 1; });
|
|
break;
|
|
}
|
|
case RPC_TEST_STREAM: {
|
|
uint64_t sizes[rpc::MAX_LANE_SIZE] = {0};
|
|
void *dst[rpc::MAX_LANE_SIZE] = {nullptr};
|
|
port->recv_n(dst, sizes, [](uint64_t size) { return new char[size]; });
|
|
port->send_n(dst, sizes);
|
|
for (uint64_t i = 0; i < rpc::MAX_LANE_SIZE; ++i) {
|
|
if (dst[i])
|
|
delete[] reinterpret_cast<uint8_t *>(dst[i]);
|
|
}
|
|
break;
|
|
}
|
|
case RPC_NOOP: {
|
|
port->recv([](rpc::Buffer *buffer) {});
|
|
break;
|
|
}
|
|
default: {
|
|
auto handler = state->devices[device_id].callbacks.find(
|
|
static_cast<rpc_opcode_t>(port->get_opcode()));
|
|
|
|
// We error out on an unhandled opcode.
|
|
if (handler == state->devices[device_id].callbacks.end())
|
|
return RPC_STATUS_UNHANDLED_OPCODE;
|
|
|
|
// Invoke the registered callback with a reference to the port.
|
|
void *data = state->devices[device_id].callback_data.at(
|
|
static_cast<rpc_opcode_t>(port->get_opcode()));
|
|
rpc_port_t port_ref{reinterpret_cast<uint64_t>(&*port)};
|
|
(handler->second)(port_ref, data);
|
|
}
|
|
}
|
|
port->close();
|
|
}
|
|
}
|
|
|
|
rpc_status_t rpc_register_callback(uint32_t device_id, rpc_opcode_t opcode,
|
|
rpc_opcode_callback_ty callback,
|
|
void *data) {
|
|
if (device_id >= state->num_devices)
|
|
return RPC_STATUS_OUT_OF_RANGE;
|
|
|
|
state->devices[device_id].callbacks[opcode] = callback;
|
|
state->devices[device_id].callback_data[opcode] = data;
|
|
return RPC_STATUS_SUCCESS;
|
|
}
|
|
|
|
void *rpc_get_buffer(uint32_t device_id) {
|
|
if (device_id >= state->num_devices)
|
|
return nullptr;
|
|
return state->devices[device_id].server.get_buffer_start();
|
|
}
|
|
|
|
void rpc_recv_and_send(rpc_port_t ref, rpc_port_callback_ty callback,
|
|
void *data) {
|
|
rpc::Server::Port *port = reinterpret_cast<rpc::Server::Port *>(ref.handle);
|
|
port->recv_and_send([=](rpc::Buffer *buffer) {
|
|
callback(reinterpret_cast<rpc_buffer_t *>(buffer), data);
|
|
});
|
|
}
|