[mlir][affine] Fix a crash when cast incompatible type (#145162)
This PR fixes a crash in `getSemiAffineExprFromFlatForm` when localExpr is not `AffineBinaryOpExpr`. Fixes #144091.
This commit is contained in:
@@ -1174,11 +1174,15 @@ static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
|
||||
// the indices in `coefficients` map, and affine expression corresponding to
|
||||
// in indices in `indexToExprMap` map.
|
||||
for (const auto &it : llvm::enumerate(localExprs)) {
|
||||
AffineExpr expr = it.value();
|
||||
if (flatExprs[numDims + numSymbols + it.index()] == 0)
|
||||
continue;
|
||||
AffineExpr lhs = cast<AffineBinaryOpExpr>(expr).getLHS();
|
||||
AffineExpr rhs = cast<AffineBinaryOpExpr>(expr).getRHS();
|
||||
AffineExpr expr = it.value();
|
||||
auto binaryExpr = dyn_cast<AffineBinaryOpExpr>(expr);
|
||||
if (!binaryExpr)
|
||||
continue;
|
||||
|
||||
AffineExpr lhs = binaryExpr.getLHS();
|
||||
AffineExpr rhs = binaryExpr.getRHS();
|
||||
if (!((isa<AffineDimExpr>(lhs) || isa<AffineSymbolExpr>(lhs)) &&
|
||||
(isa<AffineDimExpr>(rhs) || isa<AffineSymbolExpr>(rhs) ||
|
||||
isa<AffineConstantExpr>(rhs)))) {
|
||||
|
||||
@@ -592,3 +592,19 @@ func.func @semiaffine_modulo_dim(%arg0: index, %arg1: index, %arg2: index) -> in
|
||||
//CHECK: affine.apply #[[$MAP]]()[%{{.*}}, %{{.*}}, %{{.*}}]
|
||||
return %a : index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @semiaffine_simplification_floordiv_and_ceildiv_const
|
||||
func.func @semiaffine_simplification_floordiv_and_ceildiv_const(%arg0: tensor<?xf32>) -> (index, index) {
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%c13 = arith.constant 13 : index
|
||||
%dim = tensor.dim %arg0, %c0 : tensor<?xf32>
|
||||
%a = affine.apply affine_map<()[s0, s1, s2] -> (s0 floordiv (s1 + (-s1 + 2) * (-s1 + s1 * s2 + 1)))>()[%c13, %dim, %c1]
|
||||
%b = affine.apply affine_map<()[s0, s1, s2] -> (s0 ceildiv (s1 + (-s1 + 2) * (-s1 + s1 * s2 + 1)))>()[%c13, %dim, %c1]
|
||||
// CHECK: %[[C6:.*]] = arith.constant 6 : index
|
||||
// CHECK-NEXT: %[[C7:.*]] = arith.constant 7 : index
|
||||
// CHECK-NEXT: return %[[C6]], %[[C7]]
|
||||
return %a, %b : index, index
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user