From ecaef010f31e2557d94b4d98774ca4b4d5fe2149 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Valentin=20Clement=20=28=E3=83=90=E3=83=AC=E3=83=B3?= =?UTF-8?q?=E3=82=BF=E3=82=A4=E3=83=B3=20=E3=82=AF=E3=83=AC=E3=83=A1?= =?UTF-8?q?=E3=83=B3=29?= Date: Fri, 21 Mar 2025 15:39:05 -0700 Subject: [PATCH] [flang][cuda] Support corner case of data transfer (#132451) The flang runtime will complain when the number of elements in the two descriptors involved in the data transfer are not matching. In some cases, we can still perform the data transfer to match the behavior of the reference compiler. When the RHS elements count is bigger than the LHS elements count and both descriptors are contiguous, we can perform the data transfer with the bare pointers and the number of bytes from the LHS. We don't really have unit tests set up for data transfer, this is why I didn't include one here. --- flang-rt/lib/cuda/memory.cpp | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/flang-rt/lib/cuda/memory.cpp b/flang-rt/lib/cuda/memory.cpp index 1ebe5059b941..adc24ff22372 100644 --- a/flang-rt/lib/cuda/memory.cpp +++ b/flang-rt/lib/cuda/memory.cpp @@ -8,6 +8,7 @@ #include "flang/Runtime/CUDA/memory.h" #include "flang-rt/runtime/assign-impl.h" +#include "flang-rt/runtime/descriptor.h" #include "flang-rt/runtime/terminator.h" #include "flang/Runtime/CUDA/common.h" #include "flang/Runtime/CUDA/descriptor.h" @@ -98,8 +99,21 @@ void RTDECL(CUFDataTransferDescDesc)(Descriptor *dstDesc, Descriptor *srcDesc, } else { terminator.Crash("host to host copy not supported"); } - Fortran::runtime::Assign( - *dstDesc, *srcDesc, terminator, MaybeReallocate, memmoveFct); + if ((srcDesc->rank() > 0) && (dstDesc->Elements() < srcDesc->Elements())) { + // Special case when rhs is bigger than lhs and both are contiguous arrays. + // In this case we do a simple ptr to ptr transfer with the size of lhs. + // This is be allowed in the reference compiler and it avoids error + // triggered in the Assign runtime function used for the main case below. + if (!srcDesc->IsContiguous() || !dstDesc->IsContiguous()) + terminator.Crash("Unsupported data transfer: mismatching element counts " + "with non-contiguous arrays"); + RTNAME(CUFDataTransferPtrPtr)(dstDesc->raw().base_addr, + srcDesc->raw().base_addr, dstDesc->Elements() * dstDesc->ElementBytes(), + mode, sourceFile, sourceLine); + } else { + Fortran::runtime::Assign( + *dstDesc, *srcDesc, terminator, MaybeReallocate, memmoveFct); + } } void RTDECL(CUFDataTransferCstDesc)(Descriptor *dstDesc, Descriptor *srcDesc,