[mlir] Fix ComposeExpandOfCollapseOp for dynamic case (#142663)
Changes `findCollapsingReassociation` to return nullopt in all cases where source shape has `>=2` dynamic dims. `expand(collapse)` can reshape to in any valid output shape but a collapse can only collapse contiguous dimensions. When there are `>=2` dynamic dimensions it is impossible to determine if it can be simplified to a collapse or if it is preforming a more advanced reassociation. This problem was uncovered by https://github.com/llvm/llvm-project/pull/137963 --------- Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu>
This commit is contained in:
@@ -387,11 +387,14 @@ private:
|
||||
auto resultSubShape =
|
||||
resultShape.slice(resultIndices.front(), resultIndices.size());
|
||||
|
||||
if (llvm::count_if(srcSubShape, ShapedType::isDynamic) >= 2 &&
|
||||
llvm::count_if(resultSubShape, ShapedType::isDynamic) >= 2)
|
||||
return std::nullopt;
|
||||
|
||||
if (srcSubShape.size() == resultSubShape.size()) {
|
||||
if (srcSubShape != resultSubShape ||
|
||||
llvm::count_if(srcSubShape, ShapedType::isDynamic) >= 2) {
|
||||
if (srcSubShape != resultSubShape)
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
for (auto index : llvm::seq<int64_t>(0, srcSubShape.size())) {
|
||||
composedReassociation.emplace_back(1, srcIndices.front() + index);
|
||||
}
|
||||
|
||||
@@ -1272,6 +1272,20 @@ func.func @compose_expand_of_collapse_dynamic(%arg0 : tensor<4x?x10x64x2xf16>, %
|
||||
|
||||
// -----
|
||||
|
||||
func.func @no_compose_collapse_of_expand_dynamic(%arg0 : tensor<?x8x128x?xf16>, %arg1: index) -> tensor<?x128x?xf16> {
|
||||
%collapse = tensor.collapse_shape %arg0 [[0, 1, 2, 3]] : tensor<?x8x128x?xf16> into tensor<?xf16>
|
||||
%expanded_19 = tensor.expand_shape %collapse [[0, 1, 2]] output_shape [%arg1, 8, %arg1] : tensor<?xf16> into tensor<?x128x?xf16>
|
||||
return %expanded_19 : tensor<?x128x?xf16>
|
||||
}
|
||||
// CHECK-LABEL: func @no_compose_collapse_of_expand_dynamic
|
||||
// CHECK-SAME: %[[ARG0:.+]]: tensor
|
||||
// CHECK-SAME: %[[ARG1:.+]]: index
|
||||
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]]
|
||||
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[COLLAPSE]]
|
||||
// CHECK: return %[[EXPAND]]
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @zero_rank_reshape_multi
|
||||
func.func @zero_rank_reshape_multi(%arg0: tensor<f32>) -> tensor<f32> {
|
||||
// CHECK: return %arg0
|
||||
|
||||
Reference in New Issue
Block a user