This is a fixed copy of #98145 (necessary after it got reverted). @sogartar @yaochengji This PR adds the following to #98145: - `UpdateHaloOp` accepts a `memref` (instead of a tensor) and not returning a result to clarify its inplace-semantics - `UpdateHaloOp` accepts `split_axis` to allow multiple mesh-axes per tensor/memref-axis (similar to `mesh.sharding`) - The implementation of `Shardinginterface` for tensor operation (`tensor.empty` for now) moved from the tensor library to the mesh interface library. `spmdize` uses features from `mesh` dialect. @rengolin agreed that `tensor` should not depend on `mesh` so this functionality cannot live in a `tensor`s lib. The unfulfilled dependency caused the issues leading to reverting #98145. Such cases are generally possible and might lead to re-considering the current structure (like for tosa ops). - rebased onto latest main -------------------------- Replacing `#mesh.sharding` attribute with operation `mesh.sharding` - extended semantics now allow providing optional `halo_sizes` and `sharded_dims_sizes` - internally a sharding is represented as a non-IR class `mesh::MeshSharding` What previously was ```mlir %sharded0 = mesh.shard %arg0 <@mesh0, [[0]]> : tensor<4x8xf32> %sharded1 = mesh.shard %arg1 <@mesh0, [[0]]> annotate_for_users : tensor<16x8xf32> ``` is now ```mlir %sharding = mesh.sharding @mesh0, [[0]] : !mesh.sharding %0 = mesh.shard %arg0 to %sharding : tensor<4x8xf32> %1 = mesh.shard %arg1 to %sharding annotate_for_users : tensor<16x8xf32> ``` and allows additional annotations to control the shard sizes: ```mlir mesh.mesh @mesh0 (shape = 4) %sharding0 = mesh.sharding @mesh0, [[0]] halo_sizes = [1, 2] : !mesh.sharding %0 = mesh.shard %arg0 to %sharding0 : tensor<4x8xf32> %sharding1 = mesh.sharding @mesh0, [[0]] sharded_dims_sizes = [3, 5, 5, 3] : !mesh.sharding %1 = mesh.shard %arg1 to %sharding1 annotate_for_users : tensor<16x8xf32> ``` - `mesh.shard` op accepts additional optional attribute `force`, useful for halo updates - Some initial spmdization support for the new semantics - Support for `tensor.empty` reacting on `sharded_dims_sizes` and `halo_sizes` in the sharding - New collective operation `mesh.update_halo` as a spmdized target for shardings with `halo_sizes` --------- Co-authored-by: frank.schlimbach <fschlimb@smtp.igk.intel.com> Co-authored-by: Jie Fu <jiefu@tencent.com>
183 lines
9.0 KiB
MLIR
183 lines
9.0 KiB
MLIR
// RUN: mlir-opt -test-mesh-resharding-spmdization %s | FileCheck %s
|
|
|
|
mesh.mesh @mesh_1d(shape = 2)
|
|
mesh.mesh @mesh_1d_dynamic(shape = ?)
|
|
|
|
// CHECK-LABEL: func @same_source_and_target_sharding
|
|
func.func @same_source_and_target_sharding(
|
|
// CHECK-SAME: %[[ARG:.*]]: tensor<2xf32>
|
|
%arg0: tensor<2xf32>
|
|
) -> tensor<2xf32> {
|
|
%s0 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
|
|
%0 = mesh.shard %arg0 to %s0 : tensor<2xf32>
|
|
%s1 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
|
|
%1 = mesh.shard %0 to %s1 annotate_for_users : tensor<2xf32>
|
|
// CHECK: return %[[ARG]]
|
|
return %1 : tensor<2xf32>
|
|
}
|
|
|
|
// CHECK-LABEL: func @identical_source_and_target_sharding
|
|
func.func @identical_source_and_target_sharding(
|
|
// CHECK-SAME: %[[ARG:.*]]: tensor<2xf32>
|
|
%arg0: tensor<2xf32>
|
|
) -> tensor<2xf32> {
|
|
%s0 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
|
|
%0 = mesh.shard %arg0 to %s0 : tensor<2xf32>
|
|
%1 = mesh.shard %0 to %s0 annotate_for_users : tensor<2xf32>
|
|
// CHECK: return %[[ARG]]
|
|
return %1 : tensor<2xf32>
|
|
}
|
|
|
|
// CHECK-LABEL: func @split_replicated_tensor_axis
|
|
func.func @split_replicated_tensor_axis(
|
|
// CHECK-SAME: %[[ARG:.*]]: tensor<3x14xf32>
|
|
%arg0: tensor<3x14xf32>
|
|
) -> tensor<3x14xf32> {
|
|
// CHECK: %[[ALL_SLICE:.*]] = mesh.all_slice %[[ARG]] on @mesh_1d mesh_axes = [0] slice_axis = 1
|
|
// CHECK-SAME: tensor<3x14xf32> -> tensor<3x7xf32>
|
|
// CHECK: %[[RESULT:.*]] = builtin.unrealized_conversion_cast %[[ALL_SLICE]] : tensor<3x7xf32> to tensor<3x14xf32>
|
|
%s0 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
|
|
%0 = mesh.shard %arg0 to %s0 : tensor<3x14xf32>
|
|
%s1 = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding
|
|
%1 = mesh.shard %0 to %s1 annotate_for_users : tensor<3x14xf32>
|
|
// CHECK: return %[[RESULT]] : tensor<3x14xf32>
|
|
return %1 : tensor<3x14xf32>
|
|
}
|
|
|
|
// CHECK-LABEL: func @split_replicated_tensor_axis_dynamic
|
|
func.func @split_replicated_tensor_axis_dynamic(
|
|
// CHECK-SAME: %[[ARG:.*]]: tensor<?x3x?xf32>
|
|
%arg0: tensor<?x3x?xf32>
|
|
) -> tensor<?x3x?xf32> {
|
|
// CHECK: %[[RESULT:.*]] = mesh.all_slice %[[ARG]] on @mesh_1d_dynamic mesh_axes = [0] slice_axis = 0
|
|
// CHECK-SAME: tensor<?x3x?xf32> -> tensor<?x3x?xf32>
|
|
%s0 = mesh.sharding @mesh_1d_dynamic split_axes = [[], [], []] : !mesh.sharding
|
|
%0 = mesh.shard %arg0 to %s0 : tensor<?x3x?xf32>
|
|
%s1 = mesh.sharding @mesh_1d_dynamic split_axes = [[0]] : !mesh.sharding
|
|
%1 = mesh.shard %0 to %s1 annotate_for_users : tensor<?x3x?xf32>
|
|
// CHECK: return %[[RESULT]] : tensor<?x3x?xf32>
|
|
return %1 : tensor<?x3x?xf32>
|
|
}
|
|
|
|
// CHECK-LABEL: func @move_split_axis
|
|
func.func @move_split_axis(
|
|
// CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
|
|
%arg0: tensor<10x14xf32>
|
|
) -> tensor<10x14xf32> {
|
|
// CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<5x14xf32>
|
|
// CHECK: %[[TARGET_SHARD:.*]] = mesh.all_to_all %[[SOURCE_SHARD]] on @mesh_1d mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<5x14xf32> -> tensor<10x7xf32>
|
|
// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<10x7xf32> to tensor<10x14xf32>
|
|
%s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
|
|
%0 = mesh.shard %arg0 to %s0 : tensor<10x14xf32>
|
|
%s1 = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding
|
|
%1 = mesh.shard %0 to %s1 annotate_for_users : tensor<10x14xf32>
|
|
// CHECK: return %[[RES]] : tensor<10x14xf32>
|
|
return %1 : tensor<10x14xf32>
|
|
}
|
|
|
|
// CHECK-LABEL: func @move_split_axis_dynamic_mesh
|
|
func.func @move_split_axis_dynamic_mesh(
|
|
// CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
|
|
%arg0: tensor<10x14xf32>
|
|
) -> tensor<10x14xf32> {
|
|
// CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<?x14xf32>
|
|
// CHECK: %[[ALL_TO_ALL:.*]] = mesh.all_to_all %[[SOURCE_SHARD]] on @mesh_1d_dynamic mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<?x14xf32> -> tensor<?x?xf32>
|
|
// CHECK: %[[TARGET_SHARD:.*]] = tensor.cast %[[ALL_TO_ALL]] : tensor<?x?xf32> to tensor<10x?xf32>
|
|
// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<10x?xf32> to tensor<10x14xf32>
|
|
%s0 = mesh.sharding @mesh_1d_dynamic split_axes = [[0]] : !mesh.sharding
|
|
%0 = mesh.shard %arg0 to %s0 : tensor<10x14xf32>
|
|
%s1 = mesh.sharding @mesh_1d_dynamic split_axes = [[], [0]] : !mesh.sharding
|
|
%1 = mesh.shard %0 to %s1 annotate_for_users : tensor<10x14xf32>
|
|
// CHECK: return %[[RES]] : tensor<10x14xf32>
|
|
return %1 : tensor<10x14xf32>
|
|
}
|
|
|
|
// CHECK-LABEL: func @move_split_dynamic_axis
|
|
func.func @move_split_dynamic_axis(
|
|
// CHECK-SAME: %[[ARG:.*]]: tensor<?x14xf32>
|
|
%arg0: tensor<?x14xf32>
|
|
) -> tensor<?x14xf32> {
|
|
// CHECK: %[[TARGET_SHARD:.*]] = mesh.all_to_all %[[ARG]] on @mesh_1d mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<?x14xf32> -> tensor<?x7xf32>
|
|
// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<?x7xf32> to tensor<?x14xf32>
|
|
%s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
|
|
%0 = mesh.shard %arg0 to %s0 : tensor<?x14xf32>
|
|
%s1 = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding
|
|
%1 = mesh.shard %0 to %s1 annotate_for_users : tensor<?x14xf32>
|
|
// CHECK: return %[[RES]] : tensor<?x14xf32>
|
|
return %1 : tensor<?x14xf32>
|
|
}
|
|
|
|
// CHECK-LABEL: func @unshard_static_axis
|
|
func.func @unshard_static_axis(
|
|
// CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
|
|
%arg0: tensor<10x14xf32>
|
|
) -> tensor<10x14xf32> {
|
|
// CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<5x14xf32>
|
|
// CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[SOURCE_SHARD]] on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<5x14xf32> -> tensor<10x14xf32>
|
|
%s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
|
|
%0 = mesh.shard %arg0 to %s0 : tensor<10x14xf32>
|
|
%s1 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
|
|
%1 = mesh.shard %0 to %s1 annotate_for_users : tensor<10x14xf32>
|
|
// CHECK: return %[[ALL_GATHER]] : tensor<10x14xf32>
|
|
return %1 : tensor<10x14xf32>
|
|
}
|
|
|
|
// CHECK-LABEL: func @unshard_static_last_axis
|
|
func.func @unshard_static_last_axis(
|
|
// CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
|
|
%arg0: tensor<10x14xf32>
|
|
) -> tensor<10x14xf32> {
|
|
// CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<10x7xf32>
|
|
// CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[SOURCE_SHARD]] on @mesh_1d mesh_axes = [0] gather_axis = 1 : tensor<10x7xf32> -> tensor<10x14xf32>
|
|
%s0 = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding
|
|
%0 = mesh.shard %arg0 to %s0 : tensor<10x14xf32>
|
|
%s1 = mesh.sharding @mesh_1d split_axes = [[], []] : !mesh.sharding
|
|
%1 = mesh.shard %0 to %s1 annotate_for_users : tensor<10x14xf32>
|
|
// CHECK: return %[[ALL_GATHER]] : tensor<10x14xf32>
|
|
return %1 : tensor<10x14xf32>
|
|
}
|
|
|
|
// CHECK-LABEL: func @unshard_dynamic_axis
|
|
func.func @unshard_dynamic_axis(
|
|
// CHECK-SAME: %[[ARG:.*]]: tensor<?x14xf32>
|
|
%arg0: tensor<?x14xf32>
|
|
) -> tensor<?x14xf32> {
|
|
// CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[ARG]] on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<?x14xf32> -> tensor<?x14xf32>
|
|
%s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
|
|
%0 = mesh.shard %arg0 to %s0 : tensor<?x14xf32>
|
|
%s1 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
|
|
%1 = mesh.shard %0 to %s1 annotate_for_users : tensor<?x14xf32>
|
|
// CHECK: return %[[ALL_GATHER]] : tensor<?x14xf32>
|
|
return %1 : tensor<?x14xf32>
|
|
}
|
|
|
|
// CHECK-LABEL: func @unshard_static_axis_on_dynamic_mesh_axis
|
|
func.func @unshard_static_axis_on_dynamic_mesh_axis(
|
|
// CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
|
|
%arg0: tensor<10x14xf32>
|
|
) -> tensor<10x14xf32> {
|
|
// CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<?x14xf32>
|
|
// CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[SOURCE_SHARD]] on @mesh_1d_dynamic mesh_axes = [0] gather_axis = 0 : tensor<?x14xf32> -> tensor<?x14xf32>
|
|
// CHECK: %[[RES:.*]] = tensor.cast %[[ALL_GATHER]] : tensor<?x14xf32> to tensor<10x14xf32>
|
|
%s0 = mesh.sharding @mesh_1d_dynamic split_axes = [[0]] : !mesh.sharding
|
|
%0 = mesh.shard %arg0 to %s0 : tensor<10x14xf32>
|
|
%s1 = mesh.sharding @mesh_1d_dynamic split_axes = [[]] : !mesh.sharding
|
|
%1 = mesh.shard %0 to %s1 annotate_for_users : tensor<10x14xf32>
|
|
// CHECK: return %[[RES]] : tensor<10x14xf32>
|
|
return %1 : tensor<10x14xf32>
|
|
}
|
|
|
|
// CHECK-LABEL: func @partial_axis_to_full_replication
|
|
func.func @partial_axis_to_full_replication(
|
|
// CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
|
|
%arg0: tensor<10x14xf32>
|
|
) -> tensor<10x14xf32> {
|
|
// CHECK: %[[ALL_REDUCE:.*]] = mesh.all_reduce %[[ARG]] on @mesh_1d mesh_axes = [0] : tensor<10x14xf32> -> tensor<10x14xf32>
|
|
%s0 = mesh.sharding @mesh_1d split_axes = [[]] partial = sum[0] : !mesh.sharding
|
|
%0 = mesh.shard %arg0 to %s0 : tensor<10x14xf32>
|
|
%s1 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
|
|
%1 = mesh.shard %0 to %s1 annotate_for_users : tensor<10x14xf32>
|
|
// CHECK: %[[ALL_REDUCE]] : tensor<10x14xf32>
|
|
return %1 : tensor<10x14xf32>
|
|
}
|