[KnownBits] Make {s,u}{add,sub}_sat optimal (#113096)
Changes are:
1) Make signed-overflow detection optimal
2) For signed-overflow, try to rule out direction even if we can't
totally rule out overflow.
3) Intersect add/sub assuming no overflow with possible overflow
clamping values as opposed to add/sub without the assumption.
This commit is contained in:
@@ -610,28 +610,82 @@ static KnownBits computeForSatAddSub(bool Add, bool Signed,
|
||||
const KnownBits &RHS) {
|
||||
// We don't see NSW even for sadd/ssub as we want to check if the result has
|
||||
// signed overflow.
|
||||
KnownBits Res =
|
||||
KnownBits::computeForAddSub(Add, /*NSW=*/false, /*NUW=*/false, LHS, RHS);
|
||||
unsigned BitWidth = Res.getBitWidth();
|
||||
auto SignBitKnown = [&](const KnownBits &K) {
|
||||
return K.Zero[BitWidth - 1] || K.One[BitWidth - 1];
|
||||
};
|
||||
std::optional<bool> Overflow;
|
||||
unsigned BitWidth = LHS.getBitWidth();
|
||||
|
||||
std::optional<bool> Overflow;
|
||||
// Even if we can't entirely rule out overflow, we may be able to rule out
|
||||
// overflow in one direction. This allows us to potentially keep some of the
|
||||
// add/sub bits. I.e if we can't overflow in the positive direction we won't
|
||||
// clamp to INT_MAX so we can keep low 0s from the add/sub result.
|
||||
bool MayNegClamp = true;
|
||||
bool MayPosClamp = true;
|
||||
if (Signed) {
|
||||
// If we can actually detect overflow do so. Otherwise leave Overflow as
|
||||
// nullopt (we assume it may have happened).
|
||||
if (SignBitKnown(LHS) && SignBitKnown(RHS) && SignBitKnown(Res)) {
|
||||
// Easy cases we can rule out any overflow.
|
||||
if (Add && ((LHS.isNegative() && RHS.isNonNegative()) ||
|
||||
(LHS.isNonNegative() && RHS.isNegative())))
|
||||
Overflow = false;
|
||||
else if (!Add && (((LHS.isNegative() && RHS.isNegative()) ||
|
||||
(LHS.isNonNegative() && RHS.isNonNegative()))))
|
||||
Overflow = false;
|
||||
else {
|
||||
// Check if we may overflow. If we can't rule out overflow then check if
|
||||
// we can rule out a direction at least.
|
||||
KnownBits UnsignedLHS = LHS;
|
||||
KnownBits UnsignedRHS = RHS;
|
||||
// Get version of LHS/RHS with clearer signbit. This allows us to detect
|
||||
// how the addition/subtraction might overflow into the signbit. Then
|
||||
// using the actual known signbits of LHS/RHS, we can figure out which
|
||||
// overflows are/aren't possible.
|
||||
UnsignedLHS.One.clearSignBit();
|
||||
UnsignedLHS.Zero.setSignBit();
|
||||
UnsignedRHS.One.clearSignBit();
|
||||
UnsignedRHS.Zero.setSignBit();
|
||||
KnownBits Res =
|
||||
KnownBits::computeForAddSub(Add, /*NSW=*/false,
|
||||
/*NUW=*/false, UnsignedLHS, UnsignedRHS);
|
||||
if (Add) {
|
||||
// sadd.sat
|
||||
Overflow = (LHS.isNonNegative() == RHS.isNonNegative() &&
|
||||
Res.isNonNegative() != LHS.isNonNegative());
|
||||
if (Res.isNegative()) {
|
||||
// Only overflow scenario is Pos + Pos.
|
||||
MayNegClamp = false;
|
||||
// Pos + Pos will overflow with extra signbit.
|
||||
if (LHS.isNonNegative() && RHS.isNonNegative())
|
||||
Overflow = true;
|
||||
} else if (Res.isNonNegative()) {
|
||||
// Only overflow scenario is Neg + Neg
|
||||
MayPosClamp = false;
|
||||
// Neg + Neg will overflow without extra signbit.
|
||||
if (LHS.isNegative() && RHS.isNegative())
|
||||
Overflow = true;
|
||||
}
|
||||
// We will never clamp to the opposite sign of N-bit result.
|
||||
if (LHS.isNegative() || RHS.isNegative())
|
||||
MayPosClamp = false;
|
||||
if (LHS.isNonNegative() || RHS.isNonNegative())
|
||||
MayNegClamp = false;
|
||||
} else {
|
||||
// ssub.sat
|
||||
Overflow = (LHS.isNonNegative() != RHS.isNonNegative() &&
|
||||
Res.isNonNegative() != LHS.isNonNegative());
|
||||
if (Res.isNegative()) {
|
||||
// Only overflow scenario is Neg - Pos.
|
||||
MayPosClamp = false;
|
||||
// Neg - Pos will overflow with extra signbit.
|
||||
if (LHS.isNegative() && RHS.isNonNegative())
|
||||
Overflow = true;
|
||||
} else if (Res.isNonNegative()) {
|
||||
// Only overflow scenario is Pos - Neg.
|
||||
MayNegClamp = false;
|
||||
// Pos - Neg will overflow without extra signbit.
|
||||
if (LHS.isNonNegative() && RHS.isNegative())
|
||||
Overflow = true;
|
||||
}
|
||||
// We will never clamp to the opposite sign of N-bit result.
|
||||
if (LHS.isNegative() || RHS.isNonNegative())
|
||||
MayPosClamp = false;
|
||||
if (LHS.isNonNegative() || RHS.isNegative())
|
||||
MayNegClamp = false;
|
||||
}
|
||||
}
|
||||
// If we have ruled out all clamping, we will never overflow.
|
||||
if (!MayNegClamp && !MayPosClamp)
|
||||
Overflow = false;
|
||||
} else if (Add) {
|
||||
// uadd.sat
|
||||
bool Of;
|
||||
@@ -656,52 +710,8 @@ static KnownBits computeForSatAddSub(bool Add, bool Signed,
|
||||
}
|
||||
}
|
||||
|
||||
if (Signed) {
|
||||
if (Add) {
|
||||
if (LHS.isNonNegative() && RHS.isNonNegative()) {
|
||||
// Pos + Pos -> Pos
|
||||
Res.One.clearSignBit();
|
||||
Res.Zero.setSignBit();
|
||||
}
|
||||
if (LHS.isNegative() && RHS.isNegative()) {
|
||||
// Neg + Neg -> Neg
|
||||
Res.One.setSignBit();
|
||||
Res.Zero.clearSignBit();
|
||||
}
|
||||
} else {
|
||||
if (LHS.isNegative() && RHS.isNonNegative()) {
|
||||
// Neg - Pos -> Neg
|
||||
Res.One.setSignBit();
|
||||
Res.Zero.clearSignBit();
|
||||
} else if (LHS.isNonNegative() && RHS.isNegative()) {
|
||||
// Pos - Neg -> Pos
|
||||
Res.One.clearSignBit();
|
||||
Res.Zero.setSignBit();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Add: Leading ones of either operand are preserved.
|
||||
// Sub: Leading zeros of LHS and leading ones of RHS are preserved
|
||||
// as leading zeros in the result.
|
||||
unsigned LeadingKnown;
|
||||
if (Add)
|
||||
LeadingKnown =
|
||||
std::max(LHS.countMinLeadingOnes(), RHS.countMinLeadingOnes());
|
||||
else
|
||||
LeadingKnown =
|
||||
std::max(LHS.countMinLeadingZeros(), RHS.countMinLeadingOnes());
|
||||
|
||||
// We select between the operation result and all-ones/zero
|
||||
// respectively, so we can preserve known ones/zeros.
|
||||
APInt Mask = APInt::getHighBitsSet(BitWidth, LeadingKnown);
|
||||
if (Add) {
|
||||
Res.One |= Mask;
|
||||
Res.Zero &= ~Mask;
|
||||
} else {
|
||||
Res.Zero |= Mask;
|
||||
Res.One &= ~Mask;
|
||||
}
|
||||
}
|
||||
KnownBits Res = KnownBits::computeForAddSub(Add, /*NSW=*/Signed,
|
||||
/*NUW=*/!Signed, LHS, RHS);
|
||||
|
||||
if (Overflow) {
|
||||
// We know whether or not we overflowed.
|
||||
@@ -714,7 +724,7 @@ static KnownBits computeForSatAddSub(bool Add, bool Signed,
|
||||
APInt C;
|
||||
if (Signed) {
|
||||
// sadd.sat / ssub.sat
|
||||
assert(SignBitKnown(LHS) &&
|
||||
assert(!LHS.isSignUnknown() &&
|
||||
"We somehow know overflow without knowing input sign");
|
||||
C = LHS.isNegative() ? APInt::getSignedMinValue(BitWidth)
|
||||
: APInt::getSignedMaxValue(BitWidth);
|
||||
@@ -735,8 +745,10 @@ static KnownBits computeForSatAddSub(bool Add, bool Signed,
|
||||
if (Signed) {
|
||||
// sadd.sat/ssub.sat
|
||||
// We can keep our information about the sign bits.
|
||||
Res.Zero.clearLowBits(BitWidth - 1);
|
||||
Res.One.clearLowBits(BitWidth - 1);
|
||||
if (MayPosClamp)
|
||||
Res.Zero.clearLowBits(BitWidth - 1);
|
||||
if (MayNegClamp)
|
||||
Res.One.clearLowBits(BitWidth - 1);
|
||||
} else if (Add) {
|
||||
// uadd.sat
|
||||
// We need to clear all the known zeros as we can only use the leading ones.
|
||||
|
||||
@@ -142,14 +142,7 @@ define i1 @ssub_sat_low_bits(i8 %x, i8 %y) {
|
||||
|
||||
define i1 @ssub_sat_fail_may_overflow(i8 %x, i8 %y) {
|
||||
; CHECK-LABEL: @ssub_sat_fail_may_overflow(
|
||||
; CHECK-NEXT: [[XX:%.*]] = and i8 [[X:%.*]], 15
|
||||
; CHECK-NEXT: [[YY:%.*]] = and i8 [[Y:%.*]], 15
|
||||
; CHECK-NEXT: [[LHS:%.*]] = or i8 [[XX]], 1
|
||||
; CHECK-NEXT: [[RHS:%.*]] = and i8 [[YY]], -2
|
||||
; CHECK-NEXT: [[EXP:%.*]] = call i8 @llvm.ssub.sat.i8(i8 [[LHS]], i8 [[RHS]])
|
||||
; CHECK-NEXT: [[AND:%.*]] = and i8 [[EXP]], 1
|
||||
; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[AND]], 0
|
||||
; CHECK-NEXT: ret i1 [[R]]
|
||||
; CHECK-NEXT: ret i1 false
|
||||
;
|
||||
%xx = and i8 %x, 15
|
||||
%yy = and i8 %y, 15
|
||||
|
||||
@@ -383,26 +383,22 @@ TEST(KnownBitsTest, BinaryExhaustive) {
|
||||
"sadd_sat", KnownBits::sadd_sat,
|
||||
[](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
|
||||
return N1.sadd_sat(N2);
|
||||
},
|
||||
/*CheckOptimality=*/false);
|
||||
});
|
||||
testBinaryOpExhaustive(
|
||||
"uadd_sat", KnownBits::uadd_sat,
|
||||
[](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
|
||||
return N1.uadd_sat(N2);
|
||||
},
|
||||
/*CheckOptimality=*/false);
|
||||
});
|
||||
testBinaryOpExhaustive(
|
||||
"ssub_sat", KnownBits::ssub_sat,
|
||||
[](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
|
||||
return N1.ssub_sat(N2);
|
||||
},
|
||||
/*CheckOptimality=*/false);
|
||||
});
|
||||
testBinaryOpExhaustive(
|
||||
"usub_sat", KnownBits::usub_sat,
|
||||
[](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
|
||||
return N1.usub_sat(N2);
|
||||
},
|
||||
/*CheckOptimality=*/false);
|
||||
});
|
||||
testBinaryOpExhaustive(
|
||||
"shl",
|
||||
[](const KnownBits &Known1, const KnownBits &Known2) {
|
||||
|
||||
Reference in New Issue
Block a user