[AArch64][SVE] Fold fadda(ptrue, x, select(mask, y, -0.0)) into fadda(mask, x, y)

This patch adds an SVE pattern to recognize the use of a select with an
fadda in the form fadda(ptrue, x, select(mask, y, -0.0)). In this case
the select can be folded away, with the select mask used as the
predicate for fadda. This improves the codegen when vectorizing loops
with ordered fp reductions.

Differential Revision: https://reviews.llvm.org/D129623
This commit is contained in:
Rosie Sumpter
2022-07-12 08:40:59 +01:00
parent 106d695287
commit 05d424d165
3 changed files with 128 additions and 4 deletions

View File

@@ -1234,6 +1234,10 @@ def fpimm0 : FPImmLeaf<fAny, [{
return Imm.isExactlyValue(+0.0);
}]>;
def fpimm_minus0 : FPImmLeaf<fAny, [{
return Imm.isExactlyValue(-0.0);
}]>;
def fpimm_half : FPImmLeaf<fAny, [{
return Imm.isExactlyValue(+0.5);
}]>;

View File

@@ -278,10 +278,18 @@ def AArch64scvtf_mt : SDNode<"AArch64ISD::SINT_TO_FP_MERGE_PASSTHRU", SDT_AArch
def AArch64fcvtzu_mt : SDNode<"AArch64ISD::FCVTZU_MERGE_PASSTHRU", SDT_AArch64FCVT>;
def AArch64fcvtzs_mt : SDNode<"AArch64ISD::FCVTZS_MERGE_PASSTHRU", SDT_AArch64FCVT>;
def SDT_AArch64ReduceWithInit : SDTypeProfile<1, 3, [SDTCisVec<1>, SDTCisVec<3>]>;
def AArch64clasta_n : SDNode<"AArch64ISD::CLASTA_N", SDT_AArch64ReduceWithInit>;
def AArch64clastb_n : SDNode<"AArch64ISD::CLASTB_N", SDT_AArch64ReduceWithInit>;
def AArch64fadda_p : SDNode<"AArch64ISD::FADDA_PRED", SDT_AArch64ReduceWithInit>;
def SDT_AArch64ReduceWithInit : SDTypeProfile<1, 3,
[SDTCisVec<1>, SDTCVecEltisVT<1,i1>, SDTCisVec<3>, SDTCisSameNumEltsAs<1,3>]>;
def AArch64clasta_n : SDNode<"AArch64ISD::CLASTA_N", SDT_AArch64ReduceWithInit>;
def AArch64clastb_n : SDNode<"AArch64ISD::CLASTB_N", SDT_AArch64ReduceWithInit>;
def AArch64fadda_p_node : SDNode<"AArch64ISD::FADDA_PRED", SDT_AArch64ReduceWithInit>;
def AArch64fadda_p : PatFrags<(ops node:$op1, node:$op2, node:$op3),
[(AArch64fadda_p_node node:$op1, node:$op2, node:$op3),
(AArch64fadda_p_node (SVEAllActive), node:$op2,
(vselect node:$op1, node:$op3, (splat_vector (f32 fpimm_minus0)))),
(AArch64fadda_p_node (SVEAllActive), node:$op2,
(vselect node:$op1, node:$op3, (splat_vector (f64 fpimm_minus0))))]>;
def SDT_AArch64PTest : SDTypeProfile<0, 2, [SDTCisVec<0>, SDTCisSameAs<0,1>]>;
def AArch64ptest : SDNode<"AArch64ISD::PTEST", SDT_AArch64PTest>;

View File

