[Offload] Implement olShutDown (#144055)

`olShutDown` was not properly calling deinit on the platforms, resulting
in random segfaults on AMD devices.

As part of this, `olInit` and `olShutDown` now alloc and free the
offload context rather than it being static. This
allows `olShutDown` to be called within a destructor of a static object
(like the tests do) without having to worry about destructor ordering.
This commit is contained in:
Ross Brunton
2025-06-30 12:14:00 +01:00
committed by GitHub
parent 6e6c61d696
commit 003145d0c8
3 changed files with 61 additions and 22 deletions

View File

@@ -176,7 +176,7 @@ def : Function {
let desc = "Release the resources in use by Offload";
let details = [
"This decrements an internal reference count. When this reaches 0, all resources will be released",
"Subsequent API calls made after this are not valid"
"Subsequent API calls to methods other than `olInit` made after resources are released will return OL_ERRC_UNINITIALIZED"
];
let params = [];
let returns = [];

View File

@@ -96,7 +96,10 @@ struct AllocInfo {
// Global shared state for liboffload
struct OffloadContext;
static OffloadContext *OffloadContextVal;
// This pointer is non-null if and only if the context is valid and fully
// initialized
static std::atomic<OffloadContext *> OffloadContextVal;
std::mutex OffloadContextValMutex;
struct OffloadContext {
OffloadContext(OffloadContext &) = delete;
OffloadContext(OffloadContext &&) = delete;
@@ -107,6 +110,7 @@ struct OffloadContext {
bool ValidationEnabled = true;
DenseMap<void *, AllocInfo> AllocInfoMap{};
SmallVector<ol_platform_impl_t, 4> Platforms{};
size_t RefCount;
ol_device_handle_t HostDevice() {
// The host platform is always inserted last
@@ -145,20 +149,18 @@ constexpr ol_platform_backend_t pluginNameToBackend(StringRef Name) {
#define PLUGIN_TARGET(Name) extern "C" GenericPluginTy *createPlugin_##Name();
#include "Shared/Targets.def"
Error initPlugins() {
auto *Context = new OffloadContext{};
Error initPlugins(OffloadContext &Context) {
// Attempt to create an instance of each supported plugin.
#define PLUGIN_TARGET(Name) \
do { \
Context->Platforms.emplace_back(ol_platform_impl_t{ \
Context.Platforms.emplace_back(ol_platform_impl_t{ \
std::unique_ptr<GenericPluginTy>(createPlugin_##Name()), \
pluginNameToBackend(#Name)}); \
} while (false);
#include "Shared/Targets.def"
// Preemptively initialize all devices in the plugin
for (auto &Platform : Context->Platforms) {
for (auto &Platform : Context.Platforms) {
// Do not use the host plugin - it isn't supported.
if (Platform.BackendType == OL_PLATFORM_BACKEND_UNKNOWN)
continue;
@@ -178,31 +180,56 @@ Error initPlugins() {
}
// Add the special host device
auto &HostPlatform = Context->Platforms.emplace_back(
auto &HostPlatform = Context.Platforms.emplace_back(
ol_platform_impl_t{nullptr, OL_PLATFORM_BACKEND_HOST});
HostPlatform.Devices.emplace_back(-1, nullptr, nullptr, InfoTreeNode{});
Context->HostDevice()->Platform = &HostPlatform;
Context.HostDevice()->Platform = &HostPlatform;
Context->TracingEnabled = std::getenv("OFFLOAD_TRACE");
Context->ValidationEnabled = !std::getenv("OFFLOAD_DISABLE_VALIDATION");
OffloadContextVal = Context;
Context.TracingEnabled = std::getenv("OFFLOAD_TRACE");
Context.ValidationEnabled = !std::getenv("OFFLOAD_DISABLE_VALIDATION");
return Plugin::success();
}
// TODO: We can properly reference count here and manage the resources in a more
// clever way
Error olInit_impl() {
static std::once_flag InitFlag;
std::optional<Error> InitResult{};
std::call_once(InitFlag, [&] { InitResult = initPlugins(); });
std::lock_guard<std::mutex> Lock{OffloadContextValMutex};
if (InitResult)
return std::move(*InitResult);
return Error::success();
if (isOffloadInitialized()) {
OffloadContext::get().RefCount++;
return Plugin::success();
}
// Use a temporary to ensure that entry points querying OffloadContextVal do
// not get a partially initialized context
auto *NewContext = new OffloadContext{};
Error InitResult = initPlugins(*NewContext);
OffloadContextVal.store(NewContext);
OffloadContext::get().RefCount++;
return InitResult;
}
Error olShutDown_impl() {
std::lock_guard<std::mutex> Lock{OffloadContextValMutex};
if (--OffloadContext::get().RefCount != 0)
return Error::success();
llvm::Error Result = Error::success();
auto *OldContext = OffloadContextVal.exchange(nullptr);
for (auto &P : OldContext->Platforms) {
// Host plugin is nullptr and has no deinit
if (!P.Plugin)
continue;
if (auto Res = P.Plugin->deinit())
Result = llvm::joinErrors(std::move(Result), std::move(Res));
}
delete OldContext;
return Result;
}
Error olShutDown_impl() { return Error::success(); }
Error olGetPlatformInfoImplDetail(ol_platform_handle_t Platform,
ol_platform_info_t PropName, size_t PropSize,

View File

@@ -15,8 +15,20 @@
struct olInitTest : ::testing::Test {};
TEST_F(olInitTest, Success) {
ASSERT_SUCCESS(olInit());
ASSERT_SUCCESS(olShutDown());
}
TEST_F(olInitTest, Uninitialized) {
ASSERT_ERROR(OL_ERRC_UNINITIALIZED,
olIterateDevices(
[](ol_device_handle_t, void *) { return false; }, nullptr));
}
TEST_F(olInitTest, RepeatedInit) {
for (size_t I = 0; I < 10; I++) {
ASSERT_SUCCESS(olInit());
ASSERT_SUCCESS(olShutDown());
}
}