[flang][cuda] Pass stream information to kernel launch functions (#135246)

This commit is contained in:
Valentin Clement (バレンタイン クレメン)
2025-04-10 13:50:50 -07:00
committed by GitHub
parent 641de84d3b
commit 49f8ccd1eb
4 changed files with 29 additions and 15 deletions

View File

@@ -16,7 +16,7 @@ extern "C" {
void RTDEF(CUFLaunchKernel)(const void *kernel, intptr_t gridX, intptr_t gridY, void RTDEF(CUFLaunchKernel)(const void *kernel, intptr_t gridX, intptr_t gridY,
intptr_t gridZ, intptr_t blockX, intptr_t blockY, intptr_t blockZ, intptr_t gridZ, intptr_t blockX, intptr_t blockY, intptr_t blockZ,
int32_t smem, void **params, void **extra) { intptr_t stream, int32_t smem, void **params, void **extra) {
dim3 gridDim; dim3 gridDim;
gridDim.x = gridX; gridDim.x = gridX;
gridDim.y = gridY; gridDim.y = gridY;
@@ -74,15 +74,15 @@ void RTDEF(CUFLaunchKernel)(const void *kernel, intptr_t gridX, intptr_t gridY,
Fortran::runtime::Terminator terminator{__FILE__, __LINE__}; Fortran::runtime::Terminator terminator{__FILE__, __LINE__};
terminator.Crash("Too many invalid grid dimensions"); terminator.Crash("Too many invalid grid dimensions");
} }
cudaStream_t stream = 0; // TODO stream managment cudaStream_t cuStream = 0; // TODO stream managment
CUDA_REPORT_IF_ERROR( CUDA_REPORT_IF_ERROR(
cudaLaunchKernel(kernel, gridDim, blockDim, params, smem, stream)); cudaLaunchKernel(kernel, gridDim, blockDim, params, smem, cuStream));
} }
void RTDEF(CUFLaunchClusterKernel)(const void *kernel, intptr_t clusterX, void RTDEF(CUFLaunchClusterKernel)(const void *kernel, intptr_t clusterX,
intptr_t clusterY, intptr_t clusterZ, intptr_t gridX, intptr_t gridY, intptr_t clusterY, intptr_t clusterZ, intptr_t gridX, intptr_t gridY,
intptr_t gridZ, intptr_t blockX, intptr_t blockY, intptr_t blockZ, intptr_t gridZ, intptr_t blockX, intptr_t blockY, intptr_t blockZ,
int32_t smem, void **params, void **extra) { intptr_t stream, int32_t smem, void **params, void **extra) {
cudaLaunchConfig_t config; cudaLaunchConfig_t config;
config.gridDim.x = gridX; config.gridDim.x = gridX;
config.gridDim.y = gridY; config.gridDim.y = gridY;
@@ -153,7 +153,8 @@ void RTDEF(CUFLaunchClusterKernel)(const void *kernel, intptr_t clusterX,
void RTDEF(CUFLaunchCooperativeKernel)(const void *kernel, intptr_t gridX, void RTDEF(CUFLaunchCooperativeKernel)(const void *kernel, intptr_t gridX,
intptr_t gridY, intptr_t gridZ, intptr_t blockX, intptr_t blockY, intptr_t gridY, intptr_t gridZ, intptr_t blockX, intptr_t blockY,
intptr_t blockZ, int32_t smem, void **params, void **extra) { intptr_t blockZ, intptr_t stream, int32_t smem, void **params,
void **extra) {
dim3 gridDim; dim3 gridDim;
gridDim.x = gridX; gridDim.x = gridX;
gridDim.y = gridY; gridDim.y = gridY;
@@ -211,9 +212,9 @@ void RTDEF(CUFLaunchCooperativeKernel)(const void *kernel, intptr_t gridX,
Fortran::runtime::Terminator terminator{__FILE__, __LINE__}; Fortran::runtime::Terminator terminator{__FILE__, __LINE__};
terminator.Crash("Too many invalid grid dimensions"); terminator.Crash("Too many invalid grid dimensions");
} }
cudaStream_t stream = 0; // TODO stream managment cudaStream_t cuStream = 0; // TODO stream managment
CUDA_REPORT_IF_ERROR(cudaLaunchCooperativeKernel( CUDA_REPORT_IF_ERROR(cudaLaunchCooperativeKernel(
kernel, gridDim, blockDim, params, smem, stream)); kernel, gridDim, blockDim, params, smem, cuStream));
} }
} // extern "C" } // extern "C"

View File

@@ -21,16 +21,18 @@ extern "C" {
void RTDEF(CUFLaunchKernel)(const void *kernelName, intptr_t gridX, void RTDEF(CUFLaunchKernel)(const void *kernelName, intptr_t gridX,
intptr_t gridY, intptr_t gridZ, intptr_t blockX, intptr_t blockY, intptr_t gridY, intptr_t gridZ, intptr_t blockX, intptr_t blockY,
intptr_t blockZ, int32_t smem, void **params, void **extra); intptr_t blockZ, intptr_t stream, int32_t smem, void **params,
void **extra);
void RTDEF(CUFLaunchClusterKernel)(const void *kernelName, intptr_t clusterX, void RTDEF(CUFLaunchClusterKernel)(const void *kernelName, intptr_t clusterX,
intptr_t clusterY, intptr_t clusterZ, intptr_t gridX, intptr_t gridY, intptr_t clusterY, intptr_t clusterZ, intptr_t gridX, intptr_t gridY,
intptr_t gridZ, intptr_t blockX, intptr_t blockY, intptr_t blockZ, intptr_t gridZ, intptr_t blockX, intptr_t blockY, intptr_t blockZ,
int32_t smem, void **params, void **extra); intptr_t stream, int32_t smem, void **params, void **extra);
void RTDEF(CUFLaunchCooperativeKernel)(const void *kernelName, intptr_t gridX, void RTDEF(CUFLaunchCooperativeKernel)(const void *kernelName, intptr_t gridX,
intptr_t gridY, intptr_t gridZ, intptr_t blockX, intptr_t blockY, intptr_t gridY, intptr_t gridZ, intptr_t blockX, intptr_t blockY,
intptr_t blockZ, int32_t smem, void **params, void **extra); intptr_t blockZ, intptr_t stream, int32_t smem, void **params,
void **extra);
} // extern "C" } // extern "C"

View File

@@ -121,7 +121,7 @@ struct GPULaunchKernelConversion
voidTy, voidTy,
{ptrTy, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, {ptrTy, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType,
llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType,
llvmIntPtrType, llvmIntPtrType, i32Ty, ptrTy, ptrTy}, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, i32Ty, ptrTy, ptrTy},
/*isVarArg=*/false); /*isVarArg=*/false);
auto cufLaunchClusterKernel = mlir::SymbolRefAttr::get( auto cufLaunchClusterKernel = mlir::SymbolRefAttr::get(
mod.getContext(), RTNAME_STRING(CUFLaunchClusterKernel)); mod.getContext(), RTNAME_STRING(CUFLaunchClusterKernel));
@@ -133,6 +133,10 @@ struct GPULaunchKernelConversion
launchKernelFuncOp.setVisibility( launchKernelFuncOp.setVisibility(
mlir::SymbolTable::Visibility::Private); mlir::SymbolTable::Visibility::Private);
} }
mlir::Value stream = adaptor.getAsyncObject();
if (!stream)
stream = rewriter.create<mlir::LLVM::ConstantOp>(
loc, llvmIntPtrType, rewriter.getIntegerAttr(llvmIntPtrType, -1));
rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>( rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
op, funcTy, cufLaunchClusterKernel, op, funcTy, cufLaunchClusterKernel,
mlir::ValueRange{kernelPtr, adaptor.getClusterSizeX(), mlir::ValueRange{kernelPtr, adaptor.getClusterSizeX(),
@@ -140,7 +144,7 @@ struct GPULaunchKernelConversion
adaptor.getGridSizeX(), adaptor.getGridSizeY(), adaptor.getGridSizeX(), adaptor.getGridSizeY(),
adaptor.getGridSizeZ(), adaptor.getBlockSizeX(), adaptor.getGridSizeZ(), adaptor.getBlockSizeX(),
adaptor.getBlockSizeY(), adaptor.getBlockSizeZ(), adaptor.getBlockSizeY(), adaptor.getBlockSizeZ(),
dynamicMemorySize, kernelArgs, nullPtr}); stream, dynamicMemorySize, kernelArgs, nullPtr});
} else { } else {
auto procAttr = auto procAttr =
op->getAttrOfType<cuf::ProcAttributeAttr>(cuf::getProcAttrName()); op->getAttrOfType<cuf::ProcAttributeAttr>(cuf::getProcAttrName());
@@ -153,7 +157,8 @@ struct GPULaunchKernelConversion
auto funcTy = mlir::LLVM::LLVMFunctionType::get( auto funcTy = mlir::LLVM::LLVMFunctionType::get(
voidTy, voidTy,
{ptrTy, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, {ptrTy, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType,
llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, i32Ty, ptrTy, ptrTy}, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType,
i32Ty, ptrTy, ptrTy},
/*isVarArg=*/false); /*isVarArg=*/false);
auto cufLaunchKernel = auto cufLaunchKernel =
mlir::SymbolRefAttr::get(mod.getContext(), fctName); mlir::SymbolRefAttr::get(mod.getContext(), fctName);
@@ -165,12 +170,18 @@ struct GPULaunchKernelConversion
launchKernelFuncOp.setVisibility( launchKernelFuncOp.setVisibility(
mlir::SymbolTable::Visibility::Private); mlir::SymbolTable::Visibility::Private);
} }
mlir::Value stream = adaptor.getAsyncObject();
if (!stream)
stream = rewriter.create<mlir::LLVM::ConstantOp>(
loc, llvmIntPtrType, rewriter.getIntegerAttr(llvmIntPtrType, -1));
rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>( rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
op, funcTy, cufLaunchKernel, op, funcTy, cufLaunchKernel,
mlir::ValueRange{kernelPtr, adaptor.getGridSizeX(), mlir::ValueRange{kernelPtr, adaptor.getGridSizeX(),
adaptor.getGridSizeY(), adaptor.getGridSizeZ(), adaptor.getGridSizeY(), adaptor.getGridSizeZ(),
adaptor.getBlockSizeX(), adaptor.getBlockSizeY(), adaptor.getBlockSizeX(), adaptor.getBlockSizeY(),
adaptor.getBlockSizeZ(), dynamicMemorySize, adaptor.getBlockSizeZ(), stream, dynamicMemorySize,
kernelArgs, nullPtr}); kernelArgs, nullPtr});
} }

