[mlir] Fold ceil/floordiv with negative RHS. (#97031)
Currently, we only fold if the RHS is a positive constant. There doesn't seem to be a good reason to do that. The comment claims that division by negative values is undefined, but I suspect that was just copied over from the `mod` simplifier.
This commit is contained in:
committed by
GitHub
parent
eed7c5e29c
commit
22dfa1aa2c
@@ -855,8 +855,7 @@ static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs) {
|
||||
auto lhsConst = dyn_cast<AffineConstantExpr>(lhs);
|
||||
auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
|
||||
|
||||
// mlir floordiv by zero or negative numbers is undefined and preserved as is.
|
||||
if (!rhsConst || rhsConst.getValue() < 1)
|
||||
if (!rhsConst || rhsConst.getValue() == 0)
|
||||
return nullptr;
|
||||
|
||||
if (lhsConst) {
|
||||
@@ -875,12 +874,12 @@ static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs) {
|
||||
if (rhsConst == 1)
|
||||
return lhs;
|
||||
|
||||
// Simplify (expr * const) floordiv divConst when expr is known to be a
|
||||
// multiple of divConst.
|
||||
// Simplify `(expr * lrhs) floordiv rhsConst` when `lrhs` is known to be a
|
||||
// multiple of `rhsConst`.
|
||||
auto lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
|
||||
if (lBin && lBin.getKind() == AffineExprKind::Mul) {
|
||||
if (auto lrhs = dyn_cast<AffineConstantExpr>(lBin.getRHS())) {
|
||||
// rhsConst is known to be a positive constant.
|
||||
// `rhsConst` is known to be a nonzero constant.
|
||||
if (lrhs.getValue() % rhsConst.getValue() == 0)
|
||||
return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
|
||||
}
|
||||
@@ -891,7 +890,7 @@ static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs) {
|
||||
if (lBin && lBin.getKind() == AffineExprKind::Add) {
|
||||
int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor();
|
||||
int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
|
||||
// rhsConst is known to be a positive constant.
|
||||
// rhsConst is known to be a nonzero constant.
|
||||
if (llhsDiv % rhsConst.getValue() == 0 ||
|
||||
lrhsDiv % rhsConst.getValue() == 0)
|
||||
return lBin.getLHS().floorDiv(rhsConst.getValue()) +
|
||||
@@ -918,7 +917,7 @@ static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs) {
|
||||
auto lhsConst = dyn_cast<AffineConstantExpr>(lhs);
|
||||
auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
|
||||
|
||||
if (!rhsConst || rhsConst.getValue() < 1)
|
||||
if (!rhsConst || rhsConst.getValue() == 0)
|
||||
return nullptr;
|
||||
|
||||
if (lhsConst) {
|
||||
@@ -937,12 +936,12 @@ static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs) {
|
||||
if (rhsConst.getValue() == 1)
|
||||
return lhs;
|
||||
|
||||
// Simplify (expr * const) ceildiv divConst when const is known to be a
|
||||
// multiple of divConst.
|
||||
// Simplify `(expr * lrhs) ceildiv rhsConst` when `lrhs` is known to be a
|
||||
// multiple of `rhsConst`.
|
||||
auto lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
|
||||
if (lBin && lBin.getKind() == AffineExprKind::Mul) {
|
||||
if (auto lrhs = dyn_cast<AffineConstantExpr>(lBin.getRHS())) {
|
||||
// rhsConst is known to be a positive constant.
|
||||
// `rhsConst` is known to be a nonzero constant.
|
||||
if (lrhs.getValue() % rhsConst.getValue() == 0)
|
||||
return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
|
||||
}
|
||||
|
||||
@@ -76,3 +76,25 @@ TEST(AffineExprTest, constantFolding) {
|
||||
getAffineBinaryOpExpr(AffineExprKind::FloorDiv, cmin, cn1);
|
||||
ASSERT_EQ(cminfloordivcn1.getKind(), AffineExprKind::FloorDiv);
|
||||
}
|
||||
|
||||
TEST(AffineExprTest, divisionSimplification) {
|
||||
MLIRContext ctx;
|
||||
OpBuilder b(&ctx);
|
||||
auto cn6 = b.getAffineConstantExpr(-6);
|
||||
auto c6 = b.getAffineConstantExpr(6);
|
||||
auto d0 = b.getAffineDimExpr(0);
|
||||
auto d1 = b.getAffineDimExpr(1);
|
||||
|
||||
ASSERT_EQ(c6.floorDiv(-1), cn6);
|
||||
ASSERT_EQ((d0 * 6).floorDiv(2), d0 * 3);
|
||||
ASSERT_EQ((d0 * 6).floorDiv(4).getKind(), AffineExprKind::FloorDiv);
|
||||
ASSERT_EQ((d0 * 6).floorDiv(-2), d0 * -3);
|
||||
ASSERT_EQ((d0 * 6 + d1).floorDiv(2), d0 * 3 + d1.floorDiv(2));
|
||||
ASSERT_EQ((d0 * 6 + d1).floorDiv(-2), d0 * -3 + d1.floorDiv(-2));
|
||||
ASSERT_EQ((d0 * 6 + d1).floorDiv(4).getKind(), AffineExprKind::FloorDiv);
|
||||
|
||||
ASSERT_EQ(c6.ceilDiv(-1), cn6);
|
||||
ASSERT_EQ((d0 * 6).ceilDiv(2), d0 * 3);
|
||||
ASSERT_EQ((d0 * 6).ceilDiv(4).getKind(), AffineExprKind::CeilDiv);
|
||||
ASSERT_EQ((d0 * 6).ceilDiv(-2), d0 * -3);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user