Files
clang-p2996/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp
Frank Schlimbach baabcb2898 [mlir][mesh] Shardingcontrol (#102598)
This is a fixed copy of #98145 (necessary after it got reverted).

@sogartar @yaochengji
This PR adds the following to #98145:
- `UpdateHaloOp` accepts a `memref` (instead of a tensor) and not
returning a result to clarify its inplace-semantics
- `UpdateHaloOp` accepts `split_axis` to allow multiple mesh-axes per
tensor/memref-axis (similar to `mesh.sharding`)
- The implementation of `Shardinginterface` for tensor operation
(`tensor.empty` for now) moved from the tensor library to the mesh
interface library. `spmdize` uses features from `mesh` dialect.
@rengolin agreed that `tensor` should not depend on `mesh` so this
functionality cannot live in a `tensor`s lib. The unfulfilled dependency
caused the issues leading to reverting #98145. Such cases are generally
possible and might lead to re-considering the current structure (like
for tosa ops).
- rebased onto latest main
--------------------------
Replacing `#mesh.sharding` attribute with operation `mesh.sharding`
- extended semantics now allow providing optional `halo_sizes` and
`sharded_dims_sizes`
- internally a sharding is represented as a non-IR class
`mesh::MeshSharding`

What previously was
```mlir
%sharded0 = mesh.shard %arg0 <@mesh0, [[0]]> : tensor<4x8xf32>
%sharded1 = mesh.shard %arg1 <@mesh0, [[0]]> annotate_for_users : tensor<16x8xf32>
```
is now
```mlir
%sharding = mesh.sharding @mesh0, [[0]] : !mesh.sharding
%0 = mesh.shard %arg0 to %sharding : tensor<4x8xf32>
%1 = mesh.shard %arg1 to %sharding annotate_for_users : tensor<16x8xf32>
```
and allows additional annotations to control the shard sizes:
```mlir
mesh.mesh @mesh0 (shape = 4)
%sharding0 = mesh.sharding @mesh0, [[0]] halo_sizes = [1, 2] : !mesh.sharding
%0 = mesh.shard %arg0 to %sharding0 : tensor<4x8xf32>
%sharding1 = mesh.sharding @mesh0, [[0]] sharded_dims_sizes = [3, 5, 5, 3] : !mesh.sharding
%1 = mesh.shard %arg1 to %sharding1 annotate_for_users : tensor<16x8xf32>
```
- `mesh.shard` op accepts additional optional attribute `force`, useful
for halo updates
- Some initial spmdization support for the new semantics
- Support for `tensor.empty` reacting on `sharded_dims_sizes` and
`halo_sizes` in the sharding
- New collective operation `mesh.update_halo` as a spmdized target for
shardings with `halo_sizes`

---------

Co-authored-by: frank.schlimbach <fschlimb@smtp.igk.intel.com>
Co-authored-by: Jie Fu <jiefu@tencent.com>
2024-08-12 12:20:58 +01:00

125 lines
4.3 KiB
C++

//===- TestSimplification.cpp - Test simplification -----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/Dialect/Mesh/Transforms/Spmdization.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Value.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
using namespace mlir::mesh;
namespace {
struct TestMeshReshardingRewritePattern : OpRewritePattern<ShardOp> {
using OpRewritePattern<ShardOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ShardOp op,
PatternRewriter &rewriter) const override {
if (op.getAnnotateForUsers()) {
return failure();
}
SymbolTableCollection symbolTable;
mesh::MeshOp mesh = symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(
op, cast<ShardingOp>(op.getSharding().getDefiningOp()).getMeshAttr());
bool foundUser = false;
for (auto user : op->getUsers()) {
if (auto targetShardOp = llvm::dyn_cast<ShardOp>(user)) {
if (targetShardOp.getAnnotateForUsers() &&
mesh == symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(
targetShardOp,
cast<ShardingOp>(
targetShardOp.getSharding().getDefiningOp())
.getMeshAttr())) {
foundUser = true;
break;
}
}
}
if (!foundUser) {
return failure();
}
for (auto user : op->getUsers()) {
auto targetShardOp = llvm::dyn_cast<ShardOp>(user);
if (!targetShardOp || !targetShardOp.getAnnotateForUsers() ||
symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(
targetShardOp,
cast<ShardingOp>(targetShardOp.getSharding().getDefiningOp())
.getMeshAttr()) != mesh) {
continue;
}
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
ShapedType sourceShardShape =
shardShapedType(op.getResult().getType(), mesh, op.getSharding());
TypedValue<ShapedType> sourceShard = cast<TypedValue<ShapedType>>(
builder
.create<UnrealizedConversionCastOp>(sourceShardShape, op.getSrc())
->getResult(0));
TypedValue<ShapedType> targetShard =
reshard(builder, mesh, op, targetShardOp, sourceShard);
Value newTargetUnsharded =
builder
.create<UnrealizedConversionCastOp>(
targetShardOp.getResult().getType(), targetShard)
->getResult(0);
rewriter.replaceAllUsesWith(targetShardOp.getResult(),
newTargetUnsharded);
}
return success();
}
};
struct TestMeshReshardingPass
: public PassWrapper<TestMeshReshardingPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMeshReshardingPass)
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
patterns.insert<TestMeshReshardingRewritePattern>(&getContext());
if (failed(applyPatternsAndFoldGreedily(getOperation().getOperation(),
std::move(patterns)))) {
return signalPassFailure();
}
}
void getDependentDialects(DialectRegistry &registry) const override {
reshardingRegisterDependentDialects(registry);
registry.insert<BuiltinDialect>();
}
StringRef getArgument() const final {
return "test-mesh-resharding-spmdization";
}
StringRef getDescription() const final {
return "Test Mesh dialect resharding spmdization.";
}
};
} // namespace
namespace mlir {
namespace test {
void registerTestMeshReshardingSpmdizationPass() {
PassRegistration<TestMeshReshardingPass>();
}
} // namespace test
} // namespace mlir