View File

@@ -113,7 +113,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<i1, dense<8> : ve
// ----- // -----
module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<!llvm.ptr<272>, dense<64> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr, dense<64> : vector<4xi64>>, #dlti.dl_entry<i64, dense<64> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr<270>, dense<32> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<271>, dense<32> : vector<4xi64>>, #dlti.dl_entry<f64, dense<64> : vector<2xi64>>, #dlti.dl_entry<f128, dense<128> : vector<2xi64>>, #dlti.dl_entry<f16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i32, dense<32> : vector<2xi64>>, #dlti.dl_entry<f80, dense<128> : vector<2xi64>>, #dlti.dl_entry<i8, dense<8> : vector<2xi64>>, #dlti.dl_entry<i16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i128, dense<128> : vector<2xi64>>, #dlti.dl_entry<i1, dense<8> : vector<2xi64>>, #dlti.dl_entry<"dlti.endianness", "little">, #dlti.dl_entry<"dlti.stack_alignment", 128 : i64>>, fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", gpu.container_module, llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", llvm.ident = "flang version 20.0.0 (git@github.com:clementval/llvm-project.git 4116c1370ff76adf1e58eb3c39d0a14721794c70)", llvm.target_triple = "x86_64-unknown-linux-gnu"} { module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<!llvm.ptr<272>, dense<64> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr, dense<64> : vector<4xi64>>, #dlti.dl_entry<i64, dense<64> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr<270>, dense<32> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<271>, dense<32> : vector<4xi64>>, #dlti.dl_entry<f64, dense<64> : vector<2xi64>>, #dlti.dl_entry<f128, dense<128> : vector<2xi64>>, #dlti.dl_entry<f16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i32, dense<32> : vector<2xi64>>, #dlti.dl_entry<f80, dense<128> : vector<2xi64>>, #dlti.dl_entry<i8, dense<8> : vector<2xi64>>, #dlti.dl_entry<i16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i128, dense<128> : vector<2xi64>>, #dlti.dl_entry<i1, dense<8> : vector<2xi64>>, #dlti.dl_entry<"dlti.endianness", "little">, #dlti.dl_entry<"dlti.stack_alignment", 128 : i64>>, fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", gpu.container_module, llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", llvm.ident = "flang version 20.0.0 (git@github.com:clementval/llvm-project.git 4116c1370ff76adf1e58eb3c39d0a14721794c70)", llvm.target_triple = "x86_64-unknown-linux-gnu"} {
llvm.func @_FortranACUFLaunchClusterKernel(!llvm.ptr, i64, i64, i64, i64, i64, i64, i64, i64, i64, i32, !llvm.ptr, !llvm.ptr) attributes {sym_visibility = "private"} llvm.func @_FortranACUFLaunchClusterKernel(!llvm.ptr, i64, i64, i64, i64, i64, i64, i64, i64, i64, i64, i32, !llvm.ptr, !llvm.ptr) attributes {sym_visibility = "private"}
llvm.func @_QMmod1Psub1() attributes {cuf.cluster_dims = #cuf.cluster_dims<x = 2 : i64, y = 2 : i64, z = 1 : i64>} { llvm.func @_QMmod1Psub1() attributes {cuf.cluster_dims = #cuf.cluster_dims<x = 2 : i64, y = 2 : i64, z = 1 : i64>} {
llvm.return llvm.return
} }