From 9992668404cfb2302f7a62f01884c210642caea1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Valentin=20Clement=20=28=E3=83=90=E3=83=AC=E3=83=B3?= =?UTF-8?q?=E3=82=BF=E3=82=A4=E3=83=B3=20=E3=82=AF=E3=83=AC=E3=83=A1?= =?UTF-8?q?=E3=83=B3=29?= Date: Thu, 12 Jun 2025 20:47:58 -0700 Subject: [PATCH] [flang][cuda] Add runtime check for passing device arrays (#144003) --- flang-rt/lib/cuda/descriptor.cpp | 8 +++++++ flang/include/flang/Lower/LoweringOptions.def | 3 +++ .../Builder/Runtime/CUDA/Descriptor.h | 5 +++++ flang/include/flang/Runtime/CUDA/descriptor.h | 4 ++++ flang/lib/Lower/ConvertCall.cpp | 14 ++++++++++++ .../Builder/Runtime/CUDA/Descriptor.cpp | 15 +++++++++++++ flang/test/Lower/CUDA/cuda-runtime-check.cuf | 22 +++++++++++++++++++ flang/tools/bbc/bbc.cpp | 2 ++ 8 files changed, 73 insertions(+) create mode 100644 flang/test/Lower/CUDA/cuda-runtime-check.cuf diff --git a/flang-rt/lib/cuda/descriptor.cpp b/flang-rt/lib/cuda/descriptor.cpp index 7b768f91af29..aa75d4eff051 100644 --- a/flang-rt/lib/cuda/descriptor.cpp +++ b/flang-rt/lib/cuda/descriptor.cpp @@ -54,6 +54,14 @@ void RTDEF(CUFSyncGlobalDescriptor)( ((Descriptor *)devAddr, (Descriptor *)hostPtr, sourceFile, sourceLine); } +void RTDEF(CUFDescriptorCheckSection)( + const Descriptor *desc, const char *sourceFile, int sourceLine) { + if (desc && !desc->IsContiguous()) { + Terminator terminator{sourceFile, sourceLine}; + terminator.Crash("device array section argument is not contiguous"); + } +} + RT_EXT_API_GROUP_END } } // namespace Fortran::runtime::cuda diff --git a/flang/include/flang/Lower/LoweringOptions.def b/flang/include/flang/Lower/LoweringOptions.def index b062ea1a805a..d97abf4d864b 100644 --- a/flang/include/flang/Lower/LoweringOptions.def +++ b/flang/include/flang/Lower/LoweringOptions.def @@ -63,5 +63,8 @@ ENUM_LOWERINGOPT(StackRepackArrays, unsigned, 1, 0) /// in the leading dimension. ENUM_LOWERINGOPT(RepackArraysWhole, unsigned, 1, 0) +/// If true, CUDA Fortran runtime check is inserted. +ENUM_LOWERINGOPT(CUDARuntimeCheck, unsigned, 1, 0) + #undef LOWERINGOPT #undef ENUM_LOWERINGOPT diff --git a/flang/include/flang/Optimizer/Builder/Runtime/CUDA/Descriptor.h b/flang/include/flang/Optimizer/Builder/Runtime/CUDA/Descriptor.h index 14d262bf22a7..bdeb7574012c 100644 --- a/flang/include/flang/Optimizer/Builder/Runtime/CUDA/Descriptor.h +++ b/flang/include/flang/Optimizer/Builder/Runtime/CUDA/Descriptor.h @@ -26,6 +26,11 @@ namespace fir::runtime::cuda { void genSyncGlobalDescriptor(fir::FirOpBuilder &builder, mlir::Location loc, mlir::Value hostPtr); +/// Generate runtime call to check the section of a descriptor and raise an +/// error if it is not contiguous. +void genDescriptorCheckSection(fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Value desc); + } // namespace fir::runtime::cuda #endif // FORTRAN_OPTIMIZER_BUILDER_RUNTIME_CUDA_DESCRIPTOR_H_ diff --git a/flang/include/flang/Runtime/CUDA/descriptor.h b/flang/include/flang/Runtime/CUDA/descriptor.h index 0ee7feca10e4..06e4a4649db1 100644 --- a/flang/include/flang/Runtime/CUDA/descriptor.h +++ b/flang/include/flang/Runtime/CUDA/descriptor.h @@ -37,6 +37,10 @@ void RTDECL(CUFDescriptorSync)(Descriptor *dst, const Descriptor *src, void RTDECL(CUFSyncGlobalDescriptor)( void *hostPtr, const char *sourceFile = nullptr, int sourceLine = 0); +/// Check descriptor passed to a kernel. +void RTDECL(CUFDescriptorCheckSection)( + const Descriptor *, const char *sourceFile = nullptr, int sourceLine = 0); + } // extern "C" } // namespace Fortran::runtime::cuda diff --git a/flang/lib/Lower/ConvertCall.cpp b/flang/lib/Lower/ConvertCall.cpp index 7378118cfef7..864499e6c343 100644 --- a/flang/lib/Lower/ConvertCall.cpp +++ b/flang/lib/Lower/ConvertCall.cpp @@ -26,6 +26,7 @@ #include "flang/Optimizer/Builder/IntrinsicCall.h" #include "flang/Optimizer/Builder/LowLevelIntrinsics.h" #include "flang/Optimizer/Builder/MutableBox.h" +#include "flang/Optimizer/Builder/Runtime/CUDA/Descriptor.h" #include "flang/Optimizer/Builder/Runtime/Derived.h" #include "flang/Optimizer/Builder/Todo.h" #include "flang/Optimizer/Dialect/CUF/CUFOps.h" @@ -543,6 +544,19 @@ Fortran::lower::genCallOpAndResult( fir::FortranProcedureFlagsEnumAttr procAttrs = caller.getProcedureAttrs(builder.getContext()); + if (converter.getLoweringOptions().getCUDARuntimeCheck()) { + if (caller.getCallDescription().chevrons().empty()) { + for (auto [oper, arg] : + llvm::zip(operands, caller.getPassedArguments())) { + if (auto boxTy = mlir::dyn_cast(oper.getType())) { + const Fortran::semantics::Symbol *sym = caller.getDummySymbol(arg); + if (sym && Fortran::evaluate::IsCUDADeviceSymbol(*sym)) + fir::runtime::cuda::genDescriptorCheckSection(builder, loc, oper); + } + } + } + } + if (!caller.getCallDescription().chevrons().empty()) { // A call to a CUDA kernel with the chevron syntax. diff --git a/flang/lib/Optimizer/Builder/Runtime/CUDA/Descriptor.cpp b/flang/lib/Optimizer/Builder/Runtime/CUDA/Descriptor.cpp index 90662c094c65..a943469a7672 100644 --- a/flang/lib/Optimizer/Builder/Runtime/CUDA/Descriptor.cpp +++ b/flang/lib/Optimizer/Builder/Runtime/CUDA/Descriptor.cpp @@ -32,3 +32,18 @@ void fir::runtime::cuda::genSyncGlobalDescriptor(fir::FirOpBuilder &builder, builder, loc, fTy, hostPtr, sourceFile, sourceLine)}; builder.create(loc, callee, args); } + +void fir::runtime::cuda::genDescriptorCheckSection(fir::FirOpBuilder &builder, + mlir::Location loc, + mlir::Value desc) { + mlir::func::FuncOp func = + fir::runtime::getRuntimeFunc(loc, + builder); + auto fTy = func.getFunctionType(); + mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc); + mlir::Value sourceLine = + fir::factory::locationToLineNo(builder, loc, fTy.getInput(2)); + llvm::SmallVector args{fir::runtime::createArguments( + builder, loc, fTy, desc, sourceFile, sourceLine)}; + builder.create(loc, func, args); +} diff --git a/flang/test/Lower/CUDA/cuda-runtime-check.cuf b/flang/test/Lower/CUDA/cuda-runtime-check.cuf new file mode 100644 index 000000000000..f26d372769ca --- /dev/null +++ b/flang/test/Lower/CUDA/cuda-runtime-check.cuf @@ -0,0 +1,22 @@ +! RUN: bbc -emit-hlfir -fcuda %s -o - | FileCheck %s + +! Check insertion of runtime checks + +interface + subroutine foo(a) + real, device, dimension(:,:) :: a + end subroutine +end interface + + real, device, allocatable, dimension(:,:) :: a + allocate(a(10,10)) + call foo(a(1:10,1:10:2)) +end + +subroutine foo(a) + real, device, dimension(:,:) :: a +end subroutine + +! CHECK-LABEL: func.func @_QQmain() +! CHECK: fir.call @_FortranACUFDescriptorCheckSection +! CHECK: fir.call @_QPfoo diff --git a/flang/tools/bbc/bbc.cpp b/flang/tools/bbc/bbc.cpp index c80872108ac8..015c86604a1f 100644 --- a/flang/tools/bbc/bbc.cpp +++ b/flang/tools/bbc/bbc.cpp @@ -434,6 +434,8 @@ static llvm::LogicalResult convertFortranSourceToMLIR( loweringOptions.setStackRepackArrays(stackRepackArrays); loweringOptions.setRepackArrays(repackArrays); loweringOptions.setRepackArraysWhole(repackArraysWhole); + if (enableCUDA) + loweringOptions.setCUDARuntimeCheck(true); std::vector envDefaults = {}; Fortran::frontend::TargetOptions targetOpts; Fortran::frontend::CodeGenOptions cgOpts;