Files
clang-p2996/mlir/test/Dialect/Mesh/simplifications.mlir
Frank Schlimbach baabcb2898 [mlir][mesh] Shardingcontrol (#102598)
This is a fixed copy of #98145 (necessary after it got reverted).

@sogartar @yaochengji
This PR adds the following to #98145:
- `UpdateHaloOp` accepts a `memref` (instead of a tensor) and not
returning a result to clarify its inplace-semantics
- `UpdateHaloOp` accepts `split_axis` to allow multiple mesh-axes per
tensor/memref-axis (similar to `mesh.sharding`)
- The implementation of `Shardinginterface` for tensor operation
(`tensor.empty` for now) moved from the tensor library to the mesh
interface library. `spmdize` uses features from `mesh` dialect.
@rengolin agreed that `tensor` should not depend on `mesh` so this
functionality cannot live in a `tensor`s lib. The unfulfilled dependency
caused the issues leading to reverting #98145. Such cases are generally
possible and might lead to re-considering the current structure (like
for tosa ops).
- rebased onto latest main
--------------------------
Replacing `#mesh.sharding` attribute with operation `mesh.sharding`
- extended semantics now allow providing optional `halo_sizes` and
`sharded_dims_sizes`
- internally a sharding is represented as a non-IR class
`mesh::MeshSharding`

What previously was
```mlir
%sharded0 = mesh.shard %arg0 <@mesh0, [[0]]> : tensor<4x8xf32>
%sharded1 = mesh.shard %arg1 <@mesh0, [[0]]> annotate_for_users : tensor<16x8xf32>
```
is now
```mlir
%sharding = mesh.sharding @mesh0, [[0]] : !mesh.sharding
%0 = mesh.shard %arg0 to %sharding : tensor<4x8xf32>
%1 = mesh.shard %arg1 to %sharding annotate_for_users : tensor<16x8xf32>
```
and allows additional annotations to control the shard sizes:
```mlir
mesh.mesh @mesh0 (shape = 4)
%sharding0 = mesh.sharding @mesh0, [[0]] halo_sizes = [1, 2] : !mesh.sharding
%0 = mesh.shard %arg0 to %sharding0 : tensor<4x8xf32>
%sharding1 = mesh.sharding @mesh0, [[0]] sharded_dims_sizes = [3, 5, 5, 3] : !mesh.sharding
%1 = mesh.shard %arg1 to %sharding1 annotate_for_users : tensor<16x8xf32>
```
- `mesh.shard` op accepts additional optional attribute `force`, useful
for halo updates
- Some initial spmdization support for the new semantics
- Support for `tensor.empty` reacting on `sharded_dims_sizes` and
`halo_sizes` in the sharding
- New collective operation `mesh.update_halo` as a spmdized target for
shardings with `halo_sizes`

---------

Co-authored-by: frank.schlimbach <fschlimb@smtp.igk.intel.com>
Co-authored-by: Jie Fu <jiefu@tencent.com>
2024-08-12 12:20:58 +01:00

168 lines
8.1 KiB
MLIR

// RUN: mlir-opt -test-mesh-simplifications %s | FileCheck %s
mesh.mesh @mesh0(shape = 4x2)
mesh.mesh @mesh1(shape = 4)
// Checks that `all_reduce(x) + all_reduce(y)` gets transformed to
// `all_reduce(x + y)`.
// CHECK-LABEL: func.func @all_reduce_arith_addf_endomorphism
func.func @all_reduce_arith_addf_endomorphism(
// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
%arg0: tensor<5xf32>,
// CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
%arg1: tensor<5xf32>) -> tensor<5xf32> {
%0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
: tensor<5xf32> -> tensor<5xf32>
%1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0]
: tensor<5xf32> -> tensor<5xf32>
// CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ARG0]], %[[ARG1]]
%2 = arith.addf %0, %1 : tensor<5xf32>
// CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]]
// CHECK: return %[[ALL_REDUCE_RES]]
return %2 : tensor<5xf32>
}
// CHECK-LABEL: func.func @all_reduce_arith_addf_endomorphism_multiple_uses_of_result
func.func @all_reduce_arith_addf_endomorphism_multiple_uses_of_result(
// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
%arg0: tensor<5xf32>,
// CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
%arg1: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) {
%0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
: tensor<5xf32> -> tensor<5xf32>
%1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0]
: tensor<5xf32> -> tensor<5xf32>
// CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ARG0]], %[[ARG1]]
%2 = arith.addf %0, %1 : tensor<5xf32>
// CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]]
// CHECK: return %[[ALL_REDUCE_RES]], %[[ALL_REDUCE_RES]]
return %2, %2 : tensor<5xf32>, tensor<5xf32>
}
// Do not simplify if there is another use of one of the all-reduces.
// CHECK-LABEL: func.func @all_reduce_arith_addf_endomorphism_multiple_uses_of_all_reduce_result
func.func @all_reduce_arith_addf_endomorphism_multiple_uses_of_all_reduce_result(
// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
%arg0: tensor<5xf32>,
// CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
%arg1: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) {
// CHECK: %[[ALL_REDUCE_0_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]]
%0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
: tensor<5xf32> -> tensor<5xf32>
// CHECK: %[[ALL_REDUCE_1_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]]
%1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0]
: tensor<5xf32> -> tensor<5xf32>
// CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE_0_RES]], %[[ALL_REDUCE_1_RES]]
%2 = arith.addf %0, %1 : tensor<5xf32>
// CHECK: return %[[ALL_REDUCE_0_RES]], %[[ADD_RES]]
return %0, %2 : tensor<5xf32>, tensor<5xf32>
}
// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_different_mesh
func.func @all_reduce_arith_addf_no_endomorphism_different_mesh(
// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
%arg0: tensor<5xf32>,
// CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
%arg1: tensor<5xf32>) -> tensor<5xf32> {
// CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0
%0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
: tensor<5xf32> -> tensor<5xf32>
// CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh1
%1 = mesh.all_reduce %arg1 on @mesh1 mesh_axes = [0]
: tensor<5xf32> -> tensor<5xf32>
// CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]]
%2 = arith.addf %0, %1 : tensor<5xf32>
// CHECK: return %[[ADD_RES]]
return %2 : tensor<5xf32>
}
// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_different_mesh_axes
func.func @all_reduce_arith_addf_no_endomorphism_different_mesh_axes(
// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
%arg0: tensor<5xf32>,
// CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
%arg1: tensor<5xf32>) -> tensor<5xf32> {
// CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0 mesh_axes = [0]
%0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
: tensor<5xf32> -> tensor<5xf32>
// CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh0 mesh_axes = [1]
%1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [1]
: tensor<5xf32> -> tensor<5xf32>
// CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]]
%2 = arith.addf %0, %1 : tensor<5xf32>
// CHECK: return %[[ADD_RES]]
return %2 : tensor<5xf32>
}
// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_wrong_reduction_kind
func.func @all_reduce_arith_addf_no_endomorphism_wrong_reduction_kind(
// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
%arg0: tensor<5xf32>,
// CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
%arg1: tensor<5xf32>) -> tensor<5xf32> {
// CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0 mesh_axes = [0] reduction = max
%0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = max
: tensor<5xf32> -> tensor<5xf32>
// CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh0 mesh_axes = [0]
%1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0]
: tensor<5xf32> -> tensor<5xf32>
// CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]]
%2 = arith.addf %0, %1 : tensor<5xf32>
// CHECK: return %[[ADD_RES]]
return %2 : tensor<5xf32>
}
// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_different_operand_result_element_types
func.func @all_reduce_arith_addf_no_endomorphism_different_operand_result_element_types(
// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
%arg0: tensor<5xf32>,
// CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
%arg1: tensor<5xf32>) -> tensor<5xf64> {
// CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0 mesh_axes = [0]
%0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
: tensor<5xf32> -> tensor<5xf64>
// CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh0 mesh_axes = [0]
%1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0]
: tensor<5xf32> -> tensor<5xf64>
// CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]]
%2 = arith.addf %0, %1 : tensor<5xf64>
// CHECK: return %[[ADD_RES]]
return %2 : tensor<5xf64>
}
// Checks that `min(all_reduce(x), all_reduce(y))` gets transformed to
// `all_reduce(min(x, y))`.
// CHECK-LABEL: func.func @all_reduce_arith_minimumf_endomorphism
func.func @all_reduce_arith_minimumf_endomorphism(
// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
%arg0: tensor<5xf32>,
// CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
%arg1: tensor<5xf32>) -> tensor<5xf32> {
%0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = min
: tensor<5xf32> -> tensor<5xf32>
%1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0] reduction = min
: tensor<5xf32> -> tensor<5xf32>
// CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.minimumf %[[ARG0]], %[[ARG1]]
%2 = arith.minimumf %0, %1 : tensor<5xf32>
// CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]] on @mesh0 mesh_axes = [0] reduction = min
// CHECK: return %[[ALL_REDUCE_RES]]
return %2 : tensor<5xf32>
}
// CHECK-LABEL: func.func @all_reduce_arith_minsi_endomorphism
func.func @all_reduce_arith_minsi_endomorphism(
// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xi32>
%arg0: tensor<5xi32>,
// CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xi32>
%arg1: tensor<5xi32>) -> tensor<5xi32> {
%0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = min
: tensor<5xi32> -> tensor<5xi32>
%1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0] reduction = min
: tensor<5xi32> -> tensor<5xi32>
// CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.minsi %[[ARG0]], %[[ARG1]]
%2 = arith.minsi %0, %1 : tensor<5xi32>
// CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]] on @mesh0 mesh_axes = [0] reduction = min
// CHECK: return %[[ALL_REDUCE_RES]]
return %2 : tensor<5xi32>
}