[MLIR][mesh] Mesh fixes (#124724)
A collection of fixes to the mesh dialect - allow constants in sharding propagation/spmdization - fixes to tensor replication (e.g. 0d tensors) - improved canonicalization - sharding propagation incorrectly generated too many ShardOps New operation `mesh.GetShardOp` enables exchanging sharding information (like on function boundaries)
This commit is contained in:
@@ -0,0 +1,23 @@
|
||||
//===- ShardingInterfaceImpl.h - ------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_ARITH_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
|
||||
#define MLIR_DIALECT_ARITH_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class DialectRegistry;
|
||||
|
||||
namespace arith {
|
||||
|
||||
void registerShardingInterfaceExternalModels(DialectRegistry ®istry);
|
||||
|
||||
} // namespace arith
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_ARITH_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
|
||||
@@ -51,7 +51,7 @@ private:
|
||||
SmallVector<Value> dynamic_sharded_dims_offsets;
|
||||
|
||||
public:
|
||||
MeshSharding() = default;
|
||||
MeshSharding(::mlir::FlatSymbolRefAttr mesh_ = nullptr);
|
||||
MeshSharding(Value rhs);
|
||||
static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_,
|
||||
ArrayRef<MeshAxesAttr> split_axes_,
|
||||
@@ -62,7 +62,7 @@ public:
|
||||
ArrayRef<Value> dynamic_halo_sizes_ = {},
|
||||
ArrayRef<Value> dynamic_sharded_dims_offsets_ = {});
|
||||
::mlir::FlatSymbolRefAttr getMeshAttr() const { return mesh; }
|
||||
::llvm::StringRef getMesh() const { return mesh.getValue(); }
|
||||
::llvm::StringRef getMesh() const { return mesh ? mesh.getValue() : ""; }
|
||||
ArrayRef<MeshAxesAttr> getSplitAxes() const { return split_axes; }
|
||||
ArrayRef<MeshAxis> getPartialAxes() const { return partial_axes; }
|
||||
ReductionKind getPartialType() const { return partial_type; }
|
||||
@@ -201,10 +201,12 @@ ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
|
||||
Type shardType(Type type, MeshOp mesh, MeshSharding sharding);
|
||||
|
||||
// Insert shard op if there is not one that already has the same sharding.
|
||||
// Use newShardOp if it is not null. Otherwise create a new one.
|
||||
// May insert resharding if required.
|
||||
// Potentially updates newShardOp.
|
||||
void maybeInsertTargetShardingAnnotation(MeshSharding sharding,
|
||||
OpOperand &operand,
|
||||
OpBuilder &builder);
|
||||
OpOperand &operand, OpBuilder &builder,
|
||||
ShardOp &newShardOp);
|
||||
void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpResult result,
|
||||
OpBuilder &builder);
|
||||
void maybeInsertSourceShardingAnnotation(MeshSharding sharding,
|
||||
|
||||
@@ -28,7 +28,7 @@ class Mesh_Op<string mnemonic, list<Trait> traits = []> :
|
||||
Op<Mesh_Dialect, mnemonic, traits> {
|
||||
}
|
||||
|
||||
def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol]> {
|
||||
def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol, Pure]> {
|
||||
let summary = "Description of a device/process mesh.";
|
||||
let description = [{
|
||||
The mesh.mesh operation is a symbol operation that identifies a specific
|
||||
@@ -318,12 +318,33 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
|
||||
"ArrayRef<MeshAxesAttr>":$split_axes,
|
||||
"::mlir::ArrayRef<::mlir::OpFoldResult>":$halo_sizes,
|
||||
"::mlir::ArrayRef<::mlir::OpFoldResult>":$sharded_dims_offsets)>,
|
||||
OpBuilder<(ins "llvm::StringRef":$mesh,
|
||||
"ArrayRef<MeshAxesAttr>":$split_axes,
|
||||
CArg<"ArrayRef<int64_t>", "{}">:$static_halo_sizes,
|
||||
CArg<"ArrayRef<int64_t>", "{}">:$static_sharded_dims_offsets
|
||||
)>,
|
||||
OpBuilder<(ins "mlir::mesh::MeshSharding":$from)>
|
||||
];
|
||||
let hasVerifier = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Mesh_GetShardingOp : Mesh_Op<"get_sharding", [Pure]> {
|
||||
let summary = "Get the sharding of the given tensor.";
|
||||
let description = [{
|
||||
This operation returns the sharding of the given tensor as a MeshSharding.
|
||||
}];
|
||||
let arguments = (ins
|
||||
AnyRankedTensor:$source
|
||||
);
|
||||
let results = (outs
|
||||
Mesh_Sharding:$result
|
||||
);
|
||||
let assemblyFormat = [{
|
||||
$source attr-dict `:` type($source) `->` type($result)
|
||||
}];
|
||||
}
|
||||
|
||||
def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [Pure]> {
|
||||
let summary = "Get the shard shape of a given process/device.";
|
||||
let description = [{
|
||||
@@ -460,6 +481,7 @@ def Mesh_ShardOp : Mesh_Op<"shard", [
|
||||
(`annotate_for_users` $annotate_for_users^)?
|
||||
attr-dict `:` type($result)
|
||||
}];
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -36,7 +36,9 @@ struct ShardingOption {
|
||||
bool empty = false;
|
||||
ShardingOption() = default;
|
||||
ShardingOption(ShardingArray shardingArray, FlatSymbolRefAttr mesh)
|
||||
: shardingArray(std::move(shardingArray)), mesh(mesh) {}
|
||||
: shardingArray(std::move(shardingArray)), mesh(mesh) {
|
||||
assert(this->mesh);
|
||||
}
|
||||
static ShardingOption makeEmpty() {
|
||||
auto res = ShardingOption();
|
||||
res.empty = true;
|
||||
|
||||
@@ -23,6 +23,7 @@
|
||||
#include "mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h"
|
||||
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
|
||||
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
|
||||
#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
|
||||
@@ -158,6 +159,7 @@ inline void registerAllDialects(DialectRegistry ®istry) {
|
||||
arith::registerBufferDeallocationOpInterfaceExternalModels(registry);
|
||||
arith::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
arith::registerBufferViewFlowOpInterfaceExternalModels(registry);
|
||||
arith::registerShardingInterfaceExternalModels(registry);
|
||||
arith::registerValueBoundsOpInterfaceExternalModels(registry);
|
||||
bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
|
||||
registry);
|
||||
|
||||
@@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRArithTransforms
|
||||
ExpandOps.cpp
|
||||
IntRangeOptimizations.cpp
|
||||
ReifyValueBounds.cpp
|
||||
ShardingInterfaceImpl.cpp
|
||||
UnsignedWhenEquivalent.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
@@ -26,7 +27,9 @@ add_mlir_dialect_library(MLIRArithTransforms
|
||||
MLIRInferIntRangeInterface
|
||||
MLIRIR
|
||||
MLIRMemRefDialect
|
||||
MLIRMeshDialect
|
||||
MLIRPass
|
||||
MLIRShardingInterface
|
||||
MLIRTensorDialect
|
||||
MLIRTransforms
|
||||
MLIRTransformUtils
|
||||
|
||||
105
mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp
Normal file
105
mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp
Normal file
@@ -0,0 +1,105 @@
|
||||
//===- ShardingInterfaceImpl.cpp ------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
|
||||
#include "mlir/IR/DialectRegistry.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::arith;
|
||||
using namespace mlir::mesh;
|
||||
|
||||
namespace {
|
||||
|
||||
// Sharding of arith.constant
|
||||
// RankedTensor constants can be sharded like any other tensor.
|
||||
// %cst = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
|
||||
// %sharding = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding
|
||||
// Scalar constants are always replicated and need no sharding annotation.
|
||||
|
||||
struct ConstantShardingInterface
|
||||
: public ShardingInterface::ExternalModel<ConstantShardingInterface,
|
||||
ConstantOp> {
|
||||
SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
|
||||
auto ndims = 0;
|
||||
if (auto type = dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
|
||||
ndims = type.getRank();
|
||||
}
|
||||
return SmallVector<utils::IteratorType>(ndims,
|
||||
utils::IteratorType::parallel);
|
||||
}
|
||||
|
||||
SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
|
||||
if (auto type = dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
|
||||
return SmallVector<AffineMap>(1, {AffineMap::getMultiDimIdentityMap(
|
||||
type.getRank(), op->getContext())});
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
// Indicate failure if no result sharding exists.
|
||||
// Otherwise mirror result sharding if it is a tensor constant.
|
||||
// Otherwise return replication option.
|
||||
FailureOr<ShardingOption>
|
||||
getShardingOption(Operation *op, ArrayRef<MeshSharding> operandShardings,
|
||||
ArrayRef<MeshSharding> resultShardings) const {
|
||||
assert(resultShardings.size() == 1 &&
|
||||
"Expecting exactly one result sharding for arith.constant");
|
||||
auto resultSharding = resultShardings[0];
|
||||
if (!resultSharding) {
|
||||
return failure();
|
||||
}
|
||||
if (auto type = dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
|
||||
ShardingArray axesArray(resultSharding.getSplitAxes().size());
|
||||
for (auto [i, axes] : llvm::enumerate(resultSharding.getSplitAxes())) {
|
||||
axesArray[i].append(axes.asArrayRef().begin(), axes.asArrayRef().end());
|
||||
}
|
||||
return ShardingOption(axesArray, resultSharding.getMeshAttr());
|
||||
}
|
||||
return ShardingOption({}, resultSharding.getMeshAttr());
|
||||
}
|
||||
|
||||
LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
|
||||
ArrayRef<MeshSharding> operandShardings,
|
||||
ArrayRef<MeshSharding> resultShardings,
|
||||
IRMapping &spmdizationMap,
|
||||
SymbolTableCollection &symbolTable,
|
||||
OpBuilder &builder) const {
|
||||
auto cOp = cast<ConstantOp>(op);
|
||||
if (auto value = dyn_cast<DenseIntOrFPElementsAttr>(cOp.getValue())) {
|
||||
if (!value.isSplat() || !resultShardings[0]) {
|
||||
// Currently non-splat constants are not supported.
|
||||
return failure();
|
||||
}
|
||||
auto sharding = resultShardings[0];
|
||||
auto newType = cast<RankedTensorType>(shardType(
|
||||
cOp.getType(), getMesh(op, sharding.getMeshAttr(), symbolTable),
|
||||
sharding));
|
||||
auto newValue = value.resizeSplat(newType);
|
||||
auto newOp = builder.create<ConstantOp>(op->getLoc(), newType, newValue);
|
||||
spmdizationMap.map(op->getResult(0), newOp.getResult());
|
||||
spmdizationMap.map(op, newOp.getOperation());
|
||||
} else {
|
||||
// `clone` will populate the mapping of old to new results.
|
||||
(void)builder.clone(*op, spmdizationMap);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void mlir::arith::registerShardingInterfaceExternalModels(
|
||||
DialectRegistry ®istry) {
|
||||
|
||||
registry.addExtension(+[](MLIRContext *ctx, ArithDialect *dialect) {
|
||||
ConstantOp::template attachInterface<ConstantShardingInterface>(*ctx);
|
||||
});
|
||||
}
|
||||
@@ -194,6 +194,12 @@ static void shardShape(const InShape &inShape, const MeshShape &meshShape,
|
||||
const SplitAxes &splitAxes, OutShape &outShape,
|
||||
ArrayRef<int64_t> shardedDimsOffsets = {},
|
||||
ArrayRef<int64_t> haloSizes = {}) {
|
||||
// 0d tensors cannot be sharded and must get replicated
|
||||
if (inShape.empty()) {
|
||||
assert(outShape.empty());
|
||||
return;
|
||||
}
|
||||
|
||||
std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
|
||||
llvm::adl_begin(outShape));
|
||||
|
||||
@@ -271,7 +277,8 @@ Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) {
|
||||
|
||||
void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
|
||||
OpOperand &operand,
|
||||
OpBuilder &builder) {
|
||||
OpBuilder &builder,
|
||||
ShardOp &newShardOp) {
|
||||
OpBuilder::InsertionGuard insertionGuard(builder);
|
||||
Value operandValue = operand.get();
|
||||
Operation *operandOp = operand.getOwner();
|
||||
@@ -279,14 +286,20 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
|
||||
ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
|
||||
if (shardOp && sharding == shardOp.getSharding() &&
|
||||
!shardOp.getAnnotateForUsers()) {
|
||||
// No need for anything the correct sharding is already set.
|
||||
// No need for anything if the correct sharding is already set.
|
||||
if (!newShardOp) {
|
||||
newShardOp = shardOp;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
auto shardingOp = builder.create<ShardingOp>(operandValue.getLoc(), sharding);
|
||||
auto newShardOp =
|
||||
builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
|
||||
/*annotate_for_users*/ false);
|
||||
if (!newShardOp) {
|
||||
auto shardingOp =
|
||||
builder.create<ShardingOp>(operandValue.getLoc(), sharding);
|
||||
newShardOp =
|
||||
builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
|
||||
/*annotate_for_users*/ false);
|
||||
}
|
||||
IRRewriter rewriter(builder);
|
||||
rewriter.replaceUsesWithIf(
|
||||
operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
|
||||
@@ -297,17 +310,19 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
|
||||
return;
|
||||
}
|
||||
|
||||
auto newShardOp2 =
|
||||
builder.create<ShardOp>(operandValue.getLoc(), newShardOp, shardingOp,
|
||||
/*annotate_for_users*/ true);
|
||||
auto newShardOp2 = builder.create<ShardOp>(operandValue.getLoc(), newShardOp,
|
||||
newShardOp.getSharding(),
|
||||
/*annotate_for_users*/ true);
|
||||
rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2);
|
||||
return;
|
||||
}
|
||||
|
||||
void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
|
||||
OpResult result,
|
||||
OpBuilder &builder) {
|
||||
ShardOp newShardOp;
|
||||
for (auto &use : llvm::make_early_inc_range(result.getUses())) {
|
||||
maybeInsertTargetShardingAnnotation(sharding, use, builder);
|
||||
maybeInsertTargetShardingAnnotation(sharding, use, builder, newShardOp);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -316,9 +331,18 @@ void mlir::mesh::maybeInsertSourceShardingAnnotation(MeshSharding sharding,
|
||||
OpBuilder &builder) {
|
||||
OpBuilder::InsertionGuard insertionGuard(builder);
|
||||
Value operandValue = operand.get();
|
||||
Operation *operandOp = operand.getOwner();
|
||||
Operation *operandSrcOp = operandValue.getDefiningOp();
|
||||
bool isBlockArg = !operandSrcOp;
|
||||
{
|
||||
auto opType = dyn_cast<mlir::RankedTensorType>(operandValue.getType());
|
||||
assert(!opType || opType.getRank() > 0 || isFullReplication(sharding));
|
||||
}
|
||||
if (!isa<RankedTensorType>(operandValue.getType()) && operandSrcOp &&
|
||||
operandSrcOp->hasTrait<OpTrait::ConstantLike>()) {
|
||||
return;
|
||||
}
|
||||
|
||||
Operation *operandOp = operand.getOwner();
|
||||
ShardOp shardOp = dyn_cast_or_null<ShardOp>(operandSrcOp);
|
||||
|
||||
if (shardOp && sharding == shardOp.getSharding() &&
|
||||
@@ -432,16 +456,14 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
|
||||
ArrayRef<MeshAxesAttr> split_axes,
|
||||
ArrayRef<MeshAxis> partial_axes,
|
||||
mesh::ReductionKind partial_type,
|
||||
ArrayRef<int64_t> static_halo_sizes,
|
||||
ArrayRef<int64_t> static_sharded_dims_offsets) {
|
||||
ArrayRef<int64_t> static_halos,
|
||||
ArrayRef<int64_t> static_offsets) {
|
||||
return build(
|
||||
b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes),
|
||||
::mlir::DenseI16ArrayAttr::get(b.getContext(), partial_axes),
|
||||
::mlir::mesh::ReductionKindAttr::get(b.getContext(), partial_type),
|
||||
::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halo_sizes), {},
|
||||
::mlir::DenseI64ArrayAttr::get(b.getContext(),
|
||||
static_sharded_dims_offsets),
|
||||
{});
|
||||
::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
|
||||
::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {});
|
||||
}
|
||||
|
||||
void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
|
||||
@@ -453,6 +475,18 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
|
||||
{}, {}, {}, {});
|
||||
}
|
||||
|
||||
void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
|
||||
llvm::StringRef mesh, ArrayRef<MeshAxesAttr> split_axes,
|
||||
ArrayRef<int64_t> static_halos,
|
||||
ArrayRef<int64_t> static_offsets) {
|
||||
return build(
|
||||
b, odsState, FlatSymbolRefAttr::get(b.getContext(), mesh),
|
||||
MeshAxesArrayAttr::get(b.getContext(), split_axes), {},
|
||||
::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum),
|
||||
::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
|
||||
::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {});
|
||||
}
|
||||
|
||||
void ShardingOp::build(
|
||||
::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
|
||||
FlatSymbolRefAttr mesh, ArrayRef<MeshAxesAttr> split_axes,
|
||||
@@ -579,9 +613,10 @@ LogicalResult ShardingOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
||||
namespace {
|
||||
// Sharding annotations "halo sizes" and "sharded dims offsets"
|
||||
// are a mix of attributes and dynamic values. This canonicalization moves
|
||||
// constant values to the respective attribute lists and so minimizes the number
|
||||
// constant values to the respective attribute lists, minimizing the number
|
||||
// of values.
|
||||
class FoldDynamicLists final : public OpRewritePattern<ShardingOp> {
|
||||
// It also removes sharded_dims_sizes and halos if they are effectively "empty".
|
||||
class NormalizeSharding final : public OpRewritePattern<ShardingOp> {
|
||||
public:
|
||||
using OpRewritePattern<ShardingOp>::OpRewritePattern;
|
||||
|
||||
@@ -593,18 +628,48 @@ public:
|
||||
op.getDynamicShardedDimsOffsets(), b);
|
||||
|
||||
// No constant operands were folded, just return;
|
||||
if (failed(foldDynamicIndexList(mixedHalos, /*onlyNonNegative=*/true)) &&
|
||||
failed(foldDynamicIndexList(mixedOffs, /*onlyNonNegative=*/true))) {
|
||||
bool modified = succeeded(foldDynamicIndexList(mixedHalos, true)) ||
|
||||
succeeded(foldDynamicIndexList(mixedOffs, true));
|
||||
|
||||
auto [staticHalos, dynamicHalos] = decomposeMixedValues(mixedHalos);
|
||||
auto [staticOffs, dynamicOffs] = decomposeMixedValues(mixedOffs);
|
||||
|
||||
if (dynamicHalos.empty() && !staticHalos.empty()) {
|
||||
if (staticHalos[0] == 0 && llvm::all_equal(staticHalos)) {
|
||||
staticHalos.clear();
|
||||
modified = true;
|
||||
}
|
||||
}
|
||||
|
||||
// Remove sharded dims offsets if they are effectively the default values,
|
||||
// e.g. if they define equi-distance between all neighboring shards.
|
||||
// Requires static-only offsets. Compares the first distance as the
|
||||
// difference between the first two offsets. Only if all consecutive
|
||||
// distances are the same, the offsets are removed.
|
||||
if (dynamicOffs.empty() && !staticOffs.empty()) {
|
||||
assert(staticOffs.size() >= 2);
|
||||
auto diff = staticOffs[1] - staticOffs[0];
|
||||
bool all_same = staticOffs.size() > 2;
|
||||
for (auto i = 2u; i < staticOffs.size(); ++i) {
|
||||
if (staticOffs[i] - staticOffs[i - 1] != diff) {
|
||||
all_same = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (all_same) {
|
||||
staticOffs.clear();
|
||||
modified = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!modified) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto halos = decomposeMixedValues(mixedHalos);
|
||||
auto offs = decomposeMixedValues(mixedOffs);
|
||||
|
||||
op.setStaticHaloSizes(halos.first);
|
||||
op.getDynamicHaloSizesMutable().assign(halos.second);
|
||||
op.setStaticShardedDimsOffsets(offs.first);
|
||||
op.getDynamicShardedDimsOffsetsMutable().assign(offs.second);
|
||||
op.setStaticHaloSizes(staticHalos);
|
||||
op.getDynamicHaloSizesMutable().assign(dynamicHalos);
|
||||
op.setStaticShardedDimsOffsets(staticOffs);
|
||||
op.getDynamicShardedDimsOffsetsMutable().assign(dynamicOffs);
|
||||
|
||||
return success();
|
||||
}
|
||||
@@ -613,7 +678,7 @@ public:
|
||||
|
||||
void ShardingOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
|
||||
mlir::MLIRContext *context) {
|
||||
results.add<FoldDynamicLists>(context);
|
||||
results.add<NormalizeSharding>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -707,11 +772,19 @@ bool MeshSharding::operator!=(const MeshSharding &rhs) const {
|
||||
return !(*this == rhs);
|
||||
}
|
||||
|
||||
MeshSharding::MeshSharding(::mlir::FlatSymbolRefAttr mesh_) : mesh(mesh_) {}
|
||||
|
||||
MeshSharding::MeshSharding(Value rhs) {
|
||||
auto shardingOp = mlir::dyn_cast<ShardingOp>(rhs.getDefiningOp());
|
||||
assert(shardingOp && "expected sharding op");
|
||||
*this = get(shardingOp.getMeshAttr(), shardingOp.getSplitAxes().getAxes(),
|
||||
shardingOp.getPartialAxes().value_or(ArrayRef<MeshAxis>()),
|
||||
auto splitAxes = shardingOp.getSplitAxes().getAxes();
|
||||
auto partialAxes = shardingOp.getPartialAxes().value_or(ArrayRef<MeshAxis>());
|
||||
// If splitAxes and partialAxes are empty, use "empty" constructor.
|
||||
if (splitAxes.empty() && partialAxes.empty()) {
|
||||
*this = MeshSharding(shardingOp.getMeshAttr());
|
||||
return;
|
||||
}
|
||||
*this = get(shardingOp.getMeshAttr(), splitAxes, partialAxes,
|
||||
shardingOp.getPartialType().value_or(ReductionKind::Sum),
|
||||
shardingOp.getStaticHaloSizes(),
|
||||
shardingOp.getStaticShardedDimsOffsets(),
|
||||
@@ -727,8 +800,11 @@ MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_,
|
||||
ArrayRef<int64_t> static_sharded_dims_offsets_,
|
||||
ArrayRef<Value> dynamic_halo_sizes_,
|
||||
ArrayRef<Value> dynamic_sharded_dims_offsets_) {
|
||||
MeshSharding res;
|
||||
res.mesh = mesh_;
|
||||
MeshSharding res(mesh_);
|
||||
if (split_axes_.empty() && partial_axes_.empty()) {
|
||||
return res;
|
||||
}
|
||||
|
||||
res.split_axes.resize(split_axes_.size());
|
||||
for (auto [i, axis] : llvm::enumerate(split_axes_)) {
|
||||
res.split_axes[i] =
|
||||
@@ -771,6 +847,53 @@ void ShardOp::getAsmResultNames(
|
||||
setNameFn(getResult(), "sharding_annotated");
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Determine if the given ShardOp is a duplicate of another ShardOp
|
||||
// on the same value. This can happen if constant values are sharded.
|
||||
class FoldDuplicateShardOp final : public OpRewritePattern<ShardOp> {
|
||||
public:
|
||||
using OpRewritePattern<ShardOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ShardOp op, PatternRewriter &b) const override {
|
||||
// Get the use-list of the value being sharded and check if it has more than
|
||||
// one use.
|
||||
Value value = op.getSrc();
|
||||
if (value.hasOneUse() || value.getDefiningOp<ShardOp>()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Iterate through the uses of the value to find a duplicate ShardOp.
|
||||
for (auto &use : value.getUses()) {
|
||||
if (use.getOwner() != op.getOperation()) {
|
||||
auto otherOp = dyn_cast<ShardOp>(use.getOwner());
|
||||
if (!otherOp || !otherOp->isBeforeInBlock(op)) {
|
||||
return failure();
|
||||
}
|
||||
// Create a MeshSharding object for the current and the other ShardOp
|
||||
// If the two are equal replace current op with the other op.
|
||||
MeshSharding currentSharding(op.getSharding());
|
||||
MeshSharding otherSharding(otherOp.getSharding());
|
||||
if (currentSharding == otherSharding) {
|
||||
b.replaceAllUsesWith(op.getResult(), otherOp.getResult());
|
||||
b.eraseOp(op.getOperation());
|
||||
} else {
|
||||
// use the other sharding as input for op
|
||||
op.getSrcMutable().assign(otherOp.getResult());
|
||||
}
|
||||
return success();
|
||||
}
|
||||
}
|
||||
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void ShardOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
|
||||
mlir::MLIRContext *context) {
|
||||
results.add<FoldDuplicateShardOp>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// mesh.process_multi_index op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -168,17 +168,12 @@ LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() {
|
||||
|
||||
// check operands and results type
|
||||
for (Type type : op->getOperandTypes())
|
||||
if (!llvm::isa<RankedTensorType>(type))
|
||||
if (!llvm::isa<RankedTensorType>(type) && !type.isIntOrIndexOrFloat())
|
||||
return failure();
|
||||
for (Type type : op->getResultTypes())
|
||||
if (!llvm::isa<RankedTensorType>(type))
|
||||
if (!llvm::isa<RankedTensorType>(type) && !type.isIntOrIndexOrFloat())
|
||||
return failure();
|
||||
|
||||
// check loop types
|
||||
SmallVector<utils::IteratorType> loopTypes = getLoopIteratorTypes();
|
||||
if (loopTypes.empty())
|
||||
return failure();
|
||||
|
||||
// check maps
|
||||
SmallVector<AffineMap> maps = getIndexingMaps();
|
||||
if (maps.empty())
|
||||
@@ -286,18 +281,22 @@ mesh::detail::defaultGetShardingOption(Operation *op,
|
||||
continue;
|
||||
AffineMap map = maps[numOperands + shardingIt.index()];
|
||||
anyShardingInResultsOrOperands = true;
|
||||
// Handle the split axes: calculate the corresponding loop index for each
|
||||
// split axes sub-array, and then store the sub-array to
|
||||
// shardingOption[index]
|
||||
for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
|
||||
AffineExpr expr = std::get<0>(it);
|
||||
ArrayRef<MeshAxis> axes = std::get<1>(it).asArrayRef();
|
||||
auto dim = cast<AffineDimExpr>(expr);
|
||||
unsigned index = dim.getPosition();
|
||||
visitedLoopIndices.insert(index);
|
||||
if (failed(fillShardingOption(op, shardingOption, shardAttr.getMeshAttr(),
|
||||
axes, index)))
|
||||
return failure();
|
||||
if (shardAttr.getSplitAxes().empty() || map.getResults().empty()) {
|
||||
shardingOption.mesh = shardAttr.getMeshAttr();
|
||||
} else {
|
||||
// Handle the split axes: calculate the corresponding loop index for each
|
||||
// split axes sub-array, and then store the sub-array to
|
||||
// shardingOption[index]
|
||||
for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
|
||||
AffineExpr expr = std::get<0>(it);
|
||||
ArrayRef<MeshAxis> axes = std::get<1>(it).asArrayRef();
|
||||
auto dim = cast<AffineDimExpr>(expr);
|
||||
unsigned index = dim.getPosition();
|
||||
visitedLoopIndices.insert(index);
|
||||
if (failed(fillShardingOption(op, shardingOption,
|
||||
shardAttr.getMeshAttr(), axes, index)))
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
// Handle the partial axes: at this stage, the exact loop index/indices
|
||||
@@ -323,7 +322,7 @@ mesh::detail::defaultGetShardingOption(Operation *op,
|
||||
if (!shardAttr)
|
||||
continue;
|
||||
|
||||
anyShardingInResultsOrOperands = true;
|
||||
anyShardingInResultsOrOperands = !shardAttr.getSplitAxes().empty();
|
||||
AffineMap map = maps[shardingIt.index()];
|
||||
unsigned numDims = map.getNumDims();
|
||||
|
||||
@@ -448,7 +447,16 @@ static FailureOr<MeshSharding> getSharding(OpOperand &opOperand,
|
||||
const ShardingOption &shardingOption,
|
||||
AffineMap map) {
|
||||
Value operandValue = opOperand.get();
|
||||
auto operandType = cast<RankedTensorType>(operandValue.getType());
|
||||
auto operandType = dyn_cast<RankedTensorType>(operandValue.getType());
|
||||
if (!operandType) {
|
||||
if (operandValue.getType().isIntOrIndexOrFloat())
|
||||
return MeshSharding();
|
||||
return failure();
|
||||
}
|
||||
// 0d tensors cannot be sharded and must get replicated
|
||||
if (operandType.getRank() == 0) {
|
||||
return MeshSharding(shardingOption.mesh);
|
||||
}
|
||||
SmallVector<SmallVector<MeshAxis>> splitAxes(operandType.getRank());
|
||||
unsigned numDims = map.getNumDims();
|
||||
for (auto it : llvm::enumerate(map.getResults())) {
|
||||
@@ -579,7 +587,7 @@ static bool
|
||||
isValueCompatibleWithFullReplicationSharding(Value value,
|
||||
MeshSharding sharding) {
|
||||
if (isa<RankedTensorType>(value.getType())) {
|
||||
return sharding && isFullReplication(sharding);
|
||||
return isFullReplication(sharding);
|
||||
}
|
||||
|
||||
return !sharding;
|
||||
|
||||
@@ -282,11 +282,12 @@ static FailureOr<ShardingOption> selectShardingOption(
|
||||
// a `mesh.shard` operation for all remaining operands and results that do not
|
||||
// have sharding annotations.
|
||||
static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
|
||||
ShardingInterface shardingOp = llvm::dyn_cast<ShardingInterface>(op);
|
||||
if (op->hasTrait<OpTrait::IsTerminator>() ||
|
||||
llvm::isa<mesh::ShardOp, mesh::ShardingOp>(op))
|
||||
(op->hasTrait<OpTrait::ConstantLike>() && !shardingOp) ||
|
||||
llvm::isa<mesh::ShardOp, mesh::ShardingOp, mesh::GetShardingOp>(op))
|
||||
return success();
|
||||
|
||||
ShardingInterface shardingOp = llvm::dyn_cast<ShardingInterface>(op);
|
||||
if (!shardingOp) {
|
||||
op->emitOpError() << "sharding interface is not implemented.";
|
||||
return failure();
|
||||
|
||||
@@ -561,7 +561,8 @@ TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, MeshOp mesh,
|
||||
TypedValue<ShapedType> sourceUnshardedValue,
|
||||
TypedValue<ShapedType> sourceShard) {
|
||||
// If source and destination sharding are the same, no need to do anything.
|
||||
if (sourceSharding == targetSharding) {
|
||||
if (sourceSharding == targetSharding || (isFullReplication(sourceSharding) &&
|
||||
isFullReplication(targetSharding))) {
|
||||
return sourceShard;
|
||||
}
|
||||
|
||||
@@ -636,14 +637,6 @@ shardedBlockArgumentTypes(Block &block,
|
||||
return res;
|
||||
}
|
||||
|
||||
void spmdizeTriviallyShardableOperation(Operation &op,
|
||||
ArrayRef<Value> spmdizedOperands,
|
||||
ArrayRef<MeshSharding> operandShardings,
|
||||
ArrayRef<MeshSharding> resultShardings,
|
||||
IRMapping &spmdizationMap,
|
||||
SymbolTableCollection &symbolTable,
|
||||
OpBuilder &builder);
|
||||
|
||||
static LogicalResult spmdizeOperation(
|
||||
Operation &op, ArrayRef<Value> spmdizedOperands,
|
||||
ArrayRef<MeshSharding> operandShardings,
|
||||
@@ -703,8 +696,9 @@ static std::vector<MeshSharding> getResultShardings(Operation &op) {
|
||||
if (!rankedTensor) {
|
||||
return MeshSharding();
|
||||
}
|
||||
|
||||
assert(result.hasOneUse());
|
||||
if (!result.hasOneUse()) {
|
||||
return MeshSharding();
|
||||
}
|
||||
Operation *userOp = *result.getUsers().begin();
|
||||
ShardOp shardOp = llvm::cast<ShardOp>(userOp);
|
||||
return MeshSharding(shardOp.getSharding());
|
||||
@@ -744,6 +738,15 @@ spmdizeOperation(Operation &op, IRMapping &spmdizationMap,
|
||||
if (isa<ShardingOp>(op)) {
|
||||
return success();
|
||||
}
|
||||
if (auto getShardingOp = dyn_cast<GetShardingOp>(op)) {
|
||||
auto shardOp = getShardingOp.getSource().getDefiningOp<ShardOp>();
|
||||
if (!shardOp) {
|
||||
return op.emitError("expected a shard op as source of get_sharding");
|
||||
}
|
||||
auto newSharding = builder.clone(*shardOp.getSharding().getDefiningOp());
|
||||
spmdizationMap.map(op.getResult(0), newSharding->getResult(0));
|
||||
return success();
|
||||
}
|
||||
|
||||
ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
|
||||
if (shardOp) {
|
||||
@@ -765,6 +768,7 @@ spmdizeOperation(Operation &op, IRMapping &spmdizationMap,
|
||||
static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap,
|
||||
SymbolTableCollection &symbolTableCollection,
|
||||
OpBuilder &builder) {
|
||||
|
||||
SmallVector<Location> argLocations;
|
||||
llvm::transform(block.getArguments(), std::back_inserter(argLocations),
|
||||
[](BlockArgument arg) { return arg.getLoc(); });
|
||||
@@ -796,8 +800,12 @@ spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap,
|
||||
// Snapshot the original blocks to not mess up the iteration when adding new
|
||||
// blocks.
|
||||
SmallVector<Block *> originalBlocks;
|
||||
llvm::transform(op.getBlocks(), std::back_inserter(originalBlocks),
|
||||
[](Block &b) { return &b; });
|
||||
for (Block &b : op.getBlocks()) {
|
||||
if (llvm::any_of(b.getOperations(),
|
||||
[](Operation &op) { return isa<ShardOp>(op); })) {
|
||||
originalBlocks.push_back(&b);
|
||||
}
|
||||
}
|
||||
|
||||
for (Block *block : originalBlocks) {
|
||||
if (failed(spmdizeBlock(*block, spmdizationMap, symbolTableCollection,
|
||||
@@ -823,10 +831,11 @@ spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap,
|
||||
break;
|
||||
}
|
||||
}
|
||||
assert(returnOp);
|
||||
op.setType(FunctionType::get(op->getContext(),
|
||||
op.getFunctionBody().front().getArgumentTypes(),
|
||||
returnOp->getOperandTypes()));
|
||||
if (returnOp) {
|
||||
op.setType(FunctionType::get(
|
||||
op->getContext(), op.getFunctionBody().front().getArgumentTypes(),
|
||||
returnOp->getOperandTypes()));
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -22,10 +22,11 @@ using namespace mlir::mesh;
|
||||
|
||||
namespace {
|
||||
|
||||
// Sharding of tensor.empty
|
||||
struct EmptyOpShardingInterface
|
||||
: public ShardingInterface::ExternalModel<EmptyOpShardingInterface,
|
||||
tensor::EmptyOp> {
|
||||
// Sharding of tensor.empty/tensor.splat
|
||||
template <typename OpTy>
|
||||
struct CreatorOpShardingInterface
|
||||
: public ShardingInterface::ExternalModel<CreatorOpShardingInterface<OpTy>,
|
||||
OpTy> {
|
||||
SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
|
||||
auto ndims = mlir::cast<ShapedType>(op->getResult(0).getType()).getRank();
|
||||
return SmallVector<utils::IteratorType>(ndims,
|
||||
@@ -38,7 +39,9 @@ struct EmptyOpShardingInterface
|
||||
auto type = dyn_cast<RankedTensorType>(val.getType());
|
||||
if (!type)
|
||||
return {};
|
||||
return {AffineMap::getMultiDimIdentityMap(type.getRank(), ctx)};
|
||||
return SmallVector<AffineMap>(
|
||||
op->getNumOperands() + op->getNumResults(),
|
||||
{AffineMap::getMultiDimIdentityMap(type.getRank(), ctx)});
|
||||
}
|
||||
|
||||
LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
|
||||
@@ -82,8 +85,7 @@ struct EmptyOpShardingInterface
|
||||
newOperands.emplace_back(spmdizedOperands[++currOldOprndNum]);
|
||||
}
|
||||
}
|
||||
newOp =
|
||||
builder.create<tensor::EmptyOp>(op->getLoc(), shardType, newOperands);
|
||||
newOp = builder.create<OpTy>(op->getLoc(), shardType, newOperands);
|
||||
spmdizationMap.map(op->getResult(0), newOp->getResult(0));
|
||||
} else {
|
||||
// `clone` will populate the mapping of old to new results.
|
||||
@@ -100,6 +102,9 @@ void mlir::tensor::registerShardingInterfaceExternalModels(
|
||||
DialectRegistry ®istry) {
|
||||
|
||||
registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
|
||||
EmptyOp::template attachInterface<EmptyOpShardingInterface>(*ctx);
|
||||
EmptyOp::template attachInterface<CreatorOpShardingInterface<EmptyOp>>(
|
||||
*ctx);
|
||||
SplatOp::template attachInterface<CreatorOpShardingInterface<SplatOp>>(
|
||||
*ctx);
|
||||
});
|
||||
}
|
||||
|
||||
17
mlir/test/Dialect/Arith/mesh-spmdize.mlir
Normal file
17
mlir/test/Dialect/Arith/mesh-spmdize.mlir
Normal file
@@ -0,0 +1,17 @@
|
||||
// RUN: mlir-opt \
|
||||
// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization))" \
|
||||
// RUN: %s | FileCheck %s
|
||||
|
||||
mesh.mesh @mesh4x4(shape = 4x4)
|
||||
|
||||
// CHECK-LABEL: func @test_spmdize_constant
|
||||
// CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> :
|
||||
// tensor<256x1024xf32> CHECK-NEXT: [[vc434_i32:%.*]] = arith.constant 434 :
|
||||
// i32 CHECK-NEXT: return [[vcst]] : tensor<256x1024xf32>
|
||||
func.func @test_spmdize_constant() ->(tensor<1024x1024xf32>)attributes{llvm.emit_c_interface} {
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
|
||||
%sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding
|
||||
%sharding_annotated_1 = mesh.shard %cst to %sharding_1 : tensor<1024x1024xf32>
|
||||
%ci = arith.constant 434 : i32
|
||||
return %sharding_annotated_1 : tensor<1024x1024xf32>
|
||||
}
|
||||
54
mlir/test/Dialect/Arith/sharding-propagation.mlir
Normal file
54
mlir/test/Dialect/Arith/sharding-propagation.mlir
Normal file
@@ -0,0 +1,54 @@
|
||||
// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation))" %s | FileCheck %s
|
||||
|
||||
mesh.mesh @mesh4x4(shape = 4x4)
|
||||
|
||||
// CHECK-LABEL: func.func @test_shard_constant() -> tensor<1024x1024xf32> attributes {llvm.emit_c_interface} {
|
||||
// CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
|
||||
// CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
|
||||
// CHECK-NEXT: [[vsharding_annotated:%.*]] = mesh.shard [[vcst]] to [[vsharding]] : tensor<1024x1024xf32>
|
||||
// CHECK-NEXT: [[vcst_0:%.*]] = arith.constant 4.340000e+01 : f32
|
||||
// CHECK-NEXT: [[v0:%.*]] = tensor.empty() : tensor<1024x1024xf32>
|
||||
// CHECK-NEXT: [[vsharding_1:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
|
||||
// CHECK-NEXT: [[vsharding_annotated_2:%.*]] = mesh.shard [[v0]] to [[vsharding_1]] : tensor<1024x1024xf32>
|
||||
// CHECK-NEXT: [[vsharding_3:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
|
||||
// CHECK-NEXT: [[vsharding_annotated_4:%.*]] = mesh.shard [[vsharding_annotated]] to [[vsharding_3]] annotate_for_users : tensor<1024x1024xf32>
|
||||
// CHECK-NEXT: [[vsharding_5:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
|
||||
// CHECK-NEXT: [[vsharding_annotated_6:%.*]] = mesh.shard [[vsharding_annotated_2]] to [[vsharding_5]] annotate_for_users : tensor<1024x1024xf32>
|
||||
// CHECK-NEXT: [[v1:%.*]] = linalg.add ins([[vsharding_annotated_4]], [[vcst_0]] : tensor<1024x1024xf32>, f32) outs([[vsharding_annotated_6]] : tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
|
||||
// CHECK-NEXT: [[vsharding_7:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
|
||||
// CHECK-NEXT: [[vsharding_annotated_8:%.*]] = mesh.shard [[v1]] to [[vsharding_7]] : tensor<1024x1024xf32>
|
||||
// CHECK-NEXT: return [[vsharding_annotated_8]] : tensor<1024x1024xf32>
|
||||
func.func @test_shard_constant() -> (tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} {
|
||||
%cst_1 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
|
||||
%sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding
|
||||
%sharding_annotated_1 = mesh.shard %cst_1 to %sharding_1 : tensor<1024x1024xf32>
|
||||
%ci = arith.constant 43.4e+00 : f32
|
||||
%o1 = tensor.empty() : tensor<1024x1024xf32>
|
||||
%res = linalg.add ins(%sharding_annotated_1, %ci : tensor<1024x1024xf32>, f32) outs(%o1 : tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
|
||||
return %res : tensor<1024x1024xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @test_shard_constant_back() -> tensor<1024x1024xf32> attributes {llvm.emit_c_interface} {
|
||||
// CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
|
||||
// CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
|
||||
// CHECK-NEXT: [[vsharding_annotated:%.*]] = mesh.shard [[vcst]] to [[vsharding]] : tensor<1024x1024xf32>
|
||||
// CHECK-NEXT: [[vcst_0:%.*]] = arith.constant 4.340000e+01 : f32
|
||||
// CHECK-NEXT: [[v0:%.*]] = tensor.empty() : tensor<1024x1024xf32>
|
||||
// CHECK-NEXT: [[vsharding_1:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
|
||||
// CHECK-NEXT: [[vsharding_annotated_2:%.*]] = mesh.shard [[v0]] to [[vsharding_1]] : tensor<1024x1024xf32>
|
||||
// CHECK-NEXT: [[vsharding_3:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
|
||||
// CHECK-NEXT: [[vsharding_annotated_4:%.*]] = mesh.shard [[vsharding_annotated]] to [[vsharding_3]] annotate_for_users : tensor<1024x1024xf32>
|
||||
// CHECK-NEXT: [[vsharding_5:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
|
||||
// CHECK-NEXT: [[vsharding_annotated_6:%.*]] = mesh.shard [[vsharding_annotated_2]] to [[vsharding_5]] annotate_for_users : tensor<1024x1024xf32>
|
||||
// CHECK-NEXT: [[v1:%.*]] = linalg.add ins([[vsharding_annotated_4]], [[vcst_0]] : tensor<1024x1024xf32>, f32) outs([[vsharding_annotated_6]] : tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
|
||||
// CHECK-NEXT: [[vsharding_7:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
|
||||
// CHECK-NEXT: [[vsharding_annotated_8:%.*]] = mesh.shard [[v1]] to [[vsharding_7]] : tensor<1024x1024xf32>
|
||||
func.func @test_shard_constant_back() -> (tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} {
|
||||
%cst_1 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
|
||||
%ci = arith.constant 43.4e+00 : f32
|
||||
%o1 = tensor.empty() : tensor<1024x1024xf32>
|
||||
%res = linalg.add ins(%cst_1, %ci : tensor<1024x1024xf32>, f32) outs(%o1 : tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
|
||||
%sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding
|
||||
%sharding_annotated_1 = mesh.shard %res to %sharding_1 : tensor<1024x1024xf32>
|
||||
return %sharding_annotated_1 : tensor<1024x1024xf32>
|
||||
}
|
||||
@@ -207,4 +207,42 @@ func.func @test_shard_offs() -> !mesh.sharding {
|
||||
// CHECK mesh.sharding @mesh4x4 split_axes = [[0], [1]] sharded_dims_offsets = [0, 1, 2, 3, 4, 0, 2, 3, 4, 22] : !mesh.sharding
|
||||
%sharding = mesh.sharding @mesh4x4 split_axes = [[0], [1]] sharded_dims_offsets = [0, 1, %c2_i64, 3, 4, 0, %c2_i64, 3, 4, 22] : !mesh.sharding
|
||||
return %sharding : !mesh.sharding
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_duplicate_shardops
|
||||
func.func @test_duplicate_shardops() -> (tensor<1024x1024xf32>, tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} {
|
||||
// CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
|
||||
%cst_1 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
|
||||
// CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0, 1]] : !mesh.sharding
|
||||
%sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0, 1]] : !mesh.sharding
|
||||
%cst_2 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
|
||||
%sharding_2 = mesh.sharding @mesh4x4 split_axes = [[0, 1]] : !mesh.sharding
|
||||
%sharding_annotated_2 = mesh.shard %cst_2 to %sharding_2 : tensor<1024x1024xf32>
|
||||
%cst_3 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
|
||||
%sharding_3 = mesh.sharding @mesh4x4 split_axes = [[0, 1]] : !mesh.sharding
|
||||
%sharding_annotated_3 = mesh.shard %cst_3 to %sharding_3 : tensor<1024x1024xf32>
|
||||
// CHECK-NEXT: [[vsharding_annotated:%.*]] = mesh.shard [[vcst]] to [[vsharding]] : tensor<1024x1024xf32>
|
||||
%sharding_annotated_1 = mesh.shard %cst_1 to %sharding_1 : tensor<1024x1024xf32>
|
||||
// CHECK-NEXT: return [[vsharding_annotated]], [[vsharding_annotated]] : tensor<1024x1024xf32>, tensor<1024x1024xf32>
|
||||
return %sharding_annotated_1, %sharding_annotated_2 : tensor<1024x1024xf32>, tensor<1024x1024xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_duplicate_shardops_diff
|
||||
func.func @test_duplicate_shardops_diff() -> (tensor<1024x1024xf32>, tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} {
|
||||
// CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
|
||||
%cst_1 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
|
||||
// CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
|
||||
%sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding
|
||||
%cst_2 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
|
||||
// CHECK-NEXT: [[vsharding_0:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0, 1]] : !mesh.sharding
|
||||
%sharding_2 = mesh.sharding @mesh4x4 split_axes = [[0, 1]] : !mesh.sharding
|
||||
// CHECK-NEXT: [[vsharding_annotated:%.*]] = mesh.shard [[vcst]] to [[vsharding_0]] : tensor<1024x1024xf32>
|
||||
%sharding_annotated_2 = mesh.shard %cst_2 to %sharding_2 : tensor<1024x1024xf32>
|
||||
%cst_3 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
|
||||
%sharding_3 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding
|
||||
%sharding_annotated_3 = mesh.shard %cst_3 to %sharding_3 : tensor<1024x1024xf32>
|
||||
// CHECK-NEXT: [[vsharding_annotated_1:%.*]] = mesh.shard [[vsharding_annotated]] to [[vsharding]] : tensor<1024x1024xf32>
|
||||
%sharding_annotated_1 = mesh.shard %cst_1 to %sharding_1 : tensor<1024x1024xf32>
|
||||
// CHECK-NEXT: return [[vsharding_annotated_1]], [[vsharding_annotated]] : tensor<1024x1024xf32>, tensor<1024x1024xf32>
|
||||
return %sharding_annotated_1, %sharding_annotated_2 : tensor<1024x1024xf32>, tensor<1024x1024xf32>
|
||||
}
|
||||
|
||||
@@ -164,6 +164,14 @@ func.func @mesh_shard_shape() {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @mesh_get_sharding
|
||||
// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
|
||||
func.func @mesh_get_sharding(%arg0 : tensor<4x8xf32>) -> !mesh.sharding {
|
||||
// CHECK-NEXT: mesh.get_sharding %[[ARG]] : tensor<4x8xf32> -> !mesh.sharding
|
||||
%0 = mesh.get_sharding %arg0 : tensor<4x8xf32> -> !mesh.sharding
|
||||
return %0 : !mesh.sharding
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @mesh_shape
|
||||
func.func @mesh_shape() -> (index, index) {
|
||||
// CHECK: %[[RES:.*]]:2 = mesh.mesh_shape @mesh0 axes = [0, 1] : index, index
|
||||
|
||||
@@ -4,6 +4,20 @@
|
||||
|
||||
mesh.mesh @mesh_1d(shape = 2)
|
||||
|
||||
// CHECK-LABEL: func @return_sharding
|
||||
func.func @return_sharding(
|
||||
// CHECK-SAME: [[ARG:%.*]]: tensor<1xf32>
|
||||
%arg0: tensor<2xf32>
|
||||
// CHECK-SAME: ) -> (tensor<1xf32>, !mesh.sharding) {
|
||||
) -> (tensor<2xf32>, !mesh.sharding) {
|
||||
%ssharding_annotated = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
|
||||
%sharding_annotated = mesh.shard %arg0 to %ssharding_annotated : tensor<2xf32>
|
||||
// CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}0]] : !mesh.sharding
|
||||
%r = mesh.get_sharding %sharding_annotated : tensor<2xf32> -> !mesh.sharding
|
||||
// CHECK-NEXT: return [[ARG]], [[vsharding]] : tensor<1xf32>, !mesh.sharding
|
||||
return %sharding_annotated, %r : tensor<2xf32>, !mesh.sharding
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @full_replication
|
||||
func.func @full_replication(
|
||||
// CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
|
||||
|
||||
Reference in New Issue
Block a user