[flang][cuda] Support corner case of data transfer (#132451)

The flang runtime will complain when the number of elements in the two
descriptors involved in the data transfer are not matching.

In some cases, we can still perform the data transfer to match the
behavior of the reference compiler.

When the RHS elements count is bigger than the LHS elements count and
both descriptors are contiguous, we can perform the data transfer with
the bare pointers and the number of bytes from the LHS.

We don't really have unit tests set up for data transfer, this is why I
didn't include one here.
This commit is contained in:
Valentin Clement (バレンタイン クレメン)
2025-03-21 15:39:05 -07:00
committed by GitHub
parent acdb0c1f99
commit ecaef010f3

View File

@@ -8,6 +8,7 @@
#include "flang/Runtime/CUDA/memory.h"
#include "flang-rt/runtime/assign-impl.h"
#include "flang-rt/runtime/descriptor.h"
#include "flang-rt/runtime/terminator.h"
#include "flang/Runtime/CUDA/common.h"
#include "flang/Runtime/CUDA/descriptor.h"
@@ -98,8 +99,21 @@ void RTDECL(CUFDataTransferDescDesc)(Descriptor *dstDesc, Descriptor *srcDesc,
} else {
terminator.Crash("host to host copy not supported");
}
Fortran::runtime::Assign(
*dstDesc, *srcDesc, terminator, MaybeReallocate, memmoveFct);
if ((srcDesc->rank() > 0) && (dstDesc->Elements() < srcDesc->Elements())) {
// Special case when rhs is bigger than lhs and both are contiguous arrays.
// In this case we do a simple ptr to ptr transfer with the size of lhs.
// This is be allowed in the reference compiler and it avoids error
// triggered in the Assign runtime function used for the main case below.
if (!srcDesc->IsContiguous() || !dstDesc->IsContiguous())
terminator.Crash("Unsupported data transfer: mismatching element counts "
"with non-contiguous arrays");
RTNAME(CUFDataTransferPtrPtr)(dstDesc->raw().base_addr,
srcDesc->raw().base_addr, dstDesc->Elements() * dstDesc->ElementBytes(),
mode, sourceFile, sourceLine);
} else {
Fortran::runtime::Assign(
*dstDesc, *srcDesc, terminator, MaybeReallocate, memmoveFct);
}
}
void RTDECL(CUFDataTransferCstDesc)(Descriptor *dstDesc, Descriptor *srcDesc,