diff --git a/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp b/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp index da5be383df15..cbad5dd35768 100644 --- a/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp +++ b/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp @@ -64,20 +64,6 @@ static cl::opt DumpReproducers( static int64_t MaxConstraintValue = std::numeric_limits::max(); static int64_t MinSignedConstraintValue = std::numeric_limits::min(); -// A helper to multiply 2 signed integers where overflowing is allowed. -static int64_t multiplyWithOverflow(int64_t A, int64_t B) { - int64_t Result; - MulOverflow(A, B, Result); - return Result; -} - -// A helper to add 2 signed integers where overflowing is allowed. -static int64_t addWithOverflow(int64_t A, int64_t B) { - int64_t Result; - AddOverflow(A, B, Result); - return Result; -} - static Instruction *getContextInstForUse(Use &U) { Instruction *UserI = cast(U.getUser()); if (auto *Phi = dyn_cast(UserI)) @@ -366,26 +352,42 @@ struct Decomposition { Decomposition(int64_t Offset, ArrayRef Vars) : Offset(Offset), Vars(Vars) {} - void add(int64_t OtherOffset) { - Offset = addWithOverflow(Offset, OtherOffset); + /// Add \p OtherOffset and return true if the operation overflows, i.e. the + /// new decomposition is invalid. + [[nodiscard]] bool add(int64_t OtherOffset) { + return AddOverflow(Offset, OtherOffset, Offset); } - void add(const Decomposition &Other) { - add(Other.Offset); + /// Add \p Other and return true if the operation overflows, i.e. the new + /// decomposition is invalid. + [[nodiscard]] bool add(const Decomposition &Other) { + if (add(Other.Offset)) + return true; append_range(Vars, Other.Vars); + return false; } - void sub(const Decomposition &Other) { + /// Subtract \p Other and return true if the operation overflows, i.e. the new + /// decomposition is invalid. + [[nodiscard]] bool sub(const Decomposition &Other) { Decomposition Tmp = Other; - Tmp.mul(-1); - add(Tmp.Offset); + if (Tmp.mul(-1)) + return true; + if (add(Tmp.Offset)) + return true; append_range(Vars, Tmp.Vars); + return false; } - void mul(int64_t Factor) { - Offset = multiplyWithOverflow(Offset, Factor); + /// Multiply all coefficients by \p Factor and return true if the operation + /// overflows, i.e. the new decomposition is invalid. + [[nodiscard]] bool mul(int64_t Factor) { + if (MulOverflow(Offset, Factor, Offset)) + return true; for (auto &Var : Vars) - Var.Coefficient = multiplyWithOverflow(Var.Coefficient, Factor); + if (MulOverflow(Var.Coefficient, Factor, Var.Coefficient)) + return true; + return false; } }; @@ -467,8 +469,10 @@ static Decomposition decomposeGEP(GEPOperator &GEP, Decomposition Result(ConstantOffset.getSExtValue(), DecompEntry(1, BasePtr)); for (auto [Index, Scale] : VariableOffsets) { auto IdxResult = decompose(Index, Preconditions, IsSigned, DL); - IdxResult.mul(Scale.getSExtValue()); - Result.add(IdxResult); + if (IdxResult.mul(Scale.getSExtValue())) + return &GEP; + if (Result.add(IdxResult)) + return &GEP; if (!NW.hasNoUnsignedWrap()) { // Try to prove nuw from nusw and nneg. @@ -488,11 +492,13 @@ static Decomposition decompose(Value *V, SmallVectorImpl &Preconditions, bool IsSigned, const DataLayout &DL) { - auto MergeResults = [&Preconditions, IsSigned, &DL](Value *A, Value *B, - bool IsSignedB) { + auto MergeResults = [&Preconditions, IsSigned, + &DL](Value *A, Value *B, + bool IsSignedB) -> std::optional { auto ResA = decompose(A, Preconditions, IsSigned, DL); auto ResB = decompose(B, Preconditions, IsSignedB, DL); - ResA.add(ResB); + if (ResA.add(ResB)) + return std::nullopt; return ResA; }; @@ -533,21 +539,26 @@ static Decomposition decompose(Value *V, V = Op0; } - if (match(V, m_NSWAdd(m_Value(Op0), m_Value(Op1)))) - return MergeResults(Op0, Op1, IsSigned); + if (match(V, m_NSWAdd(m_Value(Op0), m_Value(Op1)))) { + if (auto Decomp = MergeResults(Op0, Op1, IsSigned)) + return *Decomp; + return {V, IsKnownNonNegative}; + } if (match(V, m_NSWSub(m_Value(Op0), m_Value(Op1)))) { auto ResA = decompose(Op0, Preconditions, IsSigned, DL); auto ResB = decompose(Op1, Preconditions, IsSigned, DL); - ResA.sub(ResB); - return ResA; + if (!ResA.sub(ResB)) + return ResA; + return {V, IsKnownNonNegative}; } ConstantInt *CI; if (match(V, m_NSWMul(m_Value(Op0), m_ConstantInt(CI))) && canUseSExt(CI)) { auto Result = decompose(Op0, Preconditions, IsSigned, DL); - Result.mul(CI->getSExtValue()); - return Result; + if (!Result.mul(CI->getSExtValue())) + return Result; + return {V, IsKnownNonNegative}; } // (shl nsw x, shift) is (mul nsw x, (1<getIntegerBitWidth() - 1) { assert(Shift < 64 && "Would overflow"); auto Result = decompose(Op0, Preconditions, IsSigned, DL); - Result.mul(int64_t(1) << Shift); - return Result; + if (!Result.mul(int64_t(1) << Shift)) + return Result; + return {V, IsKnownNonNegative}; } } @@ -593,8 +605,11 @@ static Decomposition decompose(Value *V, Value *Op1; ConstantInt *CI; if (match(V, m_NUWAdd(m_Value(Op0), m_Value(Op1)))) { - return MergeResults(Op0, Op1, IsSigned); + if (auto Decomp = MergeResults(Op0, Op1, IsSigned)) + return *Decomp; + return {V, IsKnownNonNegative}; } + if (match(V, m_NSWAdd(m_Value(Op0), m_Value(Op1)))) { if (!isKnownNonNegative(Op0, DL)) Preconditions.emplace_back(CmpInst::ICMP_SGE, Op0, @@ -603,7 +618,9 @@ static Decomposition decompose(Value *V, Preconditions.emplace_back(CmpInst::ICMP_SGE, Op1, ConstantInt::get(Op1->getType(), 0)); - return MergeResults(Op0, Op1, IsSigned); + if (auto Decomp = MergeResults(Op0, Op1, IsSigned)) + return *Decomp; + return {V, IsKnownNonNegative}; } if (match(V, m_Add(m_Value(Op0), m_ConstantInt(CI))) && CI->isNegative() && @@ -611,33 +628,41 @@ static Decomposition decompose(Value *V, Preconditions.emplace_back( CmpInst::ICMP_UGE, Op0, ConstantInt::get(Op0->getType(), CI->getSExtValue() * -1)); - return MergeResults(Op0, CI, true); + if (auto Decomp = MergeResults(Op0, CI, true)) + return *Decomp; + return {V, IsKnownNonNegative}; } // Decompose or as an add if there are no common bits between the operands. - if (match(V, m_DisjointOr(m_Value(Op0), m_ConstantInt(CI)))) - return MergeResults(Op0, CI, IsSigned); + if (match(V, m_DisjointOr(m_Value(Op0), m_ConstantInt(CI)))) { + if (auto Decomp = MergeResults(Op0, CI, IsSigned)) + return *Decomp; + return {V, IsKnownNonNegative}; + } if (match(V, m_NUWShl(m_Value(Op1), m_ConstantInt(CI))) && canUseSExt(CI)) { if (CI->getSExtValue() < 0 || CI->getSExtValue() >= 64) return {V, IsKnownNonNegative}; auto Result = decompose(Op1, Preconditions, IsSigned, DL); - Result.mul(int64_t{1} << CI->getSExtValue()); - return Result; + if (!Result.mul(int64_t{1} << CI->getSExtValue())) + return Result; + return {V, IsKnownNonNegative}; } if (match(V, m_NUWMul(m_Value(Op1), m_ConstantInt(CI))) && canUseSExt(CI) && (!CI->isNegative())) { auto Result = decompose(Op1, Preconditions, IsSigned, DL); - Result.mul(CI->getSExtValue()); - return Result; + if (!Result.mul(CI->getSExtValue())) + return Result; + return {V, IsKnownNonNegative}; } if (match(V, m_NUWSub(m_Value(Op0), m_Value(Op1)))) { auto ResA = decompose(Op0, Preconditions, IsSigned, DL); auto ResB = decompose(Op1, Preconditions, IsSigned, DL); - ResA.sub(ResB); - return ResA; + if (!ResA.sub(ResB)) + return ResA; + return {V, IsKnownNonNegative}; } return {V, IsKnownNonNegative}; diff --git a/llvm/test/Transforms/ConstraintElimination/constraint-overflow.ll b/llvm/test/Transforms/ConstraintElimination/constraint-overflow.ll index 57b7b11be0cf..f36ac311878b 100644 --- a/llvm/test/Transforms/ConstraintElimination/constraint-overflow.ll +++ b/llvm/test/Transforms/ConstraintElimination/constraint-overflow.ll @@ -52,3 +52,25 @@ entry: %c = icmp slt i64 0, %sub ret i1 %c } + +define i1 @pr140481(i32 %x) { +; CHECK-LABEL: define i1 @pr140481( +; CHECK-SAME: i32 [[X:%.*]]) { +; CHECK-NEXT: entry: +; CHECK-NEXT: [[COND:%.*]] = icmp slt i32 [[X]], 0 +; CHECK-NEXT: call void @llvm.assume(i1 [[COND]]) +; CHECK-NEXT: [[ADD:%.*]] = add nsw i32 [[X]], 5001000 +; CHECK-NEXT: [[MUL1:%.*]] = mul nsw i32 [[ADD]], -5001000 +; CHECK-NEXT: [[MUL2:%.*]] = mul nsw i32 [[MUL1]], 5001000 +; CHECK-NEXT: [[CMP2:%.*]] = icmp sgt i32 [[MUL2]], 0 +; CHECK-NEXT: ret i1 [[CMP2]] +; +entry: + %cond = icmp slt i32 %x, 0 + call void @llvm.assume(i1 %cond) + %add = add nsw i32 %x, 5001000 + %mul1 = mul nsw i32 %add, -5001000 + %mul2 = mul nsw i32 %mul1, 5001000 + %cmp2 = icmp sgt i32 %mul2, 0 + ret i1 %cmp2 +}