From 75175e72308536dff3225dc885db71343ae85267 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Valentin=20Clement=20=28=E3=83=90=E3=83=AC=E3=83=B3?= =?UTF-8?q?=E3=82=BF=E3=82=A4=E3=83=B3=20=E3=82=AF=E3=83=AC=E3=83=A1?= =?UTF-8?q?=E3=83=B3=29?= Date: Fri, 27 Jun 2025 14:59:29 -0700 Subject: [PATCH] [flang][cuda] Inline this_thread_block() calls (#146144) --- .../flang/Optimizer/Builder/IntrinsicCall.h | 1 + flang/lib/Optimizer/Builder/IntrinsicCall.cpp | 55 +++++++++++++++++++ flang/module/cooperative_groups.f90 | 13 +++++ flang/test/Lower/CUDA/cuda-cooperative.cuf | 26 +++++++++ 4 files changed, 95 insertions(+) diff --git a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h index 1e8c1198fb94..cb703134a454 100644 --- a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h +++ b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h @@ -443,6 +443,7 @@ struct IntrinsicLibrary { fir::ExtendedValue genTranspose(mlir::Type, llvm::ArrayRef); mlir::Value genThisGrid(mlir::Type, llvm::ArrayRef); + mlir::Value genThisThreadBlock(mlir::Type, llvm::ArrayRef); mlir::Value genThisWarp(mlir::Type, llvm::ArrayRef); void genThreadFence(llvm::ArrayRef); void genThreadFenceBlock(llvm::ArrayRef); diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp index 42dd78cd0e4c..651c198e7133 100644 --- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp +++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp @@ -933,6 +933,7 @@ static constexpr IntrinsicHandler handlers[]{ /*isElemental=*/false}, {"tand", &I::genTand}, {"this_grid", &I::genThisGrid, {}, /*isElemental=*/false}, + {"this_thread_block", &I::genThisThreadBlock, {}, /*isElemental=*/false}, {"this_warp", &I::genThisWarp, {}, /*isElemental=*/false}, {"threadfence", &I::genThreadFence, {}, /*isElemental=*/false}, {"threadfence_block", &I::genThreadFenceBlock, {}, /*isElemental=*/false}, @@ -8195,6 +8196,60 @@ mlir::Value IntrinsicLibrary::genThisGrid(mlir::Type resultType, return res; } +// THIS_THREAD_BLOCK +mlir::Value +IntrinsicLibrary::genThisThreadBlock(mlir::Type resultType, + llvm::ArrayRef args) { + assert(args.size() == 0); + auto recTy = mlir::cast(resultType); + assert(recTy && "RecordType expepected"); + mlir::Value res = builder.create(loc, resultType); + mlir::Type i32Ty = builder.getI32Type(); + + // this_thread_block%size = blockDim.z * blockDim.y * blockDim.x; + mlir::Value blockDimX = builder.create(loc, i32Ty); + mlir::Value blockDimY = builder.create(loc, i32Ty); + mlir::Value blockDimZ = builder.create(loc, i32Ty); + mlir::Value size = + builder.create(loc, blockDimZ, blockDimY); + size = builder.create(loc, size, blockDimX); + + // this_thread_block%rank = ((threadIdx.z * blockDim.y) * blockDim.x) + + // (threadIdx.y * blockDim.x) + threadIdx.x + 1; + mlir::Value threadIdX = builder.create(loc, i32Ty); + mlir::Value threadIdY = builder.create(loc, i32Ty); + mlir::Value threadIdZ = builder.create(loc, i32Ty); + mlir::Value r1 = + builder.create(loc, threadIdZ, blockDimY); + mlir::Value r2 = builder.create(loc, r1, blockDimX); + mlir::Value r3 = + builder.create(loc, threadIdY, blockDimX); + mlir::Value r2r3 = builder.create(loc, r2, r3); + mlir::Value rank = builder.create(loc, r2r3, threadIdX); + mlir::Value one = builder.createIntegerConstant(loc, i32Ty, 1); + rank = builder.create(loc, rank, one); + + auto sizeFieldName = recTy.getTypeList()[1].first; + mlir::Type sizeFieldTy = recTy.getTypeList()[1].second; + mlir::Type fieldIndexType = fir::FieldType::get(resultType.getContext()); + mlir::Value sizeFieldIndex = builder.create( + loc, fieldIndexType, sizeFieldName, recTy, + /*typeParams=*/mlir::ValueRange{}); + mlir::Value sizeCoord = builder.create( + loc, builder.getRefType(sizeFieldTy), res, sizeFieldIndex); + builder.create(loc, size, sizeCoord); + + auto rankFieldName = recTy.getTypeList()[2].first; + mlir::Type rankFieldTy = recTy.getTypeList()[2].second; + mlir::Value rankFieldIndex = builder.create( + loc, fieldIndexType, rankFieldName, recTy, + /*typeParams=*/mlir::ValueRange{}); + mlir::Value rankCoord = builder.create( + loc, builder.getRefType(rankFieldTy), res, rankFieldIndex); + builder.create(loc, rank, rankCoord); + return res; +} + // THIS_WARP mlir::Value IntrinsicLibrary::genThisWarp(mlir::Type resultType, llvm::ArrayRef args) { diff --git a/flang/module/cooperative_groups.f90 b/flang/module/cooperative_groups.f90 index e3c4b53afd8f..b8875f72f807 100644 --- a/flang/module/cooperative_groups.f90 +++ b/flang/module/cooperative_groups.f90 @@ -26,6 +26,12 @@ type :: coalesced_group integer(4) :: rank end type coalesced_group +type :: thread_group + type(c_devptr), private :: handle + integer(4) :: size + integer(4) :: rank +end type thread_group + interface attributes(device) function this_grid() import @@ -33,6 +39,13 @@ interface end function end interface +interface + attributes(device) function this_thread_block() + import + type(thread_group) :: this_thread_block + end function +end interface + interface this_warp attributes(device) function this_warp() import diff --git a/flang/test/Lower/CUDA/cuda-cooperative.cuf b/flang/test/Lower/CUDA/cuda-cooperative.cuf index 3dc1a5e85f84..657a87c0b5a0 100644 --- a/flang/test/Lower/CUDA/cuda-cooperative.cuf +++ b/flang/test/Lower/CUDA/cuda-cooperative.cuf @@ -70,4 +70,30 @@ end subroutine ! CHECK: %[[AND:.*]] = arith.andi %[[THREAD_ID]], %[[C31]] : i32 ! CHECK: %[[RANK:.*]] = arith.addi %[[AND]], %[[C1]] : i32 ! CHECK: %[[RANK_COORD:.*]] = fir.coordinate_of %{{.*}}, rank : (!fir.ref}>,size:i32,rank:i32}>>) -> !fir.ref + +attributes(grid_global) subroutine t1() + use cooperative_groups + type(thread_group) :: gg + gg = this_thread_block() +end subroutine +! CHECK: %{{.*}} = fir.alloca !fir.type<_QMcooperative_groupsTthread_group{_QMcooperative_groupsTthread_group.handle:!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>,size:i32,rank:i32}> +! CHECK: %[[THREAD_GROUP:.*]] = fir.alloca !fir.type<_QMcooperative_groupsTthread_group{_QMcooperative_groupsTthread_group.handle:!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>,size:i32,rank:i32}> +! CHECK: %[[NTID_X:.*]] = nvvm.read.ptx.sreg.ntid.x : i32 +! CHECK: %[[NTID_Y:.*]] = nvvm.read.ptx.sreg.ntid.y : i32 +! CHECK: %[[NTID_Z:.*]] = nvvm.read.ptx.sreg.ntid.z : i32 +! CHECK: %[[SIZE_ZY:.*]] = arith.muli %[[NTID_Z]], %[[NTID_Y]] : i32 +! CHECK: %[[SIZE:.*]] = arith.muli %[[SIZE_ZY]], %[[NTID_X]] : i32 +! CHECK: %[[TID_X:.*]] = nvvm.read.ptx.sreg.tid.x : i32 +! CHECK: %[[TID_Y:.*]] = nvvm.read.ptx.sreg.tid.y : i32 +! CHECK: %[[TID_Z:.*]] = nvvm.read.ptx.sreg.tid.z : i32 +! CHECK: %[[RANK_ZY:.*]] = arith.muli %[[TID_Z]], %[[NTID_Y]] : i32 +! CHECK: %[[RANK_ZYX:.*]] = arith.muli %[[RANK_ZY]], %[[NTID_X]] : i32 +! CHECK: %[[RANK_YX:.*]] = arith.muli %[[TID_Y]], %[[NTID_X]] : i32 +! CHECK: %[[RANK_SUM1:.*]] = arith.addi %[[RANK_ZYX]], %[[RANK_YX]] : i32 +! CHECK: %[[RANK_SUM2:.*]] = arith.addi %[[RANK_SUM1]], %[[TID_X]] : i32 +! CHECK: %[[C1:.*]] = arith.constant 1 : i32 +! CHECK: %[[RANK:.*]] = arith.addi %[[RANK_SUM2]], %[[C1]] : i32 +! CHECK: %[[SIZE_COORD:.*]] = fir.coordinate_of %[[THREAD_GROUP]], size : (!fir.ref}>,size:i32,rank:i32}>>) -> !fir.ref +! CHECK: fir.store %[[SIZE]] to %[[SIZE_COORD]] : !fir.ref +! CHECK: %[[RANK_COORD:.*]] = fir.coordinate_of %[[THREAD_GROUP]], rank : (!fir.ref}>,size:i32,rank:i32}>>) -> !fir.ref ! CHECK: fir.store %[[RANK]] to %[[RANK_COORD]] : !fir.ref