[MLIR] Add SyclRuntimeWrapper (#69648)
This commit is contained in:
@@ -126,6 +126,7 @@ add_definitions(-DMLIR_ROCM_CONVERSIONS_ENABLED=${MLIR_ENABLE_ROCM_CONVERSIONS})
|
||||
set(MLIR_ENABLE_DEPRECATED_GPU_SERIALIZATION 0 CACHE BOOL "Enable deprecated GPU serialization passes")
|
||||
set(MLIR_ENABLE_CUDA_RUNNER 0 CACHE BOOL "Enable building the mlir CUDA runner")
|
||||
set(MLIR_ENABLE_ROCM_RUNNER 0 CACHE BOOL "Enable building the mlir ROCm runner")
|
||||
set(MLIR_ENABLE_SYCL_RUNNER 0 CACHE BOOL "Enable building the mlir Sycl runner")
|
||||
set(MLIR_ENABLE_SPIRV_CPU_RUNNER 0 CACHE BOOL "Enable building the mlir SPIR-V cpu runner")
|
||||
set(MLIR_ENABLE_VULKAN_RUNNER 0 CACHE BOOL "Enable building the mlir Vulkan runner")
|
||||
set(MLIR_ENABLE_NVPTXCOMPILER 0 CACHE BOOL
|
||||
|
||||
221
mlir/cmake/modules/FindLevelZero.cmake
Normal file
221
mlir/cmake/modules/FindLevelZero.cmake
Normal file
@@ -0,0 +1,221 @@
|
||||
# CMake find_package() module for level-zero
|
||||
#
|
||||
# Example usage:
|
||||
#
|
||||
# find_package(LevelZero)
|
||||
#
|
||||
# If successful, the following variables will be defined:
|
||||
# LevelZero_FOUND
|
||||
# LevelZero_INCLUDE_DIRS
|
||||
# LevelZero_LIBRARY
|
||||
# LevelZero_LIBRARIES_DIR
|
||||
#
|
||||
# By default, the module searches the standard paths to locate the "ze_api.h"
|
||||
# and the ze_loader shared library. When using a custom level-zero installation,
|
||||
# the environment variable "LEVEL_ZERO_DIR" should be specified telling the
|
||||
# module to get the level-zero library and headers from that location.
|
||||
|
||||
include(FindPackageHandleStandardArgs)
|
||||
|
||||
# Search path priority
|
||||
# 1. CMake Variable LEVEL_ZERO_DIR
|
||||
# 2. Environment Variable LEVEL_ZERO_DIR
|
||||
|
||||
if(NOT LEVEL_ZERO_DIR)
|
||||
if(DEFINED ENV{LEVEL_ZERO_DIR})
|
||||
set(LEVEL_ZERO_DIR "$ENV{LEVEL_ZERO_DIR}")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(LEVEL_ZERO_DIR)
|
||||
find_path(LevelZero_INCLUDE_DIR
|
||||
NAMES level_zero/ze_api.h
|
||||
PATHS ${LEVEL_ZERO_DIR}/include
|
||||
NO_DEFAULT_PATH
|
||||
)
|
||||
|
||||
if(LINUX)
|
||||
find_library(LevelZero_LIBRARY
|
||||
NAMES ze_loader
|
||||
PATHS ${LEVEL_ZERO_DIR}/lib
|
||||
${LEVEL_ZERO_DIR}/lib/x86_64-linux-gnu
|
||||
NO_DEFAULT_PATH
|
||||
)
|
||||
else()
|
||||
find_library(LevelZero_LIBRARY
|
||||
NAMES ze_loader
|
||||
PATHS ${LEVEL_ZERO_DIR}/lib
|
||||
NO_DEFAULT_PATH
|
||||
)
|
||||
endif()
|
||||
else()
|
||||
find_path(LevelZero_INCLUDE_DIR
|
||||
NAMES level_zero/ze_api.h
|
||||
)
|
||||
|
||||
find_library(LevelZero_LIBRARY
|
||||
NAMES ze_loader
|
||||
)
|
||||
endif()
|
||||
|
||||
# Compares the two version string that are supposed to be in x.y.z format
|
||||
# and reports if the argument VERSION_STR1 is greater than or equal than
|
||||
# version_str2. The strings are compared lexicographically after conversion to
|
||||
# lists of equal lengths, with the shorter string getting zero-padded.
|
||||
function(compare_versions VERSION_STR1 VERSION_STR2 OUTPUT)
|
||||
# Convert the strings to list
|
||||
string(REPLACE "." ";" VL1 ${VERSION_STR1})
|
||||
string(REPLACE "." ";" VL2 ${VERSION_STR2})
|
||||
# get lengths of both lists
|
||||
list(LENGTH VL1 VL1_LEN)
|
||||
list(LENGTH VL2 VL2_LEN)
|
||||
set(LEN ${VL1_LEN})
|
||||
# If they differ in size pad the shorter list with 0s
|
||||
if(VL1_LEN GREATER VL2_LEN)
|
||||
math(EXPR DIFF "${VL1_LEN} - ${VL2_LEN}" OUTPUT_FORMAT DECIMAL)
|
||||
foreach(IDX RANGE 1 ${DIFF} 1)
|
||||
list(APPEND VL2 "0")
|
||||
endforeach()
|
||||
elseif(VL2_LEN GREATER VL2_LEN)
|
||||
math(EXPR DIFF "${VL1_LEN} - ${VL2_LEN}" OUTPUT_FORMAT DECIMAL)
|
||||
foreach(IDX RANGE 1 ${DIFF} 1)
|
||||
list(APPEND VL2 "0")
|
||||
endforeach()
|
||||
set(LEN ${VL2_LEN})
|
||||
endif()
|
||||
math(EXPR LEN_SUB_ONE "${LEN}-1")
|
||||
foreach(IDX RANGE 0 ${LEN_SUB_ONE} 1)
|
||||
list(GET VL1 ${IDX} VAL1)
|
||||
list(GET VL2 ${IDX} VAL2)
|
||||
|
||||
if(${VAL1} GREATER ${VAL2})
|
||||
set(${OUTPUT} TRUE PARENT_SCOPE)
|
||||
break()
|
||||
elseif(${VAL1} LESS ${VAL2})
|
||||
set(${OUTPUT} FALSE PARENT_SCOPE)
|
||||
break()
|
||||
else()
|
||||
set(${OUTPUT} TRUE PARENT_SCOPE)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
endfunction(compare_versions)
|
||||
|
||||
# Creates a small function to run and extract the LevelZero loader version.
|
||||
function(get_l0_loader_version)
|
||||
|
||||
set(L0_VERSIONEER_SRC
|
||||
[====[
|
||||
#include <iostream>
|
||||
#include <level_zero/loader/ze_loader.h>
|
||||
#include <string>
|
||||
int main() {
|
||||
ze_result_t result;
|
||||
std::string loader("loader");
|
||||
zel_component_version_t *versions;
|
||||
size_t size = 0;
|
||||
result = zeInit(0);
|
||||
if (result != ZE_RESULT_SUCCESS) {
|
||||
std::cerr << "Failed to init ze driver" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
zelLoaderGetVersions(&size, nullptr);
|
||||
versions = new zel_component_version_t[size];
|
||||
zelLoaderGetVersions(&size, versions);
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
if (loader.compare(versions[i].component_name) == 0) {
|
||||
std::cout << versions[i].component_lib_version.major << "."
|
||||
<< versions[i].component_lib_version.minor << "."
|
||||
<< versions[i].component_lib_version.patch;
|
||||
break;
|
||||
}
|
||||
}
|
||||
delete[] versions;
|
||||
return 0;
|
||||
}
|
||||
]====]
|
||||
)
|
||||
|
||||
set(L0_VERSIONEER_FILE ${CMAKE_BINARY_DIR}/temp/l0_versioneer.cpp)
|
||||
|
||||
file(WRITE ${L0_VERSIONEER_FILE} "${L0_VERSIONEER_SRC}")
|
||||
|
||||
# We need both the directories in the include path as ze_loader.h
|
||||
# includes "ze_api.h" and not "level_zero/ze_api.h".
|
||||
list(APPEND INCLUDE_DIRS ${LevelZero_INCLUDE_DIR})
|
||||
list(APPEND INCLUDE_DIRS ${LevelZero_INCLUDE_DIR}/level_zero)
|
||||
list(JOIN INCLUDE_DIRS ";" INCLUDE_DIRS_STR)
|
||||
try_run(L0_VERSIONEER_RUN L0_VERSIONEER_COMPILE
|
||||
"${CMAKE_BINARY_DIR}"
|
||||
"${L0_VERSIONEER_FILE}"
|
||||
LINK_LIBRARIES ${LevelZero_LIBRARY}
|
||||
CMAKE_FLAGS
|
||||
"-DINCLUDE_DIRECTORIES=${INCLUDE_DIRS_STR}"
|
||||
RUN_OUTPUT_VARIABLE L0_VERSION
|
||||
)
|
||||
if(${L0_VERSIONEER_COMPILE} AND (DEFINED L0_VERSIONEER_RUN))
|
||||
set(LevelZero_VERSION ${L0_VERSION} PARENT_SCOPE)
|
||||
message(STATUS "Found Level Zero of version: ${L0_VERSION}")
|
||||
else()
|
||||
message(FATAL_ERROR
|
||||
"Could not compile a level-zero program to extract loader version"
|
||||
)
|
||||
endif()
|
||||
endfunction(get_l0_loader_version)
|
||||
|
||||
if(LevelZero_INCLUDE_DIR AND LevelZero_LIBRARY)
|
||||
list(APPEND LevelZero_LIBRARIES "${LevelZero_LIBRARY}")
|
||||
list(APPEND LevelZero_INCLUDE_DIRS ${LevelZero_INCLUDE_DIR})
|
||||
if(OpenCL_FOUND)
|
||||
list(APPEND LevelZero_INCLUDE_DIRS ${OpenCL_INCLUDE_DIRS})
|
||||
endif()
|
||||
|
||||
cmake_path(GET LevelZero_LIBRARY PARENT_PATH LevelZero_LIBRARIES_PATH)
|
||||
set(LevelZero_LIBRARIES_DIR ${LevelZero_LIBRARIES_PATH})
|
||||
|
||||
if(NOT TARGET LevelZero::LevelZero)
|
||||
add_library(LevelZero::LevelZero INTERFACE IMPORTED)
|
||||
set_target_properties(LevelZero::LevelZero
|
||||
PROPERTIES INTERFACE_LINK_LIBRARIES "${LevelZero_LIBRARIES}"
|
||||
)
|
||||
set_target_properties(LevelZero::LevelZero
|
||||
PROPERTIES INTERFACE_INCLUDE_DIRECTORIES "${LevelZero_INCLUDE_DIRS}"
|
||||
)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# Check if a specific version of Level Zero is required
|
||||
if(LevelZero_FIND_VERSION)
|
||||
get_l0_loader_version()
|
||||
set(VERSION_GT_FIND_VERSION FALSE)
|
||||
compare_versions(
|
||||
${LevelZero_VERSION}
|
||||
${LevelZero_FIND_VERSION}
|
||||
VERSION_GT_FIND_VERSION
|
||||
)
|
||||
if(${VERSION_GT_FIND_VERSION})
|
||||
set(LevelZero_FOUND TRUE)
|
||||
else()
|
||||
set(LevelZero_FOUND FALSE)
|
||||
endif()
|
||||
else()
|
||||
set(LevelZero_FOUND TRUE)
|
||||
endif()
|
||||
|
||||
find_package_handle_standard_args(LevelZero
|
||||
REQUIRED_VARS
|
||||
LevelZero_FOUND
|
||||
LevelZero_INCLUDE_DIRS
|
||||
LevelZero_LIBRARY
|
||||
LevelZero_LIBRARIES_DIR
|
||||
HANDLE_COMPONENTS
|
||||
)
|
||||
mark_as_advanced(LevelZero_LIBRARY LevelZero_INCLUDE_DIRS)
|
||||
|
||||
if(LevelZero_FOUND)
|
||||
find_package_message(LevelZero "Found LevelZero: ${LevelZero_LIBRARY}"
|
||||
"(found version ${LevelZero_VERSION})"
|
||||
)
|
||||
else()
|
||||
find_package_message(LevelZero "Could not find LevelZero" "")
|
||||
endif()
|
||||
68
mlir/cmake/modules/FindSyclRuntime.cmake
Normal file
68
mlir/cmake/modules/FindSyclRuntime.cmake
Normal file
@@ -0,0 +1,68 @@
|
||||
# CMake find_package() module for SYCL Runtime
|
||||
#
|
||||
# Example usage:
|
||||
#
|
||||
# find_package(SyclRuntime)
|
||||
#
|
||||
# If successful, the following variables will be defined:
|
||||
# SyclRuntime_FOUND
|
||||
# SyclRuntime_INCLUDE_DIRS
|
||||
# SyclRuntime_LIBRARY
|
||||
# SyclRuntime_LIBRARIES_DIR
|
||||
#
|
||||
|
||||
include(FindPackageHandleStandardArgs)
|
||||
|
||||
if(NOT DEFINED ENV{CMPLR_ROOT})
|
||||
message(WARNING "Please make sure to install Intel DPC++ Compiler and run setvars.(sh/bat)")
|
||||
message(WARNING "You can download standalone Intel DPC++ Compiler from https://www.intel.com/content/www/us/en/developer/articles/tool/oneapi-standalone-components.html#compilers")
|
||||
else()
|
||||
if(LINUX OR (${CMAKE_SYSTEM_NAME} MATCHES "Linux"))
|
||||
set(SyclRuntime_ROOT "$ENV{CMPLR_ROOT}/linux")
|
||||
elseif(WIN32)
|
||||
set(SyclRuntime_ROOT "$ENV{CMPLR_ROOT}/windows")
|
||||
endif()
|
||||
list(APPEND SyclRuntime_INCLUDE_DIRS "${SyclRuntime_ROOT}/include")
|
||||
list(APPEND SyclRuntime_INCLUDE_DIRS "${SyclRuntime_ROOT}/include/sycl")
|
||||
|
||||
set(SyclRuntime_LIBRARY_DIR "${SyclRuntime_ROOT}/lib")
|
||||
|
||||
message(STATUS "SyclRuntime_LIBRARY_DIR: ${SyclRuntime_LIBRARY_DIR}")
|
||||
find_library(SyclRuntime_LIBRARY
|
||||
NAMES sycl
|
||||
PATHS ${SyclRuntime_LIBRARY_DIR}
|
||||
NO_DEFAULT_PATH
|
||||
)
|
||||
endif()
|
||||
|
||||
if(SyclRuntime_LIBRARY)
|
||||
set(SyclRuntime_FOUND TRUE)
|
||||
if(NOT TARGET SyclRuntime::SyclRuntime)
|
||||
add_library(SyclRuntime::SyclRuntime INTERFACE IMPORTED)
|
||||
set_target_properties(SyclRuntime::SyclRuntime
|
||||
PROPERTIES INTERFACE_LINK_LIBRARIES "${SyclRuntime_LIBRARY}"
|
||||
)
|
||||
set_target_properties(SyclRuntime::SyclRuntime
|
||||
PROPERTIES INTERFACE_INCLUDE_DIRECTORIES "${SyclRuntime_INCLUDE_DIRS}"
|
||||
)
|
||||
endif()
|
||||
else()
|
||||
set(SyclRuntime_FOUND FALSE)
|
||||
endif()
|
||||
|
||||
find_package_handle_standard_args(SyclRuntime
|
||||
REQUIRED_VARS
|
||||
SyclRuntime_FOUND
|
||||
SyclRuntime_INCLUDE_DIRS
|
||||
SyclRuntime_LIBRARY
|
||||
SyclRuntime_LIBRARY_DIR
|
||||
HANDLE_COMPONENTS
|
||||
)
|
||||
|
||||
mark_as_advanced(SyclRuntime_LIBRARY SyclRuntime_INCLUDE_DIRS)
|
||||
|
||||
if(SyclRuntime_FOUND)
|
||||
find_package_message(SyclRuntime "Found SyclRuntime: ${SyclRuntime_LIBRARY}" "")
|
||||
else()
|
||||
find_package_message(SyclRuntime "Could not find SyclRuntime" "")
|
||||
endif()
|
||||
@@ -12,6 +12,7 @@ set(LLVM_OPTIONAL_SOURCES
|
||||
RunnerUtils.cpp
|
||||
OptUtils.cpp
|
||||
JitRunner.cpp
|
||||
SyclRuntimeWrappers.cpp
|
||||
)
|
||||
|
||||
# Use a separate library for OptUtils, to avoid pulling in the entire JIT and
|
||||
@@ -328,4 +329,39 @@ if(LLVM_ENABLE_PIC)
|
||||
hip::host hip::amdhip64
|
||||
)
|
||||
endif()
|
||||
|
||||
if(MLIR_ENABLE_SYCL_RUNNER)
|
||||
find_package(SyclRuntime)
|
||||
|
||||
if(NOT SyclRuntime_FOUND)
|
||||
message(FATAL_ERROR "syclRuntime not found. Please set check oneapi installation and run setvars.sh.")
|
||||
endif()
|
||||
|
||||
find_package(LevelZero)
|
||||
|
||||
if(NOT LevelZero_FOUND)
|
||||
message(FATAL_ERROR "LevelZero not found. Please set LEVEL_ZERO_DIR.")
|
||||
endif()
|
||||
|
||||
add_mlir_library(mlir_sycl_runtime
|
||||
SHARED
|
||||
SyclRuntimeWrappers.cpp
|
||||
|
||||
EXCLUDE_FROM_LIBMLIR
|
||||
)
|
||||
|
||||
check_cxx_compiler_flag("-frtti" CXX_HAS_FRTTI_FLAG)
|
||||
if(NOT CXX_HAS_FRTTI_FLAG)
|
||||
message(FATAL_ERROR "CXX compiler does not accept flag -frtti")
|
||||
endif()
|
||||
target_compile_options (mlir_sycl_runtime PUBLIC -fexceptions -frtti)
|
||||
|
||||
target_include_directories(mlir_sycl_runtime PRIVATE
|
||||
${MLIR_INCLUDE_DIRS}
|
||||
)
|
||||
|
||||
target_link_libraries(mlir_sycl_runtime PRIVATE LevelZero::LevelZero SyclRuntime::SyclRuntime)
|
||||
|
||||
set_property(TARGET mlir_sycl_runtime APPEND PROPERTY BUILD_RPATH "${LevelZero_LIBRARIES_DIR}" "${SyclRuntime_LIBRARIES_DIR}")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
209
mlir/lib/ExecutionEngine/SyclRuntimeWrappers.cpp
Normal file
209
mlir/lib/ExecutionEngine/SyclRuntimeWrappers.cpp
Normal file
@@ -0,0 +1,209 @@
|
||||
//===- SyclRuntimeWrappers.cpp - MLIR SYCL wrapper library ------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Implements wrappers around the sycl runtime library with C linkage
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include <CL/sycl.hpp>
|
||||
#include <level_zero/ze_api.h>
|
||||
#include <sycl/ext/oneapi/backend/level_zero.hpp>
|
||||
|
||||
#ifdef _WIN32
|
||||
#define SYCL_RUNTIME_EXPORT __declspec(dllexport)
|
||||
#else
|
||||
#define SYCL_RUNTIME_EXPORT
|
||||
#endif // _WIN32
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename F>
|
||||
auto catchAll(F &&func) {
|
||||
try {
|
||||
return func();
|
||||
} catch (const std::exception &e) {
|
||||
fprintf(stdout, "An exception was thrown: %s\n", e.what());
|
||||
fflush(stdout);
|
||||
abort();
|
||||
} catch (...) {
|
||||
fprintf(stdout, "An unknown exception was thrown\n");
|
||||
fflush(stdout);
|
||||
abort();
|
||||
}
|
||||
}
|
||||
|
||||
#define L0_SAFE_CALL(call) \
|
||||
{ \
|
||||
ze_result_t status = (call); \
|
||||
if (status != ZE_RESULT_SUCCESS) { \
|
||||
fprintf(stdout, "L0 error %d\n", status); \
|
||||
fflush(stdout); \
|
||||
abort(); \
|
||||
} \
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
static sycl::device getDefaultDevice() {
|
||||
static sycl::device syclDevice;
|
||||
static bool isDeviceInitialised = false;
|
||||
if (!isDeviceInitialised) {
|
||||
auto platformList = sycl::platform::get_platforms();
|
||||
for (const auto &platform : platformList) {
|
||||
auto platformName = platform.get_info<sycl::info::platform::name>();
|
||||
bool isLevelZero = platformName.find("Level-Zero") != std::string::npos;
|
||||
if (!isLevelZero)
|
||||
continue;
|
||||
|
||||
syclDevice = platform.get_devices()[0];
|
||||
isDeviceInitialised = true;
|
||||
return syclDevice;
|
||||
}
|
||||
throw std::runtime_error("getDefaultDevice failed");
|
||||
} else
|
||||
return syclDevice;
|
||||
}
|
||||
|
||||
static sycl::context getDefaultContext() {
|
||||
static sycl::context syclContext{getDefaultDevice()};
|
||||
return syclContext;
|
||||
}
|
||||
|
||||
static void *allocDeviceMemory(sycl::queue *queue, size_t size, bool isShared) {
|
||||
void *memPtr = nullptr;
|
||||
if (isShared) {
|
||||
memPtr = sycl::aligned_alloc_shared(64, size, getDefaultDevice(),
|
||||
getDefaultContext());
|
||||
} else {
|
||||
memPtr = sycl::aligned_alloc_device(64, size, getDefaultDevice(),
|
||||
getDefaultContext());
|
||||
}
|
||||
if (memPtr == nullptr) {
|
||||
throw std::runtime_error("mem allocation failed!");
|
||||
}
|
||||
return memPtr;
|
||||
}
|
||||
|
||||
static void deallocDeviceMemory(sycl::queue *queue, void *ptr) {
|
||||
sycl::free(ptr, *queue);
|
||||
}
|
||||
|
||||
static ze_module_handle_t loadModule(const void *data, size_t dataSize) {
|
||||
assert(data);
|
||||
ze_module_handle_t zeModule;
|
||||
ze_module_desc_t desc = {ZE_STRUCTURE_TYPE_MODULE_DESC,
|
||||
nullptr,
|
||||
ZE_MODULE_FORMAT_IL_SPIRV,
|
||||
dataSize,
|
||||
(const uint8_t *)data,
|
||||
nullptr,
|
||||
nullptr};
|
||||
auto zeDevice = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(
|
||||
getDefaultDevice());
|
||||
auto zeContext = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(
|
||||
getDefaultContext());
|
||||
L0_SAFE_CALL(zeModuleCreate(zeContext, zeDevice, &desc, &zeModule, nullptr));
|
||||
return zeModule;
|
||||
}
|
||||
|
||||
static sycl::kernel *getKernel(ze_module_handle_t zeModule, const char *name) {
|
||||
assert(zeModule);
|
||||
assert(name);
|
||||
ze_kernel_handle_t zeKernel;
|
||||
ze_kernel_desc_t desc = {};
|
||||
desc.pKernelName = name;
|
||||
|
||||
L0_SAFE_CALL(zeKernelCreate(zeModule, &desc, &zeKernel));
|
||||
sycl::kernel_bundle<sycl::bundle_state::executable> kernelBundle =
|
||||
sycl::make_kernel_bundle<sycl::backend::ext_oneapi_level_zero,
|
||||
sycl::bundle_state::executable>(
|
||||
{zeModule}, getDefaultContext());
|
||||
|
||||
auto kernel = sycl::make_kernel<sycl::backend::ext_oneapi_level_zero>(
|
||||
{kernelBundle, zeKernel}, getDefaultContext());
|
||||
return new sycl::kernel(kernel);
|
||||
}
|
||||
|
||||
static void launchKernel(sycl::queue *queue, sycl::kernel *kernel, size_t gridX,
|
||||
size_t gridY, size_t gridZ, size_t blockX,
|
||||
size_t blockY, size_t blockZ, size_t sharedMemBytes,
|
||||
void **params, size_t paramsCount) {
|
||||
auto syclGlobalRange =
|
||||
sycl::range<3>(blockZ * gridZ, blockY * gridY, blockX * gridX);
|
||||
auto syclLocalRange = sycl::range<3>(blockZ, blockY, blockX);
|
||||
sycl::nd_range<3> syclNdRange(syclGlobalRange, syclLocalRange);
|
||||
|
||||
queue->submit([&](sycl::handler &cgh) {
|
||||
for (size_t i = 0; i < paramsCount; i++) {
|
||||
cgh.set_arg(static_cast<uint32_t>(i), *(static_cast<void **>(params[i])));
|
||||
}
|
||||
cgh.parallel_for(syclNdRange, *kernel);
|
||||
});
|
||||
}
|
||||
|
||||
// Wrappers
|
||||
|
||||
extern "C" SYCL_RUNTIME_EXPORT sycl::queue *mgpuStreamCreate() {
|
||||
|
||||
return catchAll([&]() {
|
||||
sycl::queue *queue =
|
||||
new sycl::queue(getDefaultContext(), getDefaultDevice());
|
||||
return queue;
|
||||
});
|
||||
}
|
||||
|
||||
extern "C" SYCL_RUNTIME_EXPORT void mgpuStreamDestroy(sycl::queue *queue) {
|
||||
catchAll([&]() { delete queue; });
|
||||
}
|
||||
|
||||
extern "C" SYCL_RUNTIME_EXPORT void *
|
||||
mgpuMemAlloc(uint64_t size, sycl::queue *queue, bool isShared) {
|
||||
return catchAll([&]() {
|
||||
return allocDeviceMemory(queue, static_cast<size_t>(size), true);
|
||||
});
|
||||
}
|
||||
|
||||
extern "C" SYCL_RUNTIME_EXPORT void mgpuMemFree(void *ptr, sycl::queue *queue) {
|
||||
catchAll([&]() {
|
||||
if (ptr) {
|
||||
deallocDeviceMemory(queue, ptr);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
extern "C" SYCL_RUNTIME_EXPORT ze_module_handle_t
|
||||
mgpuModuleLoad(const void *data, size_t gpuBlobSize) {
|
||||
return catchAll([&]() { return loadModule(data, gpuBlobSize); });
|
||||
}
|
||||
|
||||
extern "C" SYCL_RUNTIME_EXPORT sycl::kernel *
|
||||
mgpuModuleGetFunction(ze_module_handle_t module, const char *name) {
|
||||
return catchAll([&]() { return getKernel(module, name); });
|
||||
}
|
||||
|
||||
extern "C" SYCL_RUNTIME_EXPORT void
|
||||
mgpuLaunchKernel(sycl::kernel *kernel, size_t gridX, size_t gridY, size_t gridZ,
|
||||
size_t blockX, size_t blockY, size_t blockZ,
|
||||
size_t sharedMemBytes, sycl::queue *queue, void **params,
|
||||
void ** /*extra*/, size_t paramsCount) {
|
||||
return catchAll([&]() {
|
||||
launchKernel(queue, kernel, gridX, gridY, gridZ, blockX, blockY, blockZ,
|
||||
sharedMemBytes, params, paramsCount);
|
||||
});
|
||||
}
|
||||
|
||||
extern "C" SYCL_RUNTIME_EXPORT void mgpuStreamSynchronize(sycl::queue *queue) {
|
||||
|
||||
catchAll([&]() { queue->wait(); });
|
||||
}
|
||||
|
||||
extern "C" SYCL_RUNTIME_EXPORT void
|
||||
mgpuModuleUnload(ze_module_handle_t module) {
|
||||
|
||||
catchAll([&]() { L0_SAFE_CALL(zeModuleDestroy(module)); });
|
||||
}
|
||||
Reference in New Issue
Block a user