Do not trigger UB during AffineExpr parsing. (#96896)
Currently, parsing expressions that are undefined will trigger UB during compilation (e.g. `9223372036854775807 * 2`). This change instead leaves the expressions as they were. This change is an NFC for compilations that did not previously involve UB.
This commit is contained in:
committed by
GitHub
parent
fcffb2c024
commit
52216349b6
@@ -435,7 +435,8 @@ inline uint64_t divideCeil(uint64_t Numerator, uint64_t Denominator) {
|
||||
}
|
||||
|
||||
/// Returns the integer ceil(Numerator / Denominator). Signed version.
|
||||
/// Guaranteed to never overflow.
|
||||
/// Guaranteed to never overflow, unless Numerator is INT64_MIN and Denominator
|
||||
/// is -1.
|
||||
inline int64_t divideCeilSigned(int64_t Numerator, int64_t Denominator) {
|
||||
assert(Denominator && "Division by zero");
|
||||
if (!Numerator)
|
||||
@@ -448,7 +449,8 @@ inline int64_t divideCeilSigned(int64_t Numerator, int64_t Denominator) {
|
||||
}
|
||||
|
||||
/// Returns the integer floor(Numerator / Denominator). Signed version.
|
||||
/// Guaranteed to never overflow.
|
||||
/// Guaranteed to never overflow, unless Numerator is INT64_MIN and Denominator
|
||||
/// is -1.
|
||||
inline int64_t divideFloorSigned(int64_t Numerator, int64_t Denominator) {
|
||||
assert(Denominator && "Division by zero");
|
||||
if (!Numerator)
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
#include <utility>
|
||||
|
||||
#include "AffineExprDetail.h"
|
||||
@@ -645,10 +647,14 @@ mlir::getAffineConstantExprs(ArrayRef<int64_t> constants,
|
||||
static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) {
|
||||
auto lhsConst = dyn_cast<AffineConstantExpr>(lhs);
|
||||
auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
|
||||
// Fold if both LHS, RHS are a constant.
|
||||
if (lhsConst && rhsConst)
|
||||
return getAffineConstantExpr(lhsConst.getValue() + rhsConst.getValue(),
|
||||
lhs.getContext());
|
||||
// Fold if both LHS, RHS are a constant and the sum does not overflow.
|
||||
if (lhsConst && rhsConst) {
|
||||
int64_t sum;
|
||||
if (llvm::AddOverflow(lhsConst.getValue(), rhsConst.getValue(), sum)) {
|
||||
return nullptr;
|
||||
}
|
||||
return getAffineConstantExpr(sum, lhs.getContext());
|
||||
}
|
||||
|
||||
// Canonicalize so that only the RHS is a constant. (4 + d0 becomes d0 + 4).
|
||||
// If only one of them is a symbolic expressions, make it the RHS.
|
||||
@@ -774,9 +780,13 @@ static AffineExpr simplifyMul(AffineExpr lhs, AffineExpr rhs) {
|
||||
auto lhsConst = dyn_cast<AffineConstantExpr>(lhs);
|
||||
auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
|
||||
|
||||
if (lhsConst && rhsConst)
|
||||
return getAffineConstantExpr(lhsConst.getValue() * rhsConst.getValue(),
|
||||
lhs.getContext());
|
||||
if (lhsConst && rhsConst) {
|
||||
int64_t product;
|
||||
if (llvm::MulOverflow(lhsConst.getValue(), rhsConst.getValue(), product)) {
|
||||
return nullptr;
|
||||
}
|
||||
return getAffineConstantExpr(product, lhs.getContext());
|
||||
}
|
||||
|
||||
if (!lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant())
|
||||
return nullptr;
|
||||
@@ -849,10 +859,16 @@ static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs) {
|
||||
if (!rhsConst || rhsConst.getValue() < 1)
|
||||
return nullptr;
|
||||
|
||||
if (lhsConst)
|
||||
if (lhsConst) {
|
||||
// divideFloorSigned can only overflow in this case:
|
||||
if (lhsConst.getValue() == std::numeric_limits<int64_t>::min() &&
|
||||
rhsConst.getValue() == -1) {
|
||||
return nullptr;
|
||||
}
|
||||
return getAffineConstantExpr(
|
||||
divideFloorSigned(lhsConst.getValue(), rhsConst.getValue()),
|
||||
lhs.getContext());
|
||||
}
|
||||
|
||||
// Fold floordiv of a multiply with a constant that is a multiple of the
|
||||
// divisor. Eg: (i * 128) floordiv 64 = i * 2.
|
||||
@@ -905,10 +921,16 @@ static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs) {
|
||||
if (!rhsConst || rhsConst.getValue() < 1)
|
||||
return nullptr;
|
||||
|
||||
if (lhsConst)
|
||||
if (lhsConst) {
|
||||
// divideCeilSigned can only overflow in this case:
|
||||
if (lhsConst.getValue() == std::numeric_limits<int64_t>::min() &&
|
||||
rhsConst.getValue() == -1) {
|
||||
return nullptr;
|
||||
}
|
||||
return getAffineConstantExpr(
|
||||
divideCeilSigned(lhsConst.getValue(), rhsConst.getValue()),
|
||||
lhs.getContext());
|
||||
}
|
||||
|
||||
// Fold ceildiv of a multiply with a constant that is a multiple of the
|
||||
// divisor. Eg: (i * 128) ceildiv 64 = i * 2.
|
||||
@@ -950,9 +972,11 @@ static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs) {
|
||||
if (!rhsConst || rhsConst.getValue() < 1)
|
||||
return nullptr;
|
||||
|
||||
if (lhsConst)
|
||||
if (lhsConst) {
|
||||
// mod never overflows.
|
||||
return getAffineConstantExpr(mod(lhsConst.getValue(), rhsConst.getValue()),
|
||||
lhs.getContext());
|
||||
}
|
||||
|
||||
// Fold modulo of an expression that is known to be a multiple of a constant
|
||||
// to zero if that constant is a multiple of the modulo factor. Eg: (i * 128)
|
||||
|
||||
@@ -6,6 +6,9 @@
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "gtest/gtest.h"
|
||||
@@ -30,3 +33,46 @@ TEST(AffineExprTest, constructFromBinaryOperators) {
|
||||
ASSERT_EQ(product.getKind(), AffineExprKind::Mul);
|
||||
ASSERT_EQ(remainder.getKind(), AffineExprKind::Mod);
|
||||
}
|
||||
|
||||
TEST(AffineExprTest, constantFolding) {
|
||||
MLIRContext ctx;
|
||||
OpBuilder b(&ctx);
|
||||
auto cn1 = b.getAffineConstantExpr(-1);
|
||||
auto c0 = b.getAffineConstantExpr(0);
|
||||
auto c1 = b.getAffineConstantExpr(1);
|
||||
auto c2 = b.getAffineConstantExpr(2);
|
||||
auto c3 = b.getAffineConstantExpr(3);
|
||||
auto c6 = b.getAffineConstantExpr(6);
|
||||
auto cmax = b.getAffineConstantExpr(std::numeric_limits<int64_t>::max());
|
||||
auto cmin = b.getAffineConstantExpr(std::numeric_limits<int64_t>::min());
|
||||
|
||||
ASSERT_EQ(getAffineBinaryOpExpr(AffineExprKind::Add, c1, c2), c3);
|
||||
ASSERT_EQ(getAffineBinaryOpExpr(AffineExprKind::Mul, c2, c3), c6);
|
||||
ASSERT_EQ(getAffineBinaryOpExpr(AffineExprKind::FloorDiv, c3, c2), c1);
|
||||
ASSERT_EQ(getAffineBinaryOpExpr(AffineExprKind::CeilDiv, c3, c2), c2);
|
||||
|
||||
// Test division by zero:
|
||||
auto c3ceildivc0 = getAffineBinaryOpExpr(AffineExprKind::CeilDiv, c3, c0);
|
||||
ASSERT_EQ(c3ceildivc0.getKind(), AffineExprKind::CeilDiv);
|
||||
|
||||
auto c3floordivc0 = getAffineBinaryOpExpr(AffineExprKind::FloorDiv, c3, c0);
|
||||
ASSERT_EQ(c3floordivc0.getKind(), AffineExprKind::FloorDiv);
|
||||
|
||||
auto c3modc0 = getAffineBinaryOpExpr(AffineExprKind::Mod, c3, c0);
|
||||
ASSERT_EQ(c3modc0.getKind(), AffineExprKind::Mod);
|
||||
|
||||
// Test overflow:
|
||||
auto cmaxplusc1 = getAffineBinaryOpExpr(AffineExprKind::Add, cmax, c1);
|
||||
ASSERT_EQ(cmaxplusc1.getKind(), AffineExprKind::Add);
|
||||
|
||||
auto cmaxtimesc2 = getAffineBinaryOpExpr(AffineExprKind::Mul, cmax, c2);
|
||||
ASSERT_EQ(cmaxtimesc2.getKind(), AffineExprKind::Mul);
|
||||
|
||||
auto cminceildivcn1 =
|
||||
getAffineBinaryOpExpr(AffineExprKind::CeilDiv, cmin, cn1);
|
||||
ASSERT_EQ(cminceildivcn1.getKind(), AffineExprKind::CeilDiv);
|
||||
|
||||
auto cminfloordivcn1 =
|
||||
getAffineBinaryOpExpr(AffineExprKind::FloorDiv, cmin, cn1);
|
||||
ASSERT_EQ(cminfloordivcn1.getKind(), AffineExprKind::FloorDiv);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user