[flang] Fold BTEST

Implements constant folding for BTEST intrinsic function.

Differential Revision: https://reviews.llvm.org/D111849
This commit is contained in:
peter klausler
2021-10-12 18:20:18 -07:00
parent 1053e0b27c
commit 2f80b73e0c
4 changed files with 54 additions and 2 deletions

View File

@@ -50,6 +50,7 @@ template <TypeCategory CATEGORY, int KIND = 0> class Type;
using SubscriptInteger = Type<TypeCategory::Integer, 8>;
using CInteger = Type<TypeCategory::Integer, 4>;
using LargestInt = Type<TypeCategory::Integer, 16>;
using LogicalResult = Type<TypeCategory::Logical, 4>;
using LargestReal = Type<TypeCategory::Real, 16>;
using Ascii = Type<TypeCategory::Character, 1>;

View File

@@ -40,6 +40,7 @@ Expr<Type<TypeCategory::Logical, KIND>> FoldIntrinsicFunction(
auto *intrinsic{std::get_if<SpecificIntrinsic>(&funcRef.proc().u)};
CHECK(intrinsic);
std::string name{intrinsic->name};
using SameInt = Type<TypeCategory::Integer, KIND>;
if (name == "all") {
return FoldAllAny(
context, std::move(funcRef), &Scalar<T>::AND, Scalar<T>{true});
@@ -59,7 +60,6 @@ Expr<Type<TypeCategory::Logical, KIND>> FoldIntrinsicFunction(
}
return gotConstant ? Expr<T>{false} : Expr<T>{std::move(funcRef)};
} else if (name == "bge" || name == "bgt" || name == "ble" || name == "blt") {
using LargestInt = Type<TypeCategory::Integer, 16>;
static_assert(std::is_same_v<Scalar<LargestInt>, BOZLiteralConstant>);
// Arguments do not have to be of the same integer type. Convert all
// arguments to the biggest integer type before comparing them to
@@ -89,6 +89,26 @@ Expr<Type<TypeCategory::Logical, KIND>> FoldIntrinsicFunction(
[&fptr](const Scalar<LargestInt> &i, const Scalar<LargestInt> &j) {
return Scalar<T>{std::invoke(fptr, i, j)};
}));
} else if (name == "btest") {
if (const auto *ix{UnwrapExpr<Expr<SomeInteger>>(args[0])}) {
return std::visit(
[&](const auto &x) {
using IT = ResultType<decltype(x)>;
return FoldElementalIntrinsic<T, IT, SameInt>(context,
std::move(funcRef),
ScalarFunc<T, IT, SameInt>(
[&](const Scalar<IT> &x, const Scalar<SameInt> &pos) {
auto posVal{pos.ToInt64()};
if (posVal < 0 || posVal >= x.bits) {
context.messages().Say(
"POS=%jd out of range for BTEST"_err_en_US,
static_cast<std::intmax_t>(posVal));
}
return Scalar<T>{x.BTEST(posVal)};
}));
},
ix->u);
}
} else if (name == "isnan" || name == "__builtin_ieee_is_nan") {
// A warning about an invalid argument is discarded from converting
// the argument of isnan() / IEEE_IS_NAN().
@@ -139,7 +159,7 @@ Expr<Type<TypeCategory::Logical, KIND>> FoldIntrinsicFunction(
name == "__builtin_ieee_support_underflow_control") {
return Expr<T>{true};
}
// TODO: btest, dot_product, is_iostat_end,
// TODO: dot_product, is_iostat_end,
// is_iostat_eor, logical, matmul, out_of_range,
// parity, transfer
return Expr<T>{std::move(funcRef)};

View File

@@ -0,0 +1,21 @@
! RUN: %python %S/test_folding.py %s %flang_fc1
! Tests folding of BTEST
module m1
integer, parameter :: ia1(*) = [(j, j=0, 15)]
logical, parameter :: test_ia1a = all(btest(ia1, 0) .eqv. [(.false., .true., j=1, 8)])
logical, parameter :: test_ia1b = all(btest(ia1, 1) .eqv. [(.false., .false., .true., .true., j=1, 4)])
logical, parameter :: test_ia1c = all(btest(ia1, 2) .eqv. [(modulo(j/4, 2) == 1, j=0, 15)])
logical, parameter :: test_ia1d = all(btest(ia1, 3) .eqv. [(j > 8, j=1, 16)])
logical, parameter :: test_shft1 = all([(btest(ishft(1_1, j), j), j=0, 7)])
logical, parameter :: test_shft2 = all([(btest(ishft(1_2, j), j), j=0, 15)])
logical, parameter :: test_shft4 = all([(btest(ishft(1_4, j), j), j=0, 31)])
logical, parameter :: test_shft8 = all([(btest(ishft(1_8, j), j), j=0, 63)])
logical, parameter :: test_shft16 = all([(btest(ishft(1_16, j), j), j=0, 127)])
logical, parameter :: test_set1 = all([(btest(ibset(0_1, j), j), j=0, 7)])
logical, parameter :: test_set2 = all([(btest(ibset(0_2, j), j), j=0, 15)])
logical, parameter :: test_set4 = all([(btest(ibset(0_4, j), j), j=0, 31)])
logical, parameter :: test_set8 = all([(btest(ibset(0_8, j), j), j=0, 63)])
logical, parameter :: test_set16 = all([(btest(ibset(0_16, j), j), j=0, 127)])
logical, parameter :: test_z = .not. any([(btest(0_4, j), j=0, 31)])
logical, parameter :: test_shft1e = all(btest([(ishft(1_1, j), j=0, 7)], [(j, j=0, 7)]))
end module

View File

@@ -50,4 +50,14 @@ module m
!CHECK: error: Invalid 'vector=' argument in UNPACK: the 'mask=' argument has 3 true elements, but the vector has only 2 elements
x = unpack([1,2], mask, 0)
end subroutine
subroutine s6
!CHECK: error: POS=-1 out of range for BTEST
logical, parameter :: bad1 = btest(0, -1)
!CHECK: error: POS=32 out of range for BTEST
logical, parameter :: bad2 = btest(0, 32)
!CHECK-NOT: error: POS=33 out of range for BTEST
logical, parameter :: bad3 = btest(0_8, 33)
!CHECK: error: POS=64 out of range for BTEST
logical, parameter :: bad4 = btest(0_8, 64)
end subroutine
end module