Reland "[NVPTX] Unify and extend barrier{.cta} intrinsic support" (#141143)

Note: This relands #140615 adding a ".count" suffix to the non-".all"
variants.

Our current intrinsic support for barrier intrinsics is confusing and
incomplete, with multiple intrinsics mapping to the same instruction and
intrinsic names not clearly conveying intrinsic semantics. Further, we
lack support for some variants. This change unifies the IR
representation to a single consistently named set of intrinsics.

- llvm.nvvm.barrier.cta.sync.aligned.all(i32)
- llvm.nvvm.barrier.cta.sync.aligned.count(i32, i32)
- llvm.nvvm.barrier.cta.arrive.aligned.count(i32, i32)
- llvm.nvvm.barrier.cta.sync.all(i32)
- llvm.nvvm.barrier.cta.sync.count(i32, i32)
- llvm.nvvm.barrier.cta.arrive.count(i32, i32)

The following Auto-Upgrade rules are used to maintain compatibility with
IR using the legacy intrinsics:

* llvm.nvvm.barrier0 --> llvm.nvvm.barrier.cta.sync.aligned.all(0)
* llvm.nvvm.barrier.n --> llvm.nvvm.barrier.cta.sync.aligned.all(x)
* llvm.nvvm.bar.sync --> llvm.nvvm.barrier.cta.sync.aligned.all(x)
* llvm.nvvm.barrier --> llvm.nvvm.barrier.cta.sync.aligned.count(x, y)
* llvm.nvvm.barrier.sync --> llvm.nvvm.barrier.cta.sync.all(x)
* llvm.nvvm.barrier.sync.cnt --> llvm.nvvm.barrier.cta.sync.count(x, y)
This commit is contained in:
Alex MacLean
2025-05-22 19:38:10 -07:00
committed by GitHub
parent 5d76555f93
commit 3a84a4e55d
21 changed files with 349 additions and 196 deletions

View File

@@ -1160,6 +1160,22 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
case NVPTX::BI__nvvm_fence_sc_cluster:
return Builder.CreateCall(
CGM.getIntrinsic(Intrinsic::nvvm_fence_sc_cluster));
case NVPTX::BI__nvvm_bar_sync:
return Builder.CreateCall(
CGM.getIntrinsic(Intrinsic::nvvm_barrier_cta_sync_aligned_all),
EmitScalarExpr(E->getArg(0)));
case NVPTX::BI__syncthreads:
return Builder.CreateCall(
CGM.getIntrinsic(Intrinsic::nvvm_barrier_cta_sync_aligned_all),
Builder.getInt32(0));
case NVPTX::BI__nvvm_barrier_sync:
return Builder.CreateCall(
CGM.getIntrinsic(Intrinsic::nvvm_barrier_cta_sync_all),
EmitScalarExpr(E->getArg(0)));
case NVPTX::BI__nvvm_barrier_sync_cnt:
return Builder.CreateCall(
CGM.getIntrinsic(Intrinsic::nvvm_barrier_cta_sync_count),
{EmitScalarExpr(E->getArg(0)), EmitScalarExpr(E->getArg(1))});
default:
return nullptr;
}

View File

