[MLIR][mesh] Mesh fixes (#124724)

A collection of fixes to the mesh dialect
- allow constants in sharding propagation/spmdization
- fixes to tensor replication (e.g. 0d tensors)
- improved canonicalization
- sharding propagation incorrectly generated too many ShardOps
New operation `mesh.GetShardOp` enables exchanging sharding information
(like on function boundaries)
This commit is contained in:
Frank Schlimbach
2025-02-12 12:44:48 +01:00
committed by GitHub
parent 0e779ad499
commit 0fd50ec9a3
17 changed files with 525 additions and 89 deletions

View File

@@ -0,0 +1,23 @@
//===- ShardingInterfaceImpl.h - ------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_ARITH_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
#define MLIR_DIALECT_ARITH_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
namespace mlir {
class DialectRegistry;
namespace arith {
void registerShardingInterfaceExternalModels(DialectRegistry &registry);
} // namespace arith
} // namespace mlir
#endif // MLIR_DIALECT_ARITH_TRANSFORMS_SHARDINGINTERFACEIMPL_H_

View File

@@ -51,7 +51,7 @@ private:
SmallVector<Value> dynamic_sharded_dims_offsets;
public:
MeshSharding() = default;
MeshSharding(::mlir::FlatSymbolRefAttr mesh_ = nullptr);
MeshSharding(Value rhs);
static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_,
ArrayRef<MeshAxesAttr> split_axes_,
@@ -62,7 +62,7 @@ public:
ArrayRef<Value> dynamic_halo_sizes_ = {},
ArrayRef<Value> dynamic_sharded_dims_offsets_ = {});
::mlir::FlatSymbolRefAttr getMeshAttr() const { return mesh; }
::llvm::StringRef getMesh() const { return mesh.getValue(); }
::llvm::StringRef getMesh() const { return mesh ? mesh.getValue() : ""; }
ArrayRef<MeshAxesAttr> getSplitAxes() const { return split_axes; }
ArrayRef<MeshAxis> getPartialAxes() const { return partial_axes; }
ReductionKind getPartialType() const { return partial_type; }
@@ -201,10 +201,12 @@ ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
Type shardType(Type type, MeshOp mesh, MeshSharding sharding);
// Insert shard op if there is not one that already has the same sharding.
// Use newShardOp if it is not null. Otherwise create a new one.
// May insert resharding if required.
// Potentially updates newShardOp.
void maybeInsertTargetShardingAnnotation(MeshSharding sharding,
OpOperand &operand,
OpBuilder &builder);
OpOperand &operand, OpBuilder &builder,
ShardOp &newShardOp);
void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpResult result,
OpBuilder &builder);
void maybeInsertSourceShardingAnnotation(MeshSharding sharding,

View File

@@ -28,7 +28,7 @@ class Mesh_Op<string mnemonic, list<Trait> traits = []> :
Op<Mesh_Dialect, mnemonic, traits> {
}
def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol]> {
def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol, Pure]> {
let summary = "Description of a device/process mesh.";
let description = [{
The mesh.mesh operation is a symbol operation that identifies a specific
@@ -318,12 +318,33 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
"ArrayRef<MeshAxesAttr>":$split_axes,
"::mlir::ArrayRef<::mlir::OpFoldResult>":$halo_sizes,
"::mlir::ArrayRef<::mlir::OpFoldResult>":$sharded_dims_offsets)>,
OpBuilder<(ins "llvm::StringRef":$mesh,
"ArrayRef<MeshAxesAttr>":$split_axes,
CArg<"ArrayRef<int64_t>", "{}">:$static_halo_sizes,
CArg<"ArrayRef<int64_t>", "{}">:$static_sharded_dims_offsets
)>,
OpBuilder<(ins "mlir::mesh::MeshSharding":$from)>
];
let hasVerifier = 1;
let hasCanonicalizer = 1;
}
def Mesh_GetShardingOp : Mesh_Op<"get_sharding", [Pure]> {
let summary = "Get the sharding of the given tensor.";
let description = [{
This operation returns the sharding of the given tensor as a MeshSharding.
}];
let arguments = (ins
AnyRankedTensor:$source
);
let results = (outs
Mesh_Sharding:$result
);
let assemblyFormat = [{
$source attr-dict `:` type($source) `->` type($result)
}];
}
def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [Pure]> {
let summary = "Get the shard shape of a given process/device.";
let description = [{
@@ -460,6 +481,7 @@ def Mesh_ShardOp : Mesh_Op<"shard", [
(`annotate_for_users` $annotate_for_users^)?
attr-dict `:` type($result)
}];
let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//

View File

@@ -36,7 +36,9 @@ struct ShardingOption {
bool empty = false;
ShardingOption() = default;
ShardingOption(ShardingArray shardingArray, FlatSymbolRefAttr mesh)
: shardingArray(std::move(shardingArray)), mesh(mesh) {}
: shardingArray(std::move(shardingArray)), mesh(mesh) {
assert(this->mesh);
}
static ShardingOption makeEmpty() {
auto res = ShardingOption();
res.empty = true;

View File

@@ -23,6 +23,7 @@
#include "mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h"
#include "mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h"
#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h"
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
@@ -158,6 +159,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
arith::registerBufferDeallocationOpInterfaceExternalModels(registry);
arith::registerBufferizableOpInterfaceExternalModels(registry);
arith::registerBufferViewFlowOpInterfaceExternalModels(registry);
arith::registerShardingInterfaceExternalModels(registry);
arith::registerValueBoundsOpInterfaceExternalModels(registry);
bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
registry);

View File

@@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRArithTransforms
ExpandOps.cpp
IntRangeOptimizations.cpp
ReifyValueBounds.cpp
ShardingInterfaceImpl.cpp
UnsignedWhenEquivalent.cpp
ADDITIONAL_HEADER_DIRS
@@ -26,7 +27,9 @@ add_mlir_dialect_library(MLIRArithTransforms
MLIRInferIntRangeInterface
MLIRIR
MLIRMemRefDialect
MLIRMeshDialect
MLIRPass
MLIRShardingInterface
MLIRTensorDialect
MLIRTransforms
MLIRTransformUtils

View File

@@ -0,0 +1,105 @@
//===- ShardingInterfaceImpl.cpp ------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h"
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
#include "mlir/IR/DialectRegistry.h"
#include "llvm/Support/Debug.h"
using namespace mlir;
using namespace mlir::arith;
using namespace mlir::mesh;
namespace {
// Sharding of arith.constant
// RankedTensor constants can be sharded like any other tensor.
// %cst = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
// %sharding = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding
// Scalar constants are always replicated and need no sharding annotation.
struct ConstantShardingInterface
: public ShardingInterface::ExternalModel<ConstantShardingInterface,
ConstantOp> {
SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
auto ndims = 0;
if (auto type = dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
ndims = type.getRank();
}
return SmallVector<utils::IteratorType>(ndims,
utils::IteratorType::parallel);
}
SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
if (auto type = dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
return SmallVector<AffineMap>(1, {AffineMap::getMultiDimIdentityMap(
type.getRank(), op->getContext())});
}
return {};
}
// Indicate failure if no result sharding exists.
// Otherwise mirror result sharding if it is a tensor constant.
// Otherwise return replication option.
FailureOr<ShardingOption>
getShardingOption(Operation *op, ArrayRef<MeshSharding> operandShardings,
ArrayRef<MeshSharding> resultShardings) const {
assert(resultShardings.size() == 1 &&
"Expecting exactly one result sharding for arith.constant");
auto resultSharding = resultShardings[0];
if (!resultSharding) {
return failure();
}
if (auto type = dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
ShardingArray axesArray(resultSharding.getSplitAxes().size());
for (auto [i, axes] : llvm::enumerate(resultSharding.getSplitAxes())) {
axesArray[i].append(axes.asArrayRef().begin(), axes.asArrayRef().end());
}
return ShardingOption(axesArray, resultSharding.getMeshAttr());
}
return ShardingOption({}, resultSharding.getMeshAttr());
}
LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
ArrayRef<MeshSharding> operandShardings,
ArrayRef<MeshSharding> resultShardings,
IRMapping &spmdizationMap,
SymbolTableCollection &symbolTable,
OpBuilder &builder) const {
auto cOp = cast<ConstantOp>(op);
if (auto value = dyn_cast<DenseIntOrFPElementsAttr>(cOp.getValue())) {
if (!value.isSplat() || !resultShardings[0]) {
// Currently non-splat constants are not supported.
return failure();
}
auto sharding = resultShardings[0];
auto newType = cast<RankedTensorType>(shardType(
cOp.getType(), getMesh(op, sharding.getMeshAttr(), symbolTable),
sharding));
auto newValue = value.resizeSplat(newType);
auto newOp = builder.create<ConstantOp>(op->getLoc(), newType, newValue);
spmdizationMap.map(op->getResult(0), newOp.getResult());
spmdizationMap.map(op, newOp.getOperation());
} else {
// `clone` will populate the mapping of old to new results.
(void)builder.clone(*op, spmdizationMap);
}
return success();
}
};
} // namespace
void mlir::arith::registerShardingInterfaceExternalModels(
DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, ArithDialect *dialect) {
ConstantOp::template attachInterface<ConstantShardingInterface>(*ctx);
});
}

