[flang][cuda] Make argument passed by value for sync functions (#125909)
`syncthreads_and`, `syncthreads_count`, `syncthreads_or`, `synwrap` must take their argument by value. This patch updates the interfaces and makes sure these functions can be called inside a cuff kernel as well.
This commit is contained in:
committed by
GitHub
parent
718b16a0fc
commit
69ccb1357f
@@ -29,28 +29,28 @@ implicit none
|
||||
|
||||
interface
|
||||
attributes(device) integer function syncthreads_and(value)
|
||||
integer :: value
|
||||
integer, value :: value
|
||||
end function
|
||||
end interface
|
||||
public :: syncthreads_and
|
||||
|
||||
interface
|
||||
attributes(device) integer function syncthreads_count(value)
|
||||
integer :: value
|
||||
integer, value :: value
|
||||
end function
|
||||
end interface
|
||||
public :: syncthreads_count
|
||||
|
||||
interface
|
||||
attributes(device) integer function syncthreads_or(value)
|
||||
integer :: value
|
||||
integer, value :: value
|
||||
end function
|
||||
end interface
|
||||
public :: syncthreads_or
|
||||
|
||||
interface
|
||||
attributes(device) subroutine syncwarp(mask) bind(c, name='__syncwarp')
|
||||
integer :: mask
|
||||
integer, value :: mask
|
||||
end subroutine
|
||||
end interface
|
||||
public :: syncwarp
|
||||
|
||||
@@ -47,7 +47,7 @@ end
|
||||
|
||||
! CHECK-LABEL: func.func @_QPdevsub() attributes {cuf.proc_attr = #cuf.cuda_proc<global>}
|
||||
! CHECK: fir.call @llvm.nvvm.barrier0() fastmath<contract> : () -> ()
|
||||
! CHECK: fir.call @__syncwarp(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (!fir.ref<i32>) -> ()
|
||||
! CHECK: fir.call @__syncwarp(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (i32) -> ()
|
||||
! CHECK: fir.call @llvm.nvvm.membar.gl() fastmath<contract> : () -> ()
|
||||
! CHECK: fir.call @llvm.nvvm.membar.cta() fastmath<contract> : () -> ()
|
||||
! CHECK: fir.call @llvm.nvvm.membar.sys() fastmath<contract> : () -> ()
|
||||
@@ -79,17 +79,9 @@ end
|
||||
! CHECK: %{{.*}} = llvm.atomicrmw uinc_wrap %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, i32
|
||||
! CHECK: %{{.*}} = llvm.atomicrmw udec_wrap %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, i32
|
||||
|
||||
! CHECK: func.func private @llvm.nvvm.barrier0()
|
||||
! CHECK: func.func private @__syncwarp(!fir.ref<i32> {cuf.data_attr = #cuf.cuda<device>}) attributes {cuf.proc_attr = #cuf.cuda_proc<device>, fir.bindc_name = "__syncwarp", fir.proc_attrs = #fir.proc_attrs<bind_c>}
|
||||
! CHECK: func.func private @llvm.nvvm.membar.gl()
|
||||
! CHECK: func.func private @llvm.nvvm.membar.cta()
|
||||
! CHECK: func.func private @llvm.nvvm.membar.sys()
|
||||
! CHECK: func.func private @llvm.nvvm.barrier0.and(i32) -> i32
|
||||
! CHECK: func.func private @llvm.nvvm.barrier0.popc(i32) -> i32
|
||||
! CHECK: func.func private @llvm.nvvm.barrier0.or(i32) -> i32
|
||||
|
||||
subroutine host1()
|
||||
integer, device :: a(32)
|
||||
integer, device :: ret
|
||||
integer :: i, j
|
||||
|
||||
block; use cudadevice
|
||||
@@ -98,6 +90,28 @@ block; use cudadevice
|
||||
a(i) = a(i) * 2.0
|
||||
call syncthreads()
|
||||
a(i) = a(i) + a(j) - 34.0
|
||||
|
||||
call syncwarp(1)
|
||||
ret = syncthreads_and(1)
|
||||
ret = syncthreads_count(1)
|
||||
ret = syncthreads_or(1)
|
||||
end do
|
||||
end block
|
||||
end
|
||||
|
||||
! CHECK-LABEL: func.func @_QPhost1()
|
||||
! CHECK: cuf.kernel
|
||||
! CHECK: fir.call @llvm.nvvm.barrier0() fastmath<contract> : () -> ()
|
||||
! CHECK: fir.call @__syncwarp(%c1{{.*}}) proc_attrs<bind_c> fastmath<contract> : (i32) -> ()
|
||||
! CHECK: fir.call @llvm.nvvm.barrier0.and(%c1{{.*}}) fastmath<contract> : (i32) -> i32
|
||||
! CHECK: fir.call @llvm.nvvm.barrier0.popc(%c1{{.*}}) fastmath<contract> : (i32) -> i32
|
||||
! CHECK: fir.call @llvm.nvvm.barrier0.or(%c1{{.*}}) fastmath<contract> : (i32) -> i32
|
||||
|
||||
! CHECK: func.func private @llvm.nvvm.barrier0()
|
||||
! CHECK: func.func private @__syncwarp(i32) attributes {cuf.proc_attr = #cuf.cuda_proc<device>, fir.bindc_name = "__syncwarp", fir.proc_attrs = #fir.proc_attrs<bind_c>}
|
||||
! CHECK: func.func private @llvm.nvvm.membar.gl()
|
||||
! CHECK: func.func private @llvm.nvvm.membar.cta()
|
||||
! CHECK: func.func private @llvm.nvvm.membar.sys()
|
||||
! CHECK: func.func private @llvm.nvvm.barrier0.and(i32) -> i32
|
||||
! CHECK: func.func private @llvm.nvvm.barrier0.popc(i32) -> i32
|
||||
! CHECK: func.func private @llvm.nvvm.barrier0.or(i32) -> i32
|
||||
|
||||
Reference in New Issue
Block a user