[flang][cuda] Inline this_grid call for cooperative groups (#145796)
This commit is contained in:
committed by
GitHub
parent
f63bc84b0d
commit
2b2bd51f3b
@@ -442,6 +442,7 @@ struct IntrinsicLibrary {
|
||||
llvm::ArrayRef<fir::ExtendedValue>);
|
||||
fir::ExtendedValue genTranspose(mlir::Type,
|
||||
llvm::ArrayRef<fir::ExtendedValue>);
|
||||
mlir::Value genThisGrid(mlir::Type, llvm::ArrayRef<mlir::Value>);
|
||||
void genThreadFence(llvm::ArrayRef<fir::ExtendedValue>);
|
||||
void genThreadFenceBlock(llvm::ArrayRef<fir::ExtendedValue>);
|
||||
void genThreadFenceSystem(llvm::ArrayRef<fir::ExtendedValue>);
|
||||
|
||||
@@ -932,6 +932,7 @@ static constexpr IntrinsicHandler handlers[]{
|
||||
{{{"count", asAddr}, {"count_rate", asAddr}, {"count_max", asAddr}}},
|
||||
/*isElemental=*/false},
|
||||
{"tand", &I::genTand},
|
||||
{"this_grid", &I::genThisGrid, {}, /*isElemental=*/false},
|
||||
{"threadfence", &I::genThreadFence, {}, /*isElemental=*/false},
|
||||
{"threadfence_block", &I::genThreadFenceBlock, {}, /*isElemental=*/false},
|
||||
{"threadfence_system", &I::genThreadFenceSystem, {}, /*isElemental=*/false},
|
||||
@@ -8109,6 +8110,90 @@ mlir::Value IntrinsicLibrary::genTand(mlir::Type resultType,
|
||||
return getRuntimeCallGenerator("tan", ftype)(builder, loc, {arg});
|
||||
}
|
||||
|
||||
// THIS_GRID
|
||||
mlir::Value IntrinsicLibrary::genThisGrid(mlir::Type resultType,
|
||||
llvm::ArrayRef<mlir::Value> args) {
|
||||
assert(args.size() == 0);
|
||||
auto recTy = mlir::cast<fir::RecordType>(resultType);
|
||||
assert(recTy && "RecordType expepected");
|
||||
mlir::Value res = builder.create<fir::AllocaOp>(loc, resultType);
|
||||
mlir::Type i32Ty = builder.getI32Type();
|
||||
|
||||
mlir::Value threadIdX = builder.create<mlir::NVVM::ThreadIdXOp>(loc, i32Ty);
|
||||
mlir::Value threadIdY = builder.create<mlir::NVVM::ThreadIdYOp>(loc, i32Ty);
|
||||
mlir::Value threadIdZ = builder.create<mlir::NVVM::ThreadIdZOp>(loc, i32Ty);
|
||||
|
||||
mlir::Value blockIdX = builder.create<mlir::NVVM::BlockIdXOp>(loc, i32Ty);
|
||||
mlir::Value blockIdY = builder.create<mlir::NVVM::BlockIdYOp>(loc, i32Ty);
|
||||
mlir::Value blockIdZ = builder.create<mlir::NVVM::BlockIdZOp>(loc, i32Ty);
|
||||
|
||||
mlir::Value blockDimX = builder.create<mlir::NVVM::BlockDimXOp>(loc, i32Ty);
|
||||
mlir::Value blockDimY = builder.create<mlir::NVVM::BlockDimYOp>(loc, i32Ty);
|
||||
mlir::Value blockDimZ = builder.create<mlir::NVVM::BlockDimZOp>(loc, i32Ty);
|
||||
mlir::Value gridDimX = builder.create<mlir::NVVM::GridDimXOp>(loc, i32Ty);
|
||||
mlir::Value gridDimY = builder.create<mlir::NVVM::GridDimYOp>(loc, i32Ty);
|
||||
mlir::Value gridDimZ = builder.create<mlir::NVVM::GridDimZOp>(loc, i32Ty);
|
||||
|
||||
// this_grid.size = ((blockDim.z * gridDim.z) * (blockDim.y * gridDim.y)) *
|
||||
// (blockDim.x * gridDim.x);
|
||||
mlir::Value resZ =
|
||||
builder.create<mlir::arith::MulIOp>(loc, blockDimZ, gridDimZ);
|
||||
mlir::Value resY =
|
||||
builder.create<mlir::arith::MulIOp>(loc, blockDimY, gridDimY);
|
||||
mlir::Value resX =
|
||||
builder.create<mlir::arith::MulIOp>(loc, blockDimX, gridDimX);
|
||||
mlir::Value resZY = builder.create<mlir::arith::MulIOp>(loc, resZ, resY);
|
||||
mlir::Value size = builder.create<mlir::arith::MulIOp>(loc, resZY, resX);
|
||||
|
||||
// tmp = ((blockIdx.z * gridDim.y * gridDim.x) + (blockIdx.y * gridDim.x)) +
|
||||
// blockIdx.x;
|
||||
// this_group.rank = tmp * ((blockDim.x * blockDim.y) * blockDim.z) +
|
||||
// ((threadIdx.z * blockDim.y) * blockDim.x) +
|
||||
// (threadIdx.y * blockDim.x) + threadIdx.x + 1;
|
||||
mlir::Value r1 = builder.create<mlir::arith::MulIOp>(loc, blockIdZ, gridDimY);
|
||||
mlir::Value r2 = builder.create<mlir::arith::MulIOp>(loc, r1, gridDimX);
|
||||
mlir::Value r3 = builder.create<mlir::arith::MulIOp>(loc, blockIdY, gridDimX);
|
||||
mlir::Value r2r3 = builder.create<mlir::arith::AddIOp>(loc, r2, r3);
|
||||
mlir::Value tmp = builder.create<mlir::arith::AddIOp>(loc, r2r3, blockIdX);
|
||||
|
||||
mlir::Value bXbY =
|
||||
builder.create<mlir::arith::MulIOp>(loc, blockDimX, blockDimY);
|
||||
mlir::Value bXbYbZ =
|
||||
builder.create<mlir::arith::MulIOp>(loc, bXbY, blockDimZ);
|
||||
mlir::Value tZbY =
|
||||
builder.create<mlir::arith::MulIOp>(loc, threadIdZ, blockDimY);
|
||||
mlir::Value tZbYbX =
|
||||
builder.create<mlir::arith::MulIOp>(loc, tZbY, blockDimX);
|
||||
mlir::Value tYbX =
|
||||
builder.create<mlir::arith::MulIOp>(loc, threadIdY, blockDimX);
|
||||
mlir::Value rank = builder.create<mlir::arith::MulIOp>(loc, tmp, bXbYbZ);
|
||||
rank = builder.create<mlir::arith::AddIOp>(loc, rank, tZbYbX);
|
||||
rank = builder.create<mlir::arith::AddIOp>(loc, rank, tYbX);
|
||||
rank = builder.create<mlir::arith::AddIOp>(loc, rank, threadIdX);
|
||||
mlir::Value one = builder.createIntegerConstant(loc, i32Ty, 1);
|
||||
rank = builder.create<mlir::arith::AddIOp>(loc, rank, one);
|
||||
|
||||
auto sizeFieldName = recTy.getTypeList()[1].first;
|
||||
mlir::Type sizeFieldTy = recTy.getTypeList()[1].second;
|
||||
mlir::Type fieldIndexType = fir::FieldType::get(resultType.getContext());
|
||||
mlir::Value sizeFieldIndex = builder.create<fir::FieldIndexOp>(
|
||||
loc, fieldIndexType, sizeFieldName, recTy,
|
||||
/*typeParams=*/mlir::ValueRange{});
|
||||
mlir::Value sizeCoord = builder.create<fir::CoordinateOp>(
|
||||
loc, builder.getRefType(sizeFieldTy), res, sizeFieldIndex);
|
||||
builder.create<fir::StoreOp>(loc, size, sizeCoord);
|
||||
|
||||
auto rankFieldName = recTy.getTypeList()[2].first;
|
||||
mlir::Type rankFieldTy = recTy.getTypeList()[2].second;
|
||||
mlir::Value rankFieldIndex = builder.create<fir::FieldIndexOp>(
|
||||
loc, fieldIndexType, rankFieldName, recTy,
|
||||
/*typeParams=*/mlir::ValueRange{});
|
||||
mlir::Value rankCoord = builder.create<fir::CoordinateOp>(
|
||||
loc, builder.getRefType(rankFieldTy), res, rankFieldIndex);
|
||||
builder.create<fir::StoreOp>(loc, rank, rankCoord);
|
||||
return res;
|
||||
}
|
||||
|
||||
// TRAILZ
|
||||
mlir::Value IntrinsicLibrary::genTrailz(mlir::Type resultType,
|
||||
llvm::ArrayRef<mlir::Value> args) {
|
||||
|
||||
30
flang/module/cooperative_groups.f90
Normal file
30
flang/module/cooperative_groups.f90
Normal file
@@ -0,0 +1,30 @@
|
||||
!===-- module/cooperative_groups.f90 ---------------------------------------===!
|
||||
!
|
||||
! Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
! See https://llvm.org/LICENSE.txt for license information.
|
||||
! SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
!
|
||||
!===------------------------------------------------------------------------===!
|
||||
|
||||
! CUDA Fortran cooperative groups
|
||||
|
||||
module cooperative_groups
|
||||
|
||||
use, intrinsic :: __fortran_builtins, only: c_devptr => __builtin_c_devptr
|
||||
|
||||
implicit none
|
||||
|
||||
type :: grid_group
|
||||
type(c_devptr), private :: handle
|
||||
integer(4) :: size
|
||||
integer(4) :: rank
|
||||
end type grid_group
|
||||
|
||||
interface
|
||||
attributes(device) function this_grid()
|
||||
import
|
||||
type(grid_group) :: this_grid
|
||||
end function
|
||||
end interface
|
||||
|
||||
end module
|
||||
52
flang/test/Lower/CUDA/cuda-cooperative.cuf
Normal file
52
flang/test/Lower/CUDA/cuda-cooperative.cuf
Normal file
@@ -0,0 +1,52 @@
|
||||
! RUN: bbc -emit-hlfir -fcuda %s -o - | FileCheck %s
|
||||
|
||||
! Test CUDA Fortran procedures available in cooperative_groups module.
|
||||
|
||||
attributes(grid_global) subroutine g1()
|
||||
use cooperative_groups
|
||||
type(grid_group) :: gg
|
||||
gg = this_grid()
|
||||
end subroutine
|
||||
|
||||
! CHECK: %{{.*}} = fir.alloca !fir.type<_QMcooperative_groupsTgrid_group{_QMcooperative_groupsTgrid_group.handle:!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>,size:i32,rank:i32}>
|
||||
! CHECK: %[[RES:.*]] = fir.alloca !fir.type<_QMcooperative_groupsTgrid_group{_QMcooperative_groupsTgrid_group.handle:!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>,size:i32,rank:i32}>
|
||||
! CHECK: %[[THREAD_ID_X:.*]] = nvvm.read.ptx.sreg.tid.x : i32
|
||||
! CHECK: %[[THREAD_ID_Y:.*]] = nvvm.read.ptx.sreg.tid.y : i32
|
||||
! CHECK: %[[THREAD_ID_Z:.*]] = nvvm.read.ptx.sreg.tid.z : i32
|
||||
! CHECK: %[[BLOCK_ID_X:.*]] = nvvm.read.ptx.sreg.ctaid.x : i32
|
||||
! CHECK: %[[BLOCK_ID_Y:.*]] = nvvm.read.ptx.sreg.ctaid.y : i32
|
||||
! CHECK: %[[BLOCK_ID_Z:.*]] = nvvm.read.ptx.sreg.ctaid.z : i32
|
||||
! CHECK: %[[BLOCK_DIM_X:.*]] = nvvm.read.ptx.sreg.ntid.x : i32
|
||||
! CHECK: %[[BLOCK_DIM_Y:.*]] = nvvm.read.ptx.sreg.ntid.y : i32
|
||||
! CHECK: %[[BLOCK_DIM_Z:.*]] = nvvm.read.ptx.sreg.ntid.z : i32
|
||||
! CHECK: %[[GRID_DIM_X:.*]] = nvvm.read.ptx.sreg.nctaid.x : i32
|
||||
! CHECK: %[[GRID_DIM_Y:.*]] = nvvm.read.ptx.sreg.nctaid.y : i32
|
||||
! CHECK: %[[GRID_DIM_Z:.*]] = nvvm.read.ptx.sreg.nctaid.z : i32
|
||||
|
||||
! CHECK: %[[R1:.*]] = arith.muli %[[BLOCK_DIM_Z]], %[[GRID_DIM_Z]] : i32
|
||||
! CHECK: %[[R2:.*]] = arith.muli %[[BLOCK_DIM_Y]], %[[GRID_DIM_Y]] : i32
|
||||
! CHECK: %[[R3:.*]] = arith.muli %[[BLOCK_DIM_X]], %[[GRID_DIM_X]] : i32
|
||||
! CHECK: %[[R4:.*]] = arith.muli %[[R1]], %[[R2]] : i32
|
||||
! CHECK: %[[SIZE:.*]] = arith.muli %[[R4]], %[[R3]] : i32
|
||||
|
||||
! CHECK: %[[R1:.*]] = arith.muli %[[BLOCK_ID_Z]], %[[GRID_DIM_Y]] : i32
|
||||
! CHECK: %[[R2:.*]] = arith.muli %[[R1]], %[[GRID_DIM_X]] : i32
|
||||
! CHECK: %[[R3:.*]] = arith.muli %[[BLOCK_ID_Y]], %[[GRID_DIM_X]] : i32
|
||||
! CHECK: %[[R4:.*]] = arith.addi %[[R2]], %[[R3]] : i32
|
||||
! CHECK: %[[TMP:.*]] = arith.addi %[[R4]], %[[BLOCK_ID_X]] : i32
|
||||
|
||||
! CHECK: %[[R1:.*]] = arith.muli %[[BLOCK_DIM_X]], %[[BLOCK_DIM_Y]] : i32
|
||||
! CHECK: %[[R2:.*]] = arith.muli %[[R1]], %[[BLOCK_DIM_Z]] : i32
|
||||
! CHECK: %[[R3:.*]] = arith.muli %[[THREAD_ID_Z]], %[[BLOCK_DIM_Y]] : i32
|
||||
! CHECK: %[[R4:.*]] = arith.muli %[[R3]], %[[BLOCK_DIM_X]] : i32
|
||||
! CHECK: %[[R5:.*]] = arith.muli %[[THREAD_ID_Y]], %[[BLOCK_DIM_X]] : i32
|
||||
! CHECK: %[[RES0:.*]] = arith.muli %[[TMP]], %[[R2]] : i32
|
||||
! CHECK: %[[RES1:.*]] = arith.addi %[[RES0]], %[[R4]] : i32
|
||||
! CHECK: %[[RES2:.*]] = arith.addi %[[RES1]], %[[R5]] : i32
|
||||
! CHECK: %[[RES3:.*]] = arith.addi %[[RES2]], %[[THREAD_ID_X]] : i32
|
||||
! CHECK: %[[ONE:.*]] = arith.constant 1 : i32
|
||||
! CHECK: %[[RANK:.*]] = arith.addi %[[RES3]], %[[ONE]] : i32
|
||||
! CHECK: %[[COORD_SIZE:.*]] = fir.coordinate_of %[[RES]], size : (!fir.ref<!fir.type<_QMcooperative_groupsTgrid_group{_QMcooperative_groupsTgrid_group.handle:!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>,size:i32,rank:i32}>>) -> !fir.ref<i32>
|
||||
! CHECK: fir.store %[[SIZE]] to %[[COORD_SIZE]] : !fir.ref<i32>
|
||||
! CHECK: %[[COORD_RANK:.*]] = fir.coordinate_of %[[RES]], rank : (!fir.ref<!fir.type<_QMcooperative_groupsTgrid_group{_QMcooperative_groupsTgrid_group.handle:!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>,size:i32,rank:i32}>>) -> !fir.ref<i32>
|
||||
! CHECK: fir.store %[[RANK]] to %[[COORD_RANK]] : !fir.ref<i32>
|
||||
@@ -15,6 +15,7 @@ set(MODULES
|
||||
"mma"
|
||||
"__cuda_builtins"
|
||||
"__cuda_device"
|
||||
"cooperative_groups"
|
||||
"cudadevice"
|
||||
"ieee_arithmetic"
|
||||
"ieee_exceptions"
|
||||
@@ -60,12 +61,17 @@ if (NOT CMAKE_CROSSCOMPILING)
|
||||
elseif(${filename} STREQUAL "__ppc_intrinsics" OR
|
||||
${filename} STREQUAL "mma")
|
||||
set(depends ${FLANG_INTRINSIC_MODULES_DIR}/__ppc_types.mod)
|
||||
elseif(${filename} STREQUAL "__cuda_device")
|
||||
elseif(${filename} STREQUAL "__cuda_device" OR
|
||||
${filename} STREQUAL "cudadevice" OR
|
||||
${filename} STREQUAL "cooperative_groups")
|
||||
set(opts -fc1 -xcuda)
|
||||
set(depends ${FLANG_INTRINSIC_MODULES_DIR}/__cuda_builtins.mod)
|
||||
elseif(${filename} STREQUAL "cudadevice")
|
||||
set(opts -fc1 -xcuda)
|
||||
set(depends ${FLANG_INTRINSIC_MODULES_DIR}/__cuda_device.mod)
|
||||
if(${filename} STREQUAL "__cuda_device")
|
||||
set(depends ${FLANG_INTRINSIC_MODULES_DIR}/__cuda_builtins.mod)
|
||||
elseif(${filename} STREQUAL "cudadevice")
|
||||
set(depends ${FLANG_INTRINSIC_MODULES_DIR}/__cuda_device.mod)
|
||||
elseif(${filename} STREQUAL "cooperative_groups")
|
||||
set(depends ${FLANG_INTRINSIC_MODULES_DIR}/cudadevice.mod)
|
||||
endif()
|
||||
else()
|
||||
set(depends ${FLANG_INTRINSIC_MODULES_DIR}/__fortran_builtins.mod)
|
||||
if(${filename} STREQUAL "iso_fortran_env")
|
||||
|
||||
Reference in New Issue
Block a user