[flang][cuda] Lower syncwarp to NVVM intrinsic (#126164)
This commit is contained in:
committed by
GitHub
parent
b00b193728
commit
070c888292
@@ -406,6 +406,7 @@ struct IntrinsicLibrary {
|
||||
mlir::Value genSyncThreadsAnd(mlir::Type, llvm::ArrayRef<mlir::Value>);
|
||||
mlir::Value genSyncThreadsCount(mlir::Type, llvm::ArrayRef<mlir::Value>);
|
||||
mlir::Value genSyncThreadsOr(mlir::Type, llvm::ArrayRef<mlir::Value>);
|
||||
void genSyncWarp(llvm::ArrayRef<fir::ExtendedValue>);
|
||||
fir::ExtendedValue genSystem(std::optional<mlir::Type>,
|
||||
mlir::ArrayRef<fir::ExtendedValue> args);
|
||||
void genSystemClock(llvm::ArrayRef<fir::ExtendedValue>);
|
||||
|
||||
@@ -680,6 +680,7 @@ static constexpr IntrinsicHandler handlers[]{
|
||||
{"syncthreads_and", &I::genSyncThreadsAnd, {}, /*isElemental=*/false},
|
||||
{"syncthreads_count", &I::genSyncThreadsCount, {}, /*isElemental=*/false},
|
||||
{"syncthreads_or", &I::genSyncThreadsOr, {}, /*isElemental=*/false},
|
||||
{"syncwarp", &I::genSyncWarp, {}, /*isElemental=*/false},
|
||||
{"system",
|
||||
&I::genSystem,
|
||||
{{{"command", asBox}, {"exitstat", asBox, handleDynamicOptional}}},
|
||||
@@ -7704,6 +7705,18 @@ IntrinsicLibrary::genSyncThreadsOr(mlir::Type resultType,
|
||||
return builder.create<fir::CallOp>(loc, funcOp, args).getResult(0);
|
||||
}
|
||||
|
||||
// SYNCWARP
|
||||
void IntrinsicLibrary::genSyncWarp(llvm::ArrayRef<fir::ExtendedValue> args) {
|
||||
assert(args.size() == 1);
|
||||
constexpr llvm::StringLiteral funcName = "llvm.nvvm.bar.warp.sync";
|
||||
mlir::Value mask = fir::getBase(args[0]);
|
||||
mlir::FunctionType funcType =
|
||||
mlir::FunctionType::get(builder.getContext(), {mask.getType()}, {});
|
||||
auto funcOp = builder.createFunction(loc, funcName, funcType);
|
||||
llvm::SmallVector<mlir::Value> argsList{mask};
|
||||
builder.create<fir::CallOp>(loc, funcOp, argsList);
|
||||
}
|
||||
|
||||
// SYSTEM
|
||||
fir::ExtendedValue
|
||||
IntrinsicLibrary::genSystem(std::optional<mlir::Type> resultType,
|
||||
|
||||
@@ -49,7 +49,7 @@ implicit none
|
||||
public :: syncthreads_or
|
||||
|
||||
interface
|
||||
attributes(device) subroutine syncwarp(mask) bind(c, name='__syncwarp')
|
||||
attributes(device) subroutine syncwarp(mask)
|
||||
integer, value :: mask
|
||||
end subroutine
|
||||
end interface
|
||||
|
||||
@@ -47,7 +47,7 @@ end
|
||||
|
||||
! CHECK-LABEL: func.func @_QPdevsub() attributes {cuf.proc_attr = #cuf.cuda_proc<global>}
|
||||
! CHECK: fir.call @llvm.nvvm.barrier0() fastmath<contract> : () -> ()
|
||||
! CHECK: fir.call @__syncwarp(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (i32) -> ()
|
||||
! CHECK: fir.call @llvm.nvvm.bar.warp.sync(%c1{{.*}}) fastmath<contract> : (i32) -> ()
|
||||
! CHECK: fir.call @llvm.nvvm.membar.gl() fastmath<contract> : () -> ()
|
||||
! CHECK: fir.call @llvm.nvvm.membar.cta() fastmath<contract> : () -> ()
|
||||
! CHECK: fir.call @llvm.nvvm.membar.sys() fastmath<contract> : () -> ()
|
||||
@@ -102,13 +102,13 @@ end
|
||||
! CHECK-LABEL: func.func @_QPhost1()
|
||||
! CHECK: cuf.kernel
|
||||
! CHECK: fir.call @llvm.nvvm.barrier0() fastmath<contract> : () -> ()
|
||||
! CHECK: fir.call @__syncwarp(%c1{{.*}}) proc_attrs<bind_c> fastmath<contract> : (i32) -> ()
|
||||
! CHECK: fir.call @llvm.nvvm.bar.warp.sync(%c1{{.*}}) fastmath<contract> : (i32) -> ()
|
||||
! CHECK: fir.call @llvm.nvvm.barrier0.and(%c1{{.*}}) fastmath<contract> : (i32) -> i32
|
||||
! CHECK: fir.call @llvm.nvvm.barrier0.popc(%c1{{.*}}) fastmath<contract> : (i32) -> i32
|
||||
! CHECK: fir.call @llvm.nvvm.barrier0.or(%c1{{.*}}) fastmath<contract> : (i32) -> i32
|
||||
|
||||
! CHECK: func.func private @llvm.nvvm.barrier0()
|
||||
! CHECK: func.func private @__syncwarp(i32) attributes {cuf.proc_attr = #cuf.cuda_proc<device>, fir.bindc_name = "__syncwarp", fir.proc_attrs = #fir.proc_attrs<bind_c>}
|
||||
! CHECK: func.func private @llvm.nvvm.bar.warp.sync(i32)
|
||||
! CHECK: func.func private @llvm.nvvm.membar.gl()
|
||||
! CHECK: func.func private @llvm.nvvm.membar.cta()
|
||||
! CHECK: func.func private @llvm.nvvm.membar.sys()
|
||||
|
||||
Reference in New Issue
Block a user