[flang][cuda] Make argument passed by value for sync functions (#125909)

`syncthreads_and`, `syncthreads_count`, `syncthreads_or`, `synwrap` must
take their argument by value. This patch updates the interfaces and
makes sure these functions can be called inside a cuff kernel as well.
This commit is contained in:
Valentin Clement (バレンタイン クレメン)
2025-02-05 13:47:53 -08:00
committed by GitHub
parent 718b16a0fc
commit 69ccb1357f
2 changed files with 28 additions and 14 deletions

View File

@@ -29,28 +29,28 @@ implicit none
interface
attributes(device) integer function syncthreads_and(value)
integer :: value
integer, value :: value
end function
end interface
public :: syncthreads_and
interface
attributes(device) integer function syncthreads_count(value)
integer :: value
integer, value :: value
end function
end interface
public :: syncthreads_count
interface
attributes(device) integer function syncthreads_or(value)
integer :: value
integer, value :: value
end function
end interface
public :: syncthreads_or
interface
attributes(device) subroutine syncwarp(mask) bind(c, name='__syncwarp')
integer :: mask
integer, value :: mask
end subroutine
end interface
public :: syncwarp

View File

@@ -47,7 +47,7 @@ end
! CHECK-LABEL: func.func @_QPdevsub() attributes {cuf.proc_attr = #cuf.cuda_proc<global>}
! CHECK: fir.call @llvm.nvvm.barrier0() fastmath<contract> : () -> ()
! CHECK: fir.call @__syncwarp(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (!fir.ref<i32>) -> ()
! CHECK: fir.call @__syncwarp(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (i32) -> ()
! CHECK: fir.call @llvm.nvvm.membar.gl() fastmath<contract> : () -> ()
! CHECK: fir.call @llvm.nvvm.membar.cta() fastmath<contract> : () -> ()
! CHECK: fir.call @llvm.nvvm.membar.sys() fastmath<contract> : () -> ()
@@ -79,17 +79,9 @@ end
! CHECK: %{{.*}} = llvm.atomicrmw uinc_wrap %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, i32
! CHECK: %{{.*}} = llvm.atomicrmw udec_wrap %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, i32
! CHECK: func.func private @llvm.nvvm.barrier0()
! CHECK: func.func private @__syncwarp(!fir.ref<i32> {cuf.data_attr = #cuf.cuda<device>}) attributes {cuf.proc_attr = #cuf.cuda_proc<device>, fir.bindc_name = "__syncwarp", fir.proc_attrs = #fir.proc_attrs<bind_c>}
! CHECK: func.func private @llvm.nvvm.membar.gl()
! CHECK: func.func private @llvm.nvvm.membar.cta()
! CHECK: func.func private @llvm.nvvm.membar.sys()
! CHECK: func.func private @llvm.nvvm.barrier0.and(i32) -> i32
! CHECK: func.func private @llvm.nvvm.barrier0.popc(i32) -> i32
! CHECK: func.func private @llvm.nvvm.barrier0.or(i32) -> i32
subroutine host1()
integer, device :: a(32)
integer, device :: ret
integer :: i, j
block; use cudadevice
@@ -98,6 +90,28 @@ block; use cudadevice
a(i) = a(i) * 2.0
call syncthreads()
a(i) = a(i) + a(j) - 34.0
call syncwarp(1)
ret = syncthreads_and(1)
ret = syncthreads_count(1)
ret = syncthreads_or(1)
end do
end block
end
! CHECK-LABEL: func.func @_QPhost1()
! CHECK: cuf.kernel
! CHECK: fir.call @llvm.nvvm.barrier0() fastmath<contract> : () -> ()
! CHECK: fir.call @__syncwarp(%c1{{.*}}) proc_attrs<bind_c> fastmath<contract> : (i32) -> ()
! CHECK: fir.call @llvm.nvvm.barrier0.and(%c1{{.*}}) fastmath<contract> : (i32) -> i32
! CHECK: fir.call @llvm.nvvm.barrier0.popc(%c1{{.*}}) fastmath<contract> : (i32) -> i32
! CHECK: fir.call @llvm.nvvm.barrier0.or(%c1{{.*}}) fastmath<contract> : (i32) -> i32
! CHECK: func.func private @llvm.nvvm.barrier0()
! CHECK: func.func private @__syncwarp(i32) attributes {cuf.proc_attr = #cuf.cuda_proc<device>, fir.bindc_name = "__syncwarp", fir.proc_attrs = #fir.proc_attrs<bind_c>}
! CHECK: func.func private @llvm.nvvm.membar.gl()
! CHECK: func.func private @llvm.nvvm.membar.cta()
! CHECK: func.func private @llvm.nvvm.membar.sys()
! CHECK: func.func private @llvm.nvvm.barrier0.and(i32) -> i32
! CHECK: func.func private @llvm.nvvm.barrier0.popc(i32) -> i32
! CHECK: func.func private @llvm.nvvm.barrier0.or(i32) -> i32