[AArch64][SVE] Fold fadda(ptrue, x, select(mask, y, -0.0)) into fadda(mask, x, y)
This patch adds an SVE pattern to recognize the use of a select with an fadda in the form fadda(ptrue, x, select(mask, y, -0.0)). In this case the select can be folded away, with the select mask used as the predicate for fadda. This improves the codegen when vectorizing loops with ordered fp reductions. Differential Revision: https://reviews.llvm.org/D129623
This commit is contained in:
@@ -1234,6 +1234,10 @@ def fpimm0 : FPImmLeaf<fAny, [{
|
||||
return Imm.isExactlyValue(+0.0);
|
||||
}]>;
|
||||
|
||||
def fpimm_minus0 : FPImmLeaf<fAny, [{
|
||||
return Imm.isExactlyValue(-0.0);
|
||||
}]>;
|
||||
|
||||
def fpimm_half : FPImmLeaf<fAny, [{
|
||||
return Imm.isExactlyValue(+0.5);
|
||||
}]>;
|
||||
|
||||
@@ -278,10 +278,18 @@ def AArch64scvtf_mt : SDNode<"AArch64ISD::SINT_TO_FP_MERGE_PASSTHRU", SDT_AArch
|
||||
def AArch64fcvtzu_mt : SDNode<"AArch64ISD::FCVTZU_MERGE_PASSTHRU", SDT_AArch64FCVT>;
|
||||
def AArch64fcvtzs_mt : SDNode<"AArch64ISD::FCVTZS_MERGE_PASSTHRU", SDT_AArch64FCVT>;
|
||||
|
||||
def SDT_AArch64ReduceWithInit : SDTypeProfile<1, 3, [SDTCisVec<1>, SDTCisVec<3>]>;
|
||||
def AArch64clasta_n : SDNode<"AArch64ISD::CLASTA_N", SDT_AArch64ReduceWithInit>;
|
||||
def AArch64clastb_n : SDNode<"AArch64ISD::CLASTB_N", SDT_AArch64ReduceWithInit>;
|
||||
def AArch64fadda_p : SDNode<"AArch64ISD::FADDA_PRED", SDT_AArch64ReduceWithInit>;
|
||||
def SDT_AArch64ReduceWithInit : SDTypeProfile<1, 3,
|
||||
[SDTCisVec<1>, SDTCVecEltisVT<1,i1>, SDTCisVec<3>, SDTCisSameNumEltsAs<1,3>]>;
|
||||
def AArch64clasta_n : SDNode<"AArch64ISD::CLASTA_N", SDT_AArch64ReduceWithInit>;
|
||||
def AArch64clastb_n : SDNode<"AArch64ISD::CLASTB_N", SDT_AArch64ReduceWithInit>;
|
||||
def AArch64fadda_p_node : SDNode<"AArch64ISD::FADDA_PRED", SDT_AArch64ReduceWithInit>;
|
||||
|
||||
def AArch64fadda_p : PatFrags<(ops node:$op1, node:$op2, node:$op3),
|
||||
[(AArch64fadda_p_node node:$op1, node:$op2, node:$op3),
|
||||
(AArch64fadda_p_node (SVEAllActive), node:$op2,
|
||||
(vselect node:$op1, node:$op3, (splat_vector (f32 fpimm_minus0)))),
|
||||
(AArch64fadda_p_node (SVEAllActive), node:$op2,
|
||||
(vselect node:$op1, node:$op3, (splat_vector (f64 fpimm_minus0))))]>;
|
||||
|
||||
def SDT_AArch64PTest : SDTypeProfile<0, 2, [SDTCisVec<0>, SDTCisSameAs<0,1>]>;
|
||||
def AArch64ptest : SDNode<"AArch64ISD::PTEST", SDT_AArch64PTest>;
|
||||
|
||||
112
llvm/test/CodeGen/AArch64/sve-fadda-select.ll
Normal file
112
llvm/test/CodeGen/AArch64/sve-fadda-select.ll
Normal file
@@ -0,0 +1,112 @@
|
||||
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
|
||||
; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve < %s | FileCheck %s
|
||||
|
||||
; Fold fadda(ptrue, x, select(mask, y, -0.0)) -> fadda(mask, x, y)
|
||||
|
||||
define float @pred_fadda_nxv2f32(float %x, <vscale x 2 x float> %y, <vscale x 2 x i1> %mask) {
|
||||
; CHECK-LABEL: pred_fadda_nxv2f32:
|
||||
; CHECK: // %bb.0:
|
||||
; CHECK-NEXT: // kill: def $s0 killed $s0 def $z0
|
||||
; CHECK-NEXT: fadda s0, p0, s0, z1.s
|
||||
; CHECK-NEXT: // kill: def $s0 killed $s0 killed $z0
|
||||
; CHECK-NEXT: ret
|
||||
%i = insertelement <vscale x 2 x float> poison, float -0.000000e+00, i32 0
|
||||
%minus0 = shufflevector <vscale x 2 x float> %i, <vscale x 2 x float> poison, <vscale x 2 x i32> zeroinitializer
|
||||
%sel = select <vscale x 2 x i1> %mask, <vscale x 2 x float> %y, <vscale x 2 x float> %minus0
|
||||
%fadda = call float @llvm.vector.reduce.fadd.nxv2f32(float %x, <vscale x 2 x float> %sel)
|
||||
ret float %fadda
|
||||
}
|
||||
|
||||
define float @pred_fadda_nxv4f32(float %x, <vscale x 4 x float> %y, <vscale x 4 x i1> %mask) {
|
||||
; CHECK-LABEL: pred_fadda_nxv4f32:
|
||||
; CHECK: // %bb.0:
|
||||
; CHECK-NEXT: // kill: def $s0 killed $s0 def $z0
|
||||
; CHECK-NEXT: fadda s0, p0, s0, z1.s
|
||||
; CHECK-NEXT: // kill: def $s0 killed $s0 killed $z0
|
||||
; CHECK-NEXT: ret
|
||||
%i = insertelement <vscale x 4 x float> poison, float -0.000000e+00, i32 0
|
||||
%minus0 = shufflevector <vscale x 4 x float> %i, <vscale x 4 x float> poison, <vscale x 4 x i32> zeroinitializer
|
||||
%sel = select <vscale x 4 x i1> %mask, <vscale x 4 x float> %y, <vscale x 4 x float> %minus0
|
||||
%fadda = call float @llvm.vector.reduce.fadd.nxv4f32(float %x, <vscale x 4 x float> %sel)
|
||||
ret float %fadda
|
||||
}
|
||||
|
||||
define double @pred_fadda_nxv2f64(double %x, <vscale x 2 x double> %y, <vscale x 2 x i1> %mask) {
|
||||
; CHECK-LABEL: pred_fadda_nxv2f64:
|
||||
; CHECK: // %bb.0:
|
||||
; CHECK-NEXT: // kill: def $d0 killed $d0 def $z0
|
||||
; CHECK-NEXT: fadda d0, p0, d0, z1.d
|
||||
; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0
|
||||
; CHECK-NEXT: ret
|
||||
%i = insertelement <vscale x 2 x double> poison, double -0.000000e+00, i32 0
|
||||
%minus0 = shufflevector <vscale x 2 x double> %i, <vscale x 2 x double> poison, <vscale x 2 x i32> zeroinitializer
|
||||
%sel = select <vscale x 2 x i1> %mask, <vscale x 2 x double> %y, <vscale x 2 x double> %minus0
|
||||
%fadda = call double @llvm.vector.reduce.fadd.nxv2f64(double %x, <vscale x 2 x double> %sel)
|
||||
ret double %fadda
|
||||
}
|
||||
|
||||
; Currently the folding doesn't work for f16 element types, since -0.0 is not treated as a legal f16 immediate.
|
||||
|
||||
define half @pred_fadda_nxv2f16(half %x, <vscale x 2 x half> %y, <vscale x 2 x i1> %mask) {
|
||||
; CHECK-LABEL: pred_fadda_nxv2f16:
|
||||
; CHECK: // %bb.0:
|
||||
; CHECK-NEXT: adrp x8, .LCPI3_0
|
||||
; CHECK-NEXT: add x8, x8, :lo12:.LCPI3_0
|
||||
; CHECK-NEXT: ptrue p1.d
|
||||
; CHECK-NEXT: // kill: def $h0 killed $h0 def $z0
|
||||
; CHECK-NEXT: ld1rh { z2.d }, p1/z, [x8]
|
||||
; CHECK-NEXT: sel z1.d, p0, z1.d, z2.d
|
||||
; CHECK-NEXT: fadda h0, p1, h0, z1.h
|
||||
; CHECK-NEXT: // kill: def $h0 killed $h0 killed $z0
|
||||
; CHECK-NEXT: ret
|
||||
%i = insertelement <vscale x 2 x half> poison, half -0.000000e+00, i32 0
|
||||
%minus0 = shufflevector <vscale x 2 x half> %i, <vscale x 2 x half> poison, <vscale x 2 x i32> zeroinitializer
|
||||
%sel = select <vscale x 2 x i1> %mask, <vscale x 2 x half> %y, <vscale x 2 x half> %minus0
|
||||
%fadda = call half @llvm.vector.reduce.fadd.nxv2f16(half %x, <vscale x 2 x half> %sel)
|
||||
ret half %fadda
|
||||
}
|
||||
|
||||
define half @pred_fadda_nxv4f16(half %x, <vscale x 4 x half> %y, <vscale x 4 x i1> %mask) {
|
||||
; CHECK-LABEL: pred_fadda_nxv4f16:
|
||||
; CHECK: // %bb.0:
|
||||
; CHECK-NEXT: adrp x8, .LCPI4_0
|
||||
; CHECK-NEXT: add x8, x8, :lo12:.LCPI4_0
|
||||
; CHECK-NEXT: ptrue p1.s
|
||||
; CHECK-NEXT: // kill: def $h0 killed $h0 def $z0
|
||||
; CHECK-NEXT: ld1rh { z2.s }, p1/z, [x8]
|
||||
; CHECK-NEXT: sel z1.s, p0, z1.s, z2.s
|
||||
; CHECK-NEXT: fadda h0, p1, h0, z1.h
|
||||
; CHECK-NEXT: // kill: def $h0 killed $h0 killed $z0
|
||||
; CHECK-NEXT: ret
|
||||
%i = insertelement <vscale x 4 x half> poison, half -0.000000e+00, i32 0
|
||||
%minus0 = shufflevector <vscale x 4 x half> %i, <vscale x 4 x half> poison, <vscale x 4 x i32> zeroinitializer
|
||||
%sel = select <vscale x 4 x i1> %mask, <vscale x 4 x half> %y, <vscale x 4 x half> %minus0
|
||||
%fadda = call half @llvm.vector.reduce.fadd.nxv4f16(half %x, <vscale x 4 x half> %sel)
|
||||
ret half %fadda
|
||||
}
|
||||
|
||||
define half @pred_fadda_nxv8f16(half %x, <vscale x 8 x half> %y, <vscale x 8 x i1> %mask) {
|
||||
; CHECK-LABEL: pred_fadda_nxv8f16:
|
||||
; CHECK: // %bb.0:
|
||||
; CHECK-NEXT: adrp x8, .LCPI5_0
|
||||
; CHECK-NEXT: add x8, x8, :lo12:.LCPI5_0
|
||||
; CHECK-NEXT: ptrue p1.h
|
||||
; CHECK-NEXT: // kill: def $h0 killed $h0 def $z0
|
||||
; CHECK-NEXT: ld1rh { z2.h }, p1/z, [x8]
|
||||
; CHECK-NEXT: sel z1.h, p0, z1.h, z2.h
|
||||
; CHECK-NEXT: fadda h0, p1, h0, z1.h
|
||||
; CHECK-NEXT: // kill: def $h0 killed $h0 killed $z0
|
||||
; CHECK-NEXT: ret
|
||||
%i = insertelement <vscale x 8 x half> poison, half -0.000000e+00, i32 0
|
||||
%minus0 = shufflevector <vscale x 8 x half> %i, <vscale x 8 x half> poison, <vscale x 8 x i32> zeroinitializer
|
||||
%sel = select <vscale x 8 x i1> %mask, <vscale x 8 x half> %y, <vscale x 8 x half> %minus0
|
||||
%fadda = call half @llvm.vector.reduce.fadd.nxv8f16(half %x, <vscale x 8 x half> %sel)
|
||||
ret half %fadda
|
||||
}
|
||||
|
||||
declare float @llvm.vector.reduce.fadd.nxv2f32(float, <vscale x 2 x float>)
|
||||
declare float @llvm.vector.reduce.fadd.nxv4f32(float, <vscale x 4 x float>)
|
||||
declare double @llvm.vector.reduce.fadd.nxv2f64(double, <vscale x 2 x double>)
|
||||
declare half @llvm.vector.reduce.fadd.nxv2f16(half, <vscale x 2 x half>)
|
||||
declare half @llvm.vector.reduce.fadd.nxv4f16(half, <vscale x 4 x half>)
|
||||
declare half @llvm.vector.reduce.fadd.nxv8f16(half, <vscale x 8 x half>)
|
||||
Reference in New Issue
Block a user