[flang][cuda] Add runtime check for passing device arrays (#144003)
This commit is contained in:
committed by
GitHub
parent
07dad4ecba
commit
9992668404
@@ -54,6 +54,14 @@ void RTDEF(CUFSyncGlobalDescriptor)(
|
||||
((Descriptor *)devAddr, (Descriptor *)hostPtr, sourceFile, sourceLine);
|
||||
}
|
||||
|
||||
void RTDEF(CUFDescriptorCheckSection)(
|
||||
const Descriptor *desc, const char *sourceFile, int sourceLine) {
|
||||
if (desc && !desc->IsContiguous()) {
|
||||
Terminator terminator{sourceFile, sourceLine};
|
||||
terminator.Crash("device array section argument is not contiguous");
|
||||
}
|
||||
}
|
||||
|
||||
RT_EXT_API_GROUP_END
|
||||
}
|
||||
} // namespace Fortran::runtime::cuda
|
||||
|
||||
@@ -63,5 +63,8 @@ ENUM_LOWERINGOPT(StackRepackArrays, unsigned, 1, 0)
|
||||
/// in the leading dimension.
|
||||
ENUM_LOWERINGOPT(RepackArraysWhole, unsigned, 1, 0)
|
||||
|
||||
/// If true, CUDA Fortran runtime check is inserted.
|
||||
ENUM_LOWERINGOPT(CUDARuntimeCheck, unsigned, 1, 0)
|
||||
|
||||
#undef LOWERINGOPT
|
||||
#undef ENUM_LOWERINGOPT
|
||||
|
||||
@@ -26,6 +26,11 @@ namespace fir::runtime::cuda {
|
||||
void genSyncGlobalDescriptor(fir::FirOpBuilder &builder, mlir::Location loc,
|
||||
mlir::Value hostPtr);
|
||||
|
||||
/// Generate runtime call to check the section of a descriptor and raise an
|
||||
/// error if it is not contiguous.
|
||||
void genDescriptorCheckSection(fir::FirOpBuilder &builder, mlir::Location loc,
|
||||
mlir::Value desc);
|
||||
|
||||
} // namespace fir::runtime::cuda
|
||||
|
||||
#endif // FORTRAN_OPTIMIZER_BUILDER_RUNTIME_CUDA_DESCRIPTOR_H_
|
||||
|
||||
@@ -37,6 +37,10 @@ void RTDECL(CUFDescriptorSync)(Descriptor *dst, const Descriptor *src,
|
||||
void RTDECL(CUFSyncGlobalDescriptor)(
|
||||
void *hostPtr, const char *sourceFile = nullptr, int sourceLine = 0);
|
||||
|
||||
/// Check descriptor passed to a kernel.
|
||||
void RTDECL(CUFDescriptorCheckSection)(
|
||||
const Descriptor *, const char *sourceFile = nullptr, int sourceLine = 0);
|
||||
|
||||
} // extern "C"
|
||||
|
||||
} // namespace Fortran::runtime::cuda
|
||||
|
||||
@@ -26,6 +26,7 @@
|
||||
#include "flang/Optimizer/Builder/IntrinsicCall.h"
|
||||
#include "flang/Optimizer/Builder/LowLevelIntrinsics.h"
|
||||
#include "flang/Optimizer/Builder/MutableBox.h"
|
||||
#include "flang/Optimizer/Builder/Runtime/CUDA/Descriptor.h"
|
||||
#include "flang/Optimizer/Builder/Runtime/Derived.h"
|
||||
#include "flang/Optimizer/Builder/Todo.h"
|
||||
#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
|
||||
@@ -543,6 +544,19 @@ Fortran::lower::genCallOpAndResult(
|
||||
fir::FortranProcedureFlagsEnumAttr procAttrs =
|
||||
caller.getProcedureAttrs(builder.getContext());
|
||||
|
||||
if (converter.getLoweringOptions().getCUDARuntimeCheck()) {
|
||||
if (caller.getCallDescription().chevrons().empty()) {
|
||||
for (auto [oper, arg] :
|
||||
llvm::zip(operands, caller.getPassedArguments())) {
|
||||
if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(oper.getType())) {
|
||||
const Fortran::semantics::Symbol *sym = caller.getDummySymbol(arg);
|
||||
if (sym && Fortran::evaluate::IsCUDADeviceSymbol(*sym))
|
||||
fir::runtime::cuda::genDescriptorCheckSection(builder, loc, oper);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!caller.getCallDescription().chevrons().empty()) {
|
||||
// A call to a CUDA kernel with the chevron syntax.
|
||||
|
||||
|
||||
@@ -32,3 +32,18 @@ void fir::runtime::cuda::genSyncGlobalDescriptor(fir::FirOpBuilder &builder,
|
||||
builder, loc, fTy, hostPtr, sourceFile, sourceLine)};
|
||||
builder.create<fir::CallOp>(loc, callee, args);
|
||||
}
|
||||
|
||||
void fir::runtime::cuda::genDescriptorCheckSection(fir::FirOpBuilder &builder,
|
||||
mlir::Location loc,
|
||||
mlir::Value desc) {
|
||||
mlir::func::FuncOp func =
|
||||
fir::runtime::getRuntimeFunc<mkRTKey(CUFDescriptorCheckSection)>(loc,
|
||||
builder);
|
||||
auto fTy = func.getFunctionType();
|
||||
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
|
||||
mlir::Value sourceLine =
|
||||
fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
|
||||
llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
|
||||
builder, loc, fTy, desc, sourceFile, sourceLine)};
|
||||
builder.create<fir::CallOp>(loc, func, args);
|
||||
}
|
||||
|
||||
22
flang/test/Lower/CUDA/cuda-runtime-check.cuf
Normal file
22
flang/test/Lower/CUDA/cuda-runtime-check.cuf
Normal file
@@ -0,0 +1,22 @@
|
||||
! RUN: bbc -emit-hlfir -fcuda %s -o - | FileCheck %s
|
||||
|
||||
! Check insertion of runtime checks
|
||||
|
||||
interface
|
||||
subroutine foo(a)
|
||||
real, device, dimension(:,:) :: a
|
||||
end subroutine
|
||||
end interface
|
||||
|
||||
real, device, allocatable, dimension(:,:) :: a
|
||||
allocate(a(10,10))
|
||||
call foo(a(1:10,1:10:2))
|
||||
end
|
||||
|
||||
subroutine foo(a)
|
||||
real, device, dimension(:,:) :: a
|
||||
end subroutine
|
||||
|
||||
! CHECK-LABEL: func.func @_QQmain()
|
||||
! CHECK: fir.call @_FortranACUFDescriptorCheckSection
|
||||
! CHECK: fir.call @_QPfoo
|
||||
@@ -434,6 +434,8 @@ static llvm::LogicalResult convertFortranSourceToMLIR(
|
||||
loweringOptions.setStackRepackArrays(stackRepackArrays);
|
||||
loweringOptions.setRepackArrays(repackArrays);
|
||||
loweringOptions.setRepackArraysWhole(repackArraysWhole);
|
||||
if (enableCUDA)
|
||||
loweringOptions.setCUDARuntimeCheck(true);
|
||||
std::vector<Fortran::lower::EnvironmentDefault> envDefaults = {};
|
||||
Fortran::frontend::TargetOptions targetOpts;
|
||||
Fortran::frontend::CodeGenOptions cgOpts;
|
||||
|
||||
Reference in New Issue
Block a user