This is a fixed copy of #98145 (necessary after it got reverted). @sogartar @yaochengji This PR adds the following to #98145: - `UpdateHaloOp` accepts a `memref` (instead of a tensor) and not returning a result to clarify its inplace-semantics - `UpdateHaloOp` accepts `split_axis` to allow multiple mesh-axes per tensor/memref-axis (similar to `mesh.sharding`) - The implementation of `Shardinginterface` for tensor operation (`tensor.empty` for now) moved from the tensor library to the mesh interface library. `spmdize` uses features from `mesh` dialect. @rengolin agreed that `tensor` should not depend on `mesh` so this functionality cannot live in a `tensor`s lib. The unfulfilled dependency caused the issues leading to reverting #98145. Such cases are generally possible and might lead to re-considering the current structure (like for tosa ops). - rebased onto latest main -------------------------- Replacing `#mesh.sharding` attribute with operation `mesh.sharding` - extended semantics now allow providing optional `halo_sizes` and `sharded_dims_sizes` - internally a sharding is represented as a non-IR class `mesh::MeshSharding` What previously was ```mlir %sharded0 = mesh.shard %arg0 <@mesh0, [[0]]> : tensor<4x8xf32> %sharded1 = mesh.shard %arg1 <@mesh0, [[0]]> annotate_for_users : tensor<16x8xf32> ``` is now ```mlir %sharding = mesh.sharding @mesh0, [[0]] : !mesh.sharding %0 = mesh.shard %arg0 to %sharding : tensor<4x8xf32> %1 = mesh.shard %arg1 to %sharding annotate_for_users : tensor<16x8xf32> ``` and allows additional annotations to control the shard sizes: ```mlir mesh.mesh @mesh0 (shape = 4) %sharding0 = mesh.sharding @mesh0, [[0]] halo_sizes = [1, 2] : !mesh.sharding %0 = mesh.shard %arg0 to %sharding0 : tensor<4x8xf32> %sharding1 = mesh.sharding @mesh0, [[0]] sharded_dims_sizes = [3, 5, 5, 3] : !mesh.sharding %1 = mesh.shard %arg1 to %sharding1 annotate_for_users : tensor<16x8xf32> ``` - `mesh.shard` op accepts additional optional attribute `force`, useful for halo updates - Some initial spmdization support for the new semantics - Support for `tensor.empty` reacting on `sharded_dims_sizes` and `halo_sizes` in the sharding - New collective operation `mesh.update_halo` as a spmdized target for shardings with `halo_sizes` --------- Co-authored-by: frank.schlimbach <fschlimb@smtp.igk.intel.com> Co-authored-by: Jie Fu <jiefu@tencent.com>
125 lines
4.3 KiB
C++
125 lines
4.3 KiB
C++
//===- TestSimplification.cpp - Test simplification -----------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
|
|
#include "mlir/Dialect/Mesh/Transforms/Spmdization.h"
|
|
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
|
|
#include "mlir/IR/BuiltinDialect.h"
|
|
#include "mlir/IR/BuiltinOps.h"
|
|
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
|
#include "mlir/IR/Diagnostics.h"
|
|
#include "mlir/IR/ImplicitLocOpBuilder.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/IR/SymbolTable.h"
|
|
#include "mlir/IR/Value.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::mesh;
|
|
|
|
namespace {
|
|
|
|
struct TestMeshReshardingRewritePattern : OpRewritePattern<ShardOp> {
|
|
using OpRewritePattern<ShardOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(ShardOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
if (op.getAnnotateForUsers()) {
|
|
return failure();
|
|
}
|
|
|
|
SymbolTableCollection symbolTable;
|
|
mesh::MeshOp mesh = symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(
|
|
op, cast<ShardingOp>(op.getSharding().getDefiningOp()).getMeshAttr());
|
|
|
|
bool foundUser = false;
|
|
for (auto user : op->getUsers()) {
|
|
if (auto targetShardOp = llvm::dyn_cast<ShardOp>(user)) {
|
|
if (targetShardOp.getAnnotateForUsers() &&
|
|
mesh == symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(
|
|
targetShardOp,
|
|
cast<ShardingOp>(
|
|
targetShardOp.getSharding().getDefiningOp())
|
|
.getMeshAttr())) {
|
|
foundUser = true;
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
if (!foundUser) {
|
|
return failure();
|
|
}
|
|
|
|
for (auto user : op->getUsers()) {
|
|
auto targetShardOp = llvm::dyn_cast<ShardOp>(user);
|
|
if (!targetShardOp || !targetShardOp.getAnnotateForUsers() ||
|
|
symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(
|
|
targetShardOp,
|
|
cast<ShardingOp>(targetShardOp.getSharding().getDefiningOp())
|
|
.getMeshAttr()) != mesh) {
|
|
continue;
|
|
}
|
|
|
|
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
|
|
ShapedType sourceShardShape =
|
|
shardShapedType(op.getResult().getType(), mesh, op.getSharding());
|
|
TypedValue<ShapedType> sourceShard = cast<TypedValue<ShapedType>>(
|
|
builder
|
|
.create<UnrealizedConversionCastOp>(sourceShardShape, op.getSrc())
|
|
->getResult(0));
|
|
TypedValue<ShapedType> targetShard =
|
|
reshard(builder, mesh, op, targetShardOp, sourceShard);
|
|
Value newTargetUnsharded =
|
|
builder
|
|
.create<UnrealizedConversionCastOp>(
|
|
targetShardOp.getResult().getType(), targetShard)
|
|
->getResult(0);
|
|
rewriter.replaceAllUsesWith(targetShardOp.getResult(),
|
|
newTargetUnsharded);
|
|
}
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct TestMeshReshardingPass
|
|
: public PassWrapper<TestMeshReshardingPass, OperationPass<ModuleOp>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMeshReshardingPass)
|
|
|
|
void runOnOperation() override {
|
|
RewritePatternSet patterns(&getContext());
|
|
patterns.insert<TestMeshReshardingRewritePattern>(&getContext());
|
|
if (failed(applyPatternsAndFoldGreedily(getOperation().getOperation(),
|
|
std::move(patterns)))) {
|
|
return signalPassFailure();
|
|
}
|
|
}
|
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
|
reshardingRegisterDependentDialects(registry);
|
|
registry.insert<BuiltinDialect>();
|
|
}
|
|
StringRef getArgument() const final {
|
|
return "test-mesh-resharding-spmdization";
|
|
}
|
|
StringRef getDescription() const final {
|
|
return "Test Mesh dialect resharding spmdization.";
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
namespace mlir {
|
|
namespace test {
|
|
void registerTestMeshReshardingSpmdizationPass() {
|
|
PassRegistration<TestMeshReshardingPass>();
|
|
}
|
|
} // namespace test
|
|
} // namespace mlir
|