From 6f2ba4712f17d7c82228a5b705570571e13a3832 Mon Sep 17 00:00:00 2001 From: Ian Wood Date: Wed, 11 Jun 2025 14:34:02 -0700 Subject: [PATCH] [mlir] Fix ComposeExpandOfCollapseOp for dynamic case (#142663) Changes `findCollapsingReassociation` to return nullopt in all cases where source shape has `>=2` dynamic dims. `expand(collapse)` can reshape to in any valid output shape but a collapse can only collapse contiguous dimensions. When there are `>=2` dynamic dimensions it is impossible to determine if it can be simplified to a collapse or if it is preforming a more advanced reassociation. This problem was uncovered by https://github.com/llvm/llvm-project/pull/137963 --------- Signed-off-by: Ian Wood --- mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h | 9 ++++++--- mlir/test/Dialect/Tensor/canonicalize.mlir | 14 ++++++++++++++ 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h index af575e10acc8..61c2a50e514c 100644 --- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h @@ -387,11 +387,14 @@ private: auto resultSubShape = resultShape.slice(resultIndices.front(), resultIndices.size()); + if (llvm::count_if(srcSubShape, ShapedType::isDynamic) >= 2 && + llvm::count_if(resultSubShape, ShapedType::isDynamic) >= 2) + return std::nullopt; + if (srcSubShape.size() == resultSubShape.size()) { - if (srcSubShape != resultSubShape || - llvm::count_if(srcSubShape, ShapedType::isDynamic) >= 2) { + if (srcSubShape != resultSubShape) return std::nullopt; - } + for (auto index : llvm::seq(0, srcSubShape.size())) { composedReassociation.emplace_back(1, srcIndices.front() + index); } diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index 65c5b3e8602e..67b03b0a3485 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -1272,6 +1272,20 @@ func.func @compose_expand_of_collapse_dynamic(%arg0 : tensor<4x?x10x64x2xf16>, % // ----- +func.func @no_compose_collapse_of_expand_dynamic(%arg0 : tensor, %arg1: index) -> tensor { + %collapse = tensor.collapse_shape %arg0 [[0, 1, 2, 3]] : tensor into tensor + %expanded_19 = tensor.expand_shape %collapse [[0, 1, 2]] output_shape [%arg1, 8, %arg1] : tensor into tensor + return %expanded_19 : tensor +} +// CHECK-LABEL: func @no_compose_collapse_of_expand_dynamic +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK-SAME: %[[ARG1:.+]]: index +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] +// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[COLLAPSE]] +// CHECK: return %[[EXPAND]] + +// ----- + // CHECK-LABEL: func @zero_rank_reshape_multi func.func @zero_rank_reshape_multi(%arg0: tensor) -> tensor { // CHECK: return %arg0