From 67e73ba605ea78d757c293f85e32a42257f9c6ed Mon Sep 17 00:00:00 2001 From: Ross Brunton Date: Mon, 30 Jun 2025 15:00:43 +0100 Subject: [PATCH] [Offload] Refactor device/platform info queries (#146345) This makes several small changes to how the platform and device info queries are handled: * ReturnHelper has been replaced with InfoWriter which is more explicit in how it is invoked. * InfoWriter consumes `llvm::Expected` rather than values directly, and will early exit if it returns an error. * As a result of the above, `GetInfoString` now correctly returns errors rather than empty strings. * The host device now has its own dedicated "getInfo" function rather than being checked in multiple places. --- offload/liboffload/src/Helpers.hpp | 56 +++++++++-------- offload/liboffload/src/OffloadImpl.cpp | 86 +++++++++++++++++++------- 2 files changed, 91 insertions(+), 51 deletions(-) diff --git a/offload/liboffload/src/Helpers.hpp b/offload/liboffload/src/Helpers.hpp index 425934b6760d..8b85945508b9 100644 --- a/offload/liboffload/src/Helpers.hpp +++ b/offload/liboffload/src/Helpers.hpp @@ -61,39 +61,41 @@ llvm::Error getInfoArray(size_t array_length, size_t ParamValueSize, array_length * sizeof(T), memcpy); } -template <> -inline llvm::Error -getInfo(size_t ParamValueSize, void *ParamValue, - size_t *ParamValueSizeRet, const char *Value) { - return getInfoArray(strlen(Value) + 1, ParamValueSize, ParamValue, - ParamValueSizeRet, Value); +llvm::Error getInfoString(size_t ParamValueSize, void *ParamValue, + size_t *ParamValueSizeRet, llvm::StringRef Value) { + return getInfoArray(Value.size() + 1, ParamValueSize, ParamValue, + ParamValueSizeRet, Value.data()); } -class ReturnHelper { +class InfoWriter { public: - ReturnHelper(size_t ParamValueSize, void *ParamValue, - size_t *ParamValueSizeRet) - : ParamValueSize(ParamValueSize), ParamValue(ParamValue), - ParamValueSizeRet(ParamValueSizeRet) {} + InfoWriter(size_t Size, void *Target, size_t *SizeRet) + : Size(Size), Target(Target), SizeRet(SizeRet) {}; + InfoWriter() = delete; + InfoWriter(InfoWriter &) = delete; + ~InfoWriter() = default; - // A version where in/out info size is represented by a single pointer - // to a value which is updated on return - ReturnHelper(size_t *ParamValueSize, void *ParamValue) - : ParamValueSize(*ParamValueSize), ParamValue(ParamValue), - ParamValueSizeRet(ParamValueSize) {} - - // Scalar return Value - template llvm::Error operator()(const T &t) { - return getInfo(ParamValueSize, ParamValue, ParamValueSizeRet, t); + template llvm::Error write(llvm::Expected &&Val) { + if (Val) + return getInfo(Size, Target, SizeRet, *Val); + return Val.takeError(); } - // Array return Value - template llvm::Error operator()(const T *t, size_t s) { - return getInfoArray(s, ParamValueSize, ParamValue, ParamValueSizeRet, t); + template + llvm::Error writeArray(llvm::Expected &&Val, size_t Elems) { + if (Val) + return getInfoArray(Elems, Size, Target, SizeRet, *Val); + return Val.takeError(); } -protected: - size_t ParamValueSize; - void *ParamValue; - size_t *ParamValueSizeRet; + llvm::Error writeString(llvm::Expected &&Val) { + if (Val) + return getInfoString(Size, Target, SizeRet, *Val); + return Val.takeError(); + } + +private: + size_t Size; + void *Target; + size_t *SizeRet; }; diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp index 9d4f4f54a821..e7da4eddce54 100644 --- a/offload/liboffload/src/OffloadImpl.cpp +++ b/offload/liboffload/src/OffloadImpl.cpp @@ -13,6 +13,7 @@ #include "OffloadImpl.hpp" #include "Helpers.hpp" +#include "OffloadPrint.hpp" #include "PluginManager.h" #include "llvm/Support/FormatVariadic.h" #include @@ -234,23 +235,22 @@ Error olShutDown_impl() { Error olGetPlatformInfoImplDetail(ol_platform_handle_t Platform, ol_platform_info_t PropName, size_t PropSize, void *PropValue, size_t *PropSizeRet) { - ReturnHelper ReturnValue(PropSize, PropValue, PropSizeRet); + InfoWriter Info(PropSize, PropValue, PropSizeRet); bool IsHost = Platform->BackendType == OL_PLATFORM_BACKEND_HOST; switch (PropName) { case OL_PLATFORM_INFO_NAME: - return ReturnValue(IsHost ? "Host" : Platform->Plugin->getName()); + return Info.writeString(IsHost ? "Host" : Platform->Plugin->getName()); case OL_PLATFORM_INFO_VENDOR_NAME: // TODO: Implement this - return ReturnValue("Unknown platform vendor"); + return Info.writeString("Unknown platform vendor"); case OL_PLATFORM_INFO_VERSION: { - return ReturnValue(formatv("v{0}.{1}.{2}", OL_VERSION_MAJOR, - OL_VERSION_MINOR, OL_VERSION_PATCH) - .str() - .c_str()); + return Info.writeString(formatv("v{0}.{1}.{2}", OL_VERSION_MAJOR, + OL_VERSION_MINOR, OL_VERSION_PATCH) + .str()); } case OL_PLATFORM_INFO_BACKEND: { - return ReturnValue(Platform->BackendType); + return Info.write(Platform->BackendType); } default: return createOffloadError(ErrorCode::INVALID_ENUMERATION, @@ -277,36 +277,68 @@ Error olGetPlatformInfoSize_impl(ol_platform_handle_t Platform, Error olGetDeviceInfoImplDetail(ol_device_handle_t Device, ol_device_info_t PropName, size_t PropSize, void *PropValue, size_t *PropSizeRet) { + assert(Device != OffloadContext::get().HostDevice()); + InfoWriter Info(PropSize, PropValue, PropSizeRet); - ReturnHelper ReturnValue(PropSize, PropValue, PropSizeRet); + auto makeError = [&](ErrorCode Code, StringRef Err) { + std::string ErrBuffer; + llvm::raw_string_ostream(ErrBuffer) << PropName << ": " << Err; + return Plugin::error(ErrorCode::UNIMPLEMENTED, ErrBuffer.c_str()); + }; // Find the info if it exists under any of the given names - auto GetInfoString = [&](std::vector Names) { - if (Device == OffloadContext::get().HostDevice()) - return "Host"; - - for (auto Name : Names) { - if (auto Entry = Device->Info.get(Name)) + auto getInfoString = + [&](std::vector Names) -> llvm::Expected { + for (auto &Name : Names) { + if (auto Entry = Device->Info.get(Name)) { + if (!std::holds_alternative((*Entry)->Value)) + return makeError(ErrorCode::BACKEND_FAILURE, + "plugin returned incorrect type"); return std::get((*Entry)->Value).c_str(); + } } - return ""; + return makeError(ErrorCode::UNIMPLEMENTED, + "plugin did not provide a response for this information"); }; switch (PropName) { case OL_DEVICE_INFO_PLATFORM: - return ReturnValue(Device->Platform); + return Info.write(Device->Platform); case OL_DEVICE_INFO_TYPE: - return Device == OffloadContext::get().HostDevice() - ? ReturnValue(OL_DEVICE_TYPE_HOST) - : ReturnValue(OL_DEVICE_TYPE_GPU); + return Info.write(OL_DEVICE_TYPE_GPU); case OL_DEVICE_INFO_NAME: - return ReturnValue(GetInfoString({"Device Name"})); + return Info.writeString(getInfoString({"Device Name"})); case OL_DEVICE_INFO_VENDOR: - return ReturnValue(GetInfoString({"Vendor Name"})); + return Info.writeString(getInfoString({"Vendor Name"})); case OL_DEVICE_INFO_DRIVER_VERSION: - return ReturnValue( - GetInfoString({"CUDA Driver Version", "HSA Runtime Version"})); + return Info.writeString( + getInfoString({"CUDA Driver Version", "HSA Runtime Version"})); + default: + return createOffloadError(ErrorCode::INVALID_ENUMERATION, + "getDeviceInfo enum '%i' is invalid", PropName); + } + + return Error::success(); +} + +Error olGetDeviceInfoImplDetailHost(ol_device_handle_t Device, + ol_device_info_t PropName, size_t PropSize, + void *PropValue, size_t *PropSizeRet) { + assert(Device == OffloadContext::get().HostDevice()); + InfoWriter Info(PropSize, PropValue, PropSizeRet); + + switch (PropName) { + case OL_DEVICE_INFO_PLATFORM: + return Info.write(Device->Platform); + case OL_DEVICE_INFO_TYPE: + return Info.write(OL_DEVICE_TYPE_HOST); + case OL_DEVICE_INFO_NAME: + return Info.writeString("Virtual Host Device"); + case OL_DEVICE_INFO_VENDOR: + return Info.writeString("Liboffload"); + case OL_DEVICE_INFO_DRIVER_VERSION: + return Info.writeString(LLVM_VERSION_STRING); default: return createOffloadError(ErrorCode::INVALID_ENUMERATION, "getDeviceInfo enum '%i' is invalid", PropName); @@ -317,12 +349,18 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device, Error olGetDeviceInfo_impl(ol_device_handle_t Device, ol_device_info_t PropName, size_t PropSize, void *PropValue) { + if (Device == OffloadContext::get().HostDevice()) + return olGetDeviceInfoImplDetailHost(Device, PropName, PropSize, PropValue, + nullptr); return olGetDeviceInfoImplDetail(Device, PropName, PropSize, PropValue, nullptr); } Error olGetDeviceInfoSize_impl(ol_device_handle_t Device, ol_device_info_t PropName, size_t *PropSizeRet) { + if (Device == OffloadContext::get().HostDevice()) + return olGetDeviceInfoImplDetailHost(Device, PropName, 0, nullptr, + PropSizeRet); return olGetDeviceInfoImplDetail(Device, PropName, 0, nullptr, PropSizeRet); }