diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h index 61c2a50e514c..704e39e90884 100644 --- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h @@ -14,6 +14,7 @@ #ifndef MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H #define MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" @@ -305,8 +306,43 @@ struct ComposeCollapseOfExpandOp : public OpRewritePattern { rewriter.replaceOpWithNewOp( collapseOp, resultType, expandOp.getSrc(), composedReassociation); } else if (srcRank < resultRank) { + // Compute the dynamic output shape for the new expand_shape op. + Location loc = collapseOp.getLoc(); + SmallVector origOutputShape = + expandOp.getMixedOutputShape(); + SmallVector newOutputShape; + for (const ReassociationIndices &indices : + collapseOp.getReassociationIndices()) { + int64_t numStaticElems = 1; + SmallVector dynamicSizes; + for (int64_t idx : indices) { + OpFoldResult size = origOutputShape[idx]; + if (std::optional maybeCst = getConstantIntValue(size)) { + numStaticElems *= maybeCst.value(); + continue; + } + dynamicSizes.push_back(cast(size)); + } + if (dynamicSizes.empty()) { + newOutputShape.push_back(rewriter.getIndexAttr(numStaticElems)); + continue; + } + + // There is at least one dynamic size, so we can initialize `result` to + // the first dynamic size. + Value result = dynamicSizes[0]; + for (Value v : llvm::drop_begin(dynamicSizes)) + result = rewriter.create(loc, result, v); + if (numStaticElems != 1) { + result = rewriter.create( + loc, result, + rewriter.create(loc, numStaticElems)); + } + newOutputShape.push_back(result); + } rewriter.replaceOpWithNewOp( - collapseOp, resultType, expandOp.getSrc(), composedReassociation); + collapseOp, resultType, expandOp.getSrc(), composedReassociation, + newOutputShape); } else { // Collapses/expansions that do not change the rank are not allowed. Use // a cast instead. diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir index 7a267ae8a2c9..a91e54a12610 100644 --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -466,6 +466,24 @@ func.func @compose_collapse_of_collapse(%arg0 : memref) // ----- +func.func @compose_collapse_of_expand_partially_dynamic(%arg0: memref, %arg1: index, %arg2: index) -> memref<8x?x?xf16> { + %expanded = memref.expand_shape %arg0 [[0, 1, 2, 3, 4]] output_shape [4, 2, %arg1, %arg2, 32] : memref into memref<4x2x?x?x32xf16> + %collapsed = memref.collapse_shape %expanded [[0, 1], [2], [3, 4]] : memref<4x2x?x?x32xf16> into memref<8x?x?xf16> + return %collapsed : memref<8x?x?xf16> +} +// CHECK: func @compose_collapse_of_expand_partially_dynamic +// CHECK-SAME: %[[SRC:.[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ORIG_D2:.[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ORIG_D3:.[a-zA-Z0-9]+]] +// CHECK-DAG: %[[C32:.+]] = arith.constant 32 +// CHECK: %[[COLLAPSED_D2:.+]] = arith.muli %[[ORIG_D3]], %[[C32]] +// CHECK: %[[RESULT:.+]] = memref.expand_shape %[[SRC]] +// CHECK-SAME: [0, 1, 2] +// CHECK-SAME: output_shape [8, %[[ORIG_D2]], %[[COLLAPSED_D2]]] +// CHECK: return %[[RESULT]] + +// ----- + func.func @do_not_compose_collapse_of_expand_non_identity_layout( %arg0: memref>, %sz0: index, %sz1: index) -> memref> { diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index 3251c5a4a2bf..3f9236095138 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -1243,6 +1243,24 @@ func.func @compose_collapse_of_expand_1D(%arg0 : tensor<2048xf32>) // ----- +func.func @compose_collapse_of_expand_partially_dynamic(%arg0: tensor, %arg1: index, %arg2: index) -> tensor<8x?x?xf16> { + %expanded = tensor.expand_shape %arg0 [[0, 1, 2, 3, 4]] output_shape [4, 2, %arg1, %arg2, 32] : tensor into tensor<4x2x?x?x32xf16> + %collapsed = tensor.collapse_shape %expanded [[0, 1], [2], [3, 4]] : tensor<4x2x?x?x32xf16> into tensor<8x?x?xf16> + return %collapsed : tensor<8x?x?xf16> +} +// CHECK: func @compose_collapse_of_expand_partially_dynamic +// CHECK-SAME: %[[SRC:.[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ORIG_D2:.[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ORIG_D3:.[a-zA-Z0-9]+]] +// CHECK-DAG: %[[C32:.+]] = arith.constant 32 +// CHECK: %[[COLLAPSED_D2:.+]] = arith.muli %[[ORIG_D3]], %[[C32]] +// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[SRC]] +// CHECK-SAME: [0, 1, 2] +// CHECK-SAME: output_shape [8, %[[ORIG_D2]], %[[COLLAPSED_D2]]] +// CHECK: return %[[RESULT]] + +// ----- + func.func @compose_expand_of_collapse_0_rank_to_expand(%arg0 : tensor<1x1x1xf32>) -> tensor<1x1x1x1xf32> { %0 = tensor.collapse_shape %arg0 []