This is a fixed copy of #98145 (necessary after it got reverted). @sogartar @yaochengji This PR adds the following to #98145: - `UpdateHaloOp` accepts a `memref` (instead of a tensor) and not returning a result to clarify its inplace-semantics - `UpdateHaloOp` accepts `split_axis` to allow multiple mesh-axes per tensor/memref-axis (similar to `mesh.sharding`) - The implementation of `Shardinginterface` for tensor operation (`tensor.empty` for now) moved from the tensor library to the mesh interface library. `spmdize` uses features from `mesh` dialect. @rengolin agreed that `tensor` should not depend on `mesh` so this functionality cannot live in a `tensor`s lib. The unfulfilled dependency caused the issues leading to reverting #98145. Such cases are generally possible and might lead to re-considering the current structure (like for tosa ops). - rebased onto latest main -------------------------- Replacing `#mesh.sharding` attribute with operation `mesh.sharding` - extended semantics now allow providing optional `halo_sizes` and `sharded_dims_sizes` - internally a sharding is represented as a non-IR class `mesh::MeshSharding` What previously was ```mlir %sharded0 = mesh.shard %arg0 <@mesh0, [[0]]> : tensor<4x8xf32> %sharded1 = mesh.shard %arg1 <@mesh0, [[0]]> annotate_for_users : tensor<16x8xf32> ``` is now ```mlir %sharding = mesh.sharding @mesh0, [[0]] : !mesh.sharding %0 = mesh.shard %arg0 to %sharding : tensor<4x8xf32> %1 = mesh.shard %arg1 to %sharding annotate_for_users : tensor<16x8xf32> ``` and allows additional annotations to control the shard sizes: ```mlir mesh.mesh @mesh0 (shape = 4) %sharding0 = mesh.sharding @mesh0, [[0]] halo_sizes = [1, 2] : !mesh.sharding %0 = mesh.shard %arg0 to %sharding0 : tensor<4x8xf32> %sharding1 = mesh.sharding @mesh0, [[0]] sharded_dims_sizes = [3, 5, 5, 3] : !mesh.sharding %1 = mesh.shard %arg1 to %sharding1 annotate_for_users : tensor<16x8xf32> ``` - `mesh.shard` op accepts additional optional attribute `force`, useful for halo updates - Some initial spmdization support for the new semantics - Support for `tensor.empty` reacting on `sharded_dims_sizes` and `halo_sizes` in the sharding - New collective operation `mesh.update_halo` as a spmdized target for shardings with `halo_sizes` --------- Co-authored-by: frank.schlimbach <fschlimb@smtp.igk.intel.com> Co-authored-by: Jie Fu <jiefu@tencent.com>
633 lines
25 KiB
MLIR
633 lines
25 KiB
MLIR
// RUN: mlir-opt %s | mlir-opt | FileCheck %s
|
|
|
|
// CHECK: mesh.mesh @mesh0
|
|
mesh.mesh @mesh0(shape = 2x2x4)
|
|
|
|
// CHECK: mesh.mesh @mesh1(shape = 4x?)
|
|
mesh.mesh @mesh1(shape = 4x?)
|
|
|
|
// CHECK: mesh.mesh @mesh2(shape = ?x4)
|
|
mesh.mesh @mesh2(shape = ?x4)
|
|
|
|
// CHECK: mesh.mesh @mesh3(shape = ?x?)
|
|
mesh.mesh @mesh3(shape = ?x?)
|
|
|
|
mesh.mesh @mesh4(shape = 3)
|
|
|
|
// CHECK: mesh.mesh @mesh5(shape = ?)
|
|
mesh.mesh @mesh5(shape = ?)
|
|
|
|
// CHECK-LABEL: func @mesh_shard_op_fully_replicated
|
|
// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
|
|
func.func @mesh_shard_op_fully_replicated(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
|
|
// CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh0 split_axes = {{\[\[}}]] : !mesh.sharding
|
|
%s = mesh.sharding @mesh0 split_axes = [[]] : !mesh.sharding
|
|
// CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8xf32>
|
|
%0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
|
|
return %0 : tensor<4x8xf32>
|
|
}
|
|
|
|
// CHECK-LABEL: func @mesh_shard_op_1st_dim
|
|
// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
|
|
func.func @mesh_shard_op_1st_dim(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
|
|
// CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh0 split_axes = {{\[\[}}0]] : !mesh.sharding
|
|
%s = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
|
|
|
|
%0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
|
|
return %0 : tensor<4x8xf32>
|
|
}
|
|
|
|
// CHECK-LABEL: func @mesh_shard_op_2nd_dim
|
|
// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
|
|
func.func @mesh_shard_op_2nd_dim(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
|
|
// CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh1 split_axes = {{\[\[}}], [0]] : !mesh.sharding
|
|
%s = mesh.sharding @mesh1 split_axes = [[], [0]] : !mesh.sharding
|
|
// CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8xf32>
|
|
%0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
|
|
return %0 : tensor<4x8xf32>
|
|
}
|
|
|
|
// CHECK-LABEL: func @mesh_shard_op_1st_and_3rd_dim
|
|
func.func @mesh_shard_op_1st_and_3rd_dim(
|
|
// CHECK-SAME: %[[ARG:.*]]: tensor<4x8x16xf32>
|
|
%arg0 : tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
|
|
// CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh3 split_axes = {{\[\[}}0], [], [1]] : !mesh.sharding
|
|
%s = mesh.sharding @mesh3 split_axes = [[0], [], [1]] : !mesh.sharding
|
|
// CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8x16xf32>
|
|
%0 = mesh.shard %arg0 to %s : tensor<4x8x16xf32>
|
|
return %0 : tensor<4x8x16xf32>
|
|
}
|
|
|
|
// CHECK-LABEL: func @mesh_shard_op_partial_max
|
|
func.func @mesh_shard_op_partial_max(
|
|
// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
|
|
%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
|
|
// CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh3 split_axes = {{\[\[}}0]] partial = max [1] : !mesh.sharding
|
|
%s = mesh.sharding @mesh3 split_axes = [[0]] partial = max[1] : !mesh.sharding
|
|
// CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8xf32>
|
|
%0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
|
|
return %0 : tensor<4x8xf32>
|
|
}
|
|
|
|
// CHECK-LABEL: func @mesh_shard_op_partial_min
|
|
func.func @mesh_shard_op_partial_min(
|
|
// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
|
|
%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
|
|
// CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh3 split_axes = {{\[\[}}0]] partial = min [1] : !mesh.sharding
|
|
%s = mesh.sharding @mesh3 split_axes = [[0]] partial = min[1] : !mesh.sharding
|
|
// CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8xf32>
|
|
%0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
|
|
return %0 : tensor<4x8xf32>
|
|
}
|
|
|
|
// CHECK-LABEL: func @mesh_shard_op_partial_generic
|
|
func.func @mesh_shard_op_partial_generic(
|
|
// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
|
|
%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
|
|
// CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh3 split_axes = {{\[\[}}0]] partial = generic [1] : !mesh.sharding
|
|
%s = mesh.sharding @mesh3 split_axes = [[0]] partial = generic[1] : !mesh.sharding
|
|
// CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8xf32>
|
|
%0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
|
|
return %0 : tensor<4x8xf32>
|
|
}
|
|
|
|
// CHECK-LABEL: func @mesh_shard_op_partial_sum
|
|
func.func @mesh_shard_op_partial_sum(
|
|
// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
|
|
%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
|
|
// CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh3 split_axes = {{\[\[}}0]] partial = sum [1] : !mesh.sharding
|
|
%s = mesh.sharding @mesh3 split_axes = [[0]] partial = sum[1] : !mesh.sharding
|
|
// CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8xf32>
|
|
%0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
|
|
return %0 : tensor<4x8xf32>
|
|
}
|
|
|
|
// CHECK-LABEL: func @mesh_shard_op_partial_sum_multi_axes
|
|
func.func @mesh_shard_op_partial_sum_multi_axes(
|
|
// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
|
|
%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
|
|
// CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh3 split_axes = {{\[\[}}0]] partial = sum [1, 2] : !mesh.sharding
|
|
%s = mesh.sharding @mesh3 split_axes = [[0]] partial = sum[1, 2] : !mesh.sharding
|
|
// CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8xf32>
|
|
%0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
|
|
return %0 : tensor<4x8xf32>
|
|
}
|
|
|
|
// CHECK-LABEL: func @mesh_shard_op_two_users
|
|
// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
|
|
func.func @mesh_shard_op_two_users(%arg0 : tensor<4x8xf32>) ->
|
|
(tensor<4x8xf32>, tensor<4x8xf32>) {
|
|
// CHECK-NEXT: %[[V0:.*]] = mesh.sharding @mesh0 split_axes = {{\[\[}}0]] : !mesh.sharding
|
|
%s0 = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
|
|
%0 = mesh.shard %arg0 to %s0 : tensor<4x8xf32>
|
|
// CHECK-DAG: mesh.sharding @mesh0 split_axes = {{\[\[}}1]] : !mesh.sharding
|
|
%s1 = mesh.sharding @mesh0 split_axes = [[1]] : !mesh.sharding
|
|
%1 = mesh.shard %0 to %s1 annotate_for_users : tensor<4x8xf32>
|
|
// CHECK-DAG: mesh.sharding @mesh0 split_axes = {{\[\[}}2]] : !mesh.sharding
|
|
%s2 = mesh.sharding @mesh0 split_axes = [[2]] : !mesh.sharding
|
|
%2 = mesh.shard %0 to %s2 annotate_for_users : tensor<4x8xf32>
|
|
return %1, %2 : tensor<4x8xf32>, tensor<4x8xf32>
|
|
}
|
|
|
|
// CHECK-LABEL: func @mesh_shard_halo_sizes
|
|
func.func @mesh_shard_halo_sizes() -> () {
|
|
// CHECK: %[[C3:.*]] = arith.constant 3 : i64
|
|
%c3 = arith.constant 3 : i64
|
|
// CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] halo_sizes = [1, 4] : !mesh.sharding
|
|
%sharding1 = mesh.sharding @mesh4 split_axes = [[0]] halo_sizes = [1, 4] : !mesh.sharding
|
|
// CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] halo_sizes = [4, %[[C3]]] : !mesh.sharding
|
|
%sharding2 = mesh.sharding @mesh4 split_axes = [[0]] halo_sizes = [4, %c3] : !mesh.sharding
|
|
return
|
|
}
|
|
|
|
// CHECK-LABEL: func @mesh_shard_dims_sizes
|
|
func.func @mesh_shard_dims_sizes() -> () {
|
|
// CHECK: %[[C3:.*]] = arith.constant 3 : i64
|
|
%c3 = arith.constant 3 : i64
|
|
// CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] sharded_dims_sizes = [1, 4, 2] : !mesh.sharding
|
|
%sharding1 = mesh.sharding @mesh4 split_axes = [[0]] sharded_dims_sizes = [1, 4, 2] : !mesh.sharding
|
|
// CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] sharded_dims_sizes = [4, %[[C3]], 1] : !mesh.sharding
|
|
%sharding2 = mesh.sharding @mesh4 split_axes = [[0]] sharded_dims_sizes = [4, %c3, 1] : !mesh.sharding
|
|
return
|
|
}
|
|
|
|
// CHECK-LABEL: func @mesh_shard_shape
|
|
func.func @mesh_shard_shape() {
|
|
// CHECK: %[[C3:.*]] = arith.constant 3 : index
|
|
%c3 = arith.constant 3 : index
|
|
// CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh0 split_axes = {{\[\[}}]] : !mesh.sharding
|
|
%s = mesh.sharding @mesh0 split_axes = [[]] : !mesh.sharding
|
|
// CHECK-NEXT: mesh.shard_shape 8x? %[[S]] %[[C3]] : index, index
|
|
%shp:2 = mesh.shard_shape 8x? %s %c3 : index, index
|
|
// CHECK-NEXT: mesh.shard_shape 8x4 %[[S]] %[[C3]] : index, index
|
|
%shp1:2 = mesh.shard_shape 8x4 %s %c3 : index, index
|
|
return
|
|
}
|
|
|
|
// CHECK-LABEL: func @mesh_shape
|
|
func.func @mesh_shape() -> (index, index) {
|
|
// CHECK: %[[RES:.*]]:2 = mesh.mesh_shape @mesh0 axes = [0, 1] : index, index
|
|
%0:2 = mesh.mesh_shape @mesh0 axes = [0, 1] : index, index
|
|
// CHECK: return %[[RES]]#0, %[[RES]]#1 : index, index
|
|
return %0#0, %0#1 : index, index
|
|
}
|
|
|
|
// CHECK-LABEL: func @mesh_shape_default_axes
|
|
func.func @mesh_shape_default_axes() -> (index, index, index) {
|
|
// CHECK: %[[RES:.*]]:3 = mesh.mesh_shape @mesh0 : index, index, index
|
|
%0:3 = mesh.mesh_shape @mesh0 : index, index, index
|
|
// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
|
|
return %0#0, %0#1, %0#2 : index, index, index
|
|
}
|
|
|
|
// CHECK-LABEL: func @mesh_shape_empty_axes
|
|
func.func @mesh_shape_empty_axes() -> (index, index, index) {
|
|
// CHECK: %[[RES:.*]]:3 = mesh.mesh_shape @mesh0 : index, index, index
|
|
%0:3 = mesh.mesh_shape @mesh0 axes = [] : index, index, index
|
|
// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
|
|
return %0#0, %0#1, %0#2 : index, index, index
|
|
}
|
|
|
|
// CHECK-LABEL: func @process_multi_index
|
|
func.func @process_multi_index() -> (index, index) {
|
|
// CHECK: %[[RES:.*]]:2 = mesh.process_multi_index on @mesh0 axes = [0, 1] : index, index
|
|
%0:2 = mesh.process_multi_index on @mesh0 axes = [0, 1] : index, index
|
|
// CHECK: return %[[RES]]#0, %[[RES]]#1 : index, index
|
|
return %0#0, %0#1 : index, index
|
|
}
|
|
|
|
// CHECK-LABEL: func @process_multi_index_default_axes
|
|
func.func @process_multi_index_default_axes() -> (index, index, index) {
|
|
// CHECK: %[[RES:.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index
|
|
%0:3 = mesh.process_multi_index on @mesh0 : index, index, index
|
|
// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
|
|
return %0#0, %0#1, %0#2 : index, index, index
|
|
}
|
|
|
|
// CHECK-LABEL: func @process_multi_index_empty_axes
|
|
func.func @process_multi_index_empty_axes() -> (index, index, index) {
|
|
// CHECK: %[[RES:.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index
|
|
%0:3 = mesh.process_multi_index on @mesh0 axes = [] : index, index, index
|
|
// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
|
|
return %0#0, %0#1, %0#2 : index, index, index
|
|
}
|
|
|
|
// CHECK-LABEL: func @process_linear_index
|
|
func.func @process_linear_index() -> index {
|
|
// CHECK: %[[RES:.*]] = mesh.process_linear_index on @mesh0 : index
|
|
%0 = mesh.process_linear_index on @mesh0 : index
|
|
// CHECK: return %[[RES]] : index
|
|
return %0 : index
|
|
}
|
|
|
|
// CHECK-LABEL: func @all_reduce
|
|
func.func @all_reduce(
|
|
// CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
|
|
%arg0 : tensor<3x4xf32>) -> tensor<3x4xf64> {
|
|
// CHECK-NEXT: mesh.all_reduce %[[ARG]] on @mesh0 mesh_axes = [1, 0] reduction = max
|
|
// CHECK-SAME: : tensor<3x4xf32> -> tensor<3x4xf64>
|
|
%0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [1, 0] reduction = max
|
|
: tensor<3x4xf32> -> tensor<3x4xf64>
|
|
return %0 : tensor<3x4xf64>
|
|
}
|
|
|
|
// CHECK-LABEL: func @all_gather
|
|
func.func @all_gather(
|
|
// CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
|
|
%arg0 : tensor<3x4xf32>) -> tensor<3x16xf32> {
|
|
// CHECK-NEXT: mesh.all_gather %[[ARG]] on @mesh0 mesh_axes = [2] gather_axis = 1
|
|
// CHECK-SAME: : tensor<3x4xf32> -> tensor<3x16xf32>
|
|
%0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [2] gather_axis = 1
|
|
: tensor<3x4xf32> -> tensor<3x16xf32>
|
|
return %0 : tensor<3x16xf32>
|
|
}
|
|
|
|
// CHECK-LABEL: func @all_gather_dynamic_dims_in_tensor
|
|
func.func @all_gather_dynamic_dims_in_tensor(
|
|
// CHECK-SAME: %[[ARG:.*]]: tensor<?x?xf32>
|
|
%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
|
|
// CHECK-NEXT: mesh.all_gather %[[ARG]] on @mesh0 mesh_axes = [2] gather_axis = 1
|
|
// CHECK-SAME: : tensor<?x?xf32> -> tensor<?x?xf32>
|
|
%0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [2] gather_axis = 1
|
|
: tensor<?x?xf32> -> tensor<?x?xf32>
|
|
return %0 : tensor<?x?xf32>
|
|
}
|
|
|
|
// CHECK-LABEL: func @all_gather_dynamic_dims_in_mesh
|
|
func.func @all_gather_dynamic_dims_in_mesh(
|
|
// CHECK-SAME: %[[ARG:.*]]: tensor<5x6xf32>
|
|
%arg0 : tensor<5x6xf32>) -> tensor<5x?xf32> {
|
|
// CHECK-NEXT: mesh.all_gather %[[ARG]] on @mesh3 mesh_axes = [1] gather_axis = 1
|
|
// CHECK-SAME: : tensor<5x6xf32> -> tensor<5x?xf32>
|
|
%0 = mesh.all_gather %arg0 on @mesh3 mesh_axes = [1] gather_axis = 1
|
|
: tensor<5x6xf32> -> tensor<5x?xf32>
|
|
return %0 : tensor<5x?xf32>
|
|
}
|
|
|
|
// CHECK-LABEL: func @all_slice_static_dimensions
|
|
func.func @all_slice_static_dimensions(
|
|
// CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
|
|
%arg0 : tensor<3x4xf32>) -> tensor<3x1xf32> {
|
|
// CHECK-NEXT: mesh.all_slice %[[ARG]]
|
|
// CHECK-SAME: on @mesh0 mesh_axes = [2] slice_axis = 1
|
|
// CHECK-SAME: : tensor<3x4xf32> -> tensor<3x1xf32>
|
|
%0 = mesh.all_slice %arg0 on @mesh0 mesh_axes = [2] slice_axis = 1
|
|
: tensor<3x4xf32> -> tensor<3x1xf32>
|
|
return %0 : tensor<3x1xf32>
|
|
}
|
|
|
|
// CHECK-LABEL: func @all_slice_dynamic_dimensions
|
|
func.func @all_slice_dynamic_dimensions(
|
|
// CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>
|
|
%arg0 : tensor<?xf32>) -> tensor<?xf32> {
|
|
// CHECK-NEXT: mesh.all_slice %[[ARG]]
|
|
// CHECK-SAME: on @mesh3 mesh_axes = [0, 1] slice_axis = 0
|
|
// CHECK-SAME: : tensor<?xf32> -> tensor<?xf32>
|
|
%0 = mesh.all_slice %arg0 on @mesh3 mesh_axes = [0, 1] slice_axis = 0
|
|
: tensor<?xf32> -> tensor<?xf32>
|
|
return %0 : tensor<?xf32>
|
|
}
|
|
|
|
// CHECK-LABEL: func @all_to_all
|
|
func.func @all_to_all(
|
|
// CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8>
|
|
%arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
|
|
// CHECK-NEXT: mesh.all_to_all %[[ARG]]
|
|
// CHECK-SAME: on @mesh4 split_axis = 1 concat_axis = 0
|
|
// CHECK-SAME: : tensor<3x6xi8> -> tensor<3x6xi8>
|
|
%0 = mesh.all_to_all %arg0 on @mesh4
|
|
split_axis = 1 concat_axis = 0
|
|
: tensor<3x6xi8> -> tensor<3x6xi8>
|
|
return %0 : tensor<3x6xi8>
|
|
}
|
|
|
|
// CHECK-LABEL: func @all_to_all_dynamic_dims_in_result
|
|
func.func @all_to_all_dynamic_dims_in_result(
|
|
// CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8>
|
|
%arg0 : tensor<3x6xi8>) -> tensor<3x?xi8> {
|
|
// CHECK-NEXT: mesh.all_to_all %[[ARG]]
|
|
// CHECK-SAME: on @mesh4 split_axis = 1 concat_axis = 0
|
|
// CHECK-SAME: : tensor<3x6xi8> -> tensor<3x?xi8>
|
|
%0 = mesh.all_to_all %arg0 on @mesh4
|
|
split_axis = 1 concat_axis = 0
|
|
: tensor<3x6xi8> -> tensor<3x?xi8>
|
|
return %0 : tensor<3x?xi8>
|
|
}
|
|
|
|
// CHECK-LABEL: func @all_to_all_same_split_concat_dim_with_dynamic_device_group_size
|
|
func.func @all_to_all_same_split_concat_dim_with_dynamic_device_group_size(
|
|
// CHECK-SAME: %[[ARG:.*]]: tensor<3xi8>
|
|
%arg0 : tensor<3xi8>) -> tensor<3xi8> {
|
|
// CHECK-NEXT: mesh.all_to_all %[[ARG]]
|
|
// CHECK-SAME: @mesh4 split_axis = 0 concat_axis = 0
|
|
// CHECK-SAME: : tensor<3xi8> -> tensor<3xi8>
|
|
%0 = mesh.all_to_all %arg0 on @mesh4
|
|
split_axis = 0 concat_axis = 0
|
|
: tensor<3xi8> -> tensor<3xi8>
|
|
return %0 : tensor<3xi8>
|
|
}
|
|
|
|
// CHECK-LABEL: func @all_to_all_non_divisible_split_axis_size
|
|
func.func @all_to_all_non_divisible_split_axis_size(
|
|
// CHECK-SAME: %[[ARG:.*]]: tensor<2x3xi8>
|
|
%arg0 : tensor<2x3xi8>) -> tensor<?x12xi8> {
|
|
// CHECK-NEXT: mesh.all_to_all %[[ARG]]
|
|
// CHECK-SAME: @mesh0 mesh_axes = [0, 1] split_axis = 0 concat_axis = 1
|
|
// CHECK-SAME: : tensor<2x3xi8> -> tensor<?x12xi8>
|
|
%0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0, 1]
|
|
split_axis = 0 concat_axis = 1
|
|
: tensor<2x3xi8> -> tensor<?x12xi8>
|
|
return %0 : tensor<?x12xi8>
|
|
}
|
|
|
|
// CHECK-LABEL: func @broadcast_static_root
|
|
func.func @broadcast_static_root(
|
|
// CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8>
|
|
%arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
|
|
// CHECK-NEXT: mesh.broadcast %[[ARG]]
|
|
// CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
|
|
// CHECK-SAME: root = [0, 1]
|
|
// CHECK-SAME: : (tensor<3x6xi8>) -> tensor<3x6xi8>
|
|
%0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0, 2]
|
|
root = [0, 1]
|
|
: (tensor<3x6xi8>) -> tensor<3x6xi8>
|
|
return %0 : tensor<3x6xi8>
|
|
}
|
|
|
|
// CHECK-LABEL: func @broadcast_dynamic_root
|
|
func.func @broadcast_dynamic_root(
|
|
// CHECK-SAME: %[[ARG0:.*]]: tensor<3x6xi8>
|
|
%arg0 : tensor<3x6xi8>,
|
|
// CHECK-SAME: %[[ARG1:.*]]: index
|
|
%arg1 : index
|
|
) -> tensor<3x6xi8> {
|
|
// CHECK-NEXT: mesh.broadcast %[[ARG0]]
|
|
// CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
|
|
// CHECK-SAME: root = [1, %[[ARG1]]]
|
|
// CHECK-SAME: : (tensor<3x6xi8>, index) -> tensor<3x6xi8>
|
|
%0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0, 2]
|
|
root = [1, %arg1]
|
|
: (tensor<3x6xi8>, index) -> tensor<3x6xi8>
|
|
return %0 : tensor<3x6xi8>
|
|
}
|
|
|
|
// CHECK-LABEL: func @gather_static_root
|
|
func.func @gather_static_root(
|
|
// CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8>
|
|
%arg0 : tensor<3x6xi8>) -> tensor<24x6xi8> {
|
|
// CHECK-NEXT: mesh.gather %[[ARG]]
|
|
// CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
|
|
// CHECK-SAME: gather_axis = 0
|
|
// CHECK-SAME: root = [0, 1]
|
|
// CHECK-SAME: : (tensor<3x6xi8>) -> tensor<24x6xi8>
|
|
%0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0, 2]
|
|
gather_axis = 0
|
|
root = [0, 1]
|
|
: (tensor<3x6xi8>) -> tensor<24x6xi8>
|
|
return %0 : tensor<24x6xi8>
|
|
}
|
|
|
|
// CHECK-LABEL: func @gather_dynamic_root
|
|
func.func @gather_dynamic_root(
|
|
// CHECK-SAME: %[[ARG0:.*]]: tensor<3x6xi8>
|
|
%arg0 : tensor<3x6xi8>,
|
|
// CHECK-SAME: %[[ARG1:.*]]: index
|
|
%arg1 : index
|
|
) -> tensor<24x6xi8> {
|
|
// CHECK-NEXT: mesh.gather %[[ARG0]]
|
|
// CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
|
|
// CHECK-SAME: gather_axis = 0
|
|
// CHECK-SAME: root = [1, %[[ARG1]]]
|
|
// CHECK-SAME: : (tensor<3x6xi8>, index) -> tensor<24x6xi8>
|
|
%0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0, 2]
|
|
gather_axis = 0
|
|
root = [1, %arg1]
|
|
: (tensor<3x6xi8>, index) -> tensor<24x6xi8>
|
|
return %0 : tensor<24x6xi8>
|
|
}
|
|
|
|
// CHECK-LABEL: func @receive_static_source
|
|
func.func @receive_static_source(
|
|
// CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
|
|
%arg0 : tensor<2xi8>) -> tensor<2xi8> {
|
|
// CHECK-NEXT: mesh.recv %[[ARG]]
|
|
// CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
|
|
// CHECK-SAME: source = [0, 1]
|
|
// CHECK-SAME: : (tensor<2xi8>) -> tensor<2xi8>
|
|
%0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0, 2]
|
|
source = [0, 1]
|
|
: (tensor<2xi8>) -> tensor<2xi8>
|
|
return %0 : tensor<2xi8>
|
|
}
|
|
|
|
// CHECK-LABEL: func @receive_dynamic_source
|
|
func.func @receive_dynamic_source(
|
|
// CHECK-SAME: %[[ARG0:.*]]: tensor<2xi8>
|
|
%arg0 : tensor<2xi8>,
|
|
// CHECK-SAME: %[[ARG1:.*]]: index
|
|
%arg1 : index
|
|
) -> tensor<2xi8> {
|
|
// CHECK-NEXT: mesh.recv %[[ARG0]]
|
|
// CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
|
|
// CHECK-SAME: source = [1, %[[ARG1]]]
|
|
// CHECK-SAME: : (tensor<2xi8>, index) -> tensor<2xi8>
|
|
%0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0, 2]
|
|
source = [1, %arg1]
|
|
: (tensor<2xi8>, index) -> tensor<2xi8>
|
|
return %0 : tensor<2xi8>
|
|
}
|
|
|
|
// CHECK-LABEL: func @receive_no_source
|
|
func.func @receive_no_source(
|
|
// CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
|
|
%arg0 : tensor<2xi8>) -> tensor<2xi8> {
|
|
// CHECK-NEXT: mesh.recv %[[ARG]]
|
|
// CHECK-NOT: source
|
|
%0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0, 2]
|
|
: (tensor<2xi8>) -> tensor<2xi8>
|
|
return %0 : tensor<2xi8>
|
|
}
|
|
|
|
// CHECK-LABEL: func @reduce_static_root
|
|
func.func @reduce_static_root(
|
|
// CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
|
|
%arg0 : tensor<2xi8>) -> tensor<2xi8> {
|
|
// CHECK-NEXT: mesh.reduce %[[ARG]]
|
|
// CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
|
|
// CHECK-SAME: root = [0, 1]
|
|
// CHECK-SAME: : (tensor<2xi8>) -> tensor<2xi8>
|
|
%0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0, 2]
|
|
root = [0, 1]
|
|
: (tensor<2xi8>) -> tensor<2xi8>
|
|
return %0 : tensor<2xi8>
|
|
}
|
|
|
|
// CHECK-LABEL: func @reduce_dynamic_root
|
|
func.func @reduce_dynamic_root(
|
|
// CHECK-SAME: %[[ARG0:.*]]: tensor<2xi8>
|
|
%arg0 : tensor<2xi8>,
|
|
// CHECK-SAME: %[[ARG1:.*]]: index
|
|
%arg1 : index
|
|
) -> tensor<2xi8> {
|
|
// CHECK-NEXT: mesh.reduce %[[ARG0]]
|
|
// CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
|
|
// CHECK-SAME: root = [1, %[[ARG1]]]
|
|
// CHECK-SAME: : (tensor<2xi8>, index) -> tensor<2xi8>
|
|
%0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0, 2]
|
|
root = [1, %arg1]
|
|
: (tensor<2xi8>, index) -> tensor<2xi8>
|
|
return %0 : tensor<2xi8>
|
|
}
|
|
|
|
// CHECK-LABEL: func @reduce_different_return_element_type
|
|
func.func @reduce_different_return_element_type(
|
|
// CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
|
|
%arg0 : tensor<2xi8>) -> tensor<2xi16> {
|
|
// CHECK-NEXT: mesh.reduce %[[ARG]]
|
|
// CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
|
|
// CHECK-SAME: root = [0, 1]
|
|
// CHECK-SAME: : (tensor<2xi8>) -> tensor<2xi16>
|
|
%0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0, 2]
|
|
root = [0, 1]
|
|
: (tensor<2xi8>) -> tensor<2xi16>
|
|
return %0 : tensor<2xi16>
|
|
}
|
|
|
|
// CHECK-LABEL: func @reduce_scatter_static_dimensions
|
|
func.func @reduce_scatter_static_dimensions(
|
|
// CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
|
|
%arg0 : tensor<3x4xf32>) -> tensor<3x1xf64> {
|
|
// CHECK-NEXT: mesh.reduce_scatter %[[ARG]]
|
|
// CHECK-SAME: on @mesh0 mesh_axes = [2] reduction = max scatter_axis = 1
|
|
// CHECK-SAME: : tensor<3x4xf32> -> tensor<3x1xf64>
|
|
%0 = mesh.reduce_scatter %arg0 on @mesh0 mesh_axes = [2]
|
|
reduction = max scatter_axis = 1
|
|
: tensor<3x4xf32> -> tensor<3x1xf64>
|
|
return %0 : tensor<3x1xf64>
|
|
}
|
|
|
|
// CHECK-LABEL: func @reduce_scatter_dynamic_dimensions
|
|
func.func @reduce_scatter_dynamic_dimensions(
|
|
// CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>
|
|
%arg0 : tensor<?xf32>) -> tensor<?xf64> {
|
|
// CHECK-NEXT: mesh.reduce_scatter %[[ARG]]
|
|
// CHECK-SAME: on @mesh3 mesh_axes = [0, 1] scatter_axis = 0
|
|
// CHECK-SAME: : tensor<?xf32> -> tensor<?xf64>
|
|
%0 = mesh.reduce_scatter %arg0 on @mesh3 mesh_axes = [0, 1] scatter_axis = 0
|
|
: tensor<?xf32> -> tensor<?xf64>
|
|
return %0 : tensor<?xf64>
|
|
}
|
|
|
|
// CHECK-LABEL: func @scatter_static_dimensions
|
|
func.func @scatter_static_dimensions(
|
|
// CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
|
|
%arg0 : tensor<3x4xf32>) -> tensor<3x1xf32> {
|
|
// CHECK-NEXT: mesh.scatter %[[ARG]]
|
|
// CHECK-SAME: on @mesh0 mesh_axes = [2]
|
|
// CHECK-SAME: scatter_axis = 1 root = [1]
|
|
// CHECK-SAME: : (tensor<3x4xf32>) -> tensor<3x1xf32>
|
|
%0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [2]
|
|
scatter_axis = 1 root = [1]
|
|
: (tensor<3x4xf32>) -> tensor<3x1xf32>
|
|
return %0 : tensor<3x1xf32>
|
|
}
|
|
|
|
// CHECK-LABEL: func @scatter_dynamic_dimensions
|
|
func.func @scatter_dynamic_dimensions(
|
|
// CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>
|
|
%arg0 : tensor<?xf32>) -> tensor<?xf32> {
|
|
// CHECK-NEXT: mesh.scatter %[[ARG]]
|
|
// CHECK-SAME: on @mesh3 mesh_axes = [0, 1]
|
|
// CHECK-SAME: scatter_axis = 0 root = [1, 2]
|
|
// CHECK-SAME: : (tensor<?xf32>) -> tensor<?xf32>
|
|
%0 = mesh.scatter %arg0 on @mesh3 mesh_axes = [0, 1]
|
|
scatter_axis = 0 root = [1, 2]
|
|
: (tensor<?xf32>) -> tensor<?xf32>
|
|
return %0 : tensor<?xf32>
|
|
}
|
|
|
|
// CHECK-LABEL: func @scatter_dynamic_root
|
|
func.func @scatter_dynamic_root(
|
|
// CHECK-SAME: %[[ARG0:.*]]: tensor<8xi8>
|
|
%arg0 : tensor<8xi8>,
|
|
// CHECK-SAME: %[[ARG1:.*]]: index
|
|
%arg1 : index
|
|
) -> tensor<1xi8> {
|
|
// CHECK-NEXT: mesh.scatter %[[ARG0]]
|
|
// CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
|
|
// CHECK-SAME: scatter_axis = 0
|
|
// CHECK-SAME: root = [1, %[[ARG1]]]
|
|
// CHECK-SAME: : (tensor<8xi8>, index) -> tensor<1xi8>
|
|
%0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0, 2]
|
|
scatter_axis = 0
|
|
root = [1, %arg1]
|
|
: (tensor<8xi8>, index) -> tensor<1xi8>
|
|
return %0 : tensor<1xi8>
|
|
}
|
|
|
|
// CHECK-LABEL: func @send_static_destination
|
|
func.func @send_static_destination(
|
|
// CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
|
|
%arg0 : tensor<2xi8>) -> tensor<2xi8> {
|
|
// CHECK-NEXT: mesh.send %[[ARG]]
|
|
// CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
|
|
// CHECK-SAME: destination = [0, 1]
|
|
// CHECK-SAME: : (tensor<2xi8>) -> tensor<2xi8>
|
|
%0 = mesh.send %arg0 on @mesh0 mesh_axes = [0, 2]
|
|
destination = [0, 1]
|
|
: (tensor<2xi8>) -> tensor<2xi8>
|
|
return %0 : tensor<2xi8>
|
|
}
|
|
|
|
// CHECK-LABEL: func @send_dynamic_destination
|
|
func.func @send_dynamic_destination(
|
|
// CHECK-SAME: %[[ARG0:.*]]: tensor<2xi8>
|
|
%arg0 : tensor<2xi8>,
|
|
// CHECK-SAME: %[[ARG1:.*]]: index
|
|
%arg1 : index
|
|
) -> tensor<2xi8> {
|
|
// CHECK-NEXT: mesh.send %[[ARG0]]
|
|
// CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
|
|
// CHECK-SAME: destination = [1, %[[ARG1]]]
|
|
// CHECK-SAME: : (tensor<2xi8>, index) -> tensor<2xi8>
|
|
%0 = mesh.send %arg0 on @mesh0 mesh_axes = [0, 2]
|
|
destination = [1, %arg1]
|
|
: (tensor<2xi8>, index) -> tensor<2xi8>
|
|
return %0 : tensor<2xi8>
|
|
}
|
|
|
|
// CHECK-LABEL: func @shift
|
|
func.func @shift(
|
|
// CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
|
|
%arg0 : tensor<2xi8>) -> tensor<2xi8> {
|
|
// CHECK-NEXT: mesh.shift %[[ARG]]
|
|
// CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
|
|
// CHECK-SAME: shift_axis = 2 offset = -2 rotate
|
|
// CHECK-SAME: : tensor<2xi8> -> tensor<2xi8>
|
|
%0 = mesh.shift %arg0 on @mesh0 mesh_axes = [0, 2]
|
|
shift_axis = 2 offset = -2 rotate
|
|
: tensor<2xi8> -> tensor<2xi8>
|
|
return %0 : tensor<2xi8>
|
|
}
|
|
|
|
// CHECK-LABEL: func @update_halo
|
|
func.func @update_halo(
|
|
// CHECK-SAME: %[[ARG:.*]]: memref<12x12xi8>
|
|
%arg0 : memref<12x12xi8>) {
|
|
// CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i64
|
|
// CHECK-NEXT: mesh.update_halo %[[ARG]] on @mesh0
|
|
// CHECK-SAME: split_axes = {{\[\[}}0]]
|
|
// CHECK-SAME: halo_sizes = [2, %c2_i64] : memref<12x12xi8>
|
|
%c2 = arith.constant 2 : i64
|
|
mesh.update_halo %arg0 on @mesh0 split_axes = [[0]]
|
|
halo_sizes = [2, %c2] : memref<12x12xi8>
|
|
// CHECK-NEXT: mesh.update_halo %[[ARG]] on @mesh0
|
|
// CHECK-SAME: split_axes = {{\[\[}}0], [1]]
|
|
// CHECK-SAME: halo_sizes = [2, 2, %[[C2]], 2]
|
|
// CHECK-SAME: target_halo_sizes = [3, 3, 2, 2] : memref<12x12xi8>
|
|
mesh.update_halo %arg0 on @mesh0 split_axes = [[0], [1]]
|
|
halo_sizes = [2, 2, %c2, 2] target_halo_sizes = [3, 3, 2, 2]
|
|
: memref<12x12xi8>
|
|
return
|
|
}
|