[mlir][mesh] adding option for traversal order in sharding propagation (#144079)
The traversal order in sharding propagation was hard-coded. This PR provides options to the pass to select a suitable order - forward-only - backward-only - forward-backward - backward-forward Default is the previous behavior (backward-forward).
This commit is contained in:
@@ -206,9 +206,6 @@ Type shardType(Type type, MeshOp mesh, MeshSharding sharding);
|
||||
// Use newShardOp if it is not null. Otherwise create a new one.
|
||||
// May insert resharding if required.
|
||||
// Potentially updates newShardOp.
|
||||
void maybeInsertTargetShardingAnnotation(MeshSharding sharding,
|
||||
OpOperand &operand, OpBuilder &builder,
|
||||
ShardOp &newShardOp);
|
||||
void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpResult result,
|
||||
OpBuilder &builder);
|
||||
void maybeInsertSourceShardingAnnotation(MeshSharding sharding,
|
||||
|
||||
@@ -19,6 +19,18 @@ class FuncOp;
|
||||
|
||||
namespace mesh {
|
||||
|
||||
/// This enum controls the traversal order for the sharding propagation.
|
||||
enum class TraversalOrder {
|
||||
/// Forward traversal.
|
||||
Forward,
|
||||
/// Backward traversal.
|
||||
Backward,
|
||||
/// Forward then backward traversal.
|
||||
ForwardBackward,
|
||||
/// Backward then forward traversal.
|
||||
BackwardForward
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Passes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -24,6 +24,21 @@ def ShardingPropagation : InterfacePass<"sharding-propagation", "mlir::FunctionO
|
||||
operation, and the operations themselves are added with sharding option
|
||||
attributes.
|
||||
}];
|
||||
let options = [
|
||||
Option<"traversal", "traversal",
|
||||
"mlir::mesh::TraversalOrder", /*default=*/"mlir::mesh::TraversalOrder::BackwardForward",
|
||||
"Traversal order to use for sharding propagation:",
|
||||
[{::llvm::cl::values(
|
||||
clEnumValN(mlir::mesh::TraversalOrder::Forward, "forward",
|
||||
"Forward only traversal."),
|
||||
clEnumValN(mlir::mesh::TraversalOrder::Backward, "backward",
|
||||
"backward only traversal."),
|
||||
clEnumValN(mlir::mesh::TraversalOrder::ForwardBackward, "forward-backward",
|
||||
"forward-backward traversal."),
|
||||
clEnumValN(mlir::mesh::TraversalOrder::BackwardForward, "backward-forward",
|
||||
"backward-forward traversal.")
|
||||
)}]>,
|
||||
];
|
||||
let dependentDialects = [
|
||||
"mesh::MeshDialect"
|
||||
];
|
||||
|
||||
@@ -275,13 +275,12 @@ Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) {
|
||||
return type;
|
||||
}
|
||||
|
||||
void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
|
||||
OpOperand &operand,
|
||||
OpBuilder &builder,
|
||||
ShardOp &newShardOp) {
|
||||
static void maybeInsertTargetShardingAnnotationImpl(MeshSharding sharding,
|
||||
Value &operandValue,
|
||||
Operation *operandOp,
|
||||
OpBuilder &builder,
|
||||
ShardOp &newShardOp) {
|
||||
OpBuilder::InsertionGuard insertionGuard(builder);
|
||||
Value operandValue = operand.get();
|
||||
Operation *operandOp = operand.getOwner();
|
||||
builder.setInsertionPointAfterValue(operandValue);
|
||||
ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
|
||||
if (shardOp && sharding == shardOp.getSharding() &&
|
||||
@@ -300,9 +299,8 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
|
||||
builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
|
||||
/*annotate_for_users*/ false);
|
||||
}
|
||||
IRRewriter rewriter(builder);
|
||||
rewriter.replaceUsesWithIf(
|
||||
operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
|
||||
operandValue.replaceUsesWithIf(
|
||||
newShardOp, [operandOp, operandValue](OpOperand &use) {
|
||||
return use.getOwner() == operandOp && use.get() == operandValue;
|
||||
});
|
||||
|
||||
@@ -313,15 +311,20 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
|
||||
auto newShardOp2 = builder.create<ShardOp>(operandValue.getLoc(), newShardOp,
|
||||
newShardOp.getSharding(),
|
||||
/*annotate_for_users*/ true);
|
||||
rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2);
|
||||
newShardOp.getResult().replaceAllUsesExcept(newShardOp2, newShardOp2);
|
||||
}
|
||||
|
||||
void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
|
||||
OpResult result,
|
||||
OpBuilder &builder) {
|
||||
ShardOp newShardOp;
|
||||
for (auto &use : llvm::make_early_inc_range(result.getUses())) {
|
||||
maybeInsertTargetShardingAnnotation(sharding, use, builder, newShardOp);
|
||||
SmallVector<std::pair<Value, Operation *>> uses;
|
||||
for (auto &use : result.getUses()) {
|
||||
uses.emplace_back(use.get(), use.getOwner());
|
||||
}
|
||||
for (auto &[operandValue, operandOp] : uses) {
|
||||
maybeInsertTargetShardingAnnotationImpl(sharding, operandValue, operandOp,
|
||||
builder, newShardOp);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -362,6 +362,9 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
|
||||
//===----------------------------------------------------------------------===//
|
||||
struct ShardingPropagation
|
||||
: public mesh::impl::ShardingPropagationBase<ShardingPropagation> {
|
||||
|
||||
using ShardingPropagationBase<ShardingPropagation>::ShardingPropagationBase;
|
||||
|
||||
void runOnOperation() override {
|
||||
FunctionOpInterface funcOp = getOperation();
|
||||
MLIRContext *ctx = funcOp.getContext();
|
||||
@@ -382,18 +385,31 @@ struct ShardingPropagation
|
||||
shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs());
|
||||
});
|
||||
|
||||
// 1. propagate in reversed order
|
||||
for (Operation &op : llvm::make_early_inc_range(llvm::reverse(block)))
|
||||
if (failed(visitOp(&op, builder)))
|
||||
return signalPassFailure();
|
||||
auto traverse = [&](auto &&range, OpBuilder &builder,
|
||||
const char *order) -> bool {
|
||||
for (Operation &op : range) {
|
||||
if (failed(visitOp(&op, builder))) {
|
||||
signalPassFailure();
|
||||
return true;
|
||||
}
|
||||
}
|
||||
LLVM_DEBUG(DBGS() << "After " << order << " order propagation:\n"
|
||||
<< funcOp << "\n");
|
||||
LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));
|
||||
return false;
|
||||
};
|
||||
|
||||
LLVM_DEBUG(DBGS() << "After reversed order propagation:\n"
|
||||
<< funcOp << "\n");
|
||||
LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));
|
||||
// 1. Propagate in reversed order.
|
||||
if (traversal == TraversalOrder::Backward ||
|
||||
traversal == TraversalOrder::BackwardForward)
|
||||
traverse(llvm::reverse(block), builder, "backward");
|
||||
|
||||
// 2. propagate in original order
|
||||
for (Operation &op : llvm::make_early_inc_range(block))
|
||||
if (failed(visitOp(&op, builder)))
|
||||
return signalPassFailure();
|
||||
// 2. Propagate in original order.
|
||||
if (traversal != TraversalOrder::Backward)
|
||||
traverse(block, builder, "forward");
|
||||
|
||||
// 3. Propagate in backward order if needed.
|
||||
if (traversal == TraversalOrder::ForwardBackward)
|
||||
traverse(llvm::reverse(block), builder, "backward");
|
||||
}
|
||||
};
|
||||
|
||||
26
mlir/test/Dialect/Mesh/backward-sharding-propagation.mlir
Normal file
26
mlir/test/Dialect/Mesh/backward-sharding-propagation.mlir
Normal file
@@ -0,0 +1,26 @@
|
||||
// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation{traversal=backward}))" %s | FileCheck %s
|
||||
|
||||
#map = affine_map<(d0, d1) -> (d0, d1)>
|
||||
module {
|
||||
mesh.mesh @mesh(shape = 1) {sym_visibility = "private"}
|
||||
func.func @test_forward() -> tensor<6x6xi32> {
|
||||
%c1_i32 = arith.constant 1 : i32
|
||||
// CHECK: tensor.empty()
|
||||
%0 = tensor.empty() : tensor<6x6xi32>
|
||||
%sharding = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding
|
||||
// CHECK-COUNT-2: mesh.shard
|
||||
%sharding_annotated = mesh.shard %0 to %sharding : tensor<6x6xi32>
|
||||
%1 = linalg.fill ins(%c1_i32 : i32) outs(%sharding_annotated : tensor<6x6xi32>) -> tensor<6x6xi32>
|
||||
// CHECK: tensor.empty()
|
||||
// CHECK-NOT: mesh.shard @
|
||||
%2 = tensor.empty() : tensor<6x6xi32>
|
||||
%3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%1, %1
|
||||
: tensor<6x6xi32>, tensor<6x6xi32>) outs(%2 : tensor<6x6xi32>) {
|
||||
^bb0(%in: i32, %in_2: i32, %out: i32):
|
||||
%9 = arith.addi %in, %in_2 : i32
|
||||
linalg.yield %9 : i32
|
||||
} -> tensor<6x6xi32>
|
||||
// CHECK: return
|
||||
return %3 : tensor<6x6xi32>
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation{traversal=forward-backward}))" %s | FileCheck %s
|
||||
|
||||
#map = affine_map<(d0, d1) -> (d0, d1)>
|
||||
module {
|
||||
mesh.mesh @mesh(shape = 1) {sym_visibility = "private"}
|
||||
func.func @test_forward() -> tensor<6x6xi32> {
|
||||
%c1_i32 = arith.constant 1 : i32
|
||||
// CHECK: tensor.empty()
|
||||
%0 = tensor.empty() : tensor<6x6xi32>
|
||||
// CHECK-COUNT-3: mesh.sharding @mesh split_axes = {{\[\[0}}]]
|
||||
%sharding_row = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding
|
||||
%annotated_row = mesh.shard %0 to %sharding_row : tensor<6x6xi32>
|
||||
%1 = linalg.fill ins(%c1_i32 : i32) outs(%annotated_row : tensor<6x6xi32>) -> tensor<6x6xi32>
|
||||
%2 = tensor.empty() : tensor<6x6xi32>
|
||||
// CHECK-COUNT-4: mesh.sharding @mesh split_axes = {{\[\[1}}]]
|
||||
%3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%2, %1
|
||||
: tensor<6x6xi32>, tensor<6x6xi32>) outs(%2 : tensor<6x6xi32>) {
|
||||
^bb0(%in: i32, %in_2: i32, %out: i32):
|
||||
%9 = arith.addi %in, %in_2 : i32
|
||||
linalg.yield %9 : i32
|
||||
} -> tensor<6x6xi32>
|
||||
%sharding_col = mesh.sharding @mesh split_axes = [[1]] : !mesh.sharding
|
||||
%annotated_col = mesh.shard %3 to %sharding_col : tensor<6x6xi32>
|
||||
// CHECK: return
|
||||
return %annotated_col : tensor<6x6xi32>
|
||||
}
|
||||
}
|
||||
49
mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir
Normal file
49
mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir
Normal file
@@ -0,0 +1,49 @@
|
||||
// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation{traversal=forward}))" %s | FileCheck %s
|
||||
|
||||
#map = affine_map<(d0, d1) -> (d0, d1)>
|
||||
module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "mpich", "MPI:comm_world_rank" = 0 : i32>} {
|
||||
mesh.mesh @mesh(shape = 1) {sym_visibility = "private"}
|
||||
func.func @test_forward() -> (tensor<6x6xi32>, tensor<6x6xi32>, tensor<i32>) attributes {llvm.emit_c_interface} {
|
||||
%c1_i32 = arith.constant 1 : i32
|
||||
// CHECK: [[v3:%.*]] = tensor.empty() : tensor<6x6xi32>
|
||||
%0 = tensor.empty() : tensor<6x6xi32>
|
||||
// CHECK: [[v1:%.*]] = linalg.fill ins
|
||||
// CHECK: [[vsharding_0:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
|
||||
// CHECK: [[vsharding_annotated_1:%.*]] = mesh.shard [[v1]] to [[vsharding_0]] : tensor<6x6xi32>
|
||||
%1 = linalg.fill ins(%c1_i32 : i32) outs(%0 : tensor<6x6xi32>) -> tensor<6x6xi32>
|
||||
%sharding = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding
|
||||
%sharding_annotated = mesh.shard %1 to %sharding : tensor<6x6xi32>
|
||||
// CHECK: [[v2:%.*]] = tensor.empty() : tensor<6x6xi32>
|
||||
// CHECK: [[vsharding_2:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
|
||||
// CHECK: [[vsharding_annotated_3:%.*]] = mesh.shard [[vsharding_annotated_1]] to [[vsharding_2]] annotate_for_users : tensor<6x6xi32>
|
||||
%3 = tensor.empty() : tensor<6x6xi32>
|
||||
// CHECK: [[vsharding_4:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
|
||||
// CHECK: [[vsharding_annotated_5:%.*]] = mesh.shard [[v2]] to [[vsharding_4]] annotate_for_users : tensor<6x6xi32>
|
||||
// CHECK: [[v3:%.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]}
|
||||
// CHECK-SAME: ins([[vsharding_annotated_3]], [[vsharding_annotated_3]] : tensor<6x6xi32>, tensor<6x6xi32>) outs([[vsharding_annotated_5]] : tensor<6x6xi32>) {
|
||||
// CHECK: [[vsharding_6:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
|
||||
// CHECK: [[vsharding_annotated_7:%.*]] = mesh.shard [[v3]] to [[vsharding_6]] : tensor<6x6xi32>
|
||||
%4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%sharding_annotated, %sharding_annotated
|
||||
: tensor<6x6xi32>, tensor<6x6xi32>) outs(%3 : tensor<6x6xi32>) {
|
||||
^bb0(%in: i32, %in_2: i32, %out: i32):
|
||||
%9 = arith.addi %in, %in_2 : i32
|
||||
linalg.yield %9 : i32
|
||||
} -> tensor<6x6xi32>
|
||||
%c0_i32 = arith.constant 0 : i32
|
||||
%6 = tensor.empty() : tensor<i32>
|
||||
%7 = linalg.fill ins(%c0_i32 : i32) outs(%6 : tensor<i32>) -> tensor<i32>
|
||||
// CHECK: [[vreduced:%.*]] = linalg.reduce ins
|
||||
// CHECK: [[vsharding_12:%.*]] = mesh.sharding @mesh split_axes = [] partial = sum [0] : !mesh.sharding
|
||||
// CHECK: [[vsharding_annotated_13:%.*]] = mesh.shard [[vreduced]] to [[vsharding_12]] : tensor<i32>
|
||||
%reduced = linalg.reduce ins(%4 : tensor<6x6xi32>) outs(%7 : tensor<i32>) dimensions = [0, 1]
|
||||
(%in: i32, %init: i32) {
|
||||
%9 = arith.addi %in, %init : i32
|
||||
linalg.yield %9 : i32
|
||||
}
|
||||
// CHECK: [[vsharding_14:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}]] : !mesh.sharding
|
||||
%sharding_0 = mesh.sharding @mesh split_axes = [[]] : !mesh.sharding
|
||||
// CHECK: [[vsharding_annotated_15:%.*]] = mesh.shard [[vsharding_annotated_13]] to [[vsharding_14]] annotate_for_users : tensor<i32>
|
||||
%sharding_annotated_1 = mesh.shard %reduced to %sharding_0 annotate_for_users : tensor<i32>
|
||||
return %sharding_annotated, %4, %sharding_annotated_1 : tensor<6x6xi32>, tensor<6x6xi32>, tensor<i32>
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user