[mlir][math] Add missing trig math-to-llvm conversion patterns (#141069)

asin, acos, atan, and atan2 were being lowered to libm calls instead of
llvm intrinsics. Add the conversion patterns to handle these intrinsics
and update tests to expect this.
This commit is contained in:
Asher Mancinelli
2025-05-27 08:09:48 -07:00
committed by GitHub
parent d56deea1e4
commit 42b1df43e7
3 changed files with 250 additions and 10 deletions

View File

@@ -378,13 +378,167 @@ func.func @_QPtest_real8(%arg0: !fir.ref<f64> {fir.bindc_name = "x"}) -> f64 {
func.func private @llvm.round.f32(f32) -> f32
func.func private @llvm.round.f64(f64) -> f64
//--- asin_fast.fir
// RUN: fir-opt %t/asin_fast.fir --fir-to-llvm-ir="target=x86_64-unknown-linux-gnu" | FileCheck %t/asin_fast.fir
// CHECK: @_QPtest_real4
// CHECK: {{%[A-Za-z0-9._]+}} = llvm.intr.asin({{%[A-Za-z0-9._]+}}) : (f32) -> f32
// CHECK: @_QPtest_real8
// CHECK: {{%[A-Za-z0-9._]+}} = llvm.intr.asin({{%[A-Za-z0-9._]+}}) : (f64) -> f64
func.func @_QPtest_real4(%arg0: !fir.ref<f32> {fir.bindc_name = "x"}) -> f32 {
%0 = fir.alloca f32 {bindc_name = "test_real4", uniq_name = "_QFtest_real4Etest_real4"}
%1 = fir.load %arg0 : !fir.ref<f32>
%2 = math.asin %1 : f32
fir.store %2 to %0 : !fir.ref<f32>
%3 = fir.load %0 : !fir.ref<f32>
return %3 : f32
}
func.func @_QPtest_real8(%arg0: !fir.ref<f64> {fir.bindc_name = "x"}) -> f64 {
%0 = fir.alloca f64 {bindc_name = "test_real8", uniq_name = "_QFtest_real8Etest_real8"}
%1 = fir.load %arg0 : !fir.ref<f64>
%2 = math.asin %1 : f64
fir.store %2 to %0 : !fir.ref<f64>
%3 = fir.load %0 : !fir.ref<f64>
return %3 : f64
}
//--- asin_relaxed.fir
// RUN: fir-opt %t/asin_relaxed.fir --fir-to-llvm-ir="target=x86_64-unknown-linux-gnu" | FileCheck %t/asin_relaxed.fir
// CHECK: @_QPtest_real4
// CHECK: {{%[A-Za-z0-9._]+}} = llvm.intr.asin({{%[A-Za-z0-9._]+}}) : (f32) -> f32
// CHECK: @_QPtest_real8
// CHECK: {{%[A-Za-z0-9._]+}} = llvm.intr.asin({{%[A-Za-z0-9._]+}}) : (f64) -> f64
func.func @_QPtest_real4(%arg0: !fir.ref<f32> {fir.bindc_name = "x"}) -> f32 {
%0 = fir.alloca f32 {bindc_name = "test_real4", uniq_name = "_QFtest_real4Etest_real4"}
%1 = fir.load %arg0 : !fir.ref<f32>
%2 = math.asin %1 : f32
fir.store %2 to %0 : !fir.ref<f32>
%3 = fir.load %0 : !fir.ref<f32>
return %3 : f32
}
func.func @_QPtest_real8(%arg0: !fir.ref<f64> {fir.bindc_name = "x"}) -> f64 {
%0 = fir.alloca f64 {bindc_name = "test_real8", uniq_name = "_QFtest_real8Etest_real8"}
%1 = fir.load %arg0 : !fir.ref<f64>
%2 = math.asin %1 : f64
fir.store %2 to %0 : !fir.ref<f64>
%3 = fir.load %0 : !fir.ref<f64>
return %3 : f64
}
//--- asin_precise.fir
// RUN: fir-opt %t/asin_precise.fir --fir-to-llvm-ir="target=x86_64-unknown-linux-gnu" | FileCheck %t/asin_precise.fir
// CHECK: @_QPtest_real4
// CHECK: {{%[A-Za-z0-9._]+}} = llvm.call @asinf({{%[A-Za-z0-9._]+}}) : (f32) -> f32
// CHECK: @_QPtest_real8
// CHECK: {{%[A-Za-z0-9._]+}} = llvm.call @asin({{%[A-Za-z0-9._]+}}) : (f64) -> f64
func.func @_QPtest_real4(%arg0: !fir.ref<f32> {fir.bindc_name = "x"}) -> f32 {
%0 = fir.alloca f32 {bindc_name = "test_real4", uniq_name = "_QFtest_real4Etest_real4"}
%1 = fir.load %arg0 : !fir.ref<f32>
%2 = fir.call @asinf(%1) : (f32) -> f32
fir.store %2 to %0 : !fir.ref<f32>
%3 = fir.load %0 : !fir.ref<f32>
return %3 : f32
}
func.func @_QPtest_real8(%arg0: !fir.ref<f64> {fir.bindc_name = "x"}) -> f64 {
%0 = fir.alloca f64 {bindc_name = "test_real8", uniq_name = "_QFtest_real8Etest_real8"}
%1 = fir.load %arg0 : !fir.ref<f64>
%2 = fir.call @asin(%1) : (f64) -> f64
fir.store %2 to %0 : !fir.ref<f64>
%3 = fir.load %0 : !fir.ref<f64>
return %3 : f64
}
func.func private @asinf(f32) -> f32
func.func private @asin(f64) -> f64
//--- acos_fast.fir
// RUN: fir-opt %t/acos_fast.fir --fir-to-llvm-ir="target=x86_64-unknown-linux-gnu" | FileCheck %t/acos_fast.fir
// CHECK: @_QPtest_real4
// CHECK: {{%[A-Za-z0-9._]+}} = llvm.intr.acos({{%[A-Za-z0-9._]+}}) : (f32) -> f32
// CHECK: @_QPtest_real8
// CHECK: {{%[A-Za-z0-9._]+}} = llvm.intr.acos({{%[A-Za-z0-9._]+}}) : (f64) -> f64
func.func @_QPtest_real4(%arg0: !fir.ref<f32> {fir.bindc_name = "x"}) -> f32 {
%0 = fir.alloca f32 {bindc_name = "test_real4", uniq_name = "_QFtest_real4Etest_real4"}
%1 = fir.load %arg0 : !fir.ref<f32>
%2 = math.acos %1 : f32
fir.store %2 to %0 : !fir.ref<f32>
%3 = fir.load %0 : !fir.ref<f32>
return %3 : f32
}
func.func @_QPtest_real8(%arg0: !fir.ref<f64> {fir.bindc_name = "x"}) -> f64 {
%0 = fir.alloca f64 {bindc_name = "test_real8", uniq_name = "_QFtest_real8Etest_real8"}
%1 = fir.load %arg0 : !fir.ref<f64>
%2 = math.acos %1 : f64
fir.store %2 to %0 : !fir.ref<f64>
%3 = fir.load %0 : !fir.ref<f64>
return %3 : f64
}
//--- acos_relaxed.fir
// RUN: fir-opt %t/acos_relaxed.fir --fir-to-llvm-ir="target=x86_64-unknown-linux-gnu" | FileCheck %t/acos_relaxed.fir
// CHECK: @_QPtest_real4
// CHECK: {{%[A-Za-z0-9._]+}} = llvm.intr.acos({{%[A-Za-z0-9._]+}}) : (f32) -> f32
// CHECK: @_QPtest_real8
// CHECK: {{%[A-Za-z0-9._]+}} = llvm.intr.acos({{%[A-Za-z0-9._]+}}) : (f64) -> f64
func.func @_QPtest_real4(%arg0: !fir.ref<f32> {fir.bindc_name = "x"}) -> f32 {
%0 = fir.alloca f32 {bindc_name = "test_real4", uniq_name = "_QFtest_real4Etest_real4"}
%1 = fir.load %arg0 : !fir.ref<f32>
%2 = math.acos %1 : f32
fir.store %2 to %0 : !fir.ref<f32>
%3 = fir.load %0 : !fir.ref<f32>
return %3 : f32
}
func.func @_QPtest_real8(%arg0: !fir.ref<f64> {fir.bindc_name = "x"}) -> f64 {
%0 = fir.alloca f64 {bindc_name = "test_real8", uniq_name = "_QFtest_real8Etest_real8"}
%1 = fir.load %arg0 : !fir.ref<f64>
%2 = math.acos %1 : f64
fir.store %2 to %0 : !fir.ref<f64>
%3 = fir.load %0 : !fir.ref<f64>
return %3 : f64
}
//--- acos_precise.fir
// RUN: fir-opt %t/acos_precise.fir --fir-to-llvm-ir="target=x86_64-unknown-linux-gnu" | FileCheck %t/acos_precise.fir
// CHECK: @_QPtest_real4
// CHECK: {{%[A-Za-z0-9._]+}} = llvm.call @acosf({{%[A-Za-z0-9._]+}}) : (f32) -> f32
// CHECK: @_QPtest_real8
// CHECK: {{%[A-Za-z0-9._]+}} = llvm.call @acos({{%[A-Za-z0-9._]+}}) : (f64) -> f64
func.func @_QPtest_real4(%arg0: !fir.ref<f32> {fir.bindc_name = "x"}) -> f32 {
%0 = fir.alloca f32 {bindc_name = "test_real4", uniq_name = "_QFtest_real4Etest_real4"}
%1 = fir.load %arg0 : !fir.ref<f32>
%2 = fir.call @acosf(%1) : (f32) -> f32
fir.store %2 to %0 : !fir.ref<f32>
%3 = fir.load %0 : !fir.ref<f32>
return %3 : f32
}
func.func @_QPtest_real8(%arg0: !fir.ref<f64> {fir.bindc_name = "x"}) -> f64 {
%0 = fir.alloca f64 {bindc_name = "test_real8", uniq_name = "_QFtest_real8Etest_real8"}
%1 = fir.load %arg0 : !fir.ref<f64>
%2 = fir.call @acos(%1) : (f64) -> f64
fir.store %2 to %0 : !fir.ref<f64>
%3 = fir.load %0 : !fir.ref<f64>
return %3 : f64
}
func.func private @acosf(f32) -> f32
func.func private @acos(f64) -> f64
//--- atan_fast.fir
// RUN: fir-opt %t/atan_fast.fir --fir-to-llvm-ir="target=x86_64-unknown-linux-gnu" | FileCheck %t/atan_fast.fir
// CHECK: @_QPtest_real4
// CHECK: {{%[A-Za-z0-9._]+}} = llvm.call @atanf({{%[A-Za-z0-9._]+}}) : (f32) -> f32
// CHECK: {{%[A-Za-z0-9._]+}} = llvm.intr.atan({{%[A-Za-z0-9._]+}}) : (f32) -> f32
// CHECK: @_QPtest_real8
// CHECK: {{%[A-Za-z0-9._]+}} = llvm.call @atan({{%[A-Za-z0-9._]+}}) : (f64) -> f64
// CHECK: {{%[A-Za-z0-9._]+}} = llvm.intr.atan({{%[A-Za-z0-9._]+}}) : (f64) -> f64
func.func @_QPtest_real4(%arg0: !fir.ref<f32> {fir.bindc_name = "x"}) -> f32 {
%0 = fir.alloca f32 {bindc_name = "test_real4", uniq_name = "_QFtest_real4Etest_real4"}
@@ -406,10 +560,10 @@ func.func @_QPtest_real8(%arg0: !fir.ref<f64> {fir.bindc_name = "x"}) -> f64 {
//--- atan_relaxed.fir
// RUN: fir-opt %t/atan_relaxed.fir --fir-to-llvm-ir="target=x86_64-unknown-linux-gnu" | FileCheck %t/atan_relaxed.fir
// CHECK: @_QPtest_real4
// CHECK: {{%[A-Za-z0-9._]+}} = llvm.call @atanf({{%[A-Za-z0-9._]+}}) : (f32) -> f32
// CHECK: {{%[A-Za-z0-9._]+}} = llvm.intr.atan({{%[A-Za-z0-9._]+}}) : (f32) -> f32
// CHECK: @_QPtest_real8
// CHECK: {{%[A-Za-z0-9._]+}} = llvm.call @atan({{%[A-Za-z0-9._]+}}) : (f64) -> f64
// CHECK: {{%[A-Za-z0-9._]+}} = llvm.intr.atan({{%[A-Za-z0-9._]+}}) : (f64) -> f64
func.func @_QPtest_real4(%arg0: !fir.ref<f32> {fir.bindc_name = "x"}) -> f32 {
%0 = fir.alloca f32 {bindc_name = "test_real4", uniq_name = "_QFtest_real4Etest_real4"}
@@ -458,10 +612,10 @@ func.func private @atan(f64) -> f64
//--- atan2_fast.fir
// RUN: fir-opt %t/atan2_fast.fir --fir-to-llvm-ir="target=x86_64-unknown-linux-gnu" | FileCheck %t/atan2_fast.fir
// CHECK: @_QPtest_real4
// CHECK: {{%[A-Za-z0-9._]+}} = llvm.call @atan2f({{%[A-Za-z0-9._]+}}, {{%[A-Za-z0-9._]+}}) : (f32, f32) -> f32
// CHECK: {{%[A-Za-z0-9._]+}} = llvm.intr.atan2({{%[A-Za-z0-9._]+}}, {{%[A-Za-z0-9._]+}}) : (f32, f32) -> f32
// CHECK: @_QPtest_real8
// CHECK: {{%[A-Za-z0-9._]+}} = llvm.call @atan2({{%[A-Za-z0-9._]+}}, {{%[A-Za-z0-9._]+}}) : (f64, f64) -> f64
// CHECK: {{%[A-Za-z0-9._]+}} = llvm.intr.atan2({{%[A-Za-z0-9._]+}}, {{%[A-Za-z0-9._]+}}) : (f64, f64) -> f64
func.func @_QPtest_real4(%arg0: !fir.ref<f32> {fir.bindc_name = "x"}, %arg1: !fir.ref<f32> {fir.bindc_name = "y"}) -> f32 {
%0 = fir.alloca f32 {bindc_name = "test_real4", uniq_name = "_QFtest_real4Etest_real4"}
@@ -485,10 +639,10 @@ func.func @_QPtest_real8(%arg0: !fir.ref<f64> {fir.bindc_name = "x"}, %arg1: !fi
//--- atan2_relaxed.fir
// RUN: fir-opt %t/atan2_relaxed.fir --fir-to-llvm-ir="target=x86_64-unknown-linux-gnu" | FileCheck %t/atan2_relaxed.fir
// CHECK: @_QPtest_real4
// CHECK: {{%[A-Za-z0-9._]+}} = llvm.call @atan2f({{%[A-Za-z0-9._]+}}, {{%[A-Za-z0-9._]+}}) : (f32, f32) -> f32
// CHECK: {{%[A-Za-z0-9._]+}} = llvm.intr.atan2({{%[A-Za-z0-9._]+}}, {{%[A-Za-z0-9._]+}}) : (f32, f32) -> f32
// CHECK: @_QPtest_real8
// CHECK: {{%[A-Za-z0-9._]+}} = llvm.call @atan2({{%[A-Za-z0-9._]+}}, {{%[A-Za-z0-9._]+}}) : (f64, f64) -> f64
// CHECK: {{%[A-Za-z0-9._]+}} = llvm.intr.atan2({{%[A-Za-z0-9._]+}}, {{%[A-Za-z0-9._]+}}) : (f64, f64) -> f64
func.func @_QPtest_real4(%arg0: !fir.ref<f32> {fir.bindc_name = "x"}, %arg1: !fir.ref<f32> {fir.bindc_name = "y"}) -> f32 {
%0 = fir.alloca f32 {bindc_name = "test_real4", uniq_name = "_QFtest_real4Etest_real4"}

View File

@@ -42,6 +42,7 @@ using CopySignOpLowering =
ConvertFMFMathToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>;
using CosOpLowering = ConvertFMFMathToLLVMPattern<math::CosOp, LLVM::CosOp>;
using CoshOpLowering = ConvertFMFMathToLLVMPattern<math::CoshOp, LLVM::CoshOp>;
using AcosOpLowering = ConvertFMFMathToLLVMPattern<math::AcosOp, LLVM::ACosOp>;
using CtPopFOpLowering =
VectorConvertToLLVMPattern<math::CtPopOp, LLVM::CtPopOp>;
using Exp2OpLowering = ConvertFMFMathToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
@@ -62,12 +63,15 @@ using RoundOpLowering =
ConvertFMFMathToLLVMPattern<math::RoundOp, LLVM::RoundOp>;
using SinOpLowering = ConvertFMFMathToLLVMPattern<math::SinOp, LLVM::SinOp>;
using SinhOpLowering = ConvertFMFMathToLLVMPattern<math::SinhOp, LLVM::SinhOp>;
using ASinOpLowering = ConvertFMFMathToLLVMPattern<math::AsinOp, LLVM::ASinOp>;
using SqrtOpLowering = ConvertFMFMathToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
using FTruncOpLowering =
ConvertFMFMathToLLVMPattern<math::TruncOp, LLVM::FTruncOp>;
using TanOpLowering = ConvertFMFMathToLLVMPattern<math::TanOp, LLVM::TanOp>;
using TanhOpLowering = ConvertFMFMathToLLVMPattern<math::TanhOp, LLVM::TanhOp>;
using ATanOpLowering = ConvertFMFMathToLLVMPattern<math::AtanOp, LLVM::ATanOp>;
using ATan2OpLowering =
ConvertFMFMathToLLVMPattern<math::Atan2Op, LLVM::ATan2Op>;
// A `CtLz/CtTz/absi(a)` is converted into `CtLz/CtTz/absi(a, false)`.
template <typename MathOp, typename LLVMOp>
struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> {
@@ -353,6 +357,7 @@ void mlir::populateMathToLLVMConversionPatterns(
CopySignOpLowering,
CosOpLowering,
CoshOpLowering,
AcosOpLowering,
CountLeadingZerosOpLowering,
CountTrailingZerosOpLowering,
CtPopFOpLowering,
@@ -371,10 +376,13 @@ void mlir::populateMathToLLVMConversionPatterns(
RsqrtOpLowering,
SinOpLowering,
SinhOpLowering,
ASinOpLowering,
SqrtOpLowering,
FTruncOpLowering,
TanOpLowering,
TanhOpLowering
TanhOpLowering,
ATanOpLowering,
ATan2OpLowering
>(converter, benefit);
// clang-format on
}

View File

@@ -177,6 +177,84 @@ func.func @trigonometrics(%arg0: f32) {
// -----
// CHECK-LABEL: func @inverse_trigonometrics
// CHECK-SAME: [[ARG0:%.+]]: f32
func.func @inverse_trigonometrics(%arg0: f32) {
// CHECK: llvm.intr.asin([[ARG0]]) : (f32) -> f32
%0 = math.asin %arg0 : f32
// CHECK: llvm.intr.acos([[ARG0]]) : (f32) -> f32
%1 = math.acos %arg0 : f32
// CHECK: llvm.intr.atan([[ARG0]]) : (f32) -> f32
%2 = math.atan %arg0 : f32
func.return
}
// -----
// CHECK-LABEL: func @atan2
// CHECK-SAME: [[ARG0:%.+]]: f32, [[ARG1:%.+]]: f32
func.func @atan2(%arg0: f32, %arg1: f32) {
// CHECK: llvm.intr.atan2([[ARG0]], [[ARG1]]) : (f32, f32) -> f32
%0 = math.atan2 %arg0, %arg1 : f32
func.return
}
// -----
// CHECK-LABEL: func @inverse_trigonometrics_vector
// CHECK-SAME: [[ARG0:%.+]]: vector<4xf32>
func.func @inverse_trigonometrics_vector(%arg0: vector<4xf32>) {
// CHECK: llvm.intr.asin([[ARG0]]) : (vector<4xf32>) -> vector<4xf32>
%0 = math.asin %arg0 : vector<4xf32>
// CHECK: llvm.intr.acos([[ARG0]]) : (vector<4xf32>) -> vector<4xf32>
%1 = math.acos %arg0 : vector<4xf32>
// CHECK: llvm.intr.atan([[ARG0]]) : (vector<4xf32>) -> vector<4xf32>
%2 = math.atan %arg0 : vector<4xf32>
func.return
}
// -----
// CHECK-LABEL: func @atan2_vector
// CHECK-SAME: [[ARG0:%.+]]: vector<4xf32>, [[ARG1:%.+]]: vector<4xf32>
func.func @atan2_vector(%arg0: vector<4xf32>, %arg1: vector<4xf32>) {
// CHECK: llvm.intr.atan2([[ARG0]], [[ARG1]]) : (vector<4xf32>, vector<4xf32>) -> vector<4xf32>
%0 = math.atan2 %arg0, %arg1 : vector<4xf32>
func.return
}
// -----
// CHECK-LABEL: func @inverse_trigonometrics_fmf
// CHECK-SAME: [[ARG0:%.+]]: f32
func.func @inverse_trigonometrics_fmf(%arg0: f32) {
// CHECK: llvm.intr.asin([[ARG0]]) {fastmathFlags = #llvm.fastmath<fast>} : (f32) -> f32
%0 = math.asin %arg0 fastmath<fast> : f32
// CHECK: llvm.intr.acos([[ARG0]]) {fastmathFlags = #llvm.fastmath<fast>} : (f32) -> f32
%1 = math.acos %arg0 fastmath<fast> : f32
// CHECK: llvm.intr.atan([[ARG0]]) {fastmathFlags = #llvm.fastmath<fast>} : (f32) -> f32
%2 = math.atan %arg0 fastmath<fast> : f32
func.return
}
// -----
// CHECK-LABEL: func @atan2_fmf
// CHECK-SAME: [[ARG0:%.+]]: f32, [[ARG1:%.+]]: f32
func.func @atan2_fmf(%arg0: f32, %arg1: f32) {
// CHECK: llvm.intr.atan2([[ARG0]], [[ARG1]]) {fastmathFlags = #llvm.fastmath<fast>} : (f32, f32) -> f32
%0 = math.atan2 %arg0, %arg1 fastmath<fast> : f32
func.return
}
// -----
// CHECK-LABEL: func @hyperbolics
// CHECK-SAME: [[ARG0:%.+]]: f32
func.func @hyperbolics(%arg0: f32) {