@@ -0,0 +1,112 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve < %s | FileCheck %s
; Fold fadda(ptrue, x, select(mask, y, -0.0)) -> fadda(mask, x, y)
define float @pred_fadda_nxv2f32(float %x, <vscale x 2 x float> %y, <vscale x 2 x i1> %mask) {
; CHECK-LABEL: pred_fadda_nxv2f32:
; CHECK: // %bb.0:
; CHECK-NEXT: // kill: def $s0 killed $s0 def $z0
; CHECK-NEXT: fadda s0, p0, s0, z1.s
; CHECK-NEXT: // kill: def $s0 killed $s0 killed $z0
; CHECK-NEXT: ret
%i = insertelement <vscale x 2 x float> poison, float -0.000000e+00, i32 0
%minus0 = shufflevector <vscale x 2 x float> %i, <vscale x 2 x float> poison, <vscale x 2 x i32> zeroinitializer
%sel = select <vscale x 2 x i1> %mask, <vscale x 2 x float> %y, <vscale x 2 x float> %minus0
%fadda = call float @llvm.vector.reduce.fadd.nxv2f32(float %x, <vscale x 2 x float> %sel)
ret float %fadda
}
define float @pred_fadda_nxv4f32(float %x, <vscale x 4 x float> %y, <vscale x 4 x i1> %mask) {
; CHECK-LABEL: pred_fadda_nxv4f32:
; CHECK: // %bb.0:
; CHECK-NEXT: // kill: def $s0 killed $s0 def $z0
; CHECK-NEXT: fadda s0, p0, s0, z1.s
; CHECK-NEXT: // kill: def $s0 killed $s0 killed $z0
; CHECK-NEXT: ret
%i = insertelement <vscale x 4 x float> poison, float -0.000000e+00, i32 0
%minus0 = shufflevector <vscale x 4 x float> %i, <vscale x 4 x float> poison, <vscale x 4 x i32> zeroinitializer
%sel = select <vscale x 4 x i1> %mask, <vscale x 4 x float> %y, <vscale x 4 x float> %minus0
%fadda = call float @llvm.vector.reduce.fadd.nxv4f32(float %x, <vscale x 4 x float> %sel)
ret float %fadda
}
define double @pred_fadda_nxv2f64(double %x, <vscale x 2 x double> %y, <vscale x 2 x i1> %mask) {
; CHECK-LABEL: pred_fadda_nxv2f64:
; CHECK: // %bb.0:
; CHECK-NEXT: // kill: def $d0 killed $d0 def $z0
; CHECK-NEXT: fadda d0, p0, d0, z1.d
; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0
; CHECK-NEXT: ret
%i = insertelement <vscale x 2 x double> poison, double -0.000000e+00, i32 0
%minus0 = shufflevector <vscale x 2 x double> %i, <vscale x 2 x double> poison, <vscale x 2 x i32> zeroinitializer
%sel = select <vscale x 2 x i1> %mask, <vscale x 2 x double> %y, <vscale x 2 x double> %minus0
%fadda = call double @llvm.vector.reduce.fadd.nxv2f64(double %x, <vscale x 2 x double> %sel)
ret double %fadda
}
; Currently the folding doesn't work for f16 element types, since -0.0 is not treated as a legal f16 immediate.
define half @pred_fadda_nxv2f16(half %x, <vscale x 2 x half> %y, <vscale x 2 x i1> %mask) {
; CHECK-LABEL: pred_fadda_nxv2f16:
; CHECK: // %bb.0:
; CHECK-NEXT: adrp x8, .LCPI3_0
; CHECK-NEXT: add x8, x8, :lo12:.LCPI3_0
; CHECK-NEXT: ptrue p1.d
; CHECK-NEXT: // kill: def $h0 killed $h0 def $z0
; CHECK-NEXT: ld1rh { z2.d }, p1/z, [x8]
; CHECK-NEXT: sel z1.d, p0, z1.d, z2.d
; CHECK-NEXT: fadda h0, p1, h0, z1.h
; CHECK-NEXT: // kill: def $h0 killed $h0 killed $z0
; CHECK-NEXT: ret
%i = insertelement <vscale x 2 x half> poison, half -0.000000e+00, i32 0
%minus0 = shufflevector <vscale x 2 x half> %i, <vscale x 2 x half> poison, <vscale x 2 x i32> zeroinitializer
%sel = select <vscale x 2 x i1> %mask, <vscale x 2 x half> %y, <vscale x 2 x half> %minus0
%fadda = call half @llvm.vector.reduce.fadd.nxv2f16(half %x, <vscale x 2 x half> %sel)
ret half %fadda
}
define half @pred_fadda_nxv4f16(half %x, <vscale x 4 x half> %y, <vscale x 4 x i1> %mask) {
; CHECK-LABEL: pred_fadda_nxv4f16:
; CHECK: // %bb.0:
; CHECK-NEXT: adrp x8, .LCPI4_0
; CHECK-NEXT: add x8, x8, :lo12:.LCPI4_0
; CHECK-NEXT: ptrue p1.s
; CHECK-NEXT: // kill: def $h0 killed $h0 def $z0
; CHECK-NEXT: ld1rh { z2.s }, p1/z, [x8]
; CHECK-NEXT: sel z1.s, p0, z1.s, z2.s
; CHECK-NEXT: fadda h0, p1, h0, z1.h
; CHECK-NEXT: // kill: def $h0 killed $h0 killed $z0
; CHECK-NEXT: ret
%i = insertelement <vscale x 4 x half> poison, half -0.000000e+00, i32 0
%minus0 = shufflevector <vscale x 4 x half> %i, <vscale x 4 x half> poison, <vscale x 4 x i32> zeroinitializer
%sel = select <vscale x 4 x i1> %mask, <vscale x 4 x half> %y, <vscale x 4 x half> %minus0
%fadda = call half @llvm.vector.reduce.fadd.nxv4f16(half %x, <vscale x 4 x half> %sel)
ret half %fadda
}
define half @pred_fadda_nxv8f16(half %x, <vscale x 8 x half> %y, <vscale x 8 x i1> %mask) {
; CHECK-LABEL: pred_fadda_nxv8f16:
; CHECK: // %bb.0:
; CHECK-NEXT: adrp x8, .LCPI5_0
; CHECK-NEXT: add x8, x8, :lo12:.LCPI5_0
; CHECK-NEXT: ptrue p1.h
; CHECK-NEXT: // kill: def $h0 killed $h0 def $z0
; CHECK-NEXT: ld1rh { z2.h }, p1/z, [x8]
; CHECK-NEXT: sel z1.h, p0, z1.h, z2.h
; CHECK-NEXT: fadda h0, p1, h0, z1.h
; CHECK-NEXT: // kill: def $h0 killed $h0 killed $z0
; CHECK-NEXT: ret
%i = insertelement <vscale x 8 x half> poison, half -0.000000e+00, i32 0
%minus0 = shufflevector <vscale x 8 x half> %i, <vscale x 8 x half> poison, <vscale x 8 x i32> zeroinitializer
%sel = select <vscale x 8 x i1> %mask, <vscale x 8 x half> %y, <vscale x 8 x half> %minus0
%fadda = call half @llvm.vector.reduce.fadd.nxv8f16(half %x, <vscale x 8 x half> %sel)
ret half %fadda
}
declare float @llvm.vector.reduce.fadd.nxv2f32(float, <vscale x 2 x float>)
declare float @llvm.vector.reduce.fadd.nxv4f32(float, <vscale x 4 x float>)
declare double @llvm.vector.reduce.fadd.nxv2f64(double, <vscale x 2 x double>)
declare half @llvm.vector.reduce.fadd.nxv2f16(half, <vscale x 2 x half>)
declare half @llvm.vector.reduce.fadd.nxv4f16(half, <vscale x 4 x half>)
declare half @llvm.vector.reduce.fadd.nxv8f16(half, <vscale x 8 x half>)