[flang][cuda] Lower syncwarp to NVVM intrinsic (#126164)

This commit is contained in:
Valentin Clement (バレンタイン クレメン)
2025-02-06 19:43:21 -08:00
committed by GitHub
parent b00b193728
commit 070c888292
4 changed files with 18 additions and 4 deletions

View File

@@ -406,6 +406,7 @@ struct IntrinsicLibrary {
mlir::Value genSyncThreadsAnd(mlir::Type, llvm::ArrayRef<mlir::Value>);
mlir::Value genSyncThreadsCount(mlir::Type, llvm::ArrayRef<mlir::Value>);
mlir::Value genSyncThreadsOr(mlir::Type, llvm::ArrayRef<mlir::Value>);
void genSyncWarp(llvm::ArrayRef<fir::ExtendedValue>);
fir::ExtendedValue genSystem(std::optional<mlir::Type>,
mlir::ArrayRef<fir::ExtendedValue> args);
void genSystemClock(llvm::ArrayRef<fir::ExtendedValue>);

View File

@@ -680,6 +680,7 @@ static constexpr IntrinsicHandler handlers[]{
{"syncthreads_and", &I::genSyncThreadsAnd, {}, /*isElemental=*/false},
{"syncthreads_count", &I::genSyncThreadsCount, {}, /*isElemental=*/false},
{"syncthreads_or", &I::genSyncThreadsOr, {}, /*isElemental=*/false},
{"syncwarp", &I::genSyncWarp, {}, /*isElemental=*/false},
{"system",
&I::genSystem,
{{{"command", asBox}, {"exitstat", asBox, handleDynamicOptional}}},
@@ -7704,6 +7705,18 @@ IntrinsicLibrary::genSyncThreadsOr(mlir::Type resultType,
return builder.create<fir::CallOp>(loc, funcOp, args).getResult(0);
}
// SYNCWARP
void IntrinsicLibrary::genSyncWarp(llvm::ArrayRef<fir::ExtendedValue> args) {
assert(args.size() == 1);
constexpr llvm::StringLiteral funcName = "llvm.nvvm.bar.warp.sync";
mlir::Value mask = fir::getBase(args[0]);
mlir::FunctionType funcType =
mlir::FunctionType::get(builder.getContext(), {mask.getType()}, {});
auto funcOp = builder.createFunction(loc, funcName, funcType);
llvm::SmallVector<mlir::Value> argsList{mask};
builder.create<fir::CallOp>(loc, funcOp, argsList);
}
// SYSTEM
fir::ExtendedValue
IntrinsicLibrary::genSystem(std::optional<mlir::Type> resultType,

View File

@@ -49,7 +49,7 @@ implicit none
public :: syncthreads_or
interface
attributes(device) subroutine syncwarp(mask) bind(c, name='__syncwarp')
attributes(device) subroutine syncwarp(mask)
integer, value :: mask
end subroutine
end interface

View File

@@ -47,7 +47,7 @@ end
! CHECK-LABEL: func.func @_QPdevsub() attributes {cuf.proc_attr = #cuf.cuda_proc<global>}
! CHECK: fir.call @llvm.nvvm.barrier0() fastmath<contract> : () -> ()
! CHECK: fir.call @__syncwarp(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (i32) -> ()
! CHECK: fir.call @llvm.nvvm.bar.warp.sync(%c1{{.*}}) fastmath<contract> : (i32) -> ()
! CHECK: fir.call @llvm.nvvm.membar.gl() fastmath<contract> : () -> ()
! CHECK: fir.call @llvm.nvvm.membar.cta() fastmath<contract> : () -> ()
! CHECK: fir.call @llvm.nvvm.membar.sys() fastmath<contract> : () -> ()
@@ -102,13 +102,13 @@ end
! CHECK-LABEL: func.func @_QPhost1()
! CHECK: cuf.kernel
! CHECK: fir.call @llvm.nvvm.barrier0() fastmath<contract> : () -> ()
! CHECK: fir.call @__syncwarp(%c1{{.*}}) proc_attrs<bind_c> fastmath<contract> : (i32) -> ()
! CHECK: fir.call @llvm.nvvm.bar.warp.sync(%c1{{.*}}) fastmath<contract> : (i32) -> ()
! CHECK: fir.call @llvm.nvvm.barrier0.and(%c1{{.*}}) fastmath<contract> : (i32) -> i32
! CHECK: fir.call @llvm.nvvm.barrier0.popc(%c1{{.*}}) fastmath<contract> : (i32) -> i32
! CHECK: fir.call @llvm.nvvm.barrier0.or(%c1{{.*}}) fastmath<contract> : (i32) -> i32
! CHECK: func.func private @llvm.nvvm.barrier0()
! CHECK: func.func private @__syncwarp(i32) attributes {cuf.proc_attr = #cuf.cuda_proc<device>, fir.bindc_name = "__syncwarp", fir.proc_attrs = #fir.proc_attrs<bind_c>}
! CHECK: func.func private @llvm.nvvm.bar.warp.sync(i32)
! CHECK: func.func private @llvm.nvvm.membar.gl()
! CHECK: func.func private @llvm.nvvm.membar.cta()
! CHECK: func.func private @llvm.nvvm.membar.sys()