From 43e1a5a411d972fe06a1afb86ffd5ba21fd2a376 Mon Sep 17 00:00:00 2001 From: Frank Schlimbach Date: Wed, 18 Jun 2025 11:06:48 +0200 Subject: [PATCH] [mlir][mesh] adding option for traversal order in sharding propagation (#144079) The traversal order in sharding propagation was hard-coded. This PR provides options to the pass to select a suitable order - forward-only - backward-only - forward-backward - backward-forward Default is the previous behavior (backward-forward). --- mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h | 3 -- .../mlir/Dialect/Mesh/Transforms/Passes.h | 12 +++++ .../mlir/Dialect/Mesh/Transforms/Passes.td | 15 ++++++ mlir/lib/Dialect/Mesh/IR/MeshOps.cpp | 27 +++++----- .../Mesh/Transforms/ShardingPropagation.cpp | 38 +++++++++----- .../Mesh/backward-sharding-propagation.mlir | 26 ++++++++++ ...forward-backward-sharding-propagation.mlir | 27 ++++++++++ .../Mesh/forward-sharding-propagation.mlir | 49 +++++++++++++++++++ 8 files changed, 171 insertions(+), 26 deletions(-) create mode 100644 mlir/test/Dialect/Mesh/backward-sharding-propagation.mlir create mode 100644 mlir/test/Dialect/Mesh/forward-backward-sharding-propagation.mlir create mode 100644 mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h index 32c2eca2cefa..3878505f8f93 100644 --- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h +++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h @@ -206,9 +206,6 @@ Type shardType(Type type, MeshOp mesh, MeshSharding sharding); // Use newShardOp if it is not null. Otherwise create a new one. // May insert resharding if required. // Potentially updates newShardOp. -void maybeInsertTargetShardingAnnotation(MeshSharding sharding, - OpOperand &operand, OpBuilder &builder, - ShardOp &newShardOp); void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpResult result, OpBuilder &builder); void maybeInsertSourceShardingAnnotation(MeshSharding sharding, diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h index 83399d10beaa..a2424d43a8ba 100644 --- a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h @@ -19,6 +19,18 @@ class FuncOp; namespace mesh { +/// This enum controls the traversal order for the sharding propagation. +enum class TraversalOrder { + /// Forward traversal. + Forward, + /// Backward traversal. + Backward, + /// Forward then backward traversal. + ForwardBackward, + /// Backward then forward traversal. + BackwardForward +}; + //===----------------------------------------------------------------------===// // Passes //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td index 06ebf151e7d6..11ec7e78cd5e 100644 --- a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td @@ -24,6 +24,21 @@ def ShardingPropagation : InterfacePass<"sharding-propagation", "mlir::FunctionO operation, and the operations themselves are added with sharding option attributes. }]; + let options = [ + Option<"traversal", "traversal", + "mlir::mesh::TraversalOrder", /*default=*/"mlir::mesh::TraversalOrder::BackwardForward", + "Traversal order to use for sharding propagation:", + [{::llvm::cl::values( + clEnumValN(mlir::mesh::TraversalOrder::Forward, "forward", + "Forward only traversal."), + clEnumValN(mlir::mesh::TraversalOrder::Backward, "backward", + "backward only traversal."), + clEnumValN(mlir::mesh::TraversalOrder::ForwardBackward, "forward-backward", + "forward-backward traversal."), + clEnumValN(mlir::mesh::TraversalOrder::BackwardForward, "backward-forward", + "backward-forward traversal.") + )}]>, + ]; let dependentDialects = [ "mesh::MeshDialect" ]; diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp index 304cb55a3508..a2c2d1a7470c 100644 --- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp +++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp @@ -275,13 +275,12 @@ Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) { return type; } -void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding, - OpOperand &operand, - OpBuilder &builder, - ShardOp &newShardOp) { +static void maybeInsertTargetShardingAnnotationImpl(MeshSharding sharding, + Value &operandValue, + Operation *operandOp, + OpBuilder &builder, + ShardOp &newShardOp) { OpBuilder::InsertionGuard insertionGuard(builder); - Value operandValue = operand.get(); - Operation *operandOp = operand.getOwner(); builder.setInsertionPointAfterValue(operandValue); ShardOp shardOp = dyn_cast(operandOp); if (shardOp && sharding == shardOp.getSharding() && @@ -300,9 +299,8 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding, builder.create(operandValue.getLoc(), operandValue, shardingOp, /*annotate_for_users*/ false); } - IRRewriter rewriter(builder); - rewriter.replaceUsesWithIf( - operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) { + operandValue.replaceUsesWithIf( + newShardOp, [operandOp, operandValue](OpOperand &use) { return use.getOwner() == operandOp && use.get() == operandValue; }); @@ -313,15 +311,20 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding, auto newShardOp2 = builder.create(operandValue.getLoc(), newShardOp, newShardOp.getSharding(), /*annotate_for_users*/ true); - rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2); + newShardOp.getResult().replaceAllUsesExcept(newShardOp2, newShardOp2); } void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpResult result, OpBuilder &builder) { ShardOp newShardOp; - for (auto &use : llvm::make_early_inc_range(result.getUses())) { - maybeInsertTargetShardingAnnotation(sharding, use, builder, newShardOp); + SmallVector> uses; + for (auto &use : result.getUses()) { + uses.emplace_back(use.get(), use.getOwner()); + } + for (auto &[operandValue, operandOp] : uses) { + maybeInsertTargetShardingAnnotationImpl(sharding, operandValue, operandOp, + builder, newShardOp); } } diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp index 4452dd65fce9..6751fafaf177 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp +++ b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp @@ -362,6 +362,9 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) { //===----------------------------------------------------------------------===// struct ShardingPropagation : public mesh::impl::ShardingPropagationBase { + + using ShardingPropagationBase::ShardingPropagationBase; + void runOnOperation() override { FunctionOpInterface funcOp = getOperation(); MLIRContext *ctx = funcOp.getContext(); @@ -382,18 +385,31 @@ struct ShardingPropagation shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs()); }); - // 1. propagate in reversed order - for (Operation &op : llvm::make_early_inc_range(llvm::reverse(block))) - if (failed(visitOp(&op, builder))) - return signalPassFailure(); + auto traverse = [&](auto &&range, OpBuilder &builder, + const char *order) -> bool { + for (Operation &op : range) { + if (failed(visitOp(&op, builder))) { + signalPassFailure(); + return true; + } + } + LLVM_DEBUG(DBGS() << "After " << order << " order propagation:\n" + << funcOp << "\n"); + LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp)))); + return false; + }; - LLVM_DEBUG(DBGS() << "After reversed order propagation:\n" - << funcOp << "\n"); - LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp)))); + // 1. Propagate in reversed order. + if (traversal == TraversalOrder::Backward || + traversal == TraversalOrder::BackwardForward) + traverse(llvm::reverse(block), builder, "backward"); - // 2. propagate in original order - for (Operation &op : llvm::make_early_inc_range(block)) - if (failed(visitOp(&op, builder))) - return signalPassFailure(); + // 2. Propagate in original order. + if (traversal != TraversalOrder::Backward) + traverse(block, builder, "forward"); + + // 3. Propagate in backward order if needed. + if (traversal == TraversalOrder::ForwardBackward) + traverse(llvm::reverse(block), builder, "backward"); } }; diff --git a/mlir/test/Dialect/Mesh/backward-sharding-propagation.mlir b/mlir/test/Dialect/Mesh/backward-sharding-propagation.mlir new file mode 100644 index 000000000000..4223d01d6511 --- /dev/null +++ b/mlir/test/Dialect/Mesh/backward-sharding-propagation.mlir @@ -0,0 +1,26 @@ +// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation{traversal=backward}))" %s | FileCheck %s + +#map = affine_map<(d0, d1) -> (d0, d1)> +module { + mesh.mesh @mesh(shape = 1) {sym_visibility = "private"} + func.func @test_forward() -> tensor<6x6xi32> { + %c1_i32 = arith.constant 1 : i32 + // CHECK: tensor.empty() + %0 = tensor.empty() : tensor<6x6xi32> + %sharding = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding + // CHECK-COUNT-2: mesh.shard + %sharding_annotated = mesh.shard %0 to %sharding : tensor<6x6xi32> + %1 = linalg.fill ins(%c1_i32 : i32) outs(%sharding_annotated : tensor<6x6xi32>) -> tensor<6x6xi32> + // CHECK: tensor.empty() + // CHECK-NOT: mesh.shard @ + %2 = tensor.empty() : tensor<6x6xi32> + %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%1, %1 + : tensor<6x6xi32>, tensor<6x6xi32>) outs(%2 : tensor<6x6xi32>) { + ^bb0(%in: i32, %in_2: i32, %out: i32): + %9 = arith.addi %in, %in_2 : i32 + linalg.yield %9 : i32 + } -> tensor<6x6xi32> + // CHECK: return + return %3 : tensor<6x6xi32> + } +} diff --git a/mlir/test/Dialect/Mesh/forward-backward-sharding-propagation.mlir b/mlir/test/Dialect/Mesh/forward-backward-sharding-propagation.mlir new file mode 100644 index 000000000000..dd2eee2f7def --- /dev/null +++ b/mlir/test/Dialect/Mesh/forward-backward-sharding-propagation.mlir @@ -0,0 +1,27 @@ +// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation{traversal=forward-backward}))" %s | FileCheck %s + +#map = affine_map<(d0, d1) -> (d0, d1)> +module { + mesh.mesh @mesh(shape = 1) {sym_visibility = "private"} + func.func @test_forward() -> tensor<6x6xi32> { + %c1_i32 = arith.constant 1 : i32 + // CHECK: tensor.empty() + %0 = tensor.empty() : tensor<6x6xi32> + // CHECK-COUNT-3: mesh.sharding @mesh split_axes = {{\[\[0}}]] + %sharding_row = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding + %annotated_row = mesh.shard %0 to %sharding_row : tensor<6x6xi32> + %1 = linalg.fill ins(%c1_i32 : i32) outs(%annotated_row : tensor<6x6xi32>) -> tensor<6x6xi32> + %2 = tensor.empty() : tensor<6x6xi32> + // CHECK-COUNT-4: mesh.sharding @mesh split_axes = {{\[\[1}}]] + %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%2, %1 + : tensor<6x6xi32>, tensor<6x6xi32>) outs(%2 : tensor<6x6xi32>) { + ^bb0(%in: i32, %in_2: i32, %out: i32): + %9 = arith.addi %in, %in_2 : i32 + linalg.yield %9 : i32 + } -> tensor<6x6xi32> + %sharding_col = mesh.sharding @mesh split_axes = [[1]] : !mesh.sharding + %annotated_col = mesh.shard %3 to %sharding_col : tensor<6x6xi32> + // CHECK: return + return %annotated_col : tensor<6x6xi32> + } +} diff --git a/mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir b/mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir new file mode 100644 index 000000000000..98e9931b8de9 --- /dev/null +++ b/mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir @@ -0,0 +1,49 @@ +// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation{traversal=forward}))" %s | FileCheck %s + +#map = affine_map<(d0, d1) -> (d0, d1)> +module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "mpich", "MPI:comm_world_rank" = 0 : i32>} { + mesh.mesh @mesh(shape = 1) {sym_visibility = "private"} + func.func @test_forward() -> (tensor<6x6xi32>, tensor<6x6xi32>, tensor) attributes {llvm.emit_c_interface} { + %c1_i32 = arith.constant 1 : i32 + // CHECK: [[v3:%.*]] = tensor.empty() : tensor<6x6xi32> + %0 = tensor.empty() : tensor<6x6xi32> + // CHECK: [[v1:%.*]] = linalg.fill ins + // CHECK: [[vsharding_0:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding + // CHECK: [[vsharding_annotated_1:%.*]] = mesh.shard [[v1]] to [[vsharding_0]] : tensor<6x6xi32> + %1 = linalg.fill ins(%c1_i32 : i32) outs(%0 : tensor<6x6xi32>) -> tensor<6x6xi32> + %sharding = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding + %sharding_annotated = mesh.shard %1 to %sharding : tensor<6x6xi32> + // CHECK: [[v2:%.*]] = tensor.empty() : tensor<6x6xi32> + // CHECK: [[vsharding_2:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding + // CHECK: [[vsharding_annotated_3:%.*]] = mesh.shard [[vsharding_annotated_1]] to [[vsharding_2]] annotate_for_users : tensor<6x6xi32> + %3 = tensor.empty() : tensor<6x6xi32> + // CHECK: [[vsharding_4:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding + // CHECK: [[vsharding_annotated_5:%.*]] = mesh.shard [[v2]] to [[vsharding_4]] annotate_for_users : tensor<6x6xi32> + // CHECK: [[v3:%.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} + // CHECK-SAME: ins([[vsharding_annotated_3]], [[vsharding_annotated_3]] : tensor<6x6xi32>, tensor<6x6xi32>) outs([[vsharding_annotated_5]] : tensor<6x6xi32>) { + // CHECK: [[vsharding_6:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding + // CHECK: [[vsharding_annotated_7:%.*]] = mesh.shard [[v3]] to [[vsharding_6]] : tensor<6x6xi32> + %4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%sharding_annotated, %sharding_annotated + : tensor<6x6xi32>, tensor<6x6xi32>) outs(%3 : tensor<6x6xi32>) { + ^bb0(%in: i32, %in_2: i32, %out: i32): + %9 = arith.addi %in, %in_2 : i32 + linalg.yield %9 : i32 + } -> tensor<6x6xi32> + %c0_i32 = arith.constant 0 : i32 + %6 = tensor.empty() : tensor + %7 = linalg.fill ins(%c0_i32 : i32) outs(%6 : tensor) -> tensor + // CHECK: [[vreduced:%.*]] = linalg.reduce ins + // CHECK: [[vsharding_12:%.*]] = mesh.sharding @mesh split_axes = [] partial = sum [0] : !mesh.sharding + // CHECK: [[vsharding_annotated_13:%.*]] = mesh.shard [[vreduced]] to [[vsharding_12]] : tensor + %reduced = linalg.reduce ins(%4 : tensor<6x6xi32>) outs(%7 : tensor) dimensions = [0, 1] + (%in: i32, %init: i32) { + %9 = arith.addi %in, %init : i32 + linalg.yield %9 : i32 + } + // CHECK: [[vsharding_14:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}]] : !mesh.sharding + %sharding_0 = mesh.sharding @mesh split_axes = [[]] : !mesh.sharding + // CHECK: [[vsharding_annotated_15:%.*]] = mesh.shard [[vsharding_annotated_13]] to [[vsharding_14]] annotate_for_users : tensor + %sharding_annotated_1 = mesh.shard %reduced to %sharding_0 annotate_for_users : tensor + return %sharding_annotated, %4, %sharding_annotated_1 : tensor<6x6xi32>, tensor<6x6xi32>, tensor + } +}