diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp index c2a35a245e2a..6adebb25a2db 100644 --- a/offload/liboffload/src/OffloadImpl.cpp +++ b/offload/liboffload/src/OffloadImpl.cpp @@ -43,18 +43,19 @@ using namespace error; // interface. struct ol_device_impl_t { ol_device_impl_t(int DeviceNum, GenericDeviceTy *Device, - ol_platform_handle_t Platform) - : DeviceNum(DeviceNum), Device(Device), Platform(Platform) {} + ol_platform_handle_t Platform, InfoTreeNode &&DevInfo) + : DeviceNum(DeviceNum), Device(Device), Platform(Platform), + Info(std::forward(DevInfo)) {} int DeviceNum; GenericDeviceTy *Device; ol_platform_handle_t Platform; + InfoTreeNode Info; }; struct ol_platform_impl_t { ol_platform_impl_t(std::unique_ptr Plugin, - std::vector Devices, ol_platform_backend_t BackendType) - : Plugin(std::move(Plugin)), Devices(Devices), BackendType(BackendType) {} + : Plugin(std::move(Plugin)), BackendType(BackendType) {} std::unique_ptr Plugin; std::vector Devices; ol_platform_backend_t BackendType; @@ -144,7 +145,7 @@ constexpr ol_platform_backend_t pluginNameToBackend(StringRef Name) { #define PLUGIN_TARGET(Name) extern "C" GenericPluginTy *createPlugin_##Name(); #include "Shared/Targets.def" -void initPlugins() { +Error initPlugins() { auto *Context = new OffloadContext{}; // Attempt to create an instance of each supported plugin. @@ -152,7 +153,6 @@ void initPlugins() { do { \ Context->Platforms.emplace_back(ol_platform_impl_t{ \ std::unique_ptr(createPlugin_##Name()), \ - {}, \ pluginNameToBackend(#Name)}); \ } while (false); #include "Shared/Targets.def" @@ -167,31 +167,39 @@ void initPlugins() { for (auto DevNum = 0; DevNum < Platform.Plugin->number_of_devices(); DevNum++) { if (Platform.Plugin->init_device(DevNum) == OFFLOAD_SUCCESS) { - Platform.Devices.emplace_back(ol_device_impl_t{ - DevNum, &Platform.Plugin->getDevice(DevNum), &Platform}); + auto Device = &Platform.Plugin->getDevice(DevNum); + auto Info = Device->obtainInfoImpl(); + if (auto Err = Info.takeError()) + return Err; + Platform.Devices.emplace_back(DevNum, Device, &Platform, + std::move(*Info)); } } } // Add the special host device auto &HostPlatform = Context->Platforms.emplace_back( - ol_platform_impl_t{nullptr, - {ol_device_impl_t{-1, nullptr, nullptr}}, - OL_PLATFORM_BACKEND_HOST}); + ol_platform_impl_t{nullptr, OL_PLATFORM_BACKEND_HOST}); + HostPlatform.Devices.emplace_back(-1, nullptr, nullptr, InfoTreeNode{}); Context->HostDevice()->Platform = &HostPlatform; Context->TracingEnabled = std::getenv("OFFLOAD_TRACE"); Context->ValidationEnabled = !std::getenv("OFFLOAD_DISABLE_VALIDATION"); OffloadContextVal = Context; + + return Plugin::success(); } // TODO: We can properly reference count here and manage the resources in a more // clever way Error olInit_impl() { static std::once_flag InitFlag; - std::call_once(InitFlag, initPlugins); + std::optional InitResult{}; + std::call_once(InitFlag, [&] { InitResult = initPlugins(); }); + if (InitResult) + return std::move(*InitResult); return Error::success(); } Error olShutDown_impl() { return Error::success(); } @@ -250,15 +258,8 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device, if (Device == OffloadContext::get().HostDevice()) return "Host"; - if (!Device->Device) - return ""; - - auto Info = Device->Device->obtainInfoImpl(); - if (auto Err = Info.takeError()) - return ""; - for (auto Name : Names) { - if (auto Entry = Info->get(Name)) + if (auto Entry = Device->Info.get(Name)) return std::get((*Entry)->Value).c_str(); }