diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp index cc81f9d19aca..a9f0d786dfdf 100644 --- a/mlir/lib/IR/AffineExpr.cpp +++ b/mlir/lib/IR/AffineExpr.cpp @@ -781,6 +781,15 @@ static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) { if (isPositiveRhs && lhs == llrhs && rlrhs == -rrhs) { return lhs % rlrhs; } + + // Try simplify lhs's last operand with rhs. e.g: + // (s0 * 64 + s1) + (s1 // c * -c) ---> + // s0 * 64 + (s1 + s1 // c * -c) --> + // s0 * 64 + s1 % c + if (lBinOpExpr && lBinOpExpr.getKind() == AffineExprKind::Add) { + if (auto simplified = simplifyAdd(lBinOpExpr.getRHS(), rhs)) + return lBinOpExpr.getLHS() + simplified; + } return nullptr; } diff --git a/mlir/test/Dialect/Affine/decompose-affine-ops.mlir b/mlir/test/Dialect/Affine/decompose-affine-ops.mlir index 6acdc436fe67..e5db5cd9181d 100644 --- a/mlir/test/Dialect/Affine/decompose-affine-ops.mlir +++ b/mlir/test/Dialect/Affine/decompose-affine-ops.mlir @@ -27,8 +27,7 @@ func.func @simple_test_1(%0: index, %1: index, %2: index, %lb: index, %ub: index // CHECK-DAG: #[[$c42:.*]] = affine_map<() -> (42)> // CHECK-DAG: #[[$id:.*]] = affine_map<()[s0] -> (s0)> // CHECK-DAG: #[[$add:.*]] = affine_map<()[s0, s1] -> (s0 + s1)> -// CHECK-DAG: #[[$div32div4timesm4:.*]] = affine_map<()[s0] -> (((s0 floordiv 32) floordiv 4) * -4)> -// CHECK-DAG: #[[$div32:.*]] = affine_map<()[s0] -> (s0 floordiv 32)> +// CHECK-DAG: #[[$div32mod4:.*]] = affine_map<()[s0] -> ((s0 floordiv 32) mod 4)> // CHECK-LABEL: func.func @simple_test_2 // CHECK-SAME: %[[I0:[0-9a-zA-Z]+]]: index, @@ -45,10 +44,8 @@ func.func @simple_test_2(%0: index, %1: index, %2: index, %lb: index, %ub: index // CHECK: %[[R2:.*]] = affine.apply #[[$add]]()[%[[c42]], %[[R1]]] // CHECK: scf.for %[[j:.*]] = scf.for %j = %lb to %ub step %step { - // CHECK: %[[R3:.*]] = affine.apply #[[$div32div4timesm4]]()[%[[j]]] - // CHECK: %[[R4:.*]] = affine.apply #[[$add]]()[%[[R2]], %[[R3]]] - // CHECK: %[[R5:.*]] = affine.apply #[[$div32]]()[%[[j]]] - // CHECK: %[[a:.*]] = affine.apply #[[$add]]()[%[[R4]], %[[R5]]] + // CHECK: %[[R3:.*]] = affine.apply #[[$div32mod4]]()[%[[j]]] + // CHECK: %[[a:.*]] = affine.apply #[[$add]]()[%[[R2]], %[[R3]]] %a = affine.apply affine_map<(d0)[s0] -> ((d0 floordiv 32) mod 4 + s0 + 42)>(%j)[%i] // CHECK: "some_side_effecting_consumer"(%[[a]]) : (index) -> () @@ -67,8 +64,7 @@ func.func @simple_test_2(%0: index, %1: index, %2: index, %lb: index, %ub: index // CHECK-DAG: #[[$div4timesm32:.*]] = affine_map<()[s0] -> ((s0 floordiv 4) * -32)> // CHECK-DAG: #[[$times8:.*]] = affine_map<()[s0] -> (s0 * 8)> // CHECK-DAG: #[[$id:.*]] = affine_map<()[s0] -> (s0)> -// CHECK-DAG: #[[$div32div4timesm4:.*]] = affine_map<()[s0] -> (((s0 floordiv 32) floordiv 4) * -4)> -// CHECK-DAG: #[[$div32:.*]] = affine_map<()[s0] -> (s0 floordiv 32)> +// CHECK-DAG: #[[$div32mod4:.*]] = affine_map<()[s0] -> ((s0 floordiv 32) mod 4)> // CHECK-LABEL: func.func @larger_test // CHECK-SAME: %[[I0:[0-9a-zA-Z]+]]: index, @@ -126,10 +122,8 @@ func.func @larger_test(%0: index, %1: index, %2: index, %lb: index, %ub: index, // CHECK-NEXT: %[[e:.*]] = affine.apply #[[$add]]()[%[[c]], %[[idk]]] %e = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 8 - (s1 floordiv 4) * 32)>()[%k, %0] - // CHECK-NEXT: %[[R15:.*]] = affine.apply #[[$div32div4timesm4]]()[%[[k]]] - // CHECK-NEXT: %[[R16:.*]] = affine.apply #[[$add]]()[%[[idj]], %[[R15]]] - // CHECK-NEXT: %[[R17:.*]] = affine.apply #[[$div32]]()[%[[k]]] - // CHECK-NEXT: %[[f:.*]] = affine.apply #[[$add]]()[%[[R16]], %[[R17]]] + // CHECK-NEXT: %[[R15:.*]] = affine.apply #[[$div32mod4]]()[%[[k]]] + // CHECK-NEXT: %[[f:.*]] = affine.apply #[[$add]]()[%[[idj]], %[[R15]]] %f = affine.apply affine_map<(d0)[s0] -> ((d0 floordiv 32) mod 4 + s0)>(%k)[%j] // CHECK-NEXT: %[[g:.*]] = affine.apply #[[$add]]()[%[[b]], %[[idk]]]