@@ -32,10 +32,10 @@ __device__ void nvvm_sync(unsigned mask, int i, float f, int a, int b,
// CHECK: call void @llvm.nvvm.bar.warp.sync(i32
// expected-error@+1 {{'__nvvm_bar_warp_sync' needs target feature ptx60}}
__nvvm_bar_warp_sync(mask);
// CHECK: call void @llvm.nvvm.barrier.sync(i32
// CHECK: call void @llvm.nvvm.barrier.cta.sync.all(i32
// expected-error@+1 {{'__nvvm_barrier_sync' needs target feature ptx60}}
__nvvm_barrier_sync(mask);
// CHECK: call void @llvm.nvvm.barrier.sync.cnt(i32
// CHECK: call void @llvm.nvvm.barrier.cta.sync.count(i32
// expected-error@+1 {{'__nvvm_barrier_sync_cnt' needs target feature ptx60}}
__nvvm_barrier_sync_cnt(mask, i);

View File

@@ -198,7 +198,7 @@ __device__ int read_pms() {
__device__ void sync() {
// CHECK: call void @llvm.nvvm.bar.sync(i32 0)
// CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
__nvvm_bar_sync(0);
@@ -259,7 +259,7 @@ __device__ void nvvm_math(float f1, float f2, double d1, double d2) {
__nvvm_membar_gl();
// CHECK: call void @llvm.nvvm.membar.sys()
__nvvm_membar_sys();
// CHECK: call void @llvm.nvvm.barrier0()
// CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
__syncthreads();
}

View File

@@ -887,7 +887,7 @@ __gpu_kernel void foo() {
// NVPTX-LABEL: define internal void @__gpu_sync_threads(
// NVPTX-SAME: ) #[[ATTR0]] {
// NVPTX-NEXT: [[ENTRY:.*:]]
// NVPTX-NEXT: call void @llvm.nvvm.barrier0()
// NVPTX-NEXT: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
// NVPTX-NEXT: ret void
//
//

View File

@@ -199,21 +199,59 @@ map in the following way to CUDA builtins:
Barriers
--------
'``llvm.nvvm.barrier0``'
^^^^^^^^^^^^^^^^^^^^^^^^^^^
'``llvm.nvvm.barrier.cta.*``'
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Syntax:
"""""""
.. code-block:: llvm
declare void @llvm.nvvm.barrier0()
declare void @llvm.nvvm.barrier.cta.sync.count(i32 %id, i32 %n)
declare void @llvm.nvvm.barrier.cta.sync.all(i32 %id)
declare void @llvm.nvvm.barrier.cta.arrive.count(i32 %id, i32 %n)
declare void @llvm.nvvm.barrier.cta.sync.aligned.count(i32 %id, i32 %n)
declare void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 %id)
declare void @llvm.nvvm.barrier.cta.arrive.aligned.count(i32 %id, i32 %n)
Overview:
"""""""""
The '``@llvm.nvvm.barrier0()``' intrinsic emits a PTX ``bar.sync 0``
instruction, equivalent to the ``__syncthreads()`` call in CUDA.
The '``@llvm.nvvm.barrier.cta.*``' family of intrinsics perform barrier
synchronization and communication within a CTA. They can be used by the threads
within the CTA for synchronization and communication.
Semantics:
""""""""""
Operand %id specifies a logical barrier resource and must fall within the range
0 through 15. When present, operand %n specifies the number of threads
participating in the barrier. When specifying a thread count, the value must be
a multiple of the warp size. With the '``@llvm.nvvm.barrier.cta.sync.*``'
variants, the '``.all``' suffix indicates that all threads in the CTA should
participate in the barrier while the '``.count``' suffix indicates that only
the threads specified by the %n operand should participate in the barrier.
All forms of the '``@llvm.nvvm.barrier.cta.*``' intrinsic cause the executing
thread to wait for all non-exited threads from its warp and then marks the
warp's arrival at the barrier. In addition to signaling its arrival at the
barrier, the '``@llvm.nvvm.barrier.cta.sync.*``' intrinsics cause the executing
thread to wait for non-exited threads of all other warps participating in the
barrier to arrive. On the other hand, the '``@llvm.nvvm.barrier.cta.arrive.*``'
intrinsic does not cause the executing thread to wait for threads of other
participating warps.
When a barrier completes, the waiting threads are restarted without delay,
and the barrier is reinitialized so that it can be immediately reused.
The '``@llvm.nvvm.barrier.cta.*``' intrinsic has an optional '``.aligned``'
modifier to indicate textual alignment of the barrier. When specified, it
indicates that all threads in the CTA will execute the same
'``@llvm.nvvm.barrier.cta.*``' instruction. In conditionally executed code, an
aligned '``@llvm.nvvm.barrier.cta.*``' instruction should only be used if it is
known that all threads in the CTA evaluate the condition identically, otherwise
behavior is undefined.
Electing a thread
-----------------

View File

@@ -128,6 +128,12 @@
// * llvm.nvvm.swap.lo.hi.b64 --> llvm.fshl(x, x, 32)
// * llvm.nvvm.atomic.load.inc.32 --> atomicrmw uinc_wrap
// * llvm.nvvm.atomic.load.dec.32 --> atomicrmw udec_wrap
// * llvm.nvvm.barrier0 --> llvm.nvvm.barrier.cta.sync.aligned.all(0)
// * llvm.nvvm.barrier.n --> llvm.nvvm.barrier.cta.sync.aligned.all(x)
// * llvm.nvvm.bar.sync --> llvm.nvvm.barrier.cta.sync.aligned.all(x)
// * llvm.nvvm.barrier --> llvm.nvvm.barrier.cta.sync.aligned(x, y)
// * llvm.nvvm.barrier.sync --> llvm.nvvm.barrier.cta.sync.all(x)
// * llvm.nvvm.barrier.sync.cnt --> llvm.nvvm.barrier.cta.sync(x, y)
def llvm_global_ptr_ty : LLVMQualPointerType<1>; // (global)ptr
def llvm_shared_ptr_ty : LLVMQualPointerType<3>; // (shared)ptr
@@ -1278,18 +1284,6 @@ let TargetPrefix = "nvvm" in {
defm int_nvvm_atomic_cas_gen_i : PTXAtomicWithScope3<llvm_anyint_ty>;
// Bar.Sync
// The builtin for "bar.sync 0" is called __syncthreads. Unlike most of the
// intrinsics in this file, this one is a user-facing API.
def int_nvvm_barrier0 : ClangBuiltin<"__syncthreads">,
Intrinsic<[], [], [IntrConvergent, IntrNoCallback]>;
// Synchronize all threads in the CTA at barrier 'n'.
def int_nvvm_barrier_n : ClangBuiltin<"__nvvm_bar_n">,
Intrinsic<[], [llvm_i32_ty], [IntrConvergent, IntrNoCallback]>;
// Synchronize 'm', a multiple of warp size, (arg 2) threads in
// the CTA at barrier 'n' (arg 1).
def int_nvvm_barrier : ClangBuiltin<"__nvvm_bar">,
Intrinsic<[], [llvm_i32_ty, llvm_i32_ty], [IntrConvergent, IntrNoCallback]>;
def int_nvvm_barrier0_popc : ClangBuiltin<"__nvvm_bar0_popc">,
Intrinsic<[llvm_i32_ty], [llvm_i32_ty], [IntrConvergent, IntrNoCallback]>;
def int_nvvm_barrier0_and : ClangBuiltin<"__nvvm_bar0_and">,
@@ -1297,16 +1291,21 @@ let TargetPrefix = "nvvm" in {
def int_nvvm_barrier0_or : ClangBuiltin<"__nvvm_bar0_or">,
Intrinsic<[llvm_i32_ty], [llvm_i32_ty], [IntrConvergent, IntrNoCallback]>;
def int_nvvm_bar_sync : NVVMBuiltin,
Intrinsic<[], [llvm_i32_ty], [IntrConvergent, IntrNoCallback]>;
def int_nvvm_bar_warp_sync : NVVMBuiltin,
Intrinsic<[], [llvm_i32_ty], [IntrConvergent, IntrNoCallback]>;
// barrier.sync id[, cnt]
def int_nvvm_barrier_sync : NVVMBuiltin,
Intrinsic<[], [llvm_i32_ty], [IntrConvergent, IntrNoCallback]>;
def int_nvvm_barrier_sync_cnt : NVVMBuiltin,
Intrinsic<[], [llvm_i32_ty, llvm_i32_ty], [IntrConvergent, IntrNoCallback]>;
// barrier{.cta}.sync{.aligned} a{, b};
// barrier{.cta}.arrive{.aligned} a, b;
let IntrProperties = [IntrConvergent, IntrNoCallback] in {
foreach align = ["", "_aligned"] in {
def int_nvvm_barrier_cta_sync # align # _all :
Intrinsic<[], [llvm_i32_ty]>;
def int_nvvm_barrier_cta_sync # align # _count :
Intrinsic<[], [llvm_i32_ty, llvm_i32_ty]>;
def int_nvvm_barrier_cta_arrive # align # _count :
Intrinsic<[], [llvm_i32_ty, llvm_i32_ty]>;
}
}
// barrier.cluster.[wait, arrive, arrive.relaxed]
def int_nvvm_barrier_cluster_arrive :

View File

@@ -1343,12 +1343,9 @@ static bool upgradeIntrinsicFunction1(Function *F, Function *&NewFn,
// nvvm.abs.{i,ii}
Expand =
Name == "i" || Name == "ll" || Name == "bf16" || Name == "bf16x2";
else if (Name == "fabs.f" || Name == "fabs.ftz.f" || Name == "fabs.d")
else if (Name.consume_front("fabs."))
// nvvm.fabs.{f,ftz.f,d}
Expand = true;
else if (Name == "clz.ll" || Name == "popc.ll" || Name == "h2f" ||
Name == "swap.lo.hi.b64")
Expand = true;
Expand = Name == "f" || Name == "ftz.f" || Name == "d";
else if (Name.consume_front("max.") || Name.consume_front("min."))
// nvvm.{min,max}.{i,ii,ui,ull}
Expand = Name == "s" || Name == "i" || Name == "ll" || Name == "us" ||
@@ -1380,7 +1377,18 @@ static bool upgradeIntrinsicFunction1(Function *F, Function *&NewFn,
Expand = (Name.starts_with("i.") || Name.starts_with("f.") ||
Name.starts_with("p."));
else
Expand = false;
Expand = StringSwitch<bool>(Name)
.Case("barrier0", true)
.Case("barrier.n", true)
.Case("barrier.sync.cnt", true)
.Case("barrier.sync", true)
.Case("barrier", true)
.Case("bar.sync", true)
.Case("clz.ll", true)
.Case("popc.ll", true)
.Case("h2f", true)
.Case("swap.lo.hi.b64", true)
.Default(false);
if (Expand) {
NewFn = nullptr;
@@ -2478,6 +2486,21 @@ static Value *upgradeNVVMIntrinsicCall(StringRef Name, CallBase *CI,
MDNode *MD = MDNode::get(Builder.getContext(), {});
LD->setMetadata(LLVMContext::MD_invariant_load, MD);
return LD;
} else if (Name == "barrier0" || Name == "barrier.n" || Name == "bar.sync") {
Value *Arg =
Name.ends_with('0') ? Builder.getInt32(0) : CI->getArgOperand(0);
Rep = Builder.CreateIntrinsic(Intrinsic::nvvm_barrier_cta_sync_aligned_all,
{}, {Arg});
} else if (Name == "barrier") {
Rep = Builder.CreateIntrinsic(
Intrinsic::nvvm_barrier_cta_sync_aligned_count, {},
{CI->getArgOperand(0), CI->getArgOperand(1)});
} else if (Name == "barrier.sync") {
Rep = Builder.CreateIntrinsic(Intrinsic::nvvm_barrier_cta_sync_all, {},
{CI->getArgOperand(0)});
} else if (Name == "barrier.sync.cnt") {
Rep = Builder.CreateIntrinsic(Intrinsic::nvvm_barrier_cta_sync_count, {},
{CI->getArgOperand(0), CI->getArgOperand(1)});
} else {
Intrinsic::ID IID = shouldUpgradeNVPTXBF16Intrinsic(Name);
if (IID != Intrinsic::not_intrinsic &&

View File

@@ -67,15 +67,6 @@ class THREADMASK_INFO<bit sync> {
// Synchronization and shuffle functions
//-----------------------------------
let isConvergent = true in {
def INT_BARRIER0 : NVPTXInst<(outs), (ins),
"bar.sync \t0;",
[(int_nvvm_barrier0)]>;
def INT_BARRIERN : NVPTXInst<(outs), (ins Int32Regs:$src1),
"bar.sync \t$src1;",
[(int_nvvm_barrier_n i32:$src1)]>;
def INT_BARRIER : NVPTXInst<(outs), (ins Int32Regs:$src1, Int32Regs:$src2),
"bar.sync \t$src1, $src2;",
[(int_nvvm_barrier i32:$src1, i32:$src2)]>;
def INT_BARRIER0_POPC : NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$pred),
!strconcat("{{ \n\t",
".reg .pred \t%p1; \n\t",
@@ -102,9 +93,6 @@ def INT_BARRIER0_OR : NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$pred),
"}}"),
[(set i32:$dst, (int_nvvm_barrier0_or i32:$pred))]>;
def INT_BAR_SYNC : NVPTXInst<(outs), (ins i32imm:$i), "bar.sync \t$i;",
[(int_nvvm_bar_sync imm:$i)]>;
def INT_BAR_WARP_SYNC_I : NVPTXInst<(outs), (ins i32imm:$i), "bar.warp.sync \t$i;",
[(int_nvvm_bar_warp_sync imm:$i)]>,
Requires<[hasPTX<60>, hasSM<30>]>;
@@ -112,29 +100,44 @@ def INT_BAR_WARP_SYNC_R : NVPTXInst<(outs), (ins Int32Regs:$i), "bar.warp.sync \
[(int_nvvm_bar_warp_sync i32:$i)]>,
Requires<[hasPTX<60>, hasSM<30>]>;
def INT_BARRIER_SYNC_I : NVPTXInst<(outs), (ins i32imm:$i), "barrier.sync \t$i;",
[(int_nvvm_barrier_sync imm:$i)]>,
Requires<[hasPTX<60>, hasSM<30>]>;
def INT_BARRIER_SYNC_R : NVPTXInst<(outs), (ins Int32Regs:$i), "barrier.sync \t$i;",
[(int_nvvm_barrier_sync i32:$i)]>,
Requires<[hasPTX<60>, hasSM<30>]>;
multiclass BARRIER1<string asmstr, Intrinsic intrinsic, list<Predicate> requires = []> {
def _i : BasicNVPTXInst<(outs), (ins i32imm:$i), asmstr,
[(intrinsic imm:$i)]>,
Requires<requires>;
def INT_BARRIER_SYNC_CNT_RR : NVPTXInst<(outs), (ins Int32Regs:$id, Int32Regs:$cnt),
"barrier.sync \t$id, $cnt;",
[(int_nvvm_barrier_sync_cnt i32:$id, i32:$cnt)]>,
Requires<[hasPTX<60>, hasSM<30>]>;
def INT_BARRIER_SYNC_CNT_RI : NVPTXInst<(outs), (ins Int32Regs:$id, i32imm:$cnt),
"barrier.sync \t$id, $cnt;",
[(int_nvvm_barrier_sync_cnt i32:$id, imm:$cnt)]>,
Requires<[hasPTX<60>, hasSM<30>]>;
def INT_BARRIER_SYNC_CNT_IR : NVPTXInst<(outs), (ins i32imm:$id, Int32Regs:$cnt),
"barrier.sync \t$id, $cnt;",
[(int_nvvm_barrier_sync_cnt imm:$id, i32:$cnt)]>,
Requires<[hasPTX<60>, hasSM<30>]>;
def INT_BARRIER_SYNC_CNT_II : NVPTXInst<(outs), (ins i32imm:$id, i32imm:$cnt),
"barrier.sync \t$id, $cnt;",
[(int_nvvm_barrier_sync_cnt imm:$id, imm:$cnt)]>,
Requires<[hasPTX<60>, hasSM<30>]>;
def _r : BasicNVPTXInst<(outs), (ins Int32Regs:$i), asmstr,
[(intrinsic i32:$i)]>,
Requires<requires>;
}
multiclass BARRIER2<string asmstr, Intrinsic intrinsic, list<Predicate> requires = []> {
def _rr : BasicNVPTXInst<(outs), (ins Int32Regs:$i, Int32Regs:$j), asmstr,
[(intrinsic i32:$i, i32:$j)]>,
Requires<requires>;
def _ri : BasicNVPTXInst<(outs), (ins Int32Regs:$i, i32imm:$j), asmstr,
[(intrinsic i32:$i, imm:$j)]>,
Requires<requires>;
def _ir : BasicNVPTXInst<(outs), (ins i32imm:$i, Int32Regs:$j), asmstr,
[(intrinsic imm:$i, i32:$j)]>,
Requires<requires>;
def _ii : BasicNVPTXInst<(outs), (ins i32imm:$i, i32imm:$j), asmstr,
[(intrinsic imm:$i, imm:$j)]>,
Requires<requires>;
}
// Note the "bar.sync" variants could be renamed to the equivalent corresponding
// "barrier.*.aligned" variants. We use the older syntax for compatibility with
// older versions of the PTX ISA.
defm BARRIER_CTA_SYNC_ALIGNED_ALL : BARRIER1<"bar.sync", int_nvvm_barrier_cta_sync_aligned_all>;
defm BARRIER_CTA_SYNC_ALIGNED : BARRIER2<"bar.sync", int_nvvm_barrier_cta_sync_aligned_count>;
defm BARRIER_CTA_ARRIVE_ALIGNED : BARRIER2<"bar.arrive", int_nvvm_barrier_cta_arrive_aligned_count>;
defm BARRIER_CTA_SYNC_ALL : BARRIER1<"barrier.sync", int_nvvm_barrier_cta_sync_all, [hasPTX<60>]>;
defm BARRIER_CTA_SYNC : BARRIER2<"barrier.sync", int_nvvm_barrier_cta_sync_count, [hasPTX<60>]>;
defm BARRIER_CTA_ARRIVE : BARRIER2<"barrier.arrive", int_nvvm_barrier_cta_arrive_count, [hasPTX<60>]>;
class INT_BARRIER_CLUSTER<string variant, Intrinsic Intr,
list<Predicate> Preds = [hasPTX<78>, hasSM<90>]>:

View File

@@ -2150,7 +2150,8 @@ struct AANoUnwindCallSite final
bool AANoSync::isAlignedBarrier(const CallBase &CB, bool ExecutedAligned) {
switch (CB.getIntrinsicID()) {
case Intrinsic::nvvm_barrier0:
case Intrinsic::nvvm_barrier_cta_sync_aligned_all:
case Intrinsic::nvvm_barrier_cta_sync_aligned_count:
case Intrinsic::nvvm_barrier0_and:
case Intrinsic::nvvm_barrier0_or:
case Intrinsic::nvvm_barrier0_popc:

View File

@@ -11,28 +11,15 @@ target triple = "nvptx64-nvidia-cuda"
; CHECK-LABEL: @bar_sync
; CHECK: store
; CHECK: tail call void @llvm.nvvm.bar.sync(i32 0)
; CHECK: tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
; CHECK: load
define dso_local i32 @bar_sync(i32 %0) local_unnamed_addr {
store i32 %0, ptr addrspacecast (ptr addrspace(3) @s to ptr), align 4
tail call void @llvm.nvvm.bar.sync(i32 0)
tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
%2 = load i32, ptr addrspacecast (ptr addrspace(3) @s to ptr), align 4
ret i32 %2
}
declare void @llvm.nvvm.bar.sync(i32) #0
; CHECK-LABEL: @barrier0
; CHECK: store
; CHECK: tail call void @llvm.nvvm.barrier0()
; CHECK: load
define dso_local i32 @barrier0(i32 %0) local_unnamed_addr {
store i32 %0, ptr addrspacecast (ptr addrspace(3) @s to ptr), align 4
tail call void @llvm.nvvm.barrier0()
%2 = load i32, ptr addrspacecast (ptr addrspace(3) @s to ptr), align 4
ret i32 %2
}
declare void @llvm.nvvm.barrier0() #0
declare void @llvm.nvvm.barrier.cta.sync.aligned.all(i32) #0
attributes #0 = { convergent nounwind }

View File

@@ -78,6 +78,13 @@ declare void @llvm.nvvm.cp.async.bulk.tensor.g2s.im2col.3d(ptr addrspace(3) %d,
declare void @llvm.nvvm.cp.async.bulk.tensor.g2s.im2col.4d(ptr addrspace(3) %d, ptr addrspace(3) %bar, ptr %tm, i32 %d0, i32 %d1, i32 %d2, i32 %d3, i16 %im2col0, i16 %im2col1, i16 %mc, i64 %ch, i1 %f1, i1 %f2);
declare void @llvm.nvvm.cp.async.bulk.tensor.g2s.im2col.5d(ptr addrspace(3) %d, ptr addrspace(3) %bar, ptr %tm, i32 %d0, i32 %d1, i32 %d2, i32 %d3, i32 %d4, i16 %im2col0, i16 %im2col1, i16 %im2col2, i16 %mc, i64 %ch, i1 %f1, i1 %f2);
declare void @llvm.nvvm.barrier0()
declare void @llvm.nvvm.barrier.n(i32)
declare void @llvm.nvvm.bar.sync(i32)
declare void @llvm.nvvm.barrier(i32, i32)
declare void @llvm.nvvm.barrier.sync(i32)
declare void @llvm.nvvm.barrier.sync.cnt(i32, i32)
; CHECK-LABEL: @simple_upgrade
define void @simple_upgrade(i32 %a, i64 %b, i16 %c) {
; CHECK: call i32 @llvm.bitreverse.i32(i32 %a)
@@ -324,3 +331,18 @@ define void @nvvm_cp_async_bulk_tensor_g2s_tile(ptr addrspace(3) %d, ptr addrspa
ret void
}
define void @cta_barriers(i32 %x, i32 %y) {
; CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
; CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 %x)
; CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 %x)
; CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.count(i32 %x, i32 %y)
; CHECK: call void @llvm.nvvm.barrier.cta.sync.all(i32 %x)
; CHECK: call void @llvm.nvvm.barrier.cta.sync.count(i32 %x, i32 %y)
call void @llvm.nvvm.barrier0()
call void @llvm.nvvm.barrier.n(i32 %x)
call void @llvm.nvvm.bar.sync(i32 %x)
call void @llvm.nvvm.barrier(i32 %x, i32 %y)
call void @llvm.nvvm.barrier.sync(i32 %x)
call void @llvm.nvvm.barrier.sync.cnt(i32 %x, i32 %y)
ret void
}

View File

@@ -1,33 +1,136 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_30 -mattr=+ptx60 | FileCheck %s
; RUN: %if ptxas %{ llc < %s -mtriple=nvptx64 -mcpu=sm_30 -mattr=+ptx60 | %ptxas-verify %}
declare void @llvm.nvvm.bar.warp.sync(i32)
declare void @llvm.nvvm.barrier.sync(i32)
declare void @llvm.nvvm.barrier.sync.cnt(i32, i32)
declare void @llvm.nvvm.barrier.cta.sync.aligned.all(i32)
declare void @llvm.nvvm.barrier.cta.sync.aligned.count(i32, i32)
declare void @llvm.nvvm.barrier.cta.sync.all(i32)
declare void @llvm.nvvm.barrier.cta.sync.count(i32, i32)
declare void @llvm.nvvm.barrier.cta.arrive.count(i32, i32)
declare void @llvm.nvvm.barrier.cta.arrive.aligned.count(i32, i32)
; CHECK-LABEL: .func{{.*}}barrier_sync
define void @barrier_sync(i32 %id, i32 %cnt) {
; CHECK: ld.param.b32 [[ID:%r[0-9]+]], [barrier_sync_param_0];
; CHECK: ld.param.b32 [[CNT:%r[0-9]+]], [barrier_sync_param_1];
; CHECK: barrier.sync [[ID]], [[CNT]];
call void @llvm.nvvm.barrier.sync.cnt(i32 %id, i32 %cnt)
; CHECK: barrier.sync [[ID]], 32;
call void @llvm.nvvm.barrier.sync.cnt(i32 %id, i32 32)
; CHECK: barrier.sync 3, [[CNT]];
call void @llvm.nvvm.barrier.sync.cnt(i32 3, i32 %cnt)
; CHECK: barrier.sync 4, 64;
call void @llvm.nvvm.barrier.sync.cnt(i32 4, i32 64)
; CHECK: barrier.sync [[ID]];
call void @llvm.nvvm.barrier.sync(i32 %id)
; CHECK: barrier.sync 1;
call void @llvm.nvvm.barrier.sync(i32 1)
; CHECK: bar.warp.sync [[ID]];
define void @barrier_warp_sync(i32 %id) {
; CHECK-LABEL: barrier_warp_sync(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<2>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.b32 %r1, [barrier_warp_sync_param_0];
; CHECK-NEXT: bar.warp.sync %r1;
; CHECK-NEXT: bar.warp.sync 6;
; CHECK-NEXT: ret;
call void @llvm.nvvm.bar.warp.sync(i32 %id)
; CHECK: bar.warp.sync 6;
call void @llvm.nvvm.bar.warp.sync(i32 6)
ret void;
ret void
}
define void @barrier_cta_sync_aligned_all(i32 %id) {
; CHECK-LABEL: barrier_cta_sync_aligned_all(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<2>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.b32 %r1, [barrier_cta_sync_aligned_all_param_0];
; CHECK-NEXT: bar.sync %r1;
; CHECK-NEXT: bar.sync 3;
; CHECK-NEXT: ret;
call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 %id)
call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 3)
ret void
}
define void @barrier_cta_sync_aligned(i32 %id, i32 %cnt) {
; CHECK-LABEL: barrier_cta_sync_aligned(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<3>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.b32 %r1, [barrier_cta_sync_aligned_param_0];
; CHECK-NEXT: ld.param.b32 %r2, [barrier_cta_sync_aligned_param_1];
; CHECK-NEXT: bar.sync %r1, %r2;
; CHECK-NEXT: bar.sync 3, %r2;
; CHECK-NEXT: bar.sync %r1, 64;
; CHECK-NEXT: bar.sync 4, 64;
; CHECK-NEXT: ret;
call void @llvm.nvvm.barrier.cta.sync.aligned.count(i32 %id, i32 %cnt)
call void @llvm.nvvm.barrier.cta.sync.aligned.count(i32 3, i32 %cnt)
call void @llvm.nvvm.barrier.cta.sync.aligned.count(i32 %id, i32 64)
call void @llvm.nvvm.barrier.cta.sync.aligned.count(i32 4, i32 64)
ret void
}
define void @barrier_cta_arrive_aligned(i32 %id, i32 %cnt) {
; CHECK-LABEL: barrier_cta_arrive_aligned(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<3>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.b32 %r1, [barrier_cta_arrive_aligned_param_0];
; CHECK-NEXT: ld.param.b32 %r2, [barrier_cta_arrive_aligned_param_1];
; CHECK-NEXT: bar.arrive %r1, %r2;
; CHECK-NEXT: bar.arrive 3, %r2;
; CHECK-NEXT: bar.arrive %r1, 64;
; CHECK-NEXT: bar.arrive 4, 64;
; CHECK-NEXT: ret;
call void @llvm.nvvm.barrier.cta.arrive.aligned.count(i32 %id, i32 %cnt)
call void @llvm.nvvm.barrier.cta.arrive.aligned.count(i32 3, i32 %cnt)
call void @llvm.nvvm.barrier.cta.arrive.aligned.count(i32 %id, i32 64)
call void @llvm.nvvm.barrier.cta.arrive.aligned.count(i32 4, i32 64)
ret void
}
define void @barrier_cta_sync_all(i32 %id) {
; CHECK-LABEL: barrier_cta_sync_all(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<2>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.b32 %r1, [barrier_cta_sync_all_param_0];
; CHECK-NEXT: barrier.sync %r1;
; CHECK-NEXT: barrier.sync 3;
; CHECK-NEXT: ret;
call void @llvm.nvvm.barrier.cta.sync.all(i32 %id)
call void @llvm.nvvm.barrier.cta.sync.all(i32 3)
ret void
}
define void @barrier_cta_sync(i32 %id, i32 %cnt) {
; CHECK-LABEL: barrier_cta_sync(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<3>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.b32 %r1, [barrier_cta_sync_param_0];
; CHECK-NEXT: ld.param.b32 %r2, [barrier_cta_sync_param_1];
; CHECK-NEXT: barrier.sync %r1, %r2;
; CHECK-NEXT: barrier.sync 3, %r2;
; CHECK-NEXT: barrier.sync %r1, 64;
; CHECK-NEXT: barrier.sync 4, 64;
; CHECK-NEXT: ret;
call void @llvm.nvvm.barrier.cta.sync.count(i32 %id, i32 %cnt)
call void @llvm.nvvm.barrier.cta.sync.count(i32 3, i32 %cnt)
call void @llvm.nvvm.barrier.cta.sync.count(i32 %id, i32 64)
call void @llvm.nvvm.barrier.cta.sync.count(i32 4, i32 64)
ret void
}
define void @barrier_cta_arrive(i32 %id, i32 %cnt) {
; CHECK-LABEL: barrier_cta_arrive(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<3>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.b32 %r1, [barrier_cta_arrive_param_0];
; CHECK-NEXT: ld.param.b32 %r2, [barrier_cta_arrive_param_1];
; CHECK-NEXT: barrier.arrive %r1, %r2;
; CHECK-NEXT: barrier.arrive 3, %r2;
; CHECK-NEXT: barrier.arrive %r1, 64;
; CHECK-NEXT: barrier.arrive 4, 64;
; CHECK-NEXT: ret;
call void @llvm.nvvm.barrier.cta.arrive.count(i32 %id, i32 %cnt)
call void @llvm.nvvm.barrier.cta.arrive.count(i32 3, i32 %cnt)
call void @llvm.nvvm.barrier.cta.arrive.count(i32 %id, i32 64)
call void @llvm.nvvm.barrier.cta.arrive.count(i32 4, i32 64)
ret void
}

View File

@@ -1,42 +0,0 @@
; RUN: llc < %s -mtriple=nvptx -mcpu=sm_20 | FileCheck %s
; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_20 | FileCheck %s
; RUN: %if ptxas && !ptxas-12.0 %{ llc < %s -mtriple=nvptx -mcpu=sm_20 | %ptxas-verify %}
; RUN: %if ptxas %{ llc < %s -mtriple=nvptx64 -mcpu=sm_20 | %ptxas-verify %}
; Use bar.sync to arrive at a pre-computed barrier number and
; wait for all threads in CTA to also arrive:
define ptx_device void @test_barrier_named_cta() {
; CHECK: mov.b32 %r[[REG0:[0-9]+]], 0;
; CHECK: bar.sync %r[[REG0]];
; CHECK: mov.b32 %r[[REG1:[0-9]+]], 10;
; CHECK: bar.sync %r[[REG1]];
; CHECK: mov.b32 %r[[REG2:[0-9]+]], 15;
; CHECK: bar.sync %r[[REG2]];
; CHECK: ret;
call void @llvm.nvvm.barrier.n(i32 0)
call void @llvm.nvvm.barrier.n(i32 10)
call void @llvm.nvvm.barrier.n(i32 15)
ret void
}
; Use bar.sync to arrive at a pre-computed barrier number and
; wait for fixed number of cooperating threads to arrive:
define ptx_device void @test_barrier_named() {
; CHECK: mov.b32 %r[[REG0A:[0-9]+]], 32;
; CHECK: mov.b32 %r[[REG0B:[0-9]+]], 0;
; CHECK: bar.sync %r[[REG0B]], %r[[REG0A]];
; CHECK: mov.b32 %r[[REG1A:[0-9]+]], 352;
; CHECK: mov.b32 %r[[REG1B:[0-9]+]], 10;
; CHECK: bar.sync %r[[REG1B]], %r[[REG1A]];
; CHECK: mov.b32 %r[[REG2A:[0-9]+]], 992;
; CHECK: mov.b32 %r[[REG2B:[0-9]+]], 15;
; CHECK: bar.sync %r[[REG2B]], %r[[REG2A]];
; CHECK: ret;
call void @llvm.nvvm.barrier(i32 0, i32 32)
call void @llvm.nvvm.barrier(i32 10, i32 352)
call void @llvm.nvvm.barrier(i32 15, i32 992)
ret void
}
declare void @llvm.nvvm.barrier(i32, i32)
declare void @llvm.nvvm.barrier.n(i32)

View File

@@ -3,8 +3,8 @@
; Make sure the call to syncthreads is not duplicate here by the LLVM
; optimizations, because it has the noduplicate attribute set.
; CHECK: call void @llvm.nvvm.barrier0
; CHECK-NOT: call void @llvm.nvvm.barrier0
; CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.all
; CHECK-NOT: call void @llvm.nvvm.barrier.cta.sync.aligned.all
; Function Attrs: nounwind
define void @foo(ptr %output) #1 {
@@ -36,7 +36,7 @@ if.else: ; preds = %entry
br label %if.end
if.end: ; preds = %if.else, %if.then
call void @llvm.nvvm.barrier0()
call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
%6 = load ptr, ptr %output.addr, align 8
%7 = load float, ptr %6, align 4
%conv7 = fpext float %7 to double

View File

@@ -2,9 +2,9 @@
; REQUIRES: nvptx-registered-target
; Make sure LLVM knows about the convergent attribute on the
; llvm.nvvm.barrier0 intrinsic.
; llvm.nvvm.barrier.cta.sync.aligned.all intrinsic.
declare void @llvm.nvvm.barrier0()
declare void @llvm.nvvm.barrier.cta.sync.aligned.all(i32)
; CHECK: declare void @llvm.nvvm.barrier0() #[[ATTRNUM:[0-9]+]]
; CHECK: declare void @llvm.nvvm.barrier.cta.sync.aligned.all(i32) #[[ATTRNUM:[0-9]+]]
; CHECK: attributes #[[ATTRNUM]] = { convergent nocallback nounwind }

View File

@@ -70,17 +70,17 @@ define i32 @indirect_non_convergent_call(ptr %f) convergent norecurse {
ret i32 %a
}
declare void @llvm.nvvm.barrier0() convergent
declare void @llvm.nvvm.barrier.cta.sync.aligned.all(i32) convergent
define i32 @intrinsic() convergent {
; Implicitly convergent, because the intrinsic is convergent.
; CHECK: Function Attrs: convergent norecurse nounwind
; CHECK-LABEL: define {{[^@]+}}@intrinsic
; CHECK-SAME: () #[[ATTR4:[0-9]+]] {
; CHECK-NEXT: call void @llvm.nvvm.barrier0()
; CHECK-NEXT: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
; CHECK-NEXT: ret i32 0
;
call void @llvm.nvvm.barrier0()
call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
ret i32 0
}

View File

@@ -12,7 +12,7 @@ define i32 @wrapped_tid() #0 comdat align 32 {
ret i32 %1
}
declare void @llvm.nvvm.barrier0() #1
declare void @llvm.nvvm.barrier.cta.sync.aligned.all(i32) #1
; We had a bug where we duplicated basic blocks containing convergent
; functions like @llvm.nvvm.barrier0 below. Verify that we don't do
@@ -32,9 +32,9 @@ define void @foo() local_unnamed_addr #2 comdat align 32 {
br label %6
6:
; CHECK: call void @llvm.nvvm.barrier0()
; CHECK-NOT: call void @llvm.nvvm.barrier0()
call void @llvm.nvvm.barrier0()
; CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
; CHECK-NOT: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
%7 = icmp eq i32 %2, 0
br i1 %7, label %11, label %8

View File

@@ -8,7 +8,7 @@ target triple = "amdgcn-amd-amdhsa"
declare void @useI32(i32)
declare void @unknown()
declare void @aligned_barrier() "llvm.assume"="ompx_aligned_barrier"
declare void @llvm.nvvm.barrier0()
declare void @llvm.nvvm.barrier.cta.sync.aligned.all(i32)
declare i32 @llvm.nvvm.barrier0.and(i32)
declare i32 @llvm.nvvm.barrier0.or(i32)
declare i32 @llvm.nvvm.barrier0.popc(i32)
@@ -58,7 +58,7 @@ define amdgpu_kernel void @pos_empty_3() "kernel" {
; CHECK-SAME: () #[[ATTR4]] {
; CHECK-NEXT: ret void
;
call void @llvm.nvvm.barrier0()
call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
ret void
}
define amdgpu_kernel void @pos_empty_4() "kernel" {
@@ -393,12 +393,12 @@ define amdgpu_kernel void @pos_multiple() "kernel" {
; CHECK-SAME: () #[[ATTR4]] {
; CHECK-NEXT: ret void
;
call void @llvm.nvvm.barrier0()
call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
call void @aligned_barrier()
call void @aligned_barrier()
call void @llvm.amdgcn.s.barrier()
call void @aligned_barrier()
call void @llvm.nvvm.barrier0()
call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
call void @aligned_barrier()
call void @aligned_barrier()
ret void
@@ -422,7 +422,7 @@ define amdgpu_kernel void @multiple_blocks_kernel_1(i1 %c0, i1 %c1) "kernel" {
; CHECK-NEXT: ret void
;
fence acquire
call void @llvm.nvvm.barrier0()
call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
fence release
call void @aligned_barrier()
fence seq_cst
@@ -441,7 +441,7 @@ f0:
fence release
call void @aligned_barrier()
fence acquire
call void @llvm.nvvm.barrier0()
call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
fence acquire
br i1 %c1, label %t1, label %f1
t1:
@@ -473,7 +473,7 @@ define amdgpu_kernel void @multiple_blocks_kernel_2(i1 %c0, i1 %c1, ptr %p) "ker
; CHECK-NEXT: br label [[M:%.*]]
; CHECK: f0:
; CHECK-NEXT: store i32 4, ptr [[P]], align 4
; CHECK-NEXT: call void @llvm.nvvm.barrier0()
; CHECK-NEXT: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
; CHECK-NEXT: br i1 [[C1]], label [[T1:%.*]], label [[F1:%.*]]
; CHECK: t1:
; CHECK-NEXT: br label [[M]]
@@ -483,7 +483,7 @@ define amdgpu_kernel void @multiple_blocks_kernel_2(i1 %c0, i1 %c1, ptr %p) "ker
; CHECK-NEXT: store i32 4, ptr [[P]], align 4
; CHECK-NEXT: ret void
;
call void @llvm.nvvm.barrier0()
call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
store i32 4, ptr %p
call void @aligned_barrier()
br i1 %c0, label %t0, label %f0
@@ -496,7 +496,7 @@ t0b:
f0:
call void @aligned_barrier()
store i32 4, ptr %p
call void @llvm.nvvm.barrier0()
call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
br i1 %c1, label %t1, label %f1
t1:
call void @aligned_barrier()
@@ -527,7 +527,7 @@ define void @multiple_blocks_non_kernel_1(i1 %c0, i1 %c1) "kernel" {
; CHECK: m:
; CHECK-NEXT: ret void
;
call void @llvm.nvvm.barrier0()
call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
call void @aligned_barrier()
br i1 %c0, label %t0, label %f0
t0:
@@ -538,7 +538,7 @@ t0b:
br label %m
f0:
call void @aligned_barrier()
call void @llvm.nvvm.barrier0()
call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
br i1 %c1, label %t1, label %f1
t1:
call void @aligned_barrier()
@@ -577,7 +577,7 @@ t0b:
br label %m
f0:
call void @aligned_barrier()
call void @llvm.nvvm.barrier0()
call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
br i1 %c1, label %t1, label %f1
t1:
call void @aligned_barrier()
@@ -614,7 +614,7 @@ t0b:
br label %m
f0:
call void @aligned_barrier()
call void @llvm.nvvm.barrier0()
call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
br i1 %c1, label %t1, label %f1
t1:
call void @aligned_barrier()
@@ -665,7 +665,7 @@ t0b:
br label %m
f0:
call void @aligned_barrier()
call void @llvm.nvvm.barrier0()
call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
store i32 2, ptr %p
br i1 %c1, label %t1, label %f1
t1:

View File

@@ -535,8 +535,13 @@ def NVVM_MBarrierTestWaitSharedOp : NVVM_Op<"mbarrier.test.wait.shared">,
// NVVM synchronization op definitions
//===----------------------------------------------------------------------===//
def NVVM_Barrier0Op : NVVM_IntrOp<"barrier0"> {
def NVVM_Barrier0Op : NVVM_Op<"barrier0"> {
let assemblyFormat = "attr-dict";
string llvmBuilder = [{
createIntrinsicCall(
builder, llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_all,
{builder.getInt32(0)});
}];
}
def NVVM_BarrierOp : NVVM_Op<"barrier", [AttrSizedOperandSegments]> {
@@ -544,15 +549,14 @@ def NVVM_BarrierOp : NVVM_Op<"barrier", [AttrSizedOperandSegments]> {
Optional<I32>:$barrierId,
Optional<I32>:$numberOfThreads);
string llvmBuilder = [{
if ($numberOfThreads && $barrierId) {
createIntrinsicCall(builder, llvm::Intrinsic::nvvm_barrier,
{$barrierId, $numberOfThreads});
} else if($barrierId) {
createIntrinsicCall(builder, llvm::Intrinsic::nvvm_barrier_n,
{$barrierId});
} else {
createIntrinsicCall(builder, llvm::Intrinsic::nvvm_barrier0);
}
llvm::Value *id = $barrierId ? $barrierId : builder.getInt32(0);
if ($numberOfThreads)
createIntrinsicCall(
builder, llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_count,
{id, $numberOfThreads});
else
createIntrinsicCall(
builder, llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_all, {id});
}];
let hasVerifier = 1;

View File

@@ -73,12 +73,11 @@ define float @nvvm_rcp(float %0) {
; CHECK-LABEL: @llvm_nvvm_barrier0()
define void @llvm_nvvm_barrier0() {
; CHECK: nvvm.barrier0
; CHECK: llvm.nvvm.barrier.cta.sync.aligned.all
call void @llvm.nvvm.barrier0()
ret void
}
; TODO: Support the intrinsics below once they derive from NVVM_IntrOp rather than from NVVM_Op.
;
; define i32 @nvvm_shfl(i32 %0, i32 %1, i32 %2, i32 %3, float %4) {

View File

@@ -162,7 +162,7 @@ llvm.func @nvvm_rcp(%0: f32) -> f32 {
// CHECK-LABEL: @llvm_nvvm_barrier0
llvm.func @llvm_nvvm_barrier0() {
// CHECK: call void @llvm.nvvm.barrier0()
// CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
nvvm.barrier0
llvm.return
}
@@ -170,11 +170,11 @@ llvm.func @llvm_nvvm_barrier0() {
// CHECK-LABEL: @llvm_nvvm_barrier(
// CHECK-SAME: i32 %[[barId:.*]], i32 %[[numThreads:.*]])
llvm.func @llvm_nvvm_barrier(%barID : i32, %numberOfThreads : i32) {
// CHECK: call void @llvm.nvvm.barrier0()
nvvm.barrier
// CHECK: call void @llvm.nvvm.barrier.n(i32 %[[barId]])
// CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
nvvm.barrier
// CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 %[[barId]])
nvvm.barrier id = %barID
// CHECK: call void @llvm.nvvm.barrier(i32 %[[barId]], i32 %[[numThreads]])
// CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.count(i32 %[[barId]], i32 %[[numThreads]])
nvvm.barrier id = %barID number_of_threads = %numberOfThreads
llvm.return
}