Do not trigger UB during AffineExpr parsing. (#96896)

Currently, parsing expressions that are undefined will trigger UB during
compilation (e.g. `9223372036854775807 * 2`). This change instead
leaves the expressions as they were.

This change is an NFC for compilations that did not previously involve
UB.
This commit is contained in:
Johannes Reifferscheid
2024-06-28 07:31:33 +02:00
committed by GitHub
parent fcffb2c024
commit 52216349b6
3 changed files with 84 additions and 12 deletions

View File

@@ -435,7 +435,8 @@ inline uint64_t divideCeil(uint64_t Numerator, uint64_t Denominator) {
}
/// Returns the integer ceil(Numerator / Denominator). Signed version.
/// Guaranteed to never overflow.
/// Guaranteed to never overflow, unless Numerator is INT64_MIN and Denominator
/// is -1.
inline int64_t divideCeilSigned(int64_t Numerator, int64_t Denominator) {
assert(Denominator && "Division by zero");
if (!Numerator)
@@ -448,7 +449,8 @@ inline int64_t divideCeilSigned(int64_t Numerator, int64_t Denominator) {
}
/// Returns the integer floor(Numerator / Denominator). Signed version.
/// Guaranteed to never overflow.
/// Guaranteed to never overflow, unless Numerator is INT64_MIN and Denominator
/// is -1.
inline int64_t divideFloorSigned(int64_t Numerator, int64_t Denominator) {
assert(Denominator && "Division by zero");
if (!Numerator)

View File

@@ -6,6 +6,8 @@
//
//===----------------------------------------------------------------------===//
#include <cstdint>
#include <limits>
#include <utility>
#include "AffineExprDetail.h"
@@ -645,10 +647,14 @@ mlir::getAffineConstantExprs(ArrayRef<int64_t> constants,
static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) {
auto lhsConst = dyn_cast<AffineConstantExpr>(lhs);
auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
// Fold if both LHS, RHS are a constant.
if (lhsConst && rhsConst)
return getAffineConstantExpr(lhsConst.getValue() + rhsConst.getValue(),
lhs.getContext());
// Fold if both LHS, RHS are a constant and the sum does not overflow.
if (lhsConst && rhsConst) {
int64_t sum;
if (llvm::AddOverflow(lhsConst.getValue(), rhsConst.getValue(), sum)) {
return nullptr;
}
return getAffineConstantExpr(sum, lhs.getContext());
}
// Canonicalize so that only the RHS is a constant. (4 + d0 becomes d0 + 4).
// If only one of them is a symbolic expressions, make it the RHS.
@@ -774,9 +780,13 @@ static AffineExpr simplifyMul(AffineExpr lhs, AffineExpr rhs) {
auto lhsConst = dyn_cast<AffineConstantExpr>(lhs);
auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
if (lhsConst && rhsConst)
return getAffineConstantExpr(lhsConst.getValue() * rhsConst.getValue(),
lhs.getContext());
if (lhsConst && rhsConst) {
int64_t product;
if (llvm::MulOverflow(lhsConst.getValue(), rhsConst.getValue(), product)) {
return nullptr;
}
return getAffineConstantExpr(product, lhs.getContext());
}
if (!lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant())
return nullptr;
@@ -849,10 +859,16 @@ static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs) {
if (!rhsConst || rhsConst.getValue() < 1)
return nullptr;
if (lhsConst)
if (lhsConst) {
// divideFloorSigned can only overflow in this case:
if (lhsConst.getValue() == std::numeric_limits<int64_t>::min() &&
rhsConst.getValue() == -1) {
return nullptr;
}
return getAffineConstantExpr(
divideFloorSigned(lhsConst.getValue(), rhsConst.getValue()),
lhs.getContext());
}
// Fold floordiv of a multiply with a constant that is a multiple of the
// divisor. Eg: (i * 128) floordiv 64 = i * 2.
@@ -905,10 +921,16 @@ static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs) {
if (!rhsConst || rhsConst.getValue() < 1)
return nullptr;
if (lhsConst)
if (lhsConst) {
// divideCeilSigned can only overflow in this case:
if (lhsConst.getValue() == std::numeric_limits<int64_t>::min() &&
rhsConst.getValue() == -1) {
return nullptr;
}
return getAffineConstantExpr(
divideCeilSigned(lhsConst.getValue(), rhsConst.getValue()),
lhs.getContext());
}
// Fold ceildiv of a multiply with a constant that is a multiple of the
// divisor. Eg: (i * 128) ceildiv 64 = i * 2.
@@ -950,9 +972,11 @@ static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs) {
if (!rhsConst || rhsConst.getValue() < 1)
return nullptr;
if (lhsConst)
if (lhsConst) {
// mod never overflows.
return getAffineConstantExpr(mod(lhsConst.getValue(), rhsConst.getValue()),
lhs.getContext());
}
// Fold modulo of an expression that is known to be a multiple of a constant
// to zero if that constant is a multiple of the modulo factor. Eg: (i * 128)

View File

@@ -6,6 +6,9 @@
//
//===----------------------------------------------------------------------===//
#include <cstdint>
#include <limits>
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Builders.h"
#include "gtest/gtest.h"
@@ -30,3 +33,46 @@ TEST(AffineExprTest, constructFromBinaryOperators) {
ASSERT_EQ(product.getKind(), AffineExprKind::Mul);
ASSERT_EQ(remainder.getKind(), AffineExprKind::Mod);
}
TEST(AffineExprTest, constantFolding) {
MLIRContext ctx;
OpBuilder b(&ctx);
auto cn1 = b.getAffineConstantExpr(-1);
auto c0 = b.getAffineConstantExpr(0);
auto c1 = b.getAffineConstantExpr(1);
auto c2 = b.getAffineConstantExpr(2);
auto c3 = b.getAffineConstantExpr(3);
auto c6 = b.getAffineConstantExpr(6);
auto cmax = b.getAffineConstantExpr(std::numeric_limits<int64_t>::max());
auto cmin = b.getAffineConstantExpr(std::numeric_limits<int64_t>::min());
ASSERT_EQ(getAffineBinaryOpExpr(AffineExprKind::Add, c1, c2), c3);
ASSERT_EQ(getAffineBinaryOpExpr(AffineExprKind::Mul, c2, c3), c6);
ASSERT_EQ(getAffineBinaryOpExpr(AffineExprKind::FloorDiv, c3, c2), c1);
ASSERT_EQ(getAffineBinaryOpExpr(AffineExprKind::CeilDiv, c3, c2), c2);
// Test division by zero:
auto c3ceildivc0 = getAffineBinaryOpExpr(AffineExprKind::CeilDiv, c3, c0);
ASSERT_EQ(c3ceildivc0.getKind(), AffineExprKind::CeilDiv);
auto c3floordivc0 = getAffineBinaryOpExpr(AffineExprKind::FloorDiv, c3, c0);
ASSERT_EQ(c3floordivc0.getKind(), AffineExprKind::FloorDiv);
auto c3modc0 = getAffineBinaryOpExpr(AffineExprKind::Mod, c3, c0);
ASSERT_EQ(c3modc0.getKind(), AffineExprKind::Mod);
// Test overflow:
auto cmaxplusc1 = getAffineBinaryOpExpr(AffineExprKind::Add, cmax, c1);
ASSERT_EQ(cmaxplusc1.getKind(), AffineExprKind::Add);
auto cmaxtimesc2 = getAffineBinaryOpExpr(AffineExprKind::Mul, cmax, c2);
ASSERT_EQ(cmaxtimesc2.getKind(), AffineExprKind::Mul);
auto cminceildivcn1 =
getAffineBinaryOpExpr(AffineExprKind::CeilDiv, cmin, cn1);
ASSERT_EQ(cminceildivcn1.getKind(), AffineExprKind::CeilDiv);
auto cminfloordivcn1 =
getAffineBinaryOpExpr(AffineExprKind::FloorDiv, cmin, cn1);
ASSERT_EQ(cminfloordivcn1.getKind(), AffineExprKind::FloorDiv);
}