[Offload] Use llvm::Error throughout liboffload internals (#140879)

This removes the `ol_impl_result_t` helper class, replacing it with
`llvm::Error`. In addition, some internal functions that returned
`ol_errc_t` now return `llvm::Error` (with a fancy message).
This commit is contained in:
Ross Brunton
2025-05-27 19:42:56 +01:00
committed by GitHub
parent 909212feec
commit 7e9d708be0
8 changed files with 344 additions and 330 deletions

View File

@@ -69,48 +69,31 @@ struct ErrPtrHash {
using ErrSetT = std::unordered_set<ErrPtrT, ErrPtrHash, ErrPtrEqual>;
ErrSetT &errors();
struct ol_impl_result_t {
ol_impl_result_t(std::nullptr_t) : Result(OL_SUCCESS) {}
ol_impl_result_t(ol_errc_t Code) {
if (Code == OL_ERRC_SUCCESS) {
Result = nullptr;
} else {
auto Err = std::unique_ptr<ol_error_struct_t>(
new ol_error_struct_t{Code, nullptr});
Result = errors().emplace(std::move(Err)).first->get();
}
namespace {
ol_errc_t GetErrorCode(std::error_code Code) {
if (Code.category() ==
error::make_error_code(error::ErrorCode::SUCCESS).category())
return static_cast<ol_errc_t>(Code.value());
return OL_ERRC_UNKNOWN;
}
} // namespace
inline ol_result_t llvmErrorToOffloadError(llvm::Error &&Err) {
if (!Err) {
// No error
return nullptr;
}
ol_impl_result_t(ol_errc_t Code, llvm::StringRef Details) {
assert(Code != OL_ERRC_SUCCESS);
Result = nullptr;
auto DetailsStr = errorStrs().insert(Details).first->getKeyData();
auto Err = std::unique_ptr<ol_error_struct_t>(
new ol_error_struct_t{Code, DetailsStr});
Result = errors().emplace(std::move(Err)).first->get();
}
ol_errc_t ErrCode;
llvm::StringRef Details;
static ol_impl_result_t fromError(llvm::Error &&Error) {
ol_errc_t ErrCode;
llvm::StringRef Details;
llvm::handleAllErrors(std::move(Error), [&](llvm::StringError &Err) {
ErrCode = GetErrorCode(Err.convertToErrorCode());
Details = errorStrs().insert(Err.getMessage()).first->getKeyData();
});
llvm::handleAllErrors(std::move(Err), [&](llvm::StringError &Err) {
ErrCode = GetErrorCode(Err.convertToErrorCode());
Details = errorStrs().insert(Err.getMessage()).first->getKeyData();
});
return ol_impl_result_t{ErrCode, Details};
}
operator ol_result_t() { return Result; }
private:
static ol_errc_t GetErrorCode(std::error_code Code) {
if (Code.category() ==
error::make_error_code(error::ErrorCode::SUCCESS).category()) {
return static_cast<ol_errc_t>(Code.value());
}
return OL_ERRC_UNKNOWN;
}
ol_result_t Result;
};
auto NewErr = std::unique_ptr<ol_error_struct_t>(
new ol_error_struct_t{ErrCode, Details.data()});
return errors().emplace(std::move(NewErr)).first->get();
}

View File

@@ -7,7 +7,7 @@
//===----------------------------------------------------------------------===//
///////////////////////////////////////////////////////////////////////////////
ol_impl_result_t olInit_val() {
llvm::Error olInit_val() {
if (offloadConfig().ValidationEnabled) {
}
@@ -18,7 +18,7 @@ OL_APIEXPORT ol_result_t OL_APICALL olInit() {
llvm::errs() << "---> olInit";
}
ol_result_t Result = olInit_val();
ol_result_t Result = llvmErrorToOffloadError(olInit_val());
if (offloadConfig().TracingEnabled) {
llvm::errs() << "()";
@@ -38,7 +38,7 @@ ol_result_t olInitWithCodeLoc(ol_code_location_t *CodeLocation) {
}
///////////////////////////////////////////////////////////////////////////////
ol_impl_result_t olShutDown_val() {
llvm::Error olShutDown_val() {
if (offloadConfig().ValidationEnabled) {
}
@@ -49,7 +49,7 @@ OL_APIEXPORT ol_result_t OL_APICALL olShutDown() {
llvm::errs() << "---> olShutDown";
}
ol_result_t Result = olShutDown_val();
ol_result_t Result = llvmErrorToOffloadError(olShutDown_val());
if (offloadConfig().TracingEnabled) {
llvm::errs() << "()";
@@ -69,20 +69,23 @@ ol_result_t olShutDownWithCodeLoc(ol_code_location_t *CodeLocation) {
}
///////////////////////////////////////////////////////////////////////////////
ol_impl_result_t olGetPlatformInfo_val(ol_platform_handle_t Platform,
ol_platform_info_t PropName,
size_t PropSize, void *PropValue) {
llvm::Error olGetPlatformInfo_val(ol_platform_handle_t Platform,
ol_platform_info_t PropName, size_t PropSize,
void *PropValue) {
if (offloadConfig().ValidationEnabled) {
if (PropSize == 0) {
return OL_ERRC_INVALID_SIZE;
return createOffloadError(error::ErrorCode::INVALID_SIZE,
"validation failure: PropSize == 0");
}
if (NULL == Platform) {
return OL_ERRC_INVALID_NULL_HANDLE;
return createOffloadError(error::ErrorCode::INVALID_NULL_HANDLE,
"validation failure: NULL == Platform");
}
if (NULL == PropValue) {
return OL_ERRC_INVALID_NULL_POINTER;
return createOffloadError(error::ErrorCode::INVALID_NULL_POINTER,
"validation failure: NULL == PropValue");
}
}
@@ -96,8 +99,8 @@ olGetPlatformInfo(ol_platform_handle_t Platform, ol_platform_info_t PropName,
llvm::errs() << "---> olGetPlatformInfo";
}
ol_result_t Result =
olGetPlatformInfo_val(Platform, PropName, PropSize, PropValue);
ol_result_t Result = llvmErrorToOffloadError(
olGetPlatformInfo_val(Platform, PropName, PropSize, PropValue));
if (offloadConfig().TracingEnabled) {
ol_get_platform_info_params_t Params = {&Platform, &PropName, &PropSize,
@@ -123,16 +126,18 @@ ol_result_t olGetPlatformInfoWithCodeLoc(ol_platform_handle_t Platform,
}
///////////////////////////////////////////////////////////////////////////////
ol_impl_result_t olGetPlatformInfoSize_val(ol_platform_handle_t Platform,
ol_platform_info_t PropName,
size_t *PropSizeRet) {
llvm::Error olGetPlatformInfoSize_val(ol_platform_handle_t Platform,
ol_platform_info_t PropName,
size_t *PropSizeRet) {
if (offloadConfig().ValidationEnabled) {
if (NULL == Platform) {
return OL_ERRC_INVALID_NULL_HANDLE;
return createOffloadError(error::ErrorCode::INVALID_NULL_HANDLE,
"validation failure: NULL == Platform");
}
if (NULL == PropSizeRet) {
return OL_ERRC_INVALID_NULL_POINTER;
return createOffloadError(error::ErrorCode::INVALID_NULL_POINTER,
"validation failure: NULL == PropSizeRet");
}
}
@@ -146,8 +151,8 @@ olGetPlatformInfoSize(ol_platform_handle_t Platform,
llvm::errs() << "---> olGetPlatformInfoSize";
}
ol_result_t Result =
olGetPlatformInfoSize_val(Platform, PropName, PropSizeRet);
ol_result_t Result = llvmErrorToOffloadError(
olGetPlatformInfoSize_val(Platform, PropName, PropSizeRet));
if (offloadConfig().TracingEnabled) {
ol_get_platform_info_size_params_t Params = {&Platform, &PropName,
@@ -172,8 +177,8 @@ ol_result_t olGetPlatformInfoSizeWithCodeLoc(ol_platform_handle_t Platform,
}
///////////////////////////////////////////////////////////////////////////////
ol_impl_result_t olIterateDevices_val(ol_device_iterate_cb_t Callback,
void *UserData) {
llvm::Error olIterateDevices_val(ol_device_iterate_cb_t Callback,
void *UserData) {
if (offloadConfig().ValidationEnabled) {
}
@@ -185,7 +190,8 @@ olIterateDevices(ol_device_iterate_cb_t Callback, void *UserData) {
llvm::errs() << "---> olIterateDevices";
}
ol_result_t Result = olIterateDevices_val(Callback, UserData);
ol_result_t Result =
llvmErrorToOffloadError(olIterateDevices_val(Callback, UserData));
if (offloadConfig().TracingEnabled) {
ol_iterate_devices_params_t Params = {&Callback, &UserData};
@@ -208,20 +214,23 @@ ol_result_t olIterateDevicesWithCodeLoc(ol_device_iterate_cb_t Callback,
}
///////////////////////////////////////////////////////////////////////////////
ol_impl_result_t olGetDeviceInfo_val(ol_device_handle_t Device,
ol_device_info_t PropName, size_t PropSize,
void *PropValue) {
llvm::Error olGetDeviceInfo_val(ol_device_handle_t Device,
ol_device_info_t PropName, size_t PropSize,
void *PropValue) {
if (offloadConfig().ValidationEnabled) {
if (PropSize == 0) {
return OL_ERRC_INVALID_SIZE;
return createOffloadError(error::ErrorCode::INVALID_SIZE,
"validation failure: PropSize == 0");
}
if (NULL == Device) {
return OL_ERRC_INVALID_NULL_HANDLE;
return createOffloadError(error::ErrorCode::INVALID_NULL_HANDLE,
"validation failure: NULL == Device");
}
if (NULL == PropValue) {
return OL_ERRC_INVALID_NULL_POINTER;
return createOffloadError(error::ErrorCode::INVALID_NULL_POINTER,
"validation failure: NULL == PropValue");
}
}
@@ -236,8 +245,8 @@ OL_APIEXPORT ol_result_t OL_APICALL olGetDeviceInfo(ol_device_handle_t Device,
llvm::errs() << "---> olGetDeviceInfo";
}
ol_result_t Result =
olGetDeviceInfo_val(Device, PropName, PropSize, PropValue);
ol_result_t Result = llvmErrorToOffloadError(
olGetDeviceInfo_val(Device, PropName, PropSize, PropValue));
if (offloadConfig().TracingEnabled) {
ol_get_device_info_params_t Params = {&Device, &PropName, &PropSize,
@@ -262,16 +271,18 @@ ol_result_t olGetDeviceInfoWithCodeLoc(ol_device_handle_t Device,
}
///////////////////////////////////////////////////////////////////////////////
ol_impl_result_t olGetDeviceInfoSize_val(ol_device_handle_t Device,
ol_device_info_t PropName,
size_t *PropSizeRet) {
llvm::Error olGetDeviceInfoSize_val(ol_device_handle_t Device,
ol_device_info_t PropName,
size_t *PropSizeRet) {
if (offloadConfig().ValidationEnabled) {
if (NULL == Device) {
return OL_ERRC_INVALID_NULL_HANDLE;
return createOffloadError(error::ErrorCode::INVALID_NULL_HANDLE,
"validation failure: NULL == Device");
}
if (NULL == PropSizeRet) {
return OL_ERRC_INVALID_NULL_POINTER;
return createOffloadError(error::ErrorCode::INVALID_NULL_POINTER,
"validation failure: NULL == PropSizeRet");
}
}
@@ -283,7 +294,8 @@ OL_APIEXPORT ol_result_t OL_APICALL olGetDeviceInfoSize(
llvm::errs() << "---> olGetDeviceInfoSize";
}
ol_result_t Result = olGetDeviceInfoSize_val(Device, PropName, PropSizeRet);
ol_result_t Result = llvmErrorToOffloadError(
olGetDeviceInfoSize_val(Device, PropName, PropSizeRet));
if (offloadConfig().TracingEnabled) {
ol_get_device_info_size_params_t Params = {&Device, &PropName,
@@ -308,19 +320,22 @@ ol_result_t olGetDeviceInfoSizeWithCodeLoc(ol_device_handle_t Device,
}
///////////////////////////////////////////////////////////////////////////////
ol_impl_result_t olMemAlloc_val(ol_device_handle_t Device, ol_alloc_type_t Type,
size_t Size, void **AllocationOut) {
llvm::Error olMemAlloc_val(ol_device_handle_t Device, ol_alloc_type_t Type,
size_t Size, void **AllocationOut) {
if (offloadConfig().ValidationEnabled) {
if (Size == 0) {
return OL_ERRC_INVALID_SIZE;
return createOffloadError(error::ErrorCode::INVALID_SIZE,
"validation failure: Size == 0");
}
if (NULL == Device) {
return OL_ERRC_INVALID_NULL_HANDLE;
return createOffloadError(error::ErrorCode::INVALID_NULL_HANDLE,
"validation failure: NULL == Device");
}
if (NULL == AllocationOut) {
return OL_ERRC_INVALID_NULL_POINTER;
return createOffloadError(error::ErrorCode::INVALID_NULL_POINTER,
"validation failure: NULL == AllocationOut");
}
}
@@ -334,7 +349,8 @@ OL_APIEXPORT ol_result_t OL_APICALL olMemAlloc(ol_device_handle_t Device,
llvm::errs() << "---> olMemAlloc";
}
ol_result_t Result = olMemAlloc_val(Device, Type, Size, AllocationOut);
ol_result_t Result = llvmErrorToOffloadError(
olMemAlloc_val(Device, Type, Size, AllocationOut));
if (offloadConfig().TracingEnabled) {
ol_mem_alloc_params_t Params = {&Device, &Type, &Size, &AllocationOut};
@@ -358,10 +374,11 @@ ol_result_t olMemAllocWithCodeLoc(ol_device_handle_t Device,
}
///////////////////////////////////////////////////////////////////////////////
ol_impl_result_t olMemFree_val(void *Address) {
llvm::Error olMemFree_val(void *Address) {
if (offloadConfig().ValidationEnabled) {
if (NULL == Address) {
return OL_ERRC_INVALID_NULL_POINTER;
return createOffloadError(error::ErrorCode::INVALID_NULL_POINTER,
"validation failure: NULL == Address");
}
}
@@ -372,7 +389,7 @@ OL_APIEXPORT ol_result_t OL_APICALL olMemFree(void *Address) {
llvm::errs() << "---> olMemFree";
}
ol_result_t Result = olMemFree_val(Address);
ol_result_t Result = llvmErrorToOffloadError(olMemFree_val(Address));
if (offloadConfig().TracingEnabled) {
ol_mem_free_params_t Params = {&Address};
@@ -394,29 +411,35 @@ ol_result_t olMemFreeWithCodeLoc(void *Address,
}
///////////////////////////////////////////////////////////////////////////////
ol_impl_result_t olMemcpy_val(ol_queue_handle_t Queue, void *DstPtr,
ol_device_handle_t DstDevice, void *SrcPtr,
ol_device_handle_t SrcDevice, size_t Size,
ol_event_handle_t *EventOut) {
llvm::Error olMemcpy_val(ol_queue_handle_t Queue, void *DstPtr,
ol_device_handle_t DstDevice, void *SrcPtr,
ol_device_handle_t SrcDevice, size_t Size,
ol_event_handle_t *EventOut) {
if (offloadConfig().ValidationEnabled) {
if (Queue == NULL && EventOut != NULL) {
return OL_ERRC_INVALID_ARGUMENT;
return createOffloadError(
error::ErrorCode::INVALID_ARGUMENT,
"validation failure: Queue == NULL && EventOut != NULL");
}
if (NULL == DstDevice) {
return OL_ERRC_INVALID_NULL_HANDLE;
return createOffloadError(error::ErrorCode::INVALID_NULL_HANDLE,
"validation failure: NULL == DstDevice");
}
if (NULL == SrcDevice) {
return OL_ERRC_INVALID_NULL_HANDLE;
return createOffloadError(error::ErrorCode::INVALID_NULL_HANDLE,
"validation failure: NULL == SrcDevice");
}
if (NULL == DstPtr) {
return OL_ERRC_INVALID_NULL_POINTER;
return createOffloadError(error::ErrorCode::INVALID_NULL_POINTER,
"validation failure: NULL == DstPtr");
}
if (NULL == SrcPtr) {
return OL_ERRC_INVALID_NULL_POINTER;
return createOffloadError(error::ErrorCode::INVALID_NULL_POINTER,
"validation failure: NULL == SrcPtr");
}
}
@@ -431,8 +454,8 @@ olMemcpy(ol_queue_handle_t Queue, void *DstPtr, ol_device_handle_t DstDevice,
llvm::errs() << "---> olMemcpy";
}
ol_result_t Result =
olMemcpy_val(Queue, DstPtr, DstDevice, SrcPtr, SrcDevice, Size, EventOut);
ol_result_t Result = llvmErrorToOffloadError(olMemcpy_val(
Queue, DstPtr, DstDevice, SrcPtr, SrcDevice, Size, EventOut));
if (offloadConfig().TracingEnabled) {
ol_memcpy_params_t Params = {&Queue, &DstPtr, &DstDevice, &SrcPtr,
@@ -459,15 +482,17 @@ ol_result_t olMemcpyWithCodeLoc(ol_queue_handle_t Queue, void *DstPtr,
}
///////////////////////////////////////////////////////////////////////////////
ol_impl_result_t olCreateQueue_val(ol_device_handle_t Device,
ol_queue_handle_t *Queue) {
llvm::Error olCreateQueue_val(ol_device_handle_t Device,
ol_queue_handle_t *Queue) {
if (offloadConfig().ValidationEnabled) {
if (NULL == Device) {
return OL_ERRC_INVALID_NULL_HANDLE;
return createOffloadError(error::ErrorCode::INVALID_NULL_HANDLE,
"validation failure: NULL == Device");
}
if (NULL == Queue) {
return OL_ERRC_INVALID_NULL_POINTER;
return createOffloadError(error::ErrorCode::INVALID_NULL_POINTER,
"validation failure: NULL == Queue");
}
}
@@ -479,7 +504,8 @@ OL_APIEXPORT ol_result_t OL_APICALL olCreateQueue(ol_device_handle_t Device,
llvm::errs() << "---> olCreateQueue";
}
ol_result_t Result = olCreateQueue_val(Device, Queue);
ol_result_t Result =
llvmErrorToOffloadError(olCreateQueue_val(Device, Queue));
if (offloadConfig().TracingEnabled) {
ol_create_queue_params_t Params = {&Device, &Queue};
@@ -502,10 +528,11 @@ ol_result_t olCreateQueueWithCodeLoc(ol_device_handle_t Device,
}
///////////////////////////////////////////////////////////////////////////////
ol_impl_result_t olDestroyQueue_val(ol_queue_handle_t Queue) {
llvm::Error olDestroyQueue_val(ol_queue_handle_t Queue) {
if (offloadConfig().ValidationEnabled) {
if (NULL == Queue) {
return OL_ERRC_INVALID_NULL_HANDLE;
return createOffloadError(error::ErrorCode::INVALID_NULL_HANDLE,
"validation failure: NULL == Queue");
}
}
@@ -516,7 +543,7 @@ OL_APIEXPORT ol_result_t OL_APICALL olDestroyQueue(ol_queue_handle_t Queue) {
llvm::errs() << "---> olDestroyQueue";
}
ol_result_t Result = olDestroyQueue_val(Queue);
ol_result_t Result = llvmErrorToOffloadError(olDestroyQueue_val(Queue));
if (offloadConfig().TracingEnabled) {
ol_destroy_queue_params_t Params = {&Queue};
@@ -538,10 +565,11 @@ ol_result_t olDestroyQueueWithCodeLoc(ol_queue_handle_t Queue,
}
///////////////////////////////////////////////////////////////////////////////
ol_impl_result_t olWaitQueue_val(ol_queue_handle_t Queue) {
llvm::Error olWaitQueue_val(ol_queue_handle_t Queue) {
if (offloadConfig().ValidationEnabled) {
if (NULL == Queue) {
return OL_ERRC_INVALID_NULL_HANDLE;
return createOffloadError(error::ErrorCode::INVALID_NULL_HANDLE,
"validation failure: NULL == Queue");
}
}
@@ -552,7 +580,7 @@ OL_APIEXPORT ol_result_t OL_APICALL olWaitQueue(ol_queue_handle_t Queue) {
llvm::errs() << "---> olWaitQueue";
}
ol_result_t Result = olWaitQueue_val(Queue);
ol_result_t Result = llvmErrorToOffloadError(olWaitQueue_val(Queue));
if (offloadConfig().TracingEnabled) {
ol_wait_queue_params_t Params = {&Queue};
@@ -574,10 +602,11 @@ ol_result_t olWaitQueueWithCodeLoc(ol_queue_handle_t Queue,
}
///////////////////////////////////////////////////////////////////////////////
ol_impl_result_t olDestroyEvent_val(ol_event_handle_t Event) {
llvm::Error olDestroyEvent_val(ol_event_handle_t Event) {
if (offloadConfig().ValidationEnabled) {
if (NULL == Event) {
return OL_ERRC_INVALID_NULL_HANDLE;
return createOffloadError(error::ErrorCode::INVALID_NULL_HANDLE,
"validation failure: NULL == Event");
}
}
@@ -588,7 +617,7 @@ OL_APIEXPORT ol_result_t OL_APICALL olDestroyEvent(ol_event_handle_t Event) {
llvm::errs() << "---> olDestroyEvent";
}
ol_result_t Result = olDestroyEvent_val(Event);
ol_result_t Result = llvmErrorToOffloadError(olDestroyEvent_val(Event));
if (offloadConfig().TracingEnabled) {
ol_destroy_event_params_t Params = {&Event};
@@ -610,10 +639,11 @@ ol_result_t olDestroyEventWithCodeLoc(ol_event_handle_t Event,
}
///////////////////////////////////////////////////////////////////////////////
ol_impl_result_t olWaitEvent_val(ol_event_handle_t Event) {
llvm::Error olWaitEvent_val(ol_event_handle_t Event) {
if (offloadConfig().ValidationEnabled) {
if (NULL == Event) {
return OL_ERRC_INVALID_NULL_HANDLE;
return createOffloadError(error::ErrorCode::INVALID_NULL_HANDLE,
"validation failure: NULL == Event");
}
}
@@ -624,7 +654,7 @@ OL_APIEXPORT ol_result_t OL_APICALL olWaitEvent(ol_event_handle_t Event) {
llvm::errs() << "---> olWaitEvent";
}
ol_result_t Result = olWaitEvent_val(Event);
ol_result_t Result = llvmErrorToOffloadError(olWaitEvent_val(Event));
if (offloadConfig().TracingEnabled) {
ol_wait_event_params_t Params = {&Event};
@@ -646,20 +676,23 @@ ol_result_t olWaitEventWithCodeLoc(ol_event_handle_t Event,
}
///////////////////////////////////////////////////////////////////////////////
ol_impl_result_t olCreateProgram_val(ol_device_handle_t Device,
const void *ProgData, size_t ProgDataSize,
ol_program_handle_t *Program) {
llvm::Error olCreateProgram_val(ol_device_handle_t Device, const void *ProgData,
size_t ProgDataSize,
ol_program_handle_t *Program) {
if (offloadConfig().ValidationEnabled) {
if (NULL == Device) {
return OL_ERRC_INVALID_NULL_HANDLE;
return createOffloadError(error::ErrorCode::INVALID_NULL_HANDLE,
"validation failure: NULL == Device");
}
if (NULL == ProgData) {
return OL_ERRC_INVALID_NULL_POINTER;
return createOffloadError(error::ErrorCode::INVALID_NULL_POINTER,
"validation failure: NULL == ProgData");
}
if (NULL == Program) {
return OL_ERRC_INVALID_NULL_POINTER;
return createOffloadError(error::ErrorCode::INVALID_NULL_POINTER,
"validation failure: NULL == Program");
}
}
@@ -673,8 +706,8 @@ olCreateProgram(ol_device_handle_t Device, const void *ProgData,
llvm::errs() << "---> olCreateProgram";
}
ol_result_t Result =
olCreateProgram_val(Device, ProgData, ProgDataSize, Program);
ol_result_t Result = llvmErrorToOffloadError(
olCreateProgram_val(Device, ProgData, ProgDataSize, Program));
if (offloadConfig().TracingEnabled) {
ol_create_program_params_t Params = {&Device, &ProgData, &ProgDataSize,
@@ -701,10 +734,11 @@ ol_result_t olCreateProgramWithCodeLoc(ol_device_handle_t Device,
}
///////////////////////////////////////////////////////////////////////////////
ol_impl_result_t olDestroyProgram_val(ol_program_handle_t Program) {
llvm::Error olDestroyProgram_val(ol_program_handle_t Program) {
if (offloadConfig().ValidationEnabled) {
if (NULL == Program) {
return OL_ERRC_INVALID_NULL_HANDLE;
return createOffloadError(error::ErrorCode::INVALID_NULL_HANDLE,
"validation failure: NULL == Program");
}
}
@@ -716,7 +750,7 @@ olDestroyProgram(ol_program_handle_t Program) {
llvm::errs() << "---> olDestroyProgram";
}
ol_result_t Result = olDestroyProgram_val(Program);
ol_result_t Result = llvmErrorToOffloadError(olDestroyProgram_val(Program));
if (offloadConfig().TracingEnabled) {
ol_destroy_program_params_t Params = {&Program};
@@ -738,20 +772,22 @@ ol_result_t olDestroyProgramWithCodeLoc(ol_program_handle_t Program,
}
///////////////////////////////////////////////////////////////////////////////
ol_impl_result_t olGetKernel_val(ol_program_handle_t Program,
const char *KernelName,
ol_kernel_handle_t *Kernel) {
llvm::Error olGetKernel_val(ol_program_handle_t Program, const char *KernelName,
ol_kernel_handle_t *Kernel) {
if (offloadConfig().ValidationEnabled) {
if (NULL == Program) {
return OL_ERRC_INVALID_NULL_HANDLE;
return createOffloadError(error::ErrorCode::INVALID_NULL_HANDLE,
"validation failure: NULL == Program");
}
if (NULL == KernelName) {
return OL_ERRC_INVALID_NULL_POINTER;
return createOffloadError(error::ErrorCode::INVALID_NULL_POINTER,
"validation failure: NULL == KernelName");
}
if (NULL == Kernel) {
return OL_ERRC_INVALID_NULL_POINTER;
return createOffloadError(error::ErrorCode::INVALID_NULL_POINTER,
"validation failure: NULL == Kernel");
}
}
@@ -764,7 +800,8 @@ OL_APIEXPORT ol_result_t OL_APICALL olGetKernel(ol_program_handle_t Program,
llvm::errs() << "---> olGetKernel";
}
ol_result_t Result = olGetKernel_val(Program, KernelName, Kernel);
ol_result_t Result =
llvmErrorToOffloadError(olGetKernel_val(Program, KernelName, Kernel));
if (offloadConfig().TracingEnabled) {
ol_get_kernel_params_t Params = {&Program, &KernelName, &Kernel};
@@ -788,7 +825,7 @@ ol_result_t olGetKernelWithCodeLoc(ol_program_handle_t Program,
}
///////////////////////////////////////////////////////////////////////////////
ol_impl_result_t
llvm::Error
olLaunchKernel_val(ol_queue_handle_t Queue, ol_device_handle_t Device,
ol_kernel_handle_t Kernel, const void *ArgumentsData,
size_t ArgumentsSize,
@@ -796,23 +833,29 @@ olLaunchKernel_val(ol_queue_handle_t Queue, ol_device_handle_t Device,
ol_event_handle_t *EventOut) {
if (offloadConfig().ValidationEnabled) {
if (Queue == NULL && EventOut != NULL) {
return OL_ERRC_INVALID_ARGUMENT;
return createOffloadError(
error::ErrorCode::INVALID_ARGUMENT,
"validation failure: Queue == NULL && EventOut != NULL");
}
if (NULL == Device) {
return OL_ERRC_INVALID_NULL_HANDLE;
return createOffloadError(error::ErrorCode::INVALID_NULL_HANDLE,
"validation failure: NULL == Device");
}
if (NULL == Kernel) {
return OL_ERRC_INVALID_NULL_HANDLE;
return createOffloadError(error::ErrorCode::INVALID_NULL_HANDLE,
"validation failure: NULL == Kernel");
}
if (NULL == ArgumentsData) {
return OL_ERRC_INVALID_NULL_POINTER;
return createOffloadError(error::ErrorCode::INVALID_NULL_POINTER,
"validation failure: NULL == ArgumentsData");
}
if (NULL == LaunchSizeArgs) {
return OL_ERRC_INVALID_NULL_POINTER;
return createOffloadError(error::ErrorCode::INVALID_NULL_POINTER,
"validation failure: NULL == LaunchSizeArgs");
}
}
@@ -829,9 +872,9 @@ OL_APIEXPORT ol_result_t OL_APICALL olLaunchKernel(
llvm::errs() << "---> olLaunchKernel";
}
ol_result_t Result =
ol_result_t Result = llvmErrorToOffloadError(
olLaunchKernel_val(Queue, Device, Kernel, ArgumentsData, ArgumentsSize,
LaunchSizeArgs, EventOut);
LaunchSizeArgs, EventOut));
if (offloadConfig().TracingEnabled) {
ol_launch_kernel_params_t Params = {

View File

@@ -5,64 +5,56 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
ol_impl_result_t olInit_impl();
Error olInit_impl();
ol_impl_result_t olShutDown_impl();
Error olShutDown_impl();
ol_impl_result_t olGetPlatformInfo_impl(ol_platform_handle_t Platform,
ol_platform_info_t PropName,
size_t PropSize, void *PropValue);
Error olGetPlatformInfo_impl(ol_platform_handle_t Platform,
ol_platform_info_t PropName, size_t PropSize,
void *PropValue);
ol_impl_result_t olGetPlatformInfoSize_impl(ol_platform_handle_t Platform,
ol_platform_info_t PropName,
size_t *PropSizeRet);
Error olGetPlatformInfoSize_impl(ol_platform_handle_t Platform,
ol_platform_info_t PropName,
size_t *PropSizeRet);
ol_impl_result_t olIterateDevices_impl(ol_device_iterate_cb_t Callback,
void *UserData);
Error olIterateDevices_impl(ol_device_iterate_cb_t Callback, void *UserData);
ol_impl_result_t olGetDeviceInfo_impl(ol_device_handle_t Device,
ol_device_info_t PropName,
size_t PropSize, void *PropValue);
Error olGetDeviceInfo_impl(ol_device_handle_t Device, ol_device_info_t PropName,
size_t PropSize, void *PropValue);
ol_impl_result_t olGetDeviceInfoSize_impl(ol_device_handle_t Device,
ol_device_info_t PropName,
size_t *PropSizeRet);
Error olGetDeviceInfoSize_impl(ol_device_handle_t Device,
ol_device_info_t PropName, size_t *PropSizeRet);
ol_impl_result_t olMemAlloc_impl(ol_device_handle_t Device,
ol_alloc_type_t Type, size_t Size,
void **AllocationOut);
Error olMemAlloc_impl(ol_device_handle_t Device, ol_alloc_type_t Type,
size_t Size, void **AllocationOut);
ol_impl_result_t olMemFree_impl(void *Address);
Error olMemFree_impl(void *Address);
ol_impl_result_t olMemcpy_impl(ol_queue_handle_t Queue, void *DstPtr,
ol_device_handle_t DstDevice, void *SrcPtr,
ol_device_handle_t SrcDevice, size_t Size,
ol_event_handle_t *EventOut);
ol_impl_result_t olCreateQueue_impl(ol_device_handle_t Device,
ol_queue_handle_t *Queue);
ol_impl_result_t olDestroyQueue_impl(ol_queue_handle_t Queue);
ol_impl_result_t olWaitQueue_impl(ol_queue_handle_t Queue);
ol_impl_result_t olDestroyEvent_impl(ol_event_handle_t Event);
ol_impl_result_t olWaitEvent_impl(ol_event_handle_t Event);
ol_impl_result_t olCreateProgram_impl(ol_device_handle_t Device,
const void *ProgData, size_t ProgDataSize,
ol_program_handle_t *Program);
ol_impl_result_t olDestroyProgram_impl(ol_program_handle_t Program);
ol_impl_result_t olGetKernel_impl(ol_program_handle_t Program,
const char *KernelName,
ol_kernel_handle_t *Kernel);
ol_impl_result_t
olLaunchKernel_impl(ol_queue_handle_t Queue, ol_device_handle_t Device,
ol_kernel_handle_t Kernel, const void *ArgumentsData,
size_t ArgumentsSize,
const ol_kernel_launch_size_args_t *LaunchSizeArgs,
Error olMemcpy_impl(ol_queue_handle_t Queue, void *DstPtr,
ol_device_handle_t DstDevice, void *SrcPtr,
ol_device_handle_t SrcDevice, size_t Size,
ol_event_handle_t *EventOut);
Error olCreateQueue_impl(ol_device_handle_t Device, ol_queue_handle_t *Queue);
Error olDestroyQueue_impl(ol_queue_handle_t Queue);
Error olWaitQueue_impl(ol_queue_handle_t Queue);
Error olDestroyEvent_impl(ol_event_handle_t Event);
Error olWaitEvent_impl(ol_event_handle_t Event);
Error olCreateProgram_impl(ol_device_handle_t Device, const void *ProgData,
size_t ProgDataSize, ol_program_handle_t *Program);
Error olDestroyProgram_impl(ol_program_handle_t Program);
Error olGetKernel_impl(ol_program_handle_t Program, const char *KernelName,
ol_kernel_handle_t *Kernel);
Error olLaunchKernel_impl(ol_queue_handle_t Queue, ol_device_handle_t Device,
ol_kernel_handle_t Kernel, const void *ArgumentsData,
size_t ArgumentsSize,
const ol_kernel_launch_size_args_t *LaunchSizeArgs,
ol_event_handle_t *EventOut);

View File

@@ -13,20 +13,24 @@
//===----------------------------------------------------------------------===//
#include "OffloadAPI.h"
#include "Shared/OffloadError.h"
#include "llvm/Support/Error.h"
#include <cstring>
template <typename T, typename Assign>
ol_errc_t getInfoImpl(size_t ParamValueSize, void *ParamValue,
size_t *ParamValueSizeRet, T Value, size_t ValueSize,
Assign &&AssignFunc) {
llvm::Error getInfoImpl(size_t ParamValueSize, void *ParamValue,
size_t *ParamValueSizeRet, T Value, size_t ValueSize,
Assign &&AssignFunc) {
if (!ParamValue && !ParamValueSizeRet) {
return OL_ERRC_INVALID_NULL_POINTER;
return error::createOffloadError(error::ErrorCode::INVALID_NULL_POINTER,
"value and size outputs are nullptr");
}
if (ParamValue != nullptr) {
if (ParamValueSize < ValueSize) {
return OL_ERRC_INVALID_SIZE;
return error::createOffloadError(error::ErrorCode::INVALID_SIZE,
"provided size is invalid");
}
AssignFunc(ParamValue, Value, ValueSize);
}
@@ -35,12 +39,12 @@ ol_errc_t getInfoImpl(size_t ParamValueSize, void *ParamValue,
*ParamValueSizeRet = ValueSize;
}
return OL_ERRC_SUCCESS;
return llvm::Error::success();
}
template <typename T>
ol_errc_t getInfo(size_t ParamValueSize, void *ParamValue,
size_t *ParamValueSizeRet, T Value) {
llvm::Error getInfo(size_t ParamValueSize, void *ParamValue,
size_t *ParamValueSizeRet, T Value) {
auto Assignment = [](void *ParamValue, T Value, size_t) {
*static_cast<T *>(ParamValue) = Value;
};
@@ -50,17 +54,17 @@ ol_errc_t getInfo(size_t ParamValueSize, void *ParamValue,
}
template <typename T>
ol_errc_t getInfoArray(size_t array_length, size_t ParamValueSize,
void *ParamValue, size_t *ParamValueSizeRet,
const T *Value) {
llvm::Error getInfoArray(size_t array_length, size_t ParamValueSize,
void *ParamValue, size_t *ParamValueSizeRet,
const T *Value) {
return getInfoImpl(ParamValueSize, ParamValue, ParamValueSizeRet, Value,
array_length * sizeof(T), memcpy);
}
template <>
inline ol_errc_t getInfo<const char *>(size_t ParamValueSize, void *ParamValue,
size_t *ParamValueSizeRet,
const char *Value) {
inline llvm::Error
getInfo<const char *>(size_t ParamValueSize, void *ParamValue,
size_t *ParamValueSizeRet, const char *Value) {
return getInfoArray(strlen(Value) + 1, ParamValueSize, ParamValue,
ParamValueSizeRet, Value);
}
@@ -79,12 +83,12 @@ public:
ParamValueSizeRet(ParamValueSize) {}
// Scalar return Value
template <class T> ol_errc_t operator()(const T &t) {
template <class T> llvm::Error operator()(const T &t) {
return getInfo(ParamValueSize, ParamValue, ParamValueSizeRet, t);
}
// Array return Value
template <class T> ol_errc_t operator()(const T *t, size_t s) {
template <class T> llvm::Error operator()(const T *t, size_t s) {
return getInfoArray(s, ParamValueSize, ParamValue, ParamValueSizeRet, t);
}

View File

@@ -36,6 +36,7 @@ ompt_function_lookup_t lookupCallbackByName = nullptr;
using namespace llvm::omp::target;
using namespace llvm::omp::target::plugin;
using namespace error;
// Handle type definitions. Ideally these would be 1:1 with the plugins, but
// we add some additional data here for now to avoid churn in the plugin
@@ -109,9 +110,9 @@ ol_device_handle_t HostDevice() {
return &Platforms().back().Devices[0];
}
template <typename HandleT> ol_impl_result_t olDestroy(HandleT Handle) {
template <typename HandleT> Error olDestroy(HandleT Handle) {
delete Handle;
return OL_SUCCESS;
return Error::success();
}
constexpr ol_platform_backend_t pluginNameToBackend(StringRef Name) {
@@ -169,18 +170,17 @@ void initPlugins() {
// TODO: We can properly reference count here and manage the resources in a more
// clever way
ol_impl_result_t olInit_impl() {
Error olInit_impl() {
static std::once_flag InitFlag;
std::call_once(InitFlag, initPlugins);
return OL_SUCCESS;
return Error::success();
}
ol_impl_result_t olShutDown_impl() { return OL_SUCCESS; }
Error olShutDown_impl() { return Error::success(); }
ol_impl_result_t olGetPlatformInfoImplDetail(ol_platform_handle_t Platform,
ol_platform_info_t PropName,
size_t PropSize, void *PropValue,
size_t *PropSizeRet) {
Error olGetPlatformInfoImplDetail(ol_platform_handle_t Platform,
ol_platform_info_t PropName, size_t PropSize,
void *PropValue, size_t *PropSizeRet) {
ReturnHelper ReturnValue(PropSize, PropValue, PropSizeRet);
bool IsHost = Platform->BackendType == OL_PLATFORM_BACKEND_HOST;
@@ -200,30 +200,30 @@ ol_impl_result_t olGetPlatformInfoImplDetail(ol_platform_handle_t Platform,
return ReturnValue(Platform->BackendType);
}
default:
return OL_ERRC_INVALID_ENUMERATION;
return createOffloadError(ErrorCode::INVALID_ENUMERATION,
"getPlatformInfo enum '%i' is invalid", PropName);
}
return OL_SUCCESS;
return Error::success();
}
ol_impl_result_t olGetPlatformInfo_impl(ol_platform_handle_t Platform,
ol_platform_info_t PropName,
size_t PropSize, void *PropValue) {
Error olGetPlatformInfo_impl(ol_platform_handle_t Platform,
ol_platform_info_t PropName, size_t PropSize,
void *PropValue) {
return olGetPlatformInfoImplDetail(Platform, PropName, PropSize, PropValue,
nullptr);
}
ol_impl_result_t olGetPlatformInfoSize_impl(ol_platform_handle_t Platform,
ol_platform_info_t PropName,
size_t *PropSizeRet) {
Error olGetPlatformInfoSize_impl(ol_platform_handle_t Platform,
ol_platform_info_t PropName,
size_t *PropSizeRet) {
return olGetPlatformInfoImplDetail(Platform, PropName, 0, nullptr,
PropSizeRet);
}
ol_impl_result_t olGetDeviceInfoImplDetail(ol_device_handle_t Device,
ol_device_info_t PropName,
size_t PropSize, void *PropValue,
size_t *PropSizeRet) {
Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
ol_device_info_t PropName, size_t PropSize,
void *PropValue, size_t *PropSizeRet) {
ReturnHelper ReturnValue(PropSize, PropValue, PropSizeRet);
@@ -261,27 +261,25 @@ ol_impl_result_t olGetDeviceInfoImplDetail(ol_device_handle_t Device,
return ReturnValue(
GetInfo({"CUDA Driver Version", "HSA Runtime Version"}).c_str());
default:
return OL_ERRC_INVALID_ENUMERATION;
return createOffloadError(ErrorCode::INVALID_ENUMERATION,
"getDeviceInfo enum '%i' is invalid", PropName);
}
return OL_SUCCESS;
return Error::success();
}
ol_impl_result_t olGetDeviceInfo_impl(ol_device_handle_t Device,
ol_device_info_t PropName,
size_t PropSize, void *PropValue) {
Error olGetDeviceInfo_impl(ol_device_handle_t Device, ol_device_info_t PropName,
size_t PropSize, void *PropValue) {
return olGetDeviceInfoImplDetail(Device, PropName, PropSize, PropValue,
nullptr);
}
ol_impl_result_t olGetDeviceInfoSize_impl(ol_device_handle_t Device,
ol_device_info_t PropName,
size_t *PropSizeRet) {
Error olGetDeviceInfoSize_impl(ol_device_handle_t Device,
ol_device_info_t PropName, size_t *PropSizeRet) {
return olGetDeviceInfoImplDetail(Device, PropName, 0, nullptr, PropSizeRet);
}
ol_impl_result_t olIterateDevices_impl(ol_device_iterate_cb_t Callback,
void *UserData) {
Error olIterateDevices_impl(ol_device_iterate_cb_t Callback, void *UserData) {
for (auto &Platform : Platforms()) {
for (auto &Device : Platform.Devices) {
if (!Callback(&Device, UserData)) {
@@ -290,7 +288,7 @@ ol_impl_result_t olIterateDevices_impl(ol_device_iterate_cb_t Callback,
}
}
return OL_SUCCESS;
return Error::success();
}
TargetAllocTy convertOlToPluginAllocTy(ol_alloc_type_t Type) {
@@ -305,98 +303,87 @@ TargetAllocTy convertOlToPluginAllocTy(ol_alloc_type_t Type) {
}
}
ol_impl_result_t olMemAlloc_impl(ol_device_handle_t Device,
ol_alloc_type_t Type, size_t Size,
void **AllocationOut) {
Error olMemAlloc_impl(ol_device_handle_t Device, ol_alloc_type_t Type,
size_t Size, void **AllocationOut) {
auto Alloc =
Device->Device->dataAlloc(Size, nullptr, convertOlToPluginAllocTy(Type));
if (!Alloc)
return ol_impl_result_t::fromError(Alloc.takeError());
return Alloc.takeError();
*AllocationOut = *Alloc;
allocInfoMap().insert_or_assign(*Alloc, AllocInfo{Device, Type});
return OL_SUCCESS;
return Error::success();
}
ol_impl_result_t olMemFree_impl(void *Address) {
Error olMemFree_impl(void *Address) {
if (!allocInfoMap().contains(Address))
return {OL_ERRC_INVALID_ARGUMENT, "Address is not a known allocation"};
return createOffloadError(ErrorCode::INVALID_ARGUMENT,
"address is not a known allocation");
auto AllocInfo = allocInfoMap().at(Address);
auto Device = AllocInfo.Device;
auto Type = AllocInfo.Type;
auto Res =
Device->Device->dataDelete(Address, convertOlToPluginAllocTy(Type));
if (Res)
return ol_impl_result_t::fromError(std::move(Res));
if (auto Res =
Device->Device->dataDelete(Address, convertOlToPluginAllocTy(Type)))
return Res;
allocInfoMap().erase(Address);
return OL_SUCCESS;
return Error::success();
}
ol_impl_result_t olCreateQueue_impl(ol_device_handle_t Device,
ol_queue_handle_t *Queue) {
Error olCreateQueue_impl(ol_device_handle_t Device, ol_queue_handle_t *Queue) {
auto CreatedQueue = std::make_unique<ol_queue_impl_t>(nullptr, Device);
auto Err = Device->Device->initAsyncInfo(&(CreatedQueue->AsyncInfo));
if (Err)
return ol_impl_result_t::fromError(std::move(Err));
if (auto Err = Device->Device->initAsyncInfo(&(CreatedQueue->AsyncInfo)))
return Err;
*Queue = CreatedQueue.release();
return OL_SUCCESS;
return Error::success();
}
ol_impl_result_t olDestroyQueue_impl(ol_queue_handle_t Queue) {
return olDestroy(Queue);
}
Error olDestroyQueue_impl(ol_queue_handle_t Queue) { return olDestroy(Queue); }
ol_impl_result_t olWaitQueue_impl(ol_queue_handle_t Queue) {
Error olWaitQueue_impl(ol_queue_handle_t Queue) {
// Host plugin doesn't have a queue set so it's not safe to call synchronize
// on it, but we have nothing to synchronize in that situation anyway.
if (Queue->AsyncInfo->Queue) {
auto Err = Queue->Device->Device->synchronize(Queue->AsyncInfo);
if (Err)
return ol_impl_result_t::fromError(std::move(Err));
if (auto Err = Queue->Device->Device->synchronize(Queue->AsyncInfo))
return Err;
}
// Recreate the stream resource so the queue can be reused
// TODO: Would be easier for the synchronization to (optionally) not release
// it to begin with.
auto Res = Queue->Device->Device->initAsyncInfo(&Queue->AsyncInfo);
if (Res)
return ol_impl_result_t::fromError(std::move(Res));
if (auto Res = Queue->Device->Device->initAsyncInfo(&Queue->AsyncInfo))
return Res;
return OL_SUCCESS;
return Error::success();
}
ol_impl_result_t olWaitEvent_impl(ol_event_handle_t Event) {
auto Res = Event->Queue->Device->Device->syncEvent(Event->EventInfo);
if (Res)
return ol_impl_result_t::fromError(std::move(Res));
Error olWaitEvent_impl(ol_event_handle_t Event) {
if (auto Res = Event->Queue->Device->Device->syncEvent(Event->EventInfo))
return Res;
return OL_SUCCESS;
return Error::success();
}
ol_impl_result_t olDestroyEvent_impl(ol_event_handle_t Event) {
auto Res = Event->Queue->Device->Device->destroyEvent(Event->EventInfo);
if (Res)
return {OL_ERRC_INVALID_EVENT, "The event could not be destroyed"};
Error olDestroyEvent_impl(ol_event_handle_t Event) {
if (auto Res = Event->Queue->Device->Device->destroyEvent(Event->EventInfo))
return Res;
return olDestroy(Event);
}
ol_event_handle_t makeEvent(ol_queue_handle_t Queue) {
auto EventImpl = std::make_unique<ol_event_impl_t>(nullptr, Queue);
auto Res = Queue->Device->Device->createEvent(&EventImpl->EventInfo);
if (Res) {
if (auto Res = Queue->Device->Device->createEvent(&EventImpl->EventInfo)) {
llvm::consumeError(std::move(Res));
return nullptr;
}
Res = Queue->Device->Device->recordEvent(EventImpl->EventInfo,
Queue->AsyncInfo);
if (Res) {
if (auto Res = Queue->Device->Device->recordEvent(EventImpl->EventInfo,
Queue->AsyncInfo)) {
llvm::consumeError(std::move(Res));
return nullptr;
}
@@ -404,48 +391,45 @@ ol_event_handle_t makeEvent(ol_queue_handle_t Queue) {
return EventImpl.release();
}
ol_impl_result_t olMemcpy_impl(ol_queue_handle_t Queue, void *DstPtr,
ol_device_handle_t DstDevice, void *SrcPtr,
ol_device_handle_t SrcDevice, size_t Size,
ol_event_handle_t *EventOut) {
Error olMemcpy_impl(ol_queue_handle_t Queue, void *DstPtr,
ol_device_handle_t DstDevice, void *SrcPtr,
ol_device_handle_t SrcDevice, size_t Size,
ol_event_handle_t *EventOut) {
if (DstDevice == HostDevice() && SrcDevice == HostDevice()) {
if (!Queue) {
std::memcpy(DstPtr, SrcPtr, Size);
return OL_SUCCESS;
return Error::success();
} else {
return {OL_ERRC_INVALID_ARGUMENT,
"One of DstDevice and SrcDevice must be a non-host device if "
"Queue is specified"};
return createOffloadError(
ErrorCode::INVALID_ARGUMENT,
"ane of DstDevice and SrcDevice must be a non-host device if "
"queue is specified");
}
}
// If no queue is given the memcpy will be synchronous
auto QueueImpl = Queue ? Queue->AsyncInfo : nullptr;
Error Res = Error::success();
if (DstDevice == HostDevice()) {
auto Res = SrcDevice->Device->dataRetrieve(DstPtr, SrcPtr, Size, QueueImpl);
if (Res)
return ol_impl_result_t::fromError(std::move(Res));
Res = SrcDevice->Device->dataRetrieve(DstPtr, SrcPtr, Size, QueueImpl);
} else if (SrcDevice == HostDevice()) {
auto Res = DstDevice->Device->dataSubmit(DstPtr, SrcPtr, Size, QueueImpl);
if (Res)
return ol_impl_result_t::fromError(std::move(Res));
Res = DstDevice->Device->dataSubmit(DstPtr, SrcPtr, Size, QueueImpl);
} else {
auto Res = SrcDevice->Device->dataExchange(SrcPtr, *DstDevice->Device,
DstPtr, Size, QueueImpl);
if (Res)
return ol_impl_result_t::fromError(std::move(Res));
Res = SrcDevice->Device->dataExchange(SrcPtr, *DstDevice->Device, DstPtr,
Size, QueueImpl);
}
if (Res)
return Res;
if (EventOut)
*EventOut = makeEvent(Queue);
return OL_SUCCESS;
return Error::success();
}
ol_impl_result_t olCreateProgram_impl(ol_device_handle_t Device,
const void *ProgData, size_t ProgDataSize,
ol_program_handle_t *Program) {
Error olCreateProgram_impl(ol_device_handle_t Device, const void *ProgData,
size_t ProgDataSize, ol_program_handle_t *Program) {
// Make a copy of the program binary in case it is released by the caller.
auto ImageData = MemoryBuffer::getMemBufferCopy(
StringRef(reinterpret_cast<const char *>(ProgData), ProgDataSize));
@@ -462,47 +446,45 @@ ol_impl_result_t olCreateProgram_impl(ol_device_handle_t Device,
Device->Device->loadBinary(Device->Device->Plugin, &Prog->DeviceImage);
if (!Res) {
delete Prog;
return ol_impl_result_t::fromError(Res.takeError());
return Res.takeError();
}
Prog->Image = *Res;
*Program = Prog;
return OL_SUCCESS;
return Error::success();
}
ol_impl_result_t olDestroyProgram_impl(ol_program_handle_t Program) {
Error olDestroyProgram_impl(ol_program_handle_t Program) {
return olDestroy(Program);
}
ol_impl_result_t olGetKernel_impl(ol_program_handle_t Program,
const char *KernelName,
ol_kernel_handle_t *Kernel) {
Error olGetKernel_impl(ol_program_handle_t Program, const char *KernelName,
ol_kernel_handle_t *Kernel) {
auto &Device = Program->Image->getDevice();
auto KernelImpl = Device.constructKernel(KernelName);
if (!KernelImpl)
return ol_impl_result_t::fromError(KernelImpl.takeError());
return KernelImpl.takeError();
auto Err = KernelImpl->init(Device, *Program->Image);
if (Err)
return ol_impl_result_t::fromError(std::move(Err));
if (auto Err = KernelImpl->init(Device, *Program->Image))
return Err;
*Kernel = &*KernelImpl;
return OL_SUCCESS;
return Error::success();
}
ol_impl_result_t
olLaunchKernel_impl(ol_queue_handle_t Queue, ol_device_handle_t Device,
ol_kernel_handle_t Kernel, const void *ArgumentsData,
size_t ArgumentsSize,
const ol_kernel_launch_size_args_t *LaunchSizeArgs,
ol_event_handle_t *EventOut) {
Error olLaunchKernel_impl(ol_queue_handle_t Queue, ol_device_handle_t Device,
ol_kernel_handle_t Kernel, const void *ArgumentsData,
size_t ArgumentsSize,
const ol_kernel_launch_size_args_t *LaunchSizeArgs,
ol_event_handle_t *EventOut) {
auto *DeviceImpl = Device->Device;
if (Queue && Device != Queue->Device) {
return {OL_ERRC_INVALID_DEVICE,
"Device specified does not match the device of the given queue"};
return createOffloadError(
ErrorCode::INVALID_DEVICE,
"device specified does not match the device of the given queue");
}
auto *QueueImpl = Queue ? Queue->AsyncInfo : nullptr;
@@ -529,12 +511,12 @@ olLaunchKernel_impl(ol_queue_handle_t Queue, ol_device_handle_t Device,
AsyncInfoWrapper.finalize(Err);
if (Err)
return ol_impl_result_t::fromError(std::move(Err));
return Err;
if (EventOut)
*EventOut = makeEvent(Queue);
return OL_SUCCESS;
return Error::success();
}
} // namespace offload

View File

@@ -23,7 +23,7 @@ using namespace offload::tblgen;
static void EmitValidationFunc(const FunctionRec &F, raw_ostream &OS) {
OS << CommentsHeader;
// Emit preamble
OS << formatv("{0}_impl_result_t {1}_val(\n ", PrefixLower, F.getName());
OS << formatv("llvm::Error {0}_val(\n ", F.getName());
// Emit arguments
std::string ParamNameList = "";
for (auto &Param : F.getParams()) {
@@ -42,7 +42,9 @@ static void EmitValidationFunc(const FunctionRec &F, raw_ostream &OS) {
if (Condition.starts_with("`") && Condition.ends_with("`")) {
auto ConditionString = Condition.substr(1, Condition.size() - 2);
OS << formatv(TAB_2 "if ({0}) {{\n", ConditionString);
OS << formatv(TAB_3 "return {0};\n", Return.getValue());
OS << formatv(TAB_3 "return createOffloadError(error::ErrorCode::{0}, "
"\"validation failure: {1}\");\n",
Return.getUnprefixedValue(), ConditionString);
OS << TAB_2 "}\n\n";
}
}
@@ -78,8 +80,9 @@ static void EmitEntryPointFunc(const FunctionRec &F, raw_ostream &OS) {
// Perform actual function call to the validation wrapper
ParamNameList = ParamNameList.substr(0, ParamNameList.size() - 2);
OS << formatv(TAB_1 "{0}_result_t Result = {1}_val({2});\n\n", PrefixLower,
F.getName(), ParamNameList);
OS << formatv(
TAB_1 "{0}_result_t Result = llvmErrorToOffloadError({1}_val({2}));\n\n",
PrefixLower, F.getName(), ParamNameList);
// Emit post-call prints
OS << TAB_1 "if (offloadConfig().TracingEnabled) {\n";

View File

@@ -61,7 +61,7 @@ void EmitOffloadImplFuncDecls(const RecordKeeper &Records, raw_ostream &OS) {
OS << GenericHeader;
for (auto *R : Records.getAllDerivedDefinitions("Function")) {
FunctionRec F{R};
OS << formatv("{0}_impl_result_t {1}_impl(", PrefixLower, F.getName());
OS << formatv("Error {0}_impl(", F.getName());
auto Params = F.getParams();
for (auto &Param : Params) {
OS << Param.getType() << " " << Param.getName();

View File

@@ -184,6 +184,13 @@ class ReturnRec {
public:
ReturnRec(const Record *rec) : rec(rec) {}
StringRef getValue() const { return rec->getValueAsString("value"); }
// Strip the "OL_ERRC_" from the value, resulting in just "FOO" from
// "OL_ERRC_FOO"
StringRef getUnprefixedValue() const {
constexpr const char *ERRC = "ERRC_";
auto Start = getValue().find(ERRC) + strlen(ERRC);
return getValue().substr(Start);
}
std::vector<StringRef> getConditions() const {
return rec->getValueAsListOfStrings("conditions");
}