57 lines
1.9 KiB
C++
57 lines
1.9 KiB
C++
//===-- runtime/CUDA/kernel.cpp -------------------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "flang/Runtime/CUDA/kernel.h"
|
|
#include "../terminator.h"
|
|
#include "flang/Runtime/CUDA/common.h"
|
|
|
|
#include "cuda_runtime.h"
|
|
|
|
extern "C" {
|
|
|
|
void RTDEF(CUFLaunchKernel)(const void *kernel, intptr_t gridX, intptr_t gridY,
|
|
intptr_t gridZ, intptr_t blockX, intptr_t blockY, intptr_t blockZ,
|
|
int32_t smem, void **params, void **extra) {
|
|
dim3 gridDim;
|
|
gridDim.x = gridX;
|
|
gridDim.y = gridY;
|
|
gridDim.z = gridZ;
|
|
dim3 blockDim;
|
|
blockDim.x = blockX;
|
|
blockDim.y = blockY;
|
|
blockDim.z = blockZ;
|
|
cudaStream_t stream = 0; // TODO stream managment
|
|
CUDA_REPORT_IF_ERROR(
|
|
cudaLaunchKernel(kernel, gridDim, blockDim, params, smem, stream));
|
|
}
|
|
|
|
void RTDEF(CUFLaunchClusterKernel)(const void *kernel, intptr_t clusterX,
|
|
intptr_t clusterY, intptr_t clusterZ, intptr_t gridX, intptr_t gridY,
|
|
intptr_t gridZ, intptr_t blockX, intptr_t blockY, intptr_t blockZ,
|
|
int32_t smem, void **params, void **extra) {
|
|
cudaLaunchConfig_t config;
|
|
config.gridDim.x = gridX;
|
|
config.gridDim.y = gridY;
|
|
config.gridDim.z = gridZ;
|
|
config.blockDim.x = blockX;
|
|
config.blockDim.y = blockY;
|
|
config.blockDim.z = blockZ;
|
|
config.dynamicSmemBytes = smem;
|
|
config.stream = 0; // TODO stream managment
|
|
cudaLaunchAttribute launchAttr[1];
|
|
launchAttr[0].id = cudaLaunchAttributeClusterDimension;
|
|
launchAttr[0].val.clusterDim.x = clusterX;
|
|
launchAttr[0].val.clusterDim.y = clusterY;
|
|
launchAttr[0].val.clusterDim.z = clusterZ;
|
|
config.numAttrs = 1;
|
|
config.attrs = launchAttr;
|
|
CUDA_REPORT_IF_ERROR(cudaLaunchKernelExC(&config, kernel, params));
|
|
}
|
|
|
|
} // extern "C"
|