[Offload] Refactor device/platform info queries (#146345)

This makes several small changes to how the platform and device info
queries are handled:
* ReturnHelper has been replaced with InfoWriter which is more explicit
  in how it is invoked.
* InfoWriter consumes `llvm::Expected` rather than values directly, and
  will early exit if it returns an error.
* As a result of the above, `GetInfoString` now correctly returns errors
  rather than empty strings.
* The host device now has its own dedicated "getInfo" function rather
  than being checked in multiple places.
This commit is contained in:
Ross Brunton
2025-06-30 15:00:43 +01:00
committed by GitHub
parent 619f7afd71
commit 67e73ba605
2 changed files with 91 additions and 51 deletions

View File

@@ -61,39 +61,41 @@ llvm::Error getInfoArray(size_t array_length, size_t ParamValueSize,
array_length * sizeof(T), memcpy); array_length * sizeof(T), memcpy);
} }
template <> llvm::Error getInfoString(size_t ParamValueSize, void *ParamValue,
inline llvm::Error size_t *ParamValueSizeRet, llvm::StringRef Value) {
getInfo<const char *>(size_t ParamValueSize, void *ParamValue, return getInfoArray(Value.size() + 1, ParamValueSize, ParamValue,
size_t *ParamValueSizeRet, const char *Value) { ParamValueSizeRet, Value.data());
return getInfoArray(strlen(Value) + 1, ParamValueSize, ParamValue,
ParamValueSizeRet, Value);
} }
class ReturnHelper { class InfoWriter {
public: public:
ReturnHelper(size_t ParamValueSize, void *ParamValue, InfoWriter(size_t Size, void *Target, size_t *SizeRet)
size_t *ParamValueSizeRet) : Size(Size), Target(Target), SizeRet(SizeRet) {};
: ParamValueSize(ParamValueSize), ParamValue(ParamValue), InfoWriter() = delete;
ParamValueSizeRet(ParamValueSizeRet) {} InfoWriter(InfoWriter &) = delete;
~InfoWriter() = default;
// A version where in/out info size is represented by a single pointer template <typename T> llvm::Error write(llvm::Expected<T> &&Val) {
// to a value which is updated on return if (Val)
ReturnHelper(size_t *ParamValueSize, void *ParamValue) return getInfo(Size, Target, SizeRet, *Val);
: ParamValueSize(*ParamValueSize), ParamValue(ParamValue), return Val.takeError();
ParamValueSizeRet(ParamValueSize) {}
// Scalar return Value
template <class T> llvm::Error operator()(const T &t) {
return getInfo(ParamValueSize, ParamValue, ParamValueSizeRet, t);
} }
// Array return Value template <typename T>
template <class T> llvm::Error operator()(const T *t, size_t s) { llvm::Error writeArray(llvm::Expected<T> &&Val, size_t Elems) {
return getInfoArray(s, ParamValueSize, ParamValue, ParamValueSizeRet, t); if (Val)
return getInfoArray(Elems, Size, Target, SizeRet, *Val);
return Val.takeError();
} }
protected: llvm::Error writeString(llvm::Expected<llvm::StringRef> &&Val) {
size_t ParamValueSize; if (Val)
void *ParamValue; return getInfoString(Size, Target, SizeRet, *Val);
size_t *ParamValueSizeRet; return Val.takeError();
}
private:
size_t Size;
void *Target;
size_t *SizeRet;
}; };

View File