View File

@@ -194,6 +194,12 @@ static void shardShape(const InShape &inShape, const MeshShape &meshShape,
const SplitAxes &splitAxes, OutShape &outShape,
ArrayRef<int64_t> shardedDimsOffsets = {},
ArrayRef<int64_t> haloSizes = {}) {
// 0d tensors cannot be sharded and must get replicated
if (inShape.empty()) {
assert(outShape.empty());
return;
}
std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
llvm::adl_begin(outShape));
@@ -271,7 +277,8 @@ Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) {
void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
OpOperand &operand,
OpBuilder &builder) {
OpBuilder &builder,
ShardOp &newShardOp) {
OpBuilder::InsertionGuard insertionGuard(builder);
Value operandValue = operand.get();
Operation *operandOp = operand.getOwner();
@@ -279,14 +286,20 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
if (shardOp && sharding == shardOp.getSharding() &&
!shardOp.getAnnotateForUsers()) {
// No need for anything the correct sharding is already set.
// No need for anything if the correct sharding is already set.
if (!newShardOp) {
newShardOp = shardOp;
}
return;
}
auto shardingOp = builder.create<ShardingOp>(operandValue.getLoc(), sharding);
auto newShardOp =
builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
/*annotate_for_users*/ false);
if (!newShardOp) {
auto shardingOp =
builder.create<ShardingOp>(operandValue.getLoc(), sharding);
newShardOp =
builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
/*annotate_for_users*/ false);
}
IRRewriter rewriter(builder);
rewriter.replaceUsesWithIf(
operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
@@ -297,17 +310,19 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
return;
}
auto newShardOp2 =
builder.create<ShardOp>(operandValue.getLoc(), newShardOp, shardingOp,
/*annotate_for_users*/ true);
auto newShardOp2 = builder.create<ShardOp>(operandValue.getLoc(), newShardOp,
newShardOp.getSharding(),
/*annotate_for_users*/ true);
rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2);
return;
}
void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
OpResult result,
OpBuilder &builder) {
ShardOp newShardOp;
for (auto &use : llvm::make_early_inc_range(result.getUses())) {
maybeInsertTargetShardingAnnotation(sharding, use, builder);
maybeInsertTargetShardingAnnotation(sharding, use, builder, newShardOp);
}
}
@@ -316,9 +331,18 @@ void mlir::mesh::maybeInsertSourceShardingAnnotation(MeshSharding sharding,
OpBuilder &builder) {
OpBuilder::InsertionGuard insertionGuard(builder);
Value operandValue = operand.get();
Operation *operandOp = operand.getOwner();
Operation *operandSrcOp = operandValue.getDefiningOp();
bool isBlockArg = !operandSrcOp;
{
auto opType = dyn_cast<mlir::RankedTensorType>(operandValue.getType());
assert(!opType || opType.getRank() > 0 || isFullReplication(sharding));
}
if (!isa<RankedTensorType>(operandValue.getType()) && operandSrcOp &&
operandSrcOp->hasTrait<OpTrait::ConstantLike>()) {
return;
}
Operation *operandOp = operand.getOwner();
ShardOp shardOp = dyn_cast_or_null<ShardOp>(operandSrcOp);
if (shardOp && sharding == shardOp.getSharding() &&
@@ -432,16 +456,14 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
ArrayRef<MeshAxesAttr> split_axes,
ArrayRef<MeshAxis> partial_axes,
mesh::ReductionKind partial_type,
ArrayRef<int64_t> static_halo_sizes,
ArrayRef<int64_t> static_sharded_dims_offsets) {
ArrayRef<int64_t> static_halos,
ArrayRef<int64_t> static_offsets) {
return build(
b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes),
::mlir::DenseI16ArrayAttr::get(b.getContext(), partial_axes),
::mlir::mesh::ReductionKindAttr::get(b.getContext(), partial_type),
::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halo_sizes), {},
::mlir::DenseI64ArrayAttr::get(b.getContext(),
static_sharded_dims_offsets),
{});
::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {});
}
void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
@@ -453,6 +475,18 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
{}, {}, {}, {});
}
void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
llvm::StringRef mesh, ArrayRef<MeshAxesAttr> split_axes,
ArrayRef<int64_t> static_halos,
ArrayRef<int64_t> static_offsets) {
return build(
b, odsState, FlatSymbolRefAttr::get(b.getContext(), mesh),
MeshAxesArrayAttr::get(b.getContext(), split_axes), {},
::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum),
::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {});
}
void ShardingOp::build(
::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
FlatSymbolRefAttr mesh, ArrayRef<MeshAxesAttr> split_axes,
@@ -579,9 +613,10 @@ LogicalResult ShardingOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
namespace {
// Sharding annotations "halo sizes" and "sharded dims offsets"
// are a mix of attributes and dynamic values. This canonicalization moves
// constant values to the respective attribute lists and so minimizes the number
// constant values to the respective attribute lists, minimizing the number
// of values.
class FoldDynamicLists final : public OpRewritePattern<ShardingOp> {
// It also removes sharded_dims_sizes and halos if they are effectively "empty".
class NormalizeSharding final : public OpRewritePattern<ShardingOp> {
public:
using OpRewritePattern<ShardingOp>::OpRewritePattern;
@@ -593,18 +628,48 @@ public:
op.getDynamicShardedDimsOffsets(), b);
// No constant operands were folded, just return;
if (failed(foldDynamicIndexList(mixedHalos, /*onlyNonNegative=*/true)) &&
failed(foldDynamicIndexList(mixedOffs, /*onlyNonNegative=*/true))) {
bool modified = succeeded(foldDynamicIndexList(mixedHalos, true)) ||
succeeded(foldDynamicIndexList(mixedOffs, true));
auto [staticHalos, dynamicHalos] = decomposeMixedValues(mixedHalos);
auto [staticOffs, dynamicOffs] = decomposeMixedValues(mixedOffs);
if (dynamicHalos.empty() && !staticHalos.empty()) {
if (staticHalos[0] == 0 && llvm::all_equal(staticHalos)) {
staticHalos.clear();
modified = true;
}
}
// Remove sharded dims offsets if they are effectively the default values,
// e.g. if they define equi-distance between all neighboring shards.
// Requires static-only offsets. Compares the first distance as the
// difference between the first two offsets. Only if all consecutive
// distances are the same, the offsets are removed.
if (dynamicOffs.empty() && !staticOffs.empty()) {
assert(staticOffs.size() >= 2);
auto diff = staticOffs[1] - staticOffs[0];
bool all_same = staticOffs.size() > 2;
for (auto i = 2u; i < staticOffs.size(); ++i) {
if (staticOffs[i] - staticOffs[i - 1] != diff) {
all_same = false;
break;
}
}
if (all_same) {
staticOffs.clear();
modified = true;
}
}
if (!modified) {
return failure();
}
auto halos = decomposeMixedValues(mixedHalos);
auto offs = decomposeMixedValues(mixedOffs);
op.setStaticHaloSizes(halos.first);
op.getDynamicHaloSizesMutable().assign(halos.second);
op.setStaticShardedDimsOffsets(offs.first);
op.getDynamicShardedDimsOffsetsMutable().assign(offs.second);
op.setStaticHaloSizes(staticHalos);
op.getDynamicHaloSizesMutable().assign(dynamicHalos);
op.setStaticShardedDimsOffsets(staticOffs);
op.getDynamicShardedDimsOffsetsMutable().assign(dynamicOffs);
return success();
}
@@ -613,7 +678,7 @@ public:
void ShardingOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
mlir::MLIRContext *context) {
results.add<FoldDynamicLists>(context);
results.add<NormalizeSharding>(context);
}
//===----------------------------------------------------------------------===//
@@ -707,11 +772,19 @@ bool MeshSharding::operator!=(const MeshSharding &rhs) const {
return !(*this == rhs);
}
MeshSharding::MeshSharding(::mlir::FlatSymbolRefAttr mesh_) : mesh(mesh_) {}
MeshSharding::MeshSharding(Value rhs) {
auto shardingOp = mlir::dyn_cast<ShardingOp>(rhs.getDefiningOp());
assert(shardingOp && "expected sharding op");
*this = get(shardingOp.getMeshAttr(), shardingOp.getSplitAxes().getAxes(),
shardingOp.getPartialAxes().value_or(ArrayRef<MeshAxis>()),
auto splitAxes = shardingOp.getSplitAxes().getAxes();
auto partialAxes = shardingOp.getPartialAxes().value_or(ArrayRef<MeshAxis>());
// If splitAxes and partialAxes are empty, use "empty" constructor.
if (splitAxes.empty() && partialAxes.empty()) {
*this = MeshSharding(shardingOp.getMeshAttr());
return;
}
*this = get(shardingOp.getMeshAttr(), splitAxes, partialAxes,
shardingOp.getPartialType().value_or(ReductionKind::Sum),
shardingOp.getStaticHaloSizes(),
shardingOp.getStaticShardedDimsOffsets(),
@@ -727,8 +800,11 @@ MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_,
ArrayRef<int64_t> static_sharded_dims_offsets_,
ArrayRef<Value> dynamic_halo_sizes_,
ArrayRef<Value> dynamic_sharded_dims_offsets_) {
MeshSharding res;
res.mesh = mesh_;
MeshSharding res(mesh_);
if (split_axes_.empty() && partial_axes_.empty()) {
return res;
}
res.split_axes.resize(split_axes_.size());
for (auto [i, axis] : llvm::enumerate(split_axes_)) {
res.split_axes[i] =
@@ -771,6 +847,53 @@ void ShardOp::getAsmResultNames(
setNameFn(getResult(), "sharding_annotated");
}
namespace {
// Determine if the given ShardOp is a duplicate of another ShardOp
// on the same value. This can happen if constant values are sharded.
class FoldDuplicateShardOp final : public OpRewritePattern<ShardOp> {
public:
using OpRewritePattern<ShardOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ShardOp op, PatternRewriter &b) const override {
// Get the use-list of the value being sharded and check if it has more than
// one use.
Value value = op.getSrc();
if (value.hasOneUse() || value.getDefiningOp<ShardOp>()) {
return failure();
}
// Iterate through the uses of the value to find a duplicate ShardOp.
for (auto &use : value.getUses()) {
if (use.getOwner() != op.getOperation()) {
auto otherOp = dyn_cast<ShardOp>(use.getOwner());
if (!otherOp || !otherOp->isBeforeInBlock(op)) {
return failure();
}
// Create a MeshSharding object for the current and the other ShardOp
// If the two are equal replace current op with the other op.
MeshSharding currentSharding(op.getSharding());
MeshSharding otherSharding(otherOp.getSharding());
if (currentSharding == otherSharding) {
b.replaceAllUsesWith(op.getResult(), otherOp.getResult());
b.eraseOp(op.getOperation());
} else {
// use the other sharding as input for op
op.getSrcMutable().assign(otherOp.getResult());
}
return success();
}
}
return failure();
}
};
} // namespace
void ShardOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
mlir::MLIRContext *context) {
results.add<FoldDuplicateShardOp>(context);
}
//===----------------------------------------------------------------------===//
// mesh.process_multi_index op
//===----------------------------------------------------------------------===//

View File

@@ -168,17 +168,12 @@ LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() {
// check operands and results type
for (Type type : op->getOperandTypes())
if (!llvm::isa<RankedTensorType>(type))
if (!llvm::isa<RankedTensorType>(type) && !type.isIntOrIndexOrFloat())
return failure();
for (Type type : op->getResultTypes())
if (!llvm::isa<RankedTensorType>(type))
if (!llvm::isa<RankedTensorType>(type) && !type.isIntOrIndexOrFloat())
return failure();
// check loop types
SmallVector<utils::IteratorType> loopTypes = getLoopIteratorTypes();
if (loopTypes.empty())
return failure();
// check maps
SmallVector<AffineMap> maps = getIndexingMaps();
if (maps.empty())
@@ -286,18 +281,22 @@ mesh::detail::defaultGetShardingOption(Operation *op,
continue;
AffineMap map = maps[numOperands + shardingIt.index()];
anyShardingInResultsOrOperands = true;
// Handle the split axes: calculate the corresponding loop index for each
// split axes sub-array, and then store the sub-array to
// shardingOption[index]
for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
AffineExpr expr = std::get<0>(it);
ArrayRef<MeshAxis> axes = std::get<1>(it).asArrayRef();
auto dim = cast<AffineDimExpr>(expr);
unsigned index = dim.getPosition();
visitedLoopIndices.insert(index);
if (failed(fillShardingOption(op, shardingOption, shardAttr.getMeshAttr(),
axes, index)))
return failure();
if (shardAttr.getSplitAxes().empty() || map.getResults().empty()) {
shardingOption.mesh = shardAttr.getMeshAttr();
} else {
// Handle the split axes: calculate the corresponding loop index for each
// split axes sub-array, and then store the sub-array to
// shardingOption[index]
for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
AffineExpr expr = std::get<0>(it);
ArrayRef<MeshAxis> axes = std::get<1>(it).asArrayRef();
auto dim = cast<AffineDimExpr>(expr);
unsigned index = dim.getPosition();
visitedLoopIndices.insert(index);
if (failed(fillShardingOption(op, shardingOption,
shardAttr.getMeshAttr(), axes, index)))
return failure();
}
}
// Handle the partial axes: at this stage, the exact loop index/indices
@@ -323,7 +322,7 @@ mesh::detail::defaultGetShardingOption(Operation *op,
if (!shardAttr)
continue;
anyShardingInResultsOrOperands = true;
anyShardingInResultsOrOperands = !shardAttr.getSplitAxes().empty();
AffineMap map = maps[shardingIt.index()];
unsigned numDims = map.getNumDims();
@@ -448,7 +447,16 @@ static FailureOr<MeshSharding> getSharding(OpOperand &opOperand,
const ShardingOption &shardingOption,
AffineMap map) {
Value operandValue = opOperand.get();
auto operandType = cast<RankedTensorType>(operandValue.getType());
auto operandType = dyn_cast<RankedTensorType>(operandValue.getType());
if (!operandType) {
if (operandValue.getType().isIntOrIndexOrFloat())
return MeshSharding();
return failure();
}
// 0d tensors cannot be sharded and must get replicated
if (operandType.getRank() == 0) {
return MeshSharding(shardingOption.mesh);
}
SmallVector<SmallVector<MeshAxis>> splitAxes(operandType.getRank());
unsigned numDims = map.getNumDims();
for (auto it : llvm::enumerate(map.getResults())) {
@@ -579,7 +587,7 @@ static bool
isValueCompatibleWithFullReplicationSharding(Value value,
MeshSharding sharding) {
if (isa<RankedTensorType>(value.getType())) {
return sharding && isFullReplication(sharding);
return isFullReplication(sharding);
}
return !sharding;

View File

@@ -282,11 +282,12 @@ static FailureOr<ShardingOption> selectShardingOption(
// a `mesh.shard` operation for all remaining operands and results that do not
// have sharding annotations.
static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
ShardingInterface shardingOp = llvm::dyn_cast<ShardingInterface>(op);
if (op->hasTrait<OpTrait::IsTerminator>() ||
llvm::isa<mesh::ShardOp, mesh::ShardingOp>(op))
(op->hasTrait<OpTrait::ConstantLike>() && !shardingOp) ||
llvm::isa<mesh::ShardOp, mesh::ShardingOp, mesh::GetShardingOp>(op))
return success();
ShardingInterface shardingOp = llvm::dyn_cast<ShardingInterface>(op);
if (!shardingOp) {
op->emitOpError() << "sharding interface is not implemented.";
return failure();

View File

@@ -561,7 +561,8 @@ TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, MeshOp mesh,
TypedValue<ShapedType> sourceUnshardedValue,
TypedValue<ShapedType> sourceShard) {
// If source and destination sharding are the same, no need to do anything.
if (sourceSharding == targetSharding) {
if (sourceSharding == targetSharding || (isFullReplication(sourceSharding) &&
isFullReplication(targetSharding))) {
return sourceShard;
}
@@ -636,14 +637,6 @@ shardedBlockArgumentTypes(Block &block,
return res;
}
void spmdizeTriviallyShardableOperation(Operation &op,
ArrayRef<Value> spmdizedOperands,
ArrayRef<MeshSharding> operandShardings,
ArrayRef<MeshSharding> resultShardings,
IRMapping &spmdizationMap,
SymbolTableCollection &symbolTable,
OpBuilder &builder);
static LogicalResult spmdizeOperation(
Operation &op, ArrayRef<Value> spmdizedOperands,
ArrayRef<MeshSharding> operandShardings,
@@ -703,8 +696,9 @@ static std::vector<MeshSharding> getResultShardings(Operation &op) {
if (!rankedTensor) {
return MeshSharding();
}
assert(result.hasOneUse());
if (!result.hasOneUse()) {
return MeshSharding();
}
Operation *userOp = *result.getUsers().begin();
ShardOp shardOp = llvm::cast<ShardOp>(userOp);
return MeshSharding(shardOp.getSharding());
@@ -744,6 +738,15 @@ spmdizeOperation(Operation &op, IRMapping &spmdizationMap,
if (isa<ShardingOp>(op)) {
return success();
}
if (auto getShardingOp = dyn_cast<GetShardingOp>(op)) {
auto shardOp = getShardingOp.getSource().getDefiningOp<ShardOp>();
if (!shardOp) {
return op.emitError("expected a shard op as source of get_sharding");
}
auto newSharding = builder.clone(*shardOp.getSharding().getDefiningOp());
spmdizationMap.map(op.getResult(0), newSharding->getResult(0));
return success();
}
ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
if (shardOp) {
@@ -765,6 +768,7 @@ spmdizeOperation(Operation &op, IRMapping &spmdizationMap,
static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap,
SymbolTableCollection &symbolTableCollection,
OpBuilder &builder) {
SmallVector<Location> argLocations;
llvm::transform(block.getArguments(), std::back_inserter(argLocations),
[](BlockArgument arg) { return arg.getLoc(); });
@@ -796,8 +800,12 @@ spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap,
// Snapshot the original blocks to not mess up the iteration when adding new
// blocks.
SmallVector<Block *> originalBlocks;
llvm::transform(op.getBlocks(), std::back_inserter(originalBlocks),
[](Block &b) { return &b; });
for (Block &b : op.getBlocks()) {
if (llvm::any_of(b.getOperations(),
[](Operation &op) { return isa<ShardOp>(op); })) {
originalBlocks.push_back(&b);
}
}
for (Block *block : originalBlocks) {
if (failed(spmdizeBlock(*block, spmdizationMap, symbolTableCollection,
@@ -823,10 +831,11 @@ spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap,
break;
}
}
assert(returnOp);
op.setType(FunctionType::get(op->getContext(),
op.getFunctionBody().front().getArgumentTypes(),
returnOp->getOperandTypes()));
if (returnOp) {
op.setType(FunctionType::get(
op->getContext(), op.getFunctionBody().front().getArgumentTypes(),
returnOp->getOperandTypes()));
}
return success();
}

View File

@@ -22,10 +22,11 @@ using namespace mlir::mesh;
namespace {
// Sharding of tensor.empty
struct EmptyOpShardingInterface
: public ShardingInterface::ExternalModel<EmptyOpShardingInterface,
tensor::EmptyOp> {
// Sharding of tensor.empty/tensor.splat
template <typename OpTy>
struct CreatorOpShardingInterface
: public ShardingInterface::ExternalModel<CreatorOpShardingInterface<OpTy>,
OpTy> {
SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
auto ndims = mlir::cast<ShapedType>(op->getResult(0).getType()).getRank();
return SmallVector<utils::IteratorType>(ndims,
@@ -38,7 +39,9 @@ struct EmptyOpShardingInterface
auto type = dyn_cast<RankedTensorType>(val.getType());
if (!type)
return {};
return {AffineMap::getMultiDimIdentityMap(type.getRank(), ctx)};
return SmallVector<AffineMap>(
op->getNumOperands() + op->getNumResults(),
{AffineMap::getMultiDimIdentityMap(type.getRank(), ctx)});
}
LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
@@ -82,8 +85,7 @@ struct EmptyOpShardingInterface
newOperands.emplace_back(spmdizedOperands[++currOldOprndNum]);
}
}
newOp =
builder.create<tensor::EmptyOp>(op->getLoc(), shardType, newOperands);
newOp = builder.create<OpTy>(op->getLoc(), shardType, newOperands);
spmdizationMap.map(op->getResult(0), newOp->getResult(0));
} else {
// `clone` will populate the mapping of old to new results.
@@ -100,6 +102,9 @@ void mlir::tensor::registerShardingInterfaceExternalModels(
DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
EmptyOp::template attachInterface<EmptyOpShardingInterface>(*ctx);
EmptyOp::template attachInterface<CreatorOpShardingInterface<EmptyOp>>(
*ctx);
SplatOp::template attachInterface<CreatorOpShardingInterface<SplatOp>>(
*ctx);
});
}

View File

@@ -0,0 +1,17 @@
// RUN: mlir-opt \
// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization))" \
// RUN: %s | FileCheck %s
mesh.mesh @mesh4x4(shape = 4x4)
// CHECK-LABEL: func @test_spmdize_constant
// CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> :
// tensor<256x1024xf32> CHECK-NEXT: [[vc434_i32:%.*]] = arith.constant 434 :
// i32 CHECK-NEXT: return [[vcst]] : tensor<256x1024xf32>
func.func @test_spmdize_constant() ->(tensor<1024x1024xf32>)attributes{llvm.emit_c_interface} {
%cst = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
%sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding
%sharding_annotated_1 = mesh.shard %cst to %sharding_1 : tensor<1024x1024xf32>
%ci = arith.constant 434 : i32
return %sharding_annotated_1 : tensor<1024x1024xf32>
}

View File

@@ -0,0 +1,54 @@
// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation))" %s | FileCheck %s
mesh.mesh @mesh4x4(shape = 4x4)
// CHECK-LABEL: func.func @test_shard_constant() -> tensor<1024x1024xf32> attributes {llvm.emit_c_interface} {
// CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
// CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
// CHECK-NEXT: [[vsharding_annotated:%.*]] = mesh.shard [[vcst]] to [[vsharding]] : tensor<1024x1024xf32>
// CHECK-NEXT: [[vcst_0:%.*]] = arith.constant 4.340000e+01 : f32
// CHECK-NEXT: [[v0:%.*]] = tensor.empty() : tensor<1024x1024xf32>
// CHECK-NEXT: [[vsharding_1:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
// CHECK-NEXT: [[vsharding_annotated_2:%.*]] = mesh.shard [[v0]] to [[vsharding_1]] : tensor<1024x1024xf32>
// CHECK-NEXT: [[vsharding_3:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
// CHECK-NEXT: [[vsharding_annotated_4:%.*]] = mesh.shard [[vsharding_annotated]] to [[vsharding_3]] annotate_for_users : tensor<1024x1024xf32>
// CHECK-NEXT: [[vsharding_5:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
// CHECK-NEXT: [[vsharding_annotated_6:%.*]] = mesh.shard [[vsharding_annotated_2]] to [[vsharding_5]] annotate_for_users : tensor<1024x1024xf32>
// CHECK-NEXT: [[v1:%.*]] = linalg.add ins([[vsharding_annotated_4]], [[vcst_0]] : tensor<1024x1024xf32>, f32) outs([[vsharding_annotated_6]] : tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
// CHECK-NEXT: [[vsharding_7:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
// CHECK-NEXT: [[vsharding_annotated_8:%.*]] = mesh.shard [[v1]] to [[vsharding_7]] : tensor<1024x1024xf32>
// CHECK-NEXT: return [[vsharding_annotated_8]] : tensor<1024x1024xf32>
func.func @test_shard_constant() -> (tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} {
%cst_1 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
%sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding
%sharding_annotated_1 = mesh.shard %cst_1 to %sharding_1 : tensor<1024x1024xf32>
%ci = arith.constant 43.4e+00 : f32
%o1 = tensor.empty() : tensor<1024x1024xf32>
%res = linalg.add ins(%sharding_annotated_1, %ci : tensor<1024x1024xf32>, f32) outs(%o1 : tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
return %res : tensor<1024x1024xf32>
}
// CHECK-LABEL: func.func @test_shard_constant_back() -> tensor<1024x1024xf32> attributes {llvm.emit_c_interface} {
// CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
// CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
// CHECK-NEXT: [[vsharding_annotated:%.*]] = mesh.shard [[vcst]] to [[vsharding]] : tensor<1024x1024xf32>
// CHECK-NEXT: [[vcst_0:%.*]] = arith.constant 4.340000e+01 : f32
// CHECK-NEXT: [[v0:%.*]] = tensor.empty() : tensor<1024x1024xf32>
// CHECK-NEXT: [[vsharding_1:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
// CHECK-NEXT: [[vsharding_annotated_2:%.*]] = mesh.shard [[v0]] to [[vsharding_1]] : tensor<1024x1024xf32>
// CHECK-NEXT: [[vsharding_3:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
// CHECK-NEXT: [[vsharding_annotated_4:%.*]] = mesh.shard [[vsharding_annotated]] to [[vsharding_3]] annotate_for_users : tensor<1024x1024xf32>
// CHECK-NEXT: [[vsharding_5:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
// CHECK-NEXT: [[vsharding_annotated_6:%.*]] = mesh.shard [[vsharding_annotated_2]] to [[vsharding_5]] annotate_for_users : tensor<1024x1024xf32>
// CHECK-NEXT: [[v1:%.*]] = linalg.add ins([[vsharding_annotated_4]], [[vcst_0]] : tensor<1024x1024xf32>, f32) outs([[vsharding_annotated_6]] : tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
// CHECK-NEXT: [[vsharding_7:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
// CHECK-NEXT: [[vsharding_annotated_8:%.*]] = mesh.shard [[v1]] to [[vsharding_7]] : tensor<1024x1024xf32>
func.func @test_shard_constant_back() -> (tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} {
%cst_1 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
%ci = arith.constant 43.4e+00 : f32
%o1 = tensor.empty() : tensor<1024x1024xf32>
%res = linalg.add ins(%cst_1, %ci : tensor<1024x1024xf32>, f32) outs(%o1 : tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
%sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding
%sharding_annotated_1 = mesh.shard %res to %sharding_1 : tensor<1024x1024xf32>
return %sharding_annotated_1 : tensor<1024x1024xf32>
}

View File

@@ -207,4 +207,42 @@ func.func @test_shard_offs() -> !mesh.sharding {
// CHECK mesh.sharding @mesh4x4 split_axes = [[0], [1]] sharded_dims_offsets = [0, 1, 2, 3, 4, 0, 2, 3, 4, 22] : !mesh.sharding
%sharding = mesh.sharding @mesh4x4 split_axes = [[0], [1]] sharded_dims_offsets = [0, 1, %c2_i64, 3, 4, 0, %c2_i64, 3, 4, 22] : !mesh.sharding
return %sharding : !mesh.sharding
}
}
// CHECK-LABEL: func @test_duplicate_shardops
func.func @test_duplicate_shardops() -> (tensor<1024x1024xf32>, tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} {
// CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
%cst_1 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
// CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0, 1]] : !mesh.sharding
%sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0, 1]] : !mesh.sharding
%cst_2 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
%sharding_2 = mesh.sharding @mesh4x4 split_axes = [[0, 1]] : !mesh.sharding
%sharding_annotated_2 = mesh.shard %cst_2 to %sharding_2 : tensor<1024x1024xf32>
%cst_3 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
%sharding_3 = mesh.sharding @mesh4x4 split_axes = [[0, 1]] : !mesh.sharding
%sharding_annotated_3 = mesh.shard %cst_3 to %sharding_3 : tensor<1024x1024xf32>
// CHECK-NEXT: [[vsharding_annotated:%.*]] = mesh.shard [[vcst]] to [[vsharding]] : tensor<1024x1024xf32>
%sharding_annotated_1 = mesh.shard %cst_1 to %sharding_1 : tensor<1024x1024xf32>
// CHECK-NEXT: return [[vsharding_annotated]], [[vsharding_annotated]] : tensor<1024x1024xf32>, tensor<1024x1024xf32>
return %sharding_annotated_1, %sharding_annotated_2 : tensor<1024x1024xf32>, tensor<1024x1024xf32>
}
// CHECK-LABEL: func @test_duplicate_shardops_diff
func.func @test_duplicate_shardops_diff() -> (tensor<1024x1024xf32>, tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} {
// CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
%cst_1 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
// CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
%sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding
%cst_2 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
// CHECK-NEXT: [[vsharding_0:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0, 1]] : !mesh.sharding
%sharding_2 = mesh.sharding @mesh4x4 split_axes = [[0, 1]] : !mesh.sharding
// CHECK-NEXT: [[vsharding_annotated:%.*]] = mesh.shard [[vcst]] to [[vsharding_0]] : tensor<1024x1024xf32>
%sharding_annotated_2 = mesh.shard %cst_2 to %sharding_2 : tensor<1024x1024xf32>
%cst_3 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
%sharding_3 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding
%sharding_annotated_3 = mesh.shard %cst_3 to %sharding_3 : tensor<1024x1024xf32>
// CHECK-NEXT: [[vsharding_annotated_1:%.*]] = mesh.shard [[vsharding_annotated]] to [[vsharding]] : tensor<1024x1024xf32>
%sharding_annotated_1 = mesh.shard %cst_1 to %sharding_1 : tensor<1024x1024xf32>
// CHECK-NEXT: return [[vsharding_annotated_1]], [[vsharding_annotated]] : tensor<1024x1024xf32>, tensor<1024x1024xf32>
return %sharding_annotated_1, %sharding_annotated_2 : tensor<1024x1024xf32>, tensor<1024x1024xf32>
}

View File

@@ -164,6 +164,14 @@ func.func @mesh_shard_shape() {
return
}
// CHECK-LABEL: func @mesh_get_sharding
// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
func.func @mesh_get_sharding(%arg0 : tensor<4x8xf32>) -> !mesh.sharding {
// CHECK-NEXT: mesh.get_sharding %[[ARG]] : tensor<4x8xf32> -> !mesh.sharding
%0 = mesh.get_sharding %arg0 : tensor<4x8xf32> -> !mesh.sharding
return %0 : !mesh.sharding
}
// CHECK-LABEL: func @mesh_shape
func.func @mesh_shape() -> (index, index) {
// CHECK: %[[RES:.*]]:2 = mesh.mesh_shape @mesh0 axes = [0, 1] : index, index

View File

@@ -4,6 +4,20 @@
mesh.mesh @mesh_1d(shape = 2)
// CHECK-LABEL: func @return_sharding
func.func @return_sharding(
// CHECK-SAME: [[ARG:%.*]]: tensor<1xf32>
%arg0: tensor<2xf32>
// CHECK-SAME: ) -> (tensor<1xf32>, !mesh.sharding) {
) -> (tensor<2xf32>, !mesh.sharding) {
%ssharding_annotated = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
%sharding_annotated = mesh.shard %arg0 to %ssharding_annotated : tensor<2xf32>
// CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}0]] : !mesh.sharding
%r = mesh.get_sharding %sharding_annotated : tensor<2xf32> -> !mesh.sharding
// CHECK-NEXT: return [[ARG]], [[vsharding]] : tensor<1xf32>, !mesh.sharding
return %sharding_annotated, %r : tensor<2xf32>, !mesh.sharding
}
// CHECK-LABEL: func @full_replication
func.func @full_replication(
// CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>