[flang][cuda] Support any_sync and ballot_sync (#134135)
This commit is contained in:
committed by
GitHub
parent
066787b9bd
commit
db21ae7803
@@ -442,6 +442,8 @@ struct IntrinsicLibrary {
|
||||
fir::ExtendedValue genUnpack(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
|
||||
fir::ExtendedValue genVerify(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
|
||||
mlir::Value genVoteAllSync(mlir::Type, llvm::ArrayRef<mlir::Value>);
|
||||
mlir::Value genVoteAnySync(mlir::Type, llvm::ArrayRef<mlir::Value>);
|
||||
mlir::Value genVoteBallotSync(mlir::Type, llvm::ArrayRef<mlir::Value>);
|
||||
|
||||
/// Implement all conversion functions like DBLE, the first argument is
|
||||
/// the value to convert. There may be an additional KIND arguments that
|
||||
|
||||
@@ -273,6 +273,10 @@ static constexpr IntrinsicHandler handlers[]{
|
||||
&I::genAny,
|
||||
{{{"mask", asAddr}, {"dim", asValue}}},
|
||||
/*isElemental=*/false},
|
||||
{"any_sync",
|
||||
&I::genVoteAnySync,
|
||||
{{{"mask", asValue}, {"pred", asValue}}},
|
||||
/*isElemental=*/false},
|
||||
{"asind", &I::genAsind},
|
||||
{"associated",
|
||||
&I::genAssociated,
|
||||
@@ -335,6 +339,10 @@ static constexpr IntrinsicHandler handlers[]{
|
||||
{"atomicsubi", &I::genAtomicSub, {{{"a", asAddr}, {"v", asValue}}}, false},
|
||||
{"atomicsubl", &I::genAtomicSub, {{{"a", asAddr}, {"v", asValue}}}, false},
|
||||
{"atomicxori", &I::genAtomicXor, {{{"a", asAddr}, {"v", asValue}}}, false},
|
||||
{"ballot_sync",
|
||||
&I::genVoteBallotSync,
|
||||
{{{"mask", asValue}, {"pred", asValue}}},
|
||||
/*isElemental=*/false},
|
||||
{"bessel_jn",
|
||||
&I::genBesselJn,
|
||||
{{{"n1", asValue}, {"n2", asValue}, {"x", asValue}}},
|
||||
@@ -6499,12 +6507,9 @@ IntrinsicLibrary::genMatchAllSync(mlir::Type resultType,
|
||||
return value;
|
||||
}
|
||||
|
||||
// ALL_SYNC
|
||||
mlir::Value IntrinsicLibrary::genVoteAllSync(mlir::Type resultType,
|
||||
llvm::ArrayRef<mlir::Value> args) {
|
||||
assert(args.size() == 2);
|
||||
|
||||
llvm::StringRef funcName = "llvm.nvvm.vote.all.sync";
|
||||
static mlir::Value genVoteSync(fir::FirOpBuilder &builder, mlir::Location loc,
|
||||
llvm::StringRef funcName,
|
||||
llvm::ArrayRef<mlir::Value> args) {
|
||||
mlir::MLIRContext *context = builder.getContext();
|
||||
mlir::Type i32Ty = builder.getI32Type();
|
||||
mlir::FunctionType ftype =
|
||||
@@ -6514,6 +6519,28 @@ mlir::Value IntrinsicLibrary::genVoteAllSync(mlir::Type resultType,
|
||||
return builder.create<fir::CallOp>(loc, funcOp, args).getResult(0);
|
||||
}
|
||||
|
||||
// ALL_SYNC
|
||||
mlir::Value IntrinsicLibrary::genVoteAllSync(mlir::Type resultType,
|
||||
llvm::ArrayRef<mlir::Value> args) {
|
||||
assert(args.size() == 2);
|
||||
return genVoteSync(builder, loc, "llvm.nvvm.vote.all.sync", args);
|
||||
}
|
||||
|
||||
// ANY_SYNC
|
||||
mlir::Value IntrinsicLibrary::genVoteAnySync(mlir::Type resultType,
|
||||
llvm::ArrayRef<mlir::Value> args) {
|
||||
assert(args.size() == 2);
|
||||
return genVoteSync(builder, loc, "llvm.nvvm.vote.any.sync", args);
|
||||
}
|
||||
|
||||
// BALLOT_SYNC
|
||||
mlir::Value
|
||||
IntrinsicLibrary::genVoteBallotSync(mlir::Type resultType,
|
||||
llvm::ArrayRef<mlir::Value> args) {
|
||||
assert(args.size() == 2);
|
||||
return genVoteSync(builder, loc, "llvm.nvvm.vote.ballot.sync", args);
|
||||
}
|
||||
|
||||
// MATCH_ANY_SYNC
|
||||
mlir::Value
|
||||
IntrinsicLibrary::genMatchAnySync(mlir::Type resultType,
|
||||
|
||||
@@ -1022,6 +1022,20 @@ implicit none
|
||||
end function
|
||||
end interface
|
||||
|
||||
interface any_sync
|
||||
attributes(device) integer function any_sync(mask, pred)
|
||||
!dir$ ignore_tkr(d) mask, (td) pred
|
||||
integer, value :: mask, pred
|
||||
end function
|
||||
end interface
|
||||
|
||||
interface ballot_sync
|
||||
attributes(device) integer function ballot_sync(mask, pred)
|
||||
!dir$ ignore_tkr(d) mask, (td) pred
|
||||
integer, value :: mask, pred
|
||||
end function
|
||||
end interface
|
||||
|
||||
! LDCG
|
||||
interface __ldcg
|
||||
attributes(device) pure integer(4) function __ldcg_i4(x) bind(c)
|
||||
|
||||
@@ -299,12 +299,14 @@ end
|
||||
attributes(device) subroutine testVote()
|
||||
integer :: a, ipred, mask, v32
|
||||
a = all_sync(mask, v32)
|
||||
|
||||
a = any_sync(mask, v32)
|
||||
a = ballot_sync(mask, v32)
|
||||
end subroutine
|
||||
|
||||
! CHECK-LABEL: func.func @_QPtestvote()
|
||||
! CHECK: fir.call @llvm.nvvm.vote.all.sync
|
||||
|
||||
! CHECK: fir.call @llvm.nvvm.vote.any.sync
|
||||
! CHECK: fir.call @llvm.nvvm.vote.ballot.sync
|
||||
|
||||
! CHECK-DAG: func.func private @__ldca_i4x4_(!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<4xi32>>)
|
||||
! CHECK-DAG: func.func private @__ldcg_i4x4_(!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<4xi32>>)
|
||||
|
||||
Reference in New Issue
Block a user