@@ -13,6 +13,7 @@
#include "OffloadImpl.hpp" #include "OffloadImpl.hpp"
#include "Helpers.hpp" #include "Helpers.hpp"
#include "OffloadPrint.hpp"
#include "PluginManager.h" #include "PluginManager.h"
#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/FormatVariadic.h"
#include <OffloadAPI.h> #include <OffloadAPI.h>
@@ -234,23 +235,22 @@ Error olShutDown_impl() {
Error olGetPlatformInfoImplDetail(ol_platform_handle_t Platform, Error olGetPlatformInfoImplDetail(ol_platform_handle_t Platform,
ol_platform_info_t PropName, size_t PropSize, ol_platform_info_t PropName, size_t PropSize,
void *PropValue, size_t *PropSizeRet) { void *PropValue, size_t *PropSizeRet) {
ReturnHelper ReturnValue(PropSize, PropValue, PropSizeRet); InfoWriter Info(PropSize, PropValue, PropSizeRet);
bool IsHost = Platform->BackendType == OL_PLATFORM_BACKEND_HOST; bool IsHost = Platform->BackendType == OL_PLATFORM_BACKEND_HOST;
switch (PropName) { switch (PropName) {
case OL_PLATFORM_INFO_NAME: case OL_PLATFORM_INFO_NAME:
return ReturnValue(IsHost ? "Host" : Platform->Plugin->getName()); return Info.writeString(IsHost ? "Host" : Platform->Plugin->getName());
case OL_PLATFORM_INFO_VENDOR_NAME: case OL_PLATFORM_INFO_VENDOR_NAME:
// TODO: Implement this // TODO: Implement this
return ReturnValue("Unknown platform vendor"); return Info.writeString("Unknown platform vendor");
case OL_PLATFORM_INFO_VERSION: { case OL_PLATFORM_INFO_VERSION: {
return ReturnValue(formatv("v{0}.{1}.{2}", OL_VERSION_MAJOR, return Info.writeString(formatv("v{0}.{1}.{2}", OL_VERSION_MAJOR,
OL_VERSION_MINOR, OL_VERSION_PATCH) OL_VERSION_MINOR, OL_VERSION_PATCH)
.str() .str());
.c_str());
} }
case OL_PLATFORM_INFO_BACKEND: { case OL_PLATFORM_INFO_BACKEND: {
return ReturnValue(Platform->BackendType); return Info.write<ol_platform_backend_t>(Platform->BackendType);
} }
default: default:
return createOffloadError(ErrorCode::INVALID_ENUMERATION, return createOffloadError(ErrorCode::INVALID_ENUMERATION,
@@ -277,36 +277,68 @@ Error olGetPlatformInfoSize_impl(ol_platform_handle_t Platform,
Error olGetDeviceInfoImplDetail(ol_device_handle_t Device, Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
ol_device_info_t PropName, size_t PropSize, ol_device_info_t PropName, size_t PropSize,
void *PropValue, size_t *PropSizeRet) { void *PropValue, size_t *PropSizeRet) {
assert(Device != OffloadContext::get().HostDevice());
InfoWriter Info(PropSize, PropValue, PropSizeRet);
ReturnHelper ReturnValue(PropSize, PropValue, PropSizeRet); auto makeError = [&](ErrorCode Code, StringRef Err) {
std::string ErrBuffer;
llvm::raw_string_ostream(ErrBuffer) << PropName << ": " << Err;
return Plugin::error(ErrorCode::UNIMPLEMENTED, ErrBuffer.c_str());
};
// Find the info if it exists under any of the given names // Find the info if it exists under any of the given names
auto GetInfoString = [&](std::vector<std::string> Names) { auto getInfoString =
if (Device == OffloadContext::get().HostDevice()) [&](std::vector<std::string> Names) -> llvm::Expected<const char *> {
return "Host"; for (auto &Name : Names) {
if (auto Entry = Device->Info.get(Name)) {
for (auto Name : Names) { if (!std::holds_alternative<std::string>((*Entry)->Value))
if (auto Entry = Device->Info.get(Name)) return makeError(ErrorCode::BACKEND_FAILURE,
"plugin returned incorrect type");
return std::get<std::string>((*Entry)->Value).c_str(); return std::get<std::string>((*Entry)->Value).c_str();
}
} }
return ""; return makeError(ErrorCode::UNIMPLEMENTED,
"plugin did not provide a response for this information");
}; };
switch (PropName) { switch (PropName) {
case OL_DEVICE_INFO_PLATFORM: case OL_DEVICE_INFO_PLATFORM:
return ReturnValue(Device->Platform); return Info.write<void *>(Device->Platform);
case OL_DEVICE_INFO_TYPE: case OL_DEVICE_INFO_TYPE:
return Device == OffloadContext::get().HostDevice() return Info.write<ol_device_type_t>(OL_DEVICE_TYPE_GPU);
? ReturnValue(OL_DEVICE_TYPE_HOST)
: ReturnValue(OL_DEVICE_TYPE_GPU);
case OL_DEVICE_INFO_NAME: case OL_DEVICE_INFO_NAME:
return ReturnValue(GetInfoString({"Device Name"})); return Info.writeString(getInfoString({"Device Name"}));
case OL_DEVICE_INFO_VENDOR: case OL_DEVICE_INFO_VENDOR:
return ReturnValue(GetInfoString({"Vendor Name"})); return Info.writeString(getInfoString({"Vendor Name"}));
case OL_DEVICE_INFO_DRIVER_VERSION: case OL_DEVICE_INFO_DRIVER_VERSION:
return ReturnValue( return Info.writeString(
GetInfoString({"CUDA Driver Version", "HSA Runtime Version"})); getInfoString({"CUDA Driver Version", "HSA Runtime Version"}));
default:
return createOffloadError(ErrorCode::INVALID_ENUMERATION,
"getDeviceInfo enum '%i' is invalid", PropName);
}
return Error::success();
}
Error olGetDeviceInfoImplDetailHost(ol_device_handle_t Device,
ol_device_info_t PropName, size_t PropSize,
void *PropValue, size_t *PropSizeRet) {
assert(Device == OffloadContext::get().HostDevice());
InfoWriter Info(PropSize, PropValue, PropSizeRet);
switch (PropName) {
case OL_DEVICE_INFO_PLATFORM:
return Info.write<void *>(Device->Platform);
case OL_DEVICE_INFO_TYPE:
return Info.write<ol_device_type_t>(OL_DEVICE_TYPE_HOST);
case OL_DEVICE_INFO_NAME:
return Info.writeString("Virtual Host Device");
case OL_DEVICE_INFO_VENDOR:
return Info.writeString("Liboffload");
case OL_DEVICE_INFO_DRIVER_VERSION:
return Info.writeString(LLVM_VERSION_STRING);
default: default:
return createOffloadError(ErrorCode::INVALID_ENUMERATION, return createOffloadError(ErrorCode::INVALID_ENUMERATION,
"getDeviceInfo enum '%i' is invalid", PropName); "getDeviceInfo enum '%i' is invalid", PropName);
@@ -317,12 +349,18 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
Error olGetDeviceInfo_impl(ol_device_handle_t Device, ol_device_info_t PropName, Error olGetDeviceInfo_impl(ol_device_handle_t Device, ol_device_info_t PropName,
size_t PropSize, void *PropValue) { size_t PropSize, void *PropValue) {
if (Device == OffloadContext::get().HostDevice())
return olGetDeviceInfoImplDetailHost(Device, PropName, PropSize, PropValue,
nullptr);
return olGetDeviceInfoImplDetail(Device, PropName, PropSize, PropValue, return olGetDeviceInfoImplDetail(Device, PropName, PropSize, PropValue,
nullptr); nullptr);
} }
Error olGetDeviceInfoSize_impl(ol_device_handle_t Device, Error olGetDeviceInfoSize_impl(ol_device_handle_t Device,
ol_device_info_t PropName, size_t *PropSizeRet) { ol_device_info_t PropName, size_t *PropSizeRet) {
if (Device == OffloadContext::get().HostDevice())
return olGetDeviceInfoImplDetailHost(Device, PropName, 0, nullptr,
PropSizeRet);
return olGetDeviceInfoImplDetail(Device, PropName, 0, nullptr, PropSizeRet); return olGetDeviceInfoImplDetail(Device, PropName, 0, nullptr, PropSizeRet);
} }