Lower allreduce (#144716)
Adding lowering mesh.allreduce to mpi.allreduce. Minor restructuring to increase code reuse.
This commit is contained in:
@@ -905,6 +905,8 @@ def ConvertMeshToMPIPass : Pass<"convert-mesh-to-mpi"> {
|
||||
shard/partition sizes depend on the rank.
|
||||
}];
|
||||
let dependentDialects = [
|
||||
"affine::AffineDialect",
|
||||
"arith::ArithDialect",
|
||||
"memref::MemRefDialect",
|
||||
"mpi::MPIDialect",
|
||||
"scf::SCFDialect",
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MPIDialect
|
||||
|
||||
@@ -230,7 +230,7 @@ def MPI_OpMinloc : I32EnumAttrCase<"MPI_MINLOC", 11, "MPI_MINLOC">;
|
||||
def MPI_OpMaxloc : I32EnumAttrCase<"MPI_MAXLOC", 12, "MPI_MAXLOC">;
|
||||
def MPI_OpReplace : I32EnumAttrCase<"MPI_REPLACE", 13, "MPI_REPLACE">;
|
||||
|
||||
def MPI_OpClassEnum : I32EnumAttr<"MPI_OpClassEnum", "MPI operation class", [
|
||||
def MPI_ReductionOpEnum : I32EnumAttr<"MPI_ReductionOpEnum", "MPI operation class", [
|
||||
MPI_OpNull,
|
||||
MPI_OpMax,
|
||||
MPI_OpMin,
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
|
||||
include "mlir/Dialect/MPI/IR/MPI.td"
|
||||
include "mlir/Dialect/MPI/IR/MPITypes.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
|
||||
class MPI_Op<string mnemonic, list<Trait> traits = []>
|
||||
: Op<MPI_Dialect, mnemonic, traits>;
|
||||
@@ -41,7 +42,7 @@ def MPI_InitOp : MPI_Op<"init", []> {
|
||||
// CommWorldOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def MPI_CommWorldOp : MPI_Op<"comm_world", []> {
|
||||
def MPI_CommWorldOp : MPI_Op<"comm_world", [Pure]> {
|
||||
let summary = "Get the World communicator, equivalent to `MPI_COMM_WORLD`";
|
||||
let description = [{
|
||||
This operation returns the predefined MPI_COMM_WORLD communicator.
|
||||
@@ -56,7 +57,7 @@ def MPI_CommWorldOp : MPI_Op<"comm_world", []> {
|
||||
// CommRankOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def MPI_CommRankOp : MPI_Op<"comm_rank", []> {
|
||||
def MPI_CommRankOp : MPI_Op<"comm_rank", [Pure]> {
|
||||
let summary = "Get the current rank, equivalent to "
|
||||
"`MPI_Comm_rank(comm, &rank)`";
|
||||
let description = [{
|
||||
@@ -72,13 +73,14 @@ def MPI_CommRankOp : MPI_Op<"comm_rank", []> {
|
||||
);
|
||||
|
||||
let assemblyFormat = "`(` $comm `)` attr-dict `:` type(results)";
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CommSizeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def MPI_CommSizeOp : MPI_Op<"comm_size", []> {
|
||||
def MPI_CommSizeOp : MPI_Op<"comm_size", [Pure]> {
|
||||
let summary = "Get the size of the group associated to the communicator, "
|
||||
"equivalent to `MPI_Comm_size(comm, &size)`";
|
||||
let description = [{
|
||||
@@ -100,7 +102,7 @@ def MPI_CommSizeOp : MPI_Op<"comm_size", []> {
|
||||
// CommSplitOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def MPI_CommSplitOp : MPI_Op<"comm_split", []> {
|
||||
def MPI_CommSplitOp : MPI_Op<"comm_split", [Pure]> {
|
||||
let summary = "Partition the group associated with the given communicator into "
|
||||
"disjoint subgroups";
|
||||
let description = [{
|
||||
@@ -281,7 +283,7 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> {
|
||||
let arguments = (
|
||||
ins AnyMemRef : $sendbuf,
|
||||
AnyMemRef : $recvbuf,
|
||||
MPI_OpClassEnum : $op,
|
||||
MPI_ReductionOpEnum : $op,
|
||||
MPI_Comm : $comm
|
||||
);
|
||||
|
||||
|
||||
@@ -212,6 +212,11 @@ void maybeInsertSourceShardingAnnotation(MeshSharding sharding,
|
||||
OpOperand &operand,
|
||||
OpBuilder &builder);
|
||||
|
||||
/// Converts a vector of OpFoldResults (ints) into vector of Values of the
|
||||
/// provided type.
|
||||
SmallVector<Value> getMixedAsValues(OpBuilder b, const Location &loc,
|
||||
llvm::ArrayRef<int64_t> statics,
|
||||
ValueRange dynamics, Type type = Type());
|
||||
} // namespace mesh
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
@@ -584,11 +584,11 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
|
||||
```
|
||||
}];
|
||||
let arguments = !con(commonArgs, (ins
|
||||
AnyRankedTensor:$input,
|
||||
AnyTypeOf<[AnyMemRef, AnyRankedTensor]>:$input,
|
||||
DefaultValuedAttr<Mesh_ReductionKindAttr, "::mlir::mesh::ReductionKind::Sum">:$reduction
|
||||
));
|
||||
let results = (outs
|
||||
AnyRankedTensor:$result
|
||||
AnyTypeOf<[AnyMemRef, AnyRankedTensor]>:$result
|
||||
);
|
||||
let assemblyFormat = [{
|
||||
$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? (`reduction` `=` $reduction^)?
|
||||
|
||||
@@ -62,9 +62,9 @@ void populateAllReduceEndomorphismSimplificationPatterns(
|
||||
auto isEndomorphismOp = [reduction](Operation *op,
|
||||
std::optional<Operation *> referenceOp) {
|
||||
auto allReduceOp = llvm::dyn_cast<AllReduceOp>(op);
|
||||
if (!allReduceOp ||
|
||||
allReduceOp.getInput().getType().getElementType() !=
|
||||
allReduceOp.getResult().getType().getElementType() ||
|
||||
auto inType = cast<ShapedType>(allReduceOp.getInput().getType());
|
||||
auto outType = cast<ShapedType>(allReduceOp.getResult().getType());
|
||||
if (!allReduceOp || inType.getElementType() != outType.getElementType() ||
|
||||
allReduceOp.getReduction() != reduction) {
|
||||
return false;
|
||||
}
|
||||
@@ -83,9 +83,9 @@ void populateAllReduceEndomorphismSimplificationPatterns(
|
||||
}
|
||||
|
||||
auto refAllReduceOp = llvm::dyn_cast<AllReduceOp>(referenceOp.value());
|
||||
auto refType = cast<ShapedType>(refAllReduceOp.getResult().getType());
|
||||
return refAllReduceOp->getAttrs() == allReduceOp->getAttrs() &&
|
||||
allReduceOp.getInput().getType().getElementType() ==
|
||||
refAllReduceOp.getInput().getType().getElementType();
|
||||
inType.getElementType() == refType.getElementType();
|
||||
};
|
||||
auto isAlgebraicOp = [](Operation *op) {
|
||||
return static_cast<bool>(llvm::dyn_cast<AlgebraicOp>(op));
|
||||
|
||||
@@ -42,6 +42,11 @@ createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes,
|
||||
TypedValue<IndexType> createProcessLinearIndex(StringRef mesh,
|
||||
ArrayRef<MeshAxis> meshAxes,
|
||||
ImplicitLocOpBuilder &builder);
|
||||
// Get process linear index from a multi-index along the given mesh axes .
|
||||
TypedValue<IndexType>
|
||||
createProcessLinearIndex(StringRef mesh, ValueRange processInGroupMultiIndex,
|
||||
ArrayRef<MeshAxis> meshAxes,
|
||||
ImplicitLocOpBuilder &builder);
|
||||
|
||||
} // namespace mesh
|
||||
} // namespace mlir
|
||||
|
||||
@@ -116,7 +116,7 @@ public:
|
||||
/// enum value.
|
||||
virtual Value getMPIOp(const Location loc,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
mpi::MPI_OpClassEnum opAttr) = 0;
|
||||
mpi::MPI_ReductionOpEnum opAttr) = 0;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -199,49 +199,49 @@ public:
|
||||
}
|
||||
|
||||
Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter,
|
||||
mpi::MPI_OpClassEnum opAttr) override {
|
||||
mpi::MPI_ReductionOpEnum opAttr) override {
|
||||
int32_t op = MPI_NO_OP;
|
||||
switch (opAttr) {
|
||||
case mpi::MPI_OpClassEnum::MPI_OP_NULL:
|
||||
case mpi::MPI_ReductionOpEnum::MPI_OP_NULL:
|
||||
op = MPI_NO_OP;
|
||||
break;
|
||||
case mpi::MPI_OpClassEnum::MPI_MAX:
|
||||
case mpi::MPI_ReductionOpEnum::MPI_MAX:
|
||||
op = MPI_MAX;
|
||||
break;
|
||||
case mpi::MPI_OpClassEnum::MPI_MIN:
|
||||
case mpi::MPI_ReductionOpEnum::MPI_MIN:
|
||||
op = MPI_MIN;
|
||||
break;
|
||||
case mpi::MPI_OpClassEnum::MPI_SUM:
|
||||
case mpi::MPI_ReductionOpEnum::MPI_SUM:
|
||||
op = MPI_SUM;
|
||||
break;
|
||||
case mpi::MPI_OpClassEnum::MPI_PROD:
|
||||
case mpi::MPI_ReductionOpEnum::MPI_PROD:
|
||||
op = MPI_PROD;
|
||||
break;
|
||||
case mpi::MPI_OpClassEnum::MPI_LAND:
|
||||
case mpi::MPI_ReductionOpEnum::MPI_LAND:
|
||||
op = MPI_LAND;
|
||||
break;
|
||||
case mpi::MPI_OpClassEnum::MPI_BAND:
|
||||
case mpi::MPI_ReductionOpEnum::MPI_BAND:
|
||||
op = MPI_BAND;
|
||||
break;
|
||||
case mpi::MPI_OpClassEnum::MPI_LOR:
|
||||
case mpi::MPI_ReductionOpEnum::MPI_LOR:
|
||||
op = MPI_LOR;
|
||||
break;
|
||||
case mpi::MPI_OpClassEnum::MPI_BOR:
|
||||
case mpi::MPI_ReductionOpEnum::MPI_BOR:
|
||||
op = MPI_BOR;
|
||||
break;
|
||||
case mpi::MPI_OpClassEnum::MPI_LXOR:
|
||||
case mpi::MPI_ReductionOpEnum::MPI_LXOR:
|
||||
op = MPI_LXOR;
|
||||
break;
|
||||
case mpi::MPI_OpClassEnum::MPI_BXOR:
|
||||
case mpi::MPI_ReductionOpEnum::MPI_BXOR:
|
||||
op = MPI_BXOR;
|
||||
break;
|
||||
case mpi::MPI_OpClassEnum::MPI_MINLOC:
|
||||
case mpi::MPI_ReductionOpEnum::MPI_MINLOC:
|
||||
op = MPI_MINLOC;
|
||||
break;
|
||||
case mpi::MPI_OpClassEnum::MPI_MAXLOC:
|
||||
case mpi::MPI_ReductionOpEnum::MPI_MAXLOC:
|
||||
op = MPI_MAXLOC;
|
||||
break;
|
||||
case mpi::MPI_OpClassEnum::MPI_REPLACE:
|
||||
case mpi::MPI_ReductionOpEnum::MPI_REPLACE:
|
||||
op = MPI_REPLACE;
|
||||
break;
|
||||
}
|
||||
@@ -336,49 +336,49 @@ public:
|
||||
}
|
||||
|
||||
Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter,
|
||||
mpi::MPI_OpClassEnum opAttr) override {
|
||||
mpi::MPI_ReductionOpEnum opAttr) override {
|
||||
StringRef op;
|
||||
switch (opAttr) {
|
||||
case mpi::MPI_OpClassEnum::MPI_OP_NULL:
|
||||
case mpi::MPI_ReductionOpEnum::MPI_OP_NULL:
|
||||
op = "ompi_mpi_no_op";
|
||||
break;
|
||||
case mpi::MPI_OpClassEnum::MPI_MAX:
|
||||
case mpi::MPI_ReductionOpEnum::MPI_MAX:
|
||||
op = "ompi_mpi_max";
|
||||
break;
|
||||
case mpi::MPI_OpClassEnum::MPI_MIN:
|
||||
case mpi::MPI_ReductionOpEnum::MPI_MIN:
|
||||
op = "ompi_mpi_min";
|
||||
break;
|
||||
case mpi::MPI_OpClassEnum::MPI_SUM:
|
||||
case mpi::MPI_ReductionOpEnum::MPI_SUM:
|
||||
op = "ompi_mpi_sum";
|
||||
break;
|
||||
case mpi::MPI_OpClassEnum::MPI_PROD:
|
||||
case mpi::MPI_ReductionOpEnum::MPI_PROD:
|
||||
op = "ompi_mpi_prod";
|
||||
break;
|
||||
case mpi::MPI_OpClassEnum::MPI_LAND:
|
||||
case mpi::MPI_ReductionOpEnum::MPI_LAND:
|
||||
op = "ompi_mpi_land";
|
||||
break;
|
||||
case mpi::MPI_OpClassEnum::MPI_BAND:
|
||||
case mpi::MPI_ReductionOpEnum::MPI_BAND:
|
||||
op = "ompi_mpi_band";
|
||||
break;
|
||||
case mpi::MPI_OpClassEnum::MPI_LOR:
|
||||
case mpi::MPI_ReductionOpEnum::MPI_LOR:
|
||||
op = "ompi_mpi_lor";
|
||||
break;
|
||||
case mpi::MPI_OpClassEnum::MPI_BOR:
|
||||
case mpi::MPI_ReductionOpEnum::MPI_BOR:
|
||||
op = "ompi_mpi_bor";
|
||||
break;
|
||||
case mpi::MPI_OpClassEnum::MPI_LXOR:
|
||||
case mpi::MPI_ReductionOpEnum::MPI_LXOR:
|
||||
op = "ompi_mpi_lxor";
|
||||
break;
|
||||
case mpi::MPI_OpClassEnum::MPI_BXOR:
|
||||
case mpi::MPI_ReductionOpEnum::MPI_BXOR:
|
||||
op = "ompi_mpi_bxor";
|
||||
break;
|
||||
case mpi::MPI_OpClassEnum::MPI_MINLOC:
|
||||
case mpi::MPI_ReductionOpEnum::MPI_MINLOC:
|
||||
op = "ompi_mpi_minloc";
|
||||
break;
|
||||
case mpi::MPI_OpClassEnum::MPI_MAXLOC:
|
||||
case mpi::MPI_ReductionOpEnum::MPI_MAXLOC:
|
||||
op = "ompi_mpi_maxloc";
|
||||
break;
|
||||
case mpi::MPI_OpClassEnum::MPI_REPLACE:
|
||||
case mpi::MPI_ReductionOpEnum::MPI_REPLACE:
|
||||
op = "ompi_mpi_replace";
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -12,9 +12,9 @@
|
||||
|
||||
#include "mlir/Conversion/MeshToMPI/MeshToMPI.h"
|
||||
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/DLTI/DLTI.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
@@ -22,6 +22,8 @@
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
|
||||
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
|
||||
#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
|
||||
#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Utils/StaticValueUtils.h"
|
||||
@@ -289,27 +291,15 @@ struct ConvertProcessMultiIndexOp
|
||||
|
||||
class ConvertProcessLinearIndexOp
|
||||
: public OpConversionPattern<ProcessLinearIndexOp> {
|
||||
int64_t worldRank; // rank in MPI_COMM_WORLD if available, else < 0
|
||||
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
// Constructor accepting worldRank
|
||||
ConvertProcessLinearIndexOp(const TypeConverter &typeConverter,
|
||||
MLIRContext *context, int64_t worldRank = -1)
|
||||
: OpConversionPattern(typeConverter, context), worldRank(worldRank) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(ProcessLinearIndexOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
// Create mpi::CommRankOp
|
||||
Location loc = op.getLoc();
|
||||
if (worldRank >= 0) { // if rank in MPI_COMM_WORLD is known -> use it
|
||||
rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, worldRank);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Otherwise call create mpi::CommRankOp
|
||||
auto ctx = op.getContext();
|
||||
Value commWorld =
|
||||
rewriter.create<mpi::CommWorldOp>(loc, mpi::CommType::get(ctx));
|
||||
@@ -529,6 +519,129 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
|
||||
}
|
||||
};
|
||||
|
||||
static mpi::MPI_ReductionOpEnumAttr getMPIReductionOp(ReductionKindAttr kind) {
|
||||
auto ctx = kind.getContext();
|
||||
auto getReductionOp = [ctx](mpi::MPI_ReductionOpEnum redOp) {
|
||||
return mpi::MPI_ReductionOpEnumAttr::get(ctx, redOp);
|
||||
};
|
||||
|
||||
switch (kind.getValue()) {
|
||||
case ReductionKind::Sum:
|
||||
return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_SUM);
|
||||
case ReductionKind::Product:
|
||||
return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_PROD);
|
||||
case ReductionKind::Min:
|
||||
return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_MIN);
|
||||
case ReductionKind::Max:
|
||||
return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_MAX);
|
||||
case ReductionKind::BitwiseAnd:
|
||||
return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_BAND);
|
||||
case ReductionKind::BitwiseOr:
|
||||
return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_BOR);
|
||||
case ReductionKind::BitwiseXor:
|
||||
return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_BXOR);
|
||||
default:
|
||||
assert(false && "Unknown/unsupported reduction kind");
|
||||
}
|
||||
}
|
||||
|
||||
struct ConvertAllReduceOp : public OpConversionPattern<AllReduceOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(AllReduceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
SymbolTableCollection symbolTableCollection;
|
||||
auto mesh = adaptor.getMesh();
|
||||
mlir::mesh::MeshOp meshOp = getMesh(op, symbolTableCollection);
|
||||
if (!meshOp)
|
||||
return op->emitError() << "No mesh found for AllReduceOp";
|
||||
if (ShapedType::isDynamicShape(meshOp.getShape()))
|
||||
return op->emitError()
|
||||
<< "Dynamic mesh shape not supported in AllReduceOp";
|
||||
|
||||
ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter);
|
||||
Value input = adaptor.getInput();
|
||||
auto inputShape = cast<ShapedType>(input.getType()).getShape();
|
||||
|
||||
// If the source is a memref, cast it to a tensor.
|
||||
if (isa<RankedTensorType>(input.getType())) {
|
||||
auto memrefType = MemRefType::get(
|
||||
inputShape, cast<ShapedType>(input.getType()).getElementType());
|
||||
input = iBuilder.create<bufferization::ToBufferOp>(memrefType, input);
|
||||
}
|
||||
MemRefType inType = cast<MemRefType>(input.getType());
|
||||
|
||||
// Get the actual shape to allocate the buffer.
|
||||
SmallVector<OpFoldResult> shape(inType.getRank());
|
||||
for (auto i = 0; i < inType.getRank(); ++i) {
|
||||
auto s = inputShape[i];
|
||||
if (ShapedType::isDynamic(s))
|
||||
shape[i] = iBuilder.create<memref::DimOp>(input, s).getResult();
|
||||
else
|
||||
shape[i] = iBuilder.getIndexAttr(s);
|
||||
}
|
||||
|
||||
// Allocate buffer and copy input to buffer.
|
||||
Value buffer = iBuilder.create<memref::AllocOp>(
|
||||
shape, cast<ShapedType>(op.getType()).getElementType());
|
||||
iBuilder.create<linalg::CopyOp>(input, buffer);
|
||||
|
||||
// Get an MPI_Comm_split for the AllReduce operation.
|
||||
// The color is the linear index of the process in the mesh along the
|
||||
// non-reduced axes. The key is the linear index of the process in the mesh
|
||||
// along the reduced axes.
|
||||
SmallVector<Type> indexResultTypes(meshOp.getShape().size(),
|
||||
iBuilder.getIndexType());
|
||||
SmallVector<Value> myMultiIndex =
|
||||
iBuilder.create<ProcessMultiIndexOp>(indexResultTypes, mesh)
|
||||
.getResult();
|
||||
Value zero = iBuilder.create<arith::ConstantIndexOp>(0);
|
||||
SmallVector<Value> multiKey(myMultiIndex.size(), zero);
|
||||
|
||||
auto redAxes = adaptor.getMeshAxes();
|
||||
for (auto axis : redAxes) {
|
||||
multiKey[axis] = myMultiIndex[axis];
|
||||
myMultiIndex[axis] = zero;
|
||||
}
|
||||
|
||||
Value color =
|
||||
createProcessLinearIndex(mesh, myMultiIndex, redAxes, iBuilder);
|
||||
color = iBuilder.create<arith::IndexCastOp>(iBuilder.getI32Type(), color);
|
||||
Value key = createProcessLinearIndex(mesh, multiKey, redAxes, iBuilder);
|
||||
key = iBuilder.create<arith::IndexCastOp>(iBuilder.getI32Type(), key);
|
||||
|
||||
// Finally split the communicator
|
||||
auto commType = mpi::CommType::get(op->getContext());
|
||||
Value commWorld = iBuilder.create<mpi::CommWorldOp>(commType);
|
||||
auto comm =
|
||||
iBuilder.create<mpi::CommSplitOp>(commType, commWorld, color, key)
|
||||
.getNewcomm();
|
||||
|
||||
Value buffer1d = buffer;
|
||||
// Collapse shape to 1d if needed
|
||||
if (inType.getRank() > 1) {
|
||||
ReassociationIndices reassociation(inType.getRank());
|
||||
std::iota(reassociation.begin(), reassociation.end(), 0);
|
||||
buffer1d = iBuilder.create<memref::CollapseShapeOp>(
|
||||
buffer, ArrayRef<ReassociationIndices>(reassociation));
|
||||
}
|
||||
|
||||
// Create the MPI AllReduce operation.
|
||||
iBuilder.create<mpi::AllReduceOp>(
|
||||
TypeRange(), buffer1d, buffer1d,
|
||||
getMPIReductionOp(adaptor.getReductionAttr()), comm);
|
||||
|
||||
// If the destination is a memref, cast it to a tensor
|
||||
if (isa<RankedTensorType>(op.getType()))
|
||||
buffer = iBuilder.create<bufferization::ToTensorOp>(op.getType(), buffer,
|
||||
true);
|
||||
|
||||
rewriter.replaceOp(op, buffer);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
@@ -573,10 +686,10 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
|
||||
Value array = dest;
|
||||
if (isa<RankedTensorType>(array.getType())) {
|
||||
// If the destination is a memref, we need to cast it to a tensor
|
||||
auto tensorType = MemRefType::get(
|
||||
auto mmemrefType = MemRefType::get(
|
||||
dstShape, cast<ShapedType>(array.getType()).getElementType());
|
||||
array =
|
||||
rewriter.create<bufferization::ToBufferOp>(loc, tensorType, array);
|
||||
rewriter.create<bufferization::ToBufferOp>(loc, mmemrefType, array);
|
||||
}
|
||||
auto rank = cast<ShapedType>(array.getType()).getRank();
|
||||
auto opSplitAxes = adaptor.getSplitAxes().getAxes();
|
||||
@@ -753,22 +866,6 @@ struct ConvertMeshToMPIPass
|
||||
|
||||
/// Run the dialect converter on the module.
|
||||
void runOnOperation() override {
|
||||
uint64_t worldRank = -1;
|
||||
// Try to get DLTI attribute for MPI:comm_world_rank
|
||||
// If found, set worldRank to the value of the attribute.
|
||||
{
|
||||
auto dltiAttr =
|
||||
dlti::query(getOperation(), {"MPI:comm_world_rank"}, false);
|
||||
if (succeeded(dltiAttr)) {
|
||||
if (!isa<IntegerAttr>(dltiAttr.value())) {
|
||||
getOperation()->emitError()
|
||||
<< "Expected an integer attribute for MPI:comm_world_rank";
|
||||
return signalPassFailure();
|
||||
}
|
||||
worldRank = cast<IntegerAttr>(dltiAttr.value()).getInt();
|
||||
}
|
||||
}
|
||||
|
||||
auto *ctxt = &getContext();
|
||||
RewritePatternSet patterns(ctxt);
|
||||
ConversionTarget target(getContext());
|
||||
@@ -816,13 +913,13 @@ struct ConvertMeshToMPIPass
|
||||
|
||||
// No mesh dialect should left after conversion...
|
||||
target.addIllegalDialect<mesh::MeshDialect>();
|
||||
// ...except the global MeshOp
|
||||
target.addLegalOp<mesh::MeshOp>();
|
||||
// ...except the global MeshOp. MeshShapeOp which will get folded later.
|
||||
target.addLegalOp<mesh::MeshOp, mesh::MeshShapeOp>();
|
||||
// Allow all the stuff that our patterns will convert to
|
||||
target.addLegalDialect<BuiltinDialect, mpi::MPIDialect, scf::SCFDialect,
|
||||
arith::ArithDialect, tensor::TensorDialect,
|
||||
bufferization::BufferizationDialect,
|
||||
linalg::LinalgDialect, memref::MemRefDialect>();
|
||||
target.addLegalDialect<
|
||||
BuiltinDialect, mpi::MPIDialect, scf::SCFDialect, arith::ArithDialect,
|
||||
tensor::TensorDialect, bufferization::BufferizationDialect,
|
||||
linalg::LinalgDialect, memref::MemRefDialect, affine::AffineDialect>();
|
||||
// Make sure the function signature, calls etc. are legal
|
||||
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
|
||||
return typeConverter.isSignatureLegal(op.getFunctionType());
|
||||
@@ -832,9 +929,8 @@ struct ConvertMeshToMPIPass
|
||||
|
||||
patterns.add<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
|
||||
ConvertProcessMultiIndexOp, ConvertGetShardingOp,
|
||||
ConvertShardingOp, ConvertShardShapeOp>(typeConverter, ctxt);
|
||||
// ConvertProcessLinearIndexOp accepts an optional worldRank
|
||||
patterns.add<ConvertProcessLinearIndexOp>(typeConverter, ctxt, worldRank);
|
||||
ConvertShardingOp, ConvertShardShapeOp, ConvertAllReduceOp,
|
||||
ConvertProcessLinearIndexOp>(typeConverter, ctxt);
|
||||
|
||||
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
|
||||
patterns, typeConverter);
|
||||
@@ -842,6 +938,12 @@ struct ConvertMeshToMPIPass
|
||||
populateReturnOpTypeConversionPattern(patterns, typeConverter);
|
||||
|
||||
(void)applyPartialConversion(getOperation(), target, std::move(patterns));
|
||||
|
||||
// Folding patterns cannot be mixed with conversion patterns -> extra pass.
|
||||
patterns.clear();
|
||||
SymbolTableCollection symbolTableCollection;
|
||||
mlir::mesh::populateFoldingPatterns(patterns, symbolTableCollection);
|
||||
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/DLTI/DLTI.h"
|
||||
#include "mlir/Dialect/MPI/IR/MPI.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
@@ -41,6 +42,34 @@ struct FoldCast final : public mlir::OpRewritePattern<OpT> {
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
struct FoldRank final : public mlir::OpRewritePattern<mlir::mpi::CommRankOp> {
|
||||
using mlir::OpRewritePattern<mlir::mpi::CommRankOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(mlir::mpi::CommRankOp op,
|
||||
mlir::PatternRewriter &b) const override {
|
||||
auto comm = op.getComm();
|
||||
if (!comm.getDefiningOp<mlir::mpi::CommWorldOp>())
|
||||
return mlir::failure();
|
||||
|
||||
// Try to get DLTI attribute for MPI:comm_world_rank
|
||||
// If found, set worldRank to the value of the attribute.
|
||||
auto dltiAttr = dlti::query(op, {"MPI:comm_world_rank"}, false);
|
||||
if (failed(dltiAttr))
|
||||
return mlir::failure();
|
||||
if (!isa<IntegerAttr>(dltiAttr.value()))
|
||||
return op->emitError()
|
||||
<< "Expected an integer attribute for MPI:comm_world_rank";
|
||||
Value res = b.create<arith::ConstantIndexOp>(
|
||||
op.getLoc(), cast<IntegerAttr>(dltiAttr.value()).getInt());
|
||||
if (Value retVal = op.getRetval())
|
||||
b.replaceOp(op, {retVal, res});
|
||||
else
|
||||
b.replaceOp(op, res);
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void mlir::mpi::SendOp::getCanonicalizationPatterns(
|
||||
@@ -63,6 +92,11 @@ void mlir::mpi::IRecvOp::getCanonicalizationPatterns(
|
||||
results.add<FoldCast<mlir::mpi::IRecvOp>>(context);
|
||||
}
|
||||
|
||||
void mlir::mpi::CommRankOp::getCanonicalizationPatterns(
|
||||
mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
|
||||
results.add<FoldRank>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TableGen'd op method definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -75,6 +75,29 @@ static DimensionSize operator*(DimensionSize lhs, DimensionSize rhs) {
|
||||
return lhs.value() * rhs.value();
|
||||
}
|
||||
|
||||
SmallVector<Value> mlir::mesh::getMixedAsValues(OpBuilder b,
|
||||
const Location &loc,
|
||||
llvm::ArrayRef<int64_t> statics,
|
||||
ValueRange dynamics,
|
||||
Type type) {
|
||||
SmallVector<Value> values;
|
||||
auto dyn = dynamics.begin();
|
||||
Type i64 = b.getI64Type();
|
||||
if (!type)
|
||||
type = i64;
|
||||
assert((i64 == type || b.getIndexType() == type) &&
|
||||
"expected an i64 or an intex type");
|
||||
for (auto s : statics) {
|
||||
if (s == ShapedType::kDynamic) {
|
||||
values.emplace_back(*(dyn++));
|
||||
} else {
|
||||
TypedAttr val = type == i64 ? b.getI64IntegerAttr(s) : b.getIndexAttr(s);
|
||||
values.emplace_back(b.create<arith::ConstantOp>(loc, type, val));
|
||||
}
|
||||
}
|
||||
return values;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Inliner
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -207,17 +207,27 @@ createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes,
|
||||
builder.getIndexType()));
|
||||
}
|
||||
|
||||
TypedValue<IndexType> createProcessLinearIndex(StringRef mesh,
|
||||
ArrayRef<MeshAxis> meshAxes,
|
||||
ImplicitLocOpBuilder &builder) {
|
||||
ResultRange processInGroupMultiIndex =
|
||||
builder.create<ProcessMultiIndexOp>(mesh, meshAxes).getResults();
|
||||
TypedValue<IndexType>
|
||||
createProcessLinearIndex(StringRef mesh, ValueRange processInGroupMultiIndex,
|
||||
ArrayRef<MeshAxis> meshAxes,
|
||||
ImplicitLocOpBuilder &builder) {
|
||||
Operation::result_range processGroupShape =
|
||||
builder.create<MeshShapeOp>(mesh, meshAxes).getResult();
|
||||
OpFoldResult processInGroupLinearIndex = affine::linearizeIndex(
|
||||
llvm::to_vector_of<OpFoldResult>(processInGroupMultiIndex),
|
||||
llvm::to_vector_of<OpFoldResult>(processGroupShape), builder);
|
||||
return cast<TypedValue<IndexType>>(cast<Value>(processInGroupLinearIndex));
|
||||
auto res = dyn_cast<Value>(processInGroupLinearIndex);
|
||||
if (!res)
|
||||
res = builder.create<arith::ConstantIndexOp>(
|
||||
cast<IntegerAttr>(cast<Attribute>(processInGroupLinearIndex)).getInt());
|
||||
return cast<TypedValue<IndexType>>(res);
|
||||
}
|
||||
|
||||
TypedValue<IndexType> createProcessLinearIndex(StringRef mesh,
|
||||
ArrayRef<MeshAxis> meshAxes,
|
||||
ImplicitLocOpBuilder &builder) {
|
||||
return createProcessLinearIndex(
|
||||
mesh, builder.create<ProcessMultiIndexOp>(mesh, meshAxes).getResults(),
|
||||
meshAxes, builder);
|
||||
}
|
||||
} // namespace mlir::mesh
|
||||
|
||||
@@ -80,6 +80,63 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
|
||||
mesh.mesh @mesh0(shape = 3x4x5)
|
||||
// CHECK-LABEL: func.func @allreduce_tensor(
|
||||
func.func @allreduce_tensor(
|
||||
// CHECK-SAME: [[varg0:%.*]]: tensor<3x4xf32>
|
||||
%arg0 : tensor<3x4xf32>) -> tensor<3x4xf32> {
|
||||
// CHECK-DAG: [[vc4_i32:%.*]] = arith.constant 4 : i32
|
||||
// CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 2 : i32
|
||||
// CHECK: [[v0:%.*]] = bufferization.to_buffer [[varg0]] : tensor<3x4xf32> to memref<3x4xf32>
|
||||
// CHECK: [[valloc:%.*]] = memref.alloc() : memref<3x4xf32>
|
||||
// CHECK: linalg.copy ins([[v0]] : memref<3x4xf32>) outs([[valloc]] : memref<3x4xf32>)
|
||||
// CHECK: [[v1:%.*]] = mpi.comm_world : !mpi.comm
|
||||
// CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v1]], [[vc2_i32]], [[vc4_i32]]) : !mpi.comm
|
||||
// CHECK: [[vcollapse_shape:%.*]] = memref.collapse_shape [[valloc]] {{\[\[}}0, 1]] : memref<3x4xf32> into memref<12xf32>
|
||||
// CHECK: mpi.allreduce([[vcollapse_shape]], [[vcollapse_shape]], MPI_MAX, [[vnewcomm]]) : memref<12xf32>, memref<12xf32>
|
||||
// CHECK: [[v2:%.*]] = bufferization.to_tensor [[valloc]] restrict : memref<3x4xf32> to tensor<3x4xf32>
|
||||
%0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0, 1] reduction = max : tensor<3x4xf32> -> tensor<3x4xf32>
|
||||
// CHECK: return [[v2]] : tensor<3x4xf32>
|
||||
return %0 : tensor<3x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @allreduce_memref(
|
||||
func.func @allreduce_memref(
|
||||
// CHECK-SAME: [[varg0:%.*]]: memref<3x4xf32>
|
||||
%arg0 : memref<3x4xf32>) -> memref<3x4xf32> {
|
||||
// CHECK: [[vc4_i32:%.*]] = arith.constant 4 : i32
|
||||
// CHECK: [[vc2_i32:%.*]] = arith.constant 2 : i32
|
||||
// CHECK: [[valloc:%.*]] = memref.alloc() : memref<3x4xf32>
|
||||
// CHECK: linalg.copy ins([[varg0]] : memref<3x4xf32>) outs([[valloc]] : memref<3x4xf32>)
|
||||
// CHECK: [[v0:%.*]] = mpi.comm_world : !mpi.comm
|
||||
// CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc2_i32]], [[vc4_i32]]) : !mpi.comm
|
||||
// CHECK: [[vcollapse_shape:%.*]] = memref.collapse_shape [[valloc]] {{\[\[}}0, 1]] : memref<3x4xf32> into memref<12xf32>
|
||||
// CHECK: mpi.allreduce([[vcollapse_shape]], [[vcollapse_shape]], MPI_MAX, [[vnewcomm]]) : memref<12xf32>, memref<12xf32>
|
||||
%0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0, 1] reduction = max : memref<3x4xf32> -> memref<3x4xf32>
|
||||
// CHECK: return [[valloc]] : memref<3x4xf32>
|
||||
return %0 : memref<3x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @allreduce_new_type(
|
||||
func.func @allreduce_new_type(
|
||||
// CHECK-SAME: [[varg0:%.*]]: memref<3x4xf32>
|
||||
%arg0 : memref<3x4xf32>) -> memref<3x4xf64> {
|
||||
// CHECK: [[vc4_i32:%.*]] = arith.constant 4 : i32
|
||||
// CHECK: [[vc2_i32:%.*]] = arith.constant 2 : i32
|
||||
// CHECK: [[valloc:%.*]] = memref.alloc() : memref<3x4xf64>
|
||||
// CHECK: linalg.copy ins([[varg0]] : memref<3x4xf32>) outs([[valloc]] : memref<3x4xf64>)
|
||||
// CHECK: [[v0:%.*]] = mpi.comm_world : !mpi.comm
|
||||
// CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc2_i32]], [[vc4_i32]]) : !mpi.comm
|
||||
// CHECK: [[vcollapse_shape:%.*]] = memref.collapse_shape [[valloc]] {{\[\[}}0, 1]] : memref<3x4xf64> into memref<12xf64>
|
||||
// CHECK: mpi.allreduce([[vcollapse_shape]], [[vcollapse_shape]], MPI_MAX, [[vnewcomm]]) : memref<12xf64>, memref<12xf64>
|
||||
%0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0, 1] reduction = max : memref<3x4xf32> -> memref<3x4xf64>
|
||||
// CHECK: return [[valloc]] : memref<3x4xf64>
|
||||
return %0 : memref<3x4xf64>
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
mesh.mesh @mesh0(shape = 3x4x5)
|
||||
// CHECK-LABEL: func @update_halo_1d_first
|
||||
@@ -91,13 +148,13 @@ func.func @update_halo_1d_first(
|
||||
// CHECK-SAME: : memref<2x120x120xi8>, i32, i32
|
||||
// CHECK: mpi.recv(
|
||||
// CHECK-SAME: : memref<2x120x120xi8>, i32, i32
|
||||
// CHECK-NEXT: memref.subview [[arg0]][0, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8
|
||||
// CHECK: memref.subview [[arg0]][0, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8
|
||||
// CHECK: memref.subview [[arg0]][2, 0, 0] [3, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<3x120x120xi8
|
||||
// CHECK: mpi.send(
|
||||
// CHECK-SAME: : memref<3x120x120xi8>, i32, i32
|
||||
// CHECK: mpi.recv(
|
||||
// CHECK-SAME: : memref<3x120x120xi8>, i32, i32
|
||||
// CHECK-NEXT: memref.subview [[arg0]][117, 0, 0] [3, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<3x120x120xi8
|
||||
// CHECK: memref.subview [[arg0]][117, 0, 0] [3, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<3x120x120xi8
|
||||
%res = mesh.update_halo %arg0 on @mesh0 split_axes = [[0]] halo_sizes = [2, 3] : memref<120x120x120xi8>
|
||||
// CHECK: return [[res:%.*]] : memref<120x120x120xi8>
|
||||
return %res : memref<120x120x120xi8>
|
||||
@@ -110,18 +167,18 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 1> } {
|
||||
func.func @update_halo_1d_with_zero (
|
||||
// CHECK-SAME: [[varg0:%.*]]: memref<120x120x120xi8>
|
||||
%arg0 : memref<120x120x120xi8>) -> memref<120x120x120xi8> {
|
||||
// CHECK: [[vc91_i32:%.*]] = arith.constant 91 : i32
|
||||
// CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32
|
||||
// CHECK-NEXT: [[vc2_i32:%.*]] = arith.constant 2 : i32
|
||||
// CHECK-NEXT: [[v0:%.*]] = mpi.comm_world : !mpi.comm
|
||||
// CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<2x120x120xi8>
|
||||
// CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][118, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>>
|
||||
// CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>> to memref<2x120x120xi8>
|
||||
// CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc2_i32]], [[v0]]) : memref<2x120x120xi8>, i32, i32
|
||||
// CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc0_i32]], [[v0]]) : memref<2x120x120xi8>, i32, i32
|
||||
// CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[varg0]][0, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1]>>
|
||||
// CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<2x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1]>>
|
||||
// CHECK-NEXT: memref.dealloc [[valloc]] : memref<2x120x120xi8>
|
||||
// CHECK-DAG: [[vc91_i32:%.*]] = arith.constant 91 : i32
|
||||
// CHECK-DAG: [[vc0_i32:%.*]] = arith.constant 0 : i32
|
||||
// CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 2 : i32
|
||||
// CHECK: [[v0:%.*]] = mpi.comm_world : !mpi.comm
|
||||
// CHECK: [[valloc:%.*]] = memref.alloc() : memref<2x120x120xi8>
|
||||
// CHECK: [[vsubview:%.*]] = memref.subview [[varg0]][118, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>>
|
||||
// CHECK: memref.copy [[vsubview]], [[valloc]] : memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>> to memref<2x120x120xi8>
|
||||
// CHECK: mpi.send([[valloc]], [[vc91_i32]], [[vc2_i32]], [[v0]]) : memref<2x120x120xi8>, i32, i32
|
||||
// CHECK: mpi.recv([[valloc]], [[vc91_i32]], [[vc0_i32]], [[v0]]) : memref<2x120x120xi8>, i32, i32
|
||||
// CHECK: [[vsubview_0:%.*]] = memref.subview [[varg0]][0, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1]>>
|
||||
// CHECK: memref.copy [[valloc]], [[vsubview_0]] : memref<2x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1]>>
|
||||
// CHECK: memref.dealloc [[valloc]] : memref<2x120x120xi8>
|
||||
%res = mesh.update_halo %arg0 on @mesh0 split_axes = [[0]] halo_sizes = [2, 0] : memref<120x120x120xi8>
|
||||
// CHECK: return [[varg0]] : memref<120x120x120xi8>
|
||||
return %res : memref<120x120x120xi8>
|
||||
@@ -135,50 +192,50 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
|
||||
func.func @update_halo_3d(
|
||||
// CHECK-SAME: [[varg0:%.*]]: memref<120x120x120xi8>
|
||||
%arg0 : memref<120x120x120xi8>) -> memref<120x120x120xi8> {
|
||||
// CHECK-NEXT: [[vc23_i32:%.*]] = arith.constant 23 : i32
|
||||
// CHECK-NEXT: [[vc29_i32:%.*]] = arith.constant 29 : i32
|
||||
// CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
|
||||
// CHECK-NEXT: [[vc4_i32:%.*]] = arith.constant 4 : i32
|
||||
// CHECK-NEXT: [[vc44_i32:%.*]] = arith.constant 44 : i32
|
||||
// CHECK-NEXT: [[v0:%.*]] = mpi.comm_world : !mpi.comm
|
||||
// CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8>
|
||||
// CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>>
|
||||
// CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8>
|
||||
// CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc44_i32]], [[v0]]) : memref<117x113x5xi8>, i32, i32
|
||||
// CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc4_i32]], [[v0]]) : memref<117x113x5xi8>, i32, i32
|
||||
// CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[varg0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
|
||||
// CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
|
||||
// CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8>
|
||||
// CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8>
|
||||
// CHECK-NEXT: [[vsubview_2:%.*]] = memref.subview [[varg0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>>
|
||||
// CHECK-NEXT: memref.copy [[vsubview_2]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8>
|
||||
// CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc4_i32]], [[v0]]) : memref<117x113x6xi8>, i32, i32
|
||||
// CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc44_i32]], [[v0]]) : memref<117x113x6xi8>, i32, i32
|
||||
// CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[varg0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
|
||||
// CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_3]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
|
||||
// CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8>
|
||||
// CHECK-NEXT: [[v1:%.*]] = mpi.comm_world : !mpi.comm
|
||||
// CHECK-NEXT: [[valloc_4:%.*]] = memref.alloc() : memref<117x3x120xi8>
|
||||
// CHECK-NEXT: [[vsubview_5:%.*]] = memref.subview [[varg0]][1, 113, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>>
|
||||
// CHECK-NEXT: memref.copy [[vsubview_5]], [[valloc_4]] : memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>> to memref<117x3x120xi8>
|
||||
// CHECK-NEXT: mpi.send([[valloc_4]], [[vc91_i32]], [[vc29_i32]], [[v1]]) : memref<117x3x120xi8>, i32, i32
|
||||
// CHECK-NEXT: memref.dealloc [[valloc_4]] : memref<117x3x120xi8>
|
||||
// CHECK-NEXT: [[valloc_6:%.*]] = memref.alloc() : memref<117x4x120xi8>
|
||||
// CHECK-NEXT: mpi.recv([[valloc_6]], [[vc91_i32]], [[vc29_i32]], [[v1]]) : memref<117x4x120xi8>, i32, i32
|
||||
// CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[varg0]][1, 116, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
|
||||
// CHECK-NEXT: memref.copy [[valloc_6]], [[vsubview_7]] : memref<117x4x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
|
||||
// CHECK-NEXT: memref.dealloc [[valloc_6]] : memref<117x4x120xi8>
|
||||
// CHECK-NEXT: [[v2:%.*]] = mpi.comm_world : !mpi.comm
|
||||
// CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<1x120x120xi8>
|
||||
// CHECK-NEXT: mpi.recv([[valloc_8]], [[vc91_i32]], [[vc23_i32]], [[v2]]) : memref<1x120x120xi8>, i32, i32
|
||||
// CHECK-NEXT: [[vsubview_9:%.*]] = memref.subview [[varg0]][0, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
|
||||
// CHECK-NEXT: memref.copy [[valloc_8]], [[vsubview_9]] : memref<1x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
|
||||
// CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<1x120x120xi8>
|
||||
// CHECK-NEXT: [[valloc_10:%.*]] = memref.alloc() : memref<2x120x120xi8>
|
||||
// CHECK-NEXT: [[vsubview_11:%.*]] = memref.subview [[varg0]][1, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>>
|
||||
// CHECK-NEXT: memref.copy [[vsubview_11]], [[valloc_10]] : memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>> to memref<2x120x120xi8>
|
||||
// CHECK-NEXT: mpi.send([[valloc_10]], [[vc91_i32]], [[vc23_i32]], [[v2]]) : memref<2x120x120xi8>, i32, i32
|
||||
// CHECK-NEXT: memref.dealloc [[valloc_10]] : memref<2x120x120xi8>
|
||||
// CHECK-DAG: [[vc23_i32:%.*]] = arith.constant 23 : i32
|
||||
// CHECK-DAG: [[vc29_i32:%.*]] = arith.constant 29 : i32
|
||||
// CHECK-DAG: [[vc91_i32:%.*]] = arith.constant 91 : i32
|
||||
// CHECK-DAG: [[vc4_i32:%.*]] = arith.constant 4 : i32
|
||||
// CHECK-DAG: [[vc44_i32:%.*]] = arith.constant 44 : i32
|
||||
// CHECK: [[v0:%.*]] = mpi.comm_world : !mpi.comm
|
||||
// CHECK: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8>
|
||||
// CHECK: [[vsubview:%.*]] = memref.subview [[varg0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>>
|
||||
// CHECK: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8>
|
||||
// CHECK: mpi.send([[valloc]], [[vc91_i32]], [[vc44_i32]], [[v0]]) : memref<117x113x5xi8>, i32, i32
|
||||
// CHECK: mpi.recv([[valloc]], [[vc91_i32]], [[vc4_i32]], [[v0]]) : memref<117x113x5xi8>, i32, i32
|
||||
// CHECK: [[vsubview_0:%.*]] = memref.subview [[varg0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
|
||||
// CHECK: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
|
||||
// CHECK: memref.dealloc [[valloc]] : memref<117x113x5xi8>
|
||||
// CHECK: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8>
|
||||
// CHECK: [[vsubview_2:%.*]] = memref.subview [[varg0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>>
|
||||
// CHECK: memref.copy [[vsubview_2]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8>
|
||||
// CHECK: mpi.send([[valloc_1]], [[vc91_i32]], [[vc4_i32]], [[v0]]) : memref<117x113x6xi8>, i32, i32
|
||||
// CHECK: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc44_i32]], [[v0]]) : memref<117x113x6xi8>, i32, i32
|
||||
// CHECK: [[vsubview_3:%.*]] = memref.subview [[varg0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
|
||||
// CHECK: memref.copy [[valloc_1]], [[vsubview_3]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
|
||||
// CHECK: memref.dealloc [[valloc_1]] : memref<117x113x6xi8>
|
||||
// CHECK: [[v1:%.*]] = mpi.comm_world : !mpi.comm
|
||||
// CHECK: [[valloc_4:%.*]] = memref.alloc() : memref<117x3x120xi8>
|
||||
// CHECK: [[vsubview_5:%.*]] = memref.subview [[varg0]][1, 113, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>>
|
||||
// CHECK: memref.copy [[vsubview_5]], [[valloc_4]] : memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>> to memref<117x3x120xi8>
|
||||
// CHECK: mpi.send([[valloc_4]], [[vc91_i32]], [[vc29_i32]], [[v1]]) : memref<117x3x120xi8>, i32, i32
|
||||
// CHECK: memref.dealloc [[valloc_4]] : memref<117x3x120xi8>
|
||||
// CHECK: [[valloc_6:%.*]] = memref.alloc() : memref<117x4x120xi8>
|
||||
// CHECK: mpi.recv([[valloc_6]], [[vc91_i32]], [[vc29_i32]], [[v1]]) : memref<117x4x120xi8>, i32, i32
|
||||
// CHECK: [[vsubview_7:%.*]] = memref.subview [[varg0]][1, 116, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
|
||||
// CHECK: memref.copy [[valloc_6]], [[vsubview_7]] : memref<117x4x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
|
||||
// CHECK: memref.dealloc [[valloc_6]] : memref<117x4x120xi8>
|
||||
// CHECK: [[v2:%.*]] = mpi.comm_world : !mpi.comm
|
||||
// CHECK: [[valloc_8:%.*]] = memref.alloc() : memref<1x120x120xi8>
|
||||
// CHECK: mpi.recv([[valloc_8]], [[vc91_i32]], [[vc23_i32]], [[v2]]) : memref<1x120x120xi8>, i32, i32
|
||||
// CHECK: [[vsubview_9:%.*]] = memref.subview [[varg0]][0, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
|
||||
// CHECK: memref.copy [[valloc_8]], [[vsubview_9]] : memref<1x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
|
||||
// CHECK: memref.dealloc [[valloc_8]] : memref<1x120x120xi8>
|
||||
// CHECK: [[valloc_10:%.*]] = memref.alloc() : memref<2x120x120xi8>
|
||||
// CHECK: [[vsubview_11:%.*]] = memref.subview [[varg0]][1, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>>
|
||||
// CHECK: memref.copy [[vsubview_11]], [[valloc_10]] : memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>> to memref<2x120x120xi8>
|
||||
// CHECK: mpi.send([[valloc_10]], [[vc91_i32]], [[vc23_i32]], [[v2]]) : memref<2x120x120xi8>, i32, i32
|
||||
// CHECK: memref.dealloc [[valloc_10]] : memref<2x120x120xi8>
|
||||
%res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : memref<120x120x120xi8>
|
||||
// CHECK: return [[varg0]] : memref<120x120x120xi8>
|
||||
return %res : memref<120x120x120xi8>
|
||||
@@ -188,54 +245,54 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
|
||||
func.func @update_halo_3d_tensor(
|
||||
// CHECK-SAME: [[varg0:%.*]]: tensor<120x120x120xi8>
|
||||
%arg0 : tensor<120x120x120xi8>) -> tensor<120x120x120xi8> {
|
||||
// CHECK-NEXT: [[vc23_i32:%.*]] = arith.constant 23 : i32
|
||||
// CHECK-NEXT: [[vc29_i32:%.*]] = arith.constant 29 : i32
|
||||
// CHECK-NEXT: [[vc44_i32:%.*]] = arith.constant 44 : i32
|
||||
// CHECK-NEXT: [[vc4_i32:%.*]] = arith.constant 4 : i32
|
||||
// CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
|
||||
// CHECK-NEXT: [[v0:%.*]] = bufferization.to_buffer [[varg0]] : tensor<120x120x120xi8> to memref<120x120x120xi8>
|
||||
// CHECK-NEXT: [[v1:%.*]] = mpi.comm_world : !mpi.comm
|
||||
// CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8>
|
||||
// CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[v0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>>
|
||||
// CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8>
|
||||
// CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc44_i32]], [[v1]]) : memref<117x113x5xi8>, i32, i32
|
||||
// CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc4_i32]], [[v1]]) : memref<117x113x5xi8>, i32, i32
|
||||
// CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[v0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
|
||||
// CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
|
||||
// CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8>
|
||||
// CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8>
|
||||
// CHECK-NEXT: [[vsubview_2:%.*]] = memref.subview [[v0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>>
|
||||
// CHECK-NEXT: memref.copy [[vsubview_2]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8>
|
||||
// CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc4_i32]], [[v1]]) : memref<117x113x6xi8>, i32, i32
|
||||
// CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc44_i32]], [[v1]]) : memref<117x113x6xi8>, i32, i32
|
||||
// CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[v0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
|
||||
// CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_3]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
|
||||
// CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8>
|
||||
// CHECK-NEXT: [[v2:%.*]] = mpi.comm_world : !mpi.comm
|
||||
// CHECK-NEXT: [[valloc_4:%.*]] = memref.alloc() : memref<117x3x120xi8>
|
||||
// CHECK-NEXT: [[vsubview_5:%.*]] = memref.subview [[v0]][1, 113, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>>
|
||||
// CHECK-NEXT: memref.copy [[vsubview_5]], [[valloc_4]] : memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>> to memref<117x3x120xi8>
|
||||
// CHECK-NEXT: mpi.send([[valloc_4]], [[vc91_i32]], [[vc29_i32]], [[v2]]) : memref<117x3x120xi8>, i32, i32
|
||||
// CHECK-NEXT: memref.dealloc [[valloc_4]] : memref<117x3x120xi8>
|
||||
// CHECK-NEXT: [[valloc_6:%.*]] = memref.alloc() : memref<117x4x120xi8>
|
||||
// CHECK-NEXT: mpi.recv([[valloc_6]], [[vc91_i32]], [[vc29_i32]], [[v2]]) : memref<117x4x120xi8>, i32, i32
|
||||
// CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[v0]][1, 116, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
|
||||
// CHECK-NEXT: memref.copy [[valloc_6]], [[vsubview_7]] : memref<117x4x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
|
||||
// CHECK-NEXT: memref.dealloc [[valloc_6]] : memref<117x4x120xi8>
|
||||
// CHECK-NEXT: [[v3:%.*]] = mpi.comm_world : !mpi.comm
|
||||
// CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<1x120x120xi8>
|
||||
// CHECK-NEXT: mpi.recv([[valloc_8]], [[vc91_i32]], [[vc23_i32]], [[v3]]) : memref<1x120x120xi8>, i32, i32
|
||||
// CHECK-NEXT: [[vsubview_9:%.*]] = memref.subview [[v0]][0, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
|
||||
// CHECK-NEXT: memref.copy [[valloc_8]], [[vsubview_9]] : memref<1x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
|
||||
// CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<1x120x120xi8>
|
||||
// CHECK-NEXT: [[valloc_10:%.*]] = memref.alloc() : memref<2x120x120xi8>
|
||||
// CHECK-NEXT: [[vsubview_11:%.*]] = memref.subview [[v0]][1, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>>
|
||||
// CHECK-NEXT: memref.copy [[vsubview_11]], [[valloc_10]] : memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>> to memref<2x120x120xi8>
|
||||
// CHECK-NEXT: mpi.send([[valloc_10]], [[vc91_i32]], [[vc23_i32]], [[v3]]) : memref<2x120x120xi8>, i32, i32
|
||||
// CHECK-NEXT: memref.dealloc [[valloc_10]] : memref<2x120x120xi8>
|
||||
// CHECK-NEXT: [[v4:%.*]] = bufferization.to_tensor [[v0]] restrict writable : memref<120x120x120xi8> to tensor<120x120x120xi8>
|
||||
// CHECK-DAG: [[vc23_i32:%.*]] = arith.constant 23 : i32
|
||||
// CHECK-DAG: [[vc29_i32:%.*]] = arith.constant 29 : i32
|
||||
// CHECK-DAG: [[vc44_i32:%.*]] = arith.constant 44 : i32
|
||||
// CHECK-DAG: [[vc4_i32:%.*]] = arith.constant 4 : i32
|
||||
// CHECK-DAG: [[vc91_i32:%.*]] = arith.constant 91 : i32
|
||||
// CHECK: [[v0:%.*]] = bufferization.to_buffer [[varg0]] : tensor<120x120x120xi8> to memref<120x120x120xi8>
|
||||
// CHECK: [[v1:%.*]] = mpi.comm_world : !mpi.comm
|
||||
// CHECK: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8>
|
||||
// CHECK: [[vsubview:%.*]] = memref.subview [[v0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>>
|
||||
// CHECK: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8>
|
||||
// CHECK: mpi.send([[valloc]], [[vc91_i32]], [[vc44_i32]], [[v1]]) : memref<117x113x5xi8>, i32, i32
|
||||
// CHECK: mpi.recv([[valloc]], [[vc91_i32]], [[vc4_i32]], [[v1]]) : memref<117x113x5xi8>, i32, i32
|
||||
// CHECK: [[vsubview_0:%.*]] = memref.subview [[v0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
|
||||
// CHECK: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
|
||||
// CHECK: memref.dealloc [[valloc]] : memref<117x113x5xi8>
|
||||
// CHECK: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8>
|
||||
// CHECK: [[vsubview_2:%.*]] = memref.subview [[v0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>>
|
||||
// CHECK: memref.copy [[vsubview_2]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8>
|
||||
// CHECK: mpi.send([[valloc_1]], [[vc91_i32]], [[vc4_i32]], [[v1]]) : memref<117x113x6xi8>, i32, i32
|
||||
// CHECK: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc44_i32]], [[v1]]) : memref<117x113x6xi8>, i32, i32
|
||||
// CHECK: [[vsubview_3:%.*]] = memref.subview [[v0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
|
||||
// CHECK: memref.copy [[valloc_1]], [[vsubview_3]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
|
||||
// CHECK: memref.dealloc [[valloc_1]] : memref<117x113x6xi8>
|
||||
// CHECK: [[v2:%.*]] = mpi.comm_world : !mpi.comm
|
||||
// CHECK: [[valloc_4:%.*]] = memref.alloc() : memref<117x3x120xi8>
|
||||
// CHECK: [[vsubview_5:%.*]] = memref.subview [[v0]][1, 113, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>>
|
||||
// CHECK: memref.copy [[vsubview_5]], [[valloc_4]] : memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>> to memref<117x3x120xi8>
|
||||
// CHECK: mpi.send([[valloc_4]], [[vc91_i32]], [[vc29_i32]], [[v2]]) : memref<117x3x120xi8>, i32, i32
|
||||
// CHECK: memref.dealloc [[valloc_4]] : memref<117x3x120xi8>
|
||||
// CHECK: [[valloc_6:%.*]] = memref.alloc() : memref<117x4x120xi8>
|
||||
// CHECK: mpi.recv([[valloc_6]], [[vc91_i32]], [[vc29_i32]], [[v2]]) : memref<117x4x120xi8>, i32, i32
|
||||
// CHECK: [[vsubview_7:%.*]] = memref.subview [[v0]][1, 116, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
|
||||
// CHECK: memref.copy [[valloc_6]], [[vsubview_7]] : memref<117x4x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
|
||||
// CHECK: memref.dealloc [[valloc_6]] : memref<117x4x120xi8>
|
||||
// CHECK: [[v3:%.*]] = mpi.comm_world : !mpi.comm
|
||||
// CHECK: [[valloc_8:%.*]] = memref.alloc() : memref<1x120x120xi8>
|
||||
// CHECK: mpi.recv([[valloc_8]], [[vc91_i32]], [[vc23_i32]], [[v3]]) : memref<1x120x120xi8>, i32, i32
|
||||
// CHECK: [[vsubview_9:%.*]] = memref.subview [[v0]][0, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
|
||||
// CHECK: memref.copy [[valloc_8]], [[vsubview_9]] : memref<1x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
|
||||
// CHECK: memref.dealloc [[valloc_8]] : memref<1x120x120xi8>
|
||||
// CHECK: [[valloc_10:%.*]] = memref.alloc() : memref<2x120x120xi8>
|
||||
// CHECK: [[vsubview_11:%.*]] = memref.subview [[v0]][1, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>>
|
||||
// CHECK: memref.copy [[vsubview_11]], [[valloc_10]] : memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>> to memref<2x120x120xi8>
|
||||
// CHECK: mpi.send([[valloc_10]], [[vc91_i32]], [[vc23_i32]], [[v3]]) : memref<2x120x120xi8>, i32, i32
|
||||
// CHECK: memref.dealloc [[valloc_10]] : memref<2x120x120xi8>
|
||||
// CHECK: [[v4:%.*]] = bufferization.to_tensor [[v0]] restrict writable : memref<120x120x120xi8> to tensor<120x120x120xi8>
|
||||
%res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : tensor<120x120x120xi8>
|
||||
// CHECK-NEXT: return [[v4]] : tensor<120x120x120xi8>
|
||||
// CHECK: return [[v4]] : tensor<120x120x120xi8>
|
||||
return %res : tensor<120x120x120xi8>
|
||||
}
|
||||
}
|
||||
@@ -246,19 +303,19 @@ mesh.mesh @mesh0(shape = 2x2x4)
|
||||
// CHECK-SAME: [[varg0:%.*]]: tensor<2x4xf32>) -> (tensor<2x4xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>) {
|
||||
func.func @return_sharding(%arg0: tensor<2x4xf32>) -> (tensor<2x4xf32>, !mesh.sharding) {
|
||||
%sharding = mesh.sharding @mesh0 split_axes = [[0, 1], [2]] : !mesh.sharding
|
||||
// CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<2> : tensor<1xi16>
|
||||
// CHECK-NEXT: [[vcst_0:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16>
|
||||
// CHECK-NEXT: [[vcm1_i16:%.*]] = arith.constant -1 : i16
|
||||
// CHECK-NEXT: [[v0:%.*]] = tensor.empty() : tensor<2x2xi16>
|
||||
// CHECK-NEXT: [[v1:%.*]] = linalg.fill ins([[vcm1_i16]] : i16) outs([[v0]] : tensor<2x2xi16>) -> tensor<2x2xi16>
|
||||
// CHECK-NEXT: [[vinserted_slice:%.*]] = tensor.insert_slice [[vcst_0]] into [[v1]][0, 0] [1, 2] [1, 1] : tensor<2xi16> into tensor<2x2xi16>
|
||||
// CHECK-NEXT: [[vinserted_slice_1:%.*]] = tensor.insert_slice [[vcst]] into [[vinserted_slice]][1, 0] [1, 1] [1, 1] : tensor<1xi16> into tensor<2x2xi16>
|
||||
// CHECK-NEXT: [[v2:%.*]] = tensor.empty() : tensor<0x0xi64>
|
||||
// CHECK-NEXT: [[v3:%.*]] = tensor.empty() : tensor<0x0xi64>
|
||||
// CHECK-NEXT: [[vcast:%.*]] = tensor.cast [[vinserted_slice_1]] : tensor<2x2xi16> to tensor<?x?xi16>
|
||||
// CHECK-NEXT: [[vcast_2:%.*]] = tensor.cast [[v2]] : tensor<0x0xi64> to tensor<?x?xi64>
|
||||
// CHECK-NEXT: [[vcast_3:%.*]] = tensor.cast [[v3]] : tensor<0x0xi64> to tensor<?x?xi64>
|
||||
// CHECK-NEXT: return [[varg0]], [[vcast]], [[vcast_2]], [[vcast_3]] : tensor<2x4xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>
|
||||
// CHECK: [[vcst:%.*]] = arith.constant dense<2> : tensor<1xi16>
|
||||
// CHECK: [[vcst_0:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16>
|
||||
// CHECK: [[vcm1_i16:%.*]] = arith.constant -1 : i16
|
||||
// CHECK: [[v0:%.*]] = tensor.empty() : tensor<2x2xi16>
|
||||
// CHECK: [[v1:%.*]] = linalg.fill ins([[vcm1_i16]] : i16) outs([[v0]] : tensor<2x2xi16>) -> tensor<2x2xi16>
|
||||
// CHECK: [[vinserted_slice:%.*]] = tensor.insert_slice [[vcst_0]] into [[v1]][0, 0] [1, 2] [1, 1] : tensor<2xi16> into tensor<2x2xi16>
|
||||
// CHECK: [[vinserted_slice_1:%.*]] = tensor.insert_slice [[vcst]] into [[vinserted_slice]][1, 0] [1, 1] [1, 1] : tensor<1xi16> into tensor<2x2xi16>
|
||||
// CHECK: [[v2:%.*]] = tensor.empty() : tensor<0x0xi64>
|
||||
// CHECK: [[v3:%.*]] = tensor.empty() : tensor<0x0xi64>
|
||||
// CHECK: [[vcast:%.*]] = tensor.cast [[vinserted_slice_1]] : tensor<2x2xi16> to tensor<?x?xi16>
|
||||
// CHECK: [[vcast_2:%.*]] = tensor.cast [[v2]] : tensor<0x0xi64> to tensor<?x?xi64>
|
||||
// CHECK: [[vcast_3:%.*]] = tensor.cast [[v3]] : tensor<0x0xi64> to tensor<?x?xi64>
|
||||
// CHECK: return [[varg0]], [[vcast]], [[vcast_2]], [[vcast_3]] : tensor<2x4xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>
|
||||
return %arg0, %sharding : tensor<2x4xf32>, !mesh.sharding
|
||||
}
|
||||
|
||||
@@ -266,19 +323,19 @@ func.func @return_sharding(%arg0: tensor<2x4xf32>) -> (tensor<2x4xf32>, !mesh.sh
|
||||
// CHECK-SAME: [[varg0:%.*]]: tensor<6x8xf32>) -> (tensor<6x8xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>) {
|
||||
func.func @return_sharding_halos(%arg0: tensor<6x8xf32>) -> (tensor<6x8xf32>, !mesh.sharding) {
|
||||
%sharding = mesh.sharding @mesh0 split_axes = [[0, 1], [2]] halo_sizes = [0, 4, 3, 1] : !mesh.sharding
|
||||
// CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<{{\[\[}}0, 4], [3, 1]]> : tensor<2x2xi64>
|
||||
// CHECK-NEXT: [[vcst_0:%.*]] = arith.constant dense<2> : tensor<1xi16>
|
||||
// CHECK-NEXT: [[vcst_1:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16>
|
||||
// CHECK-NEXT: [[vcm1_i16:%.*]] = arith.constant -1 : i16
|
||||
// CHECK-NEXT: [[v0:%.*]] = tensor.empty() : tensor<2x2xi16>
|
||||
// CHECK-NEXT: [[v1:%.*]] = linalg.fill ins([[vcm1_i16]] : i16) outs([[v0]] : tensor<2x2xi16>) -> tensor<2x2xi16>
|
||||
// CHECK-NEXT: [[vinserted_slice:%.*]] = tensor.insert_slice [[vcst_1]] into [[v1]][0, 0] [1, 2] [1, 1] : tensor<2xi16> into tensor<2x2xi16>
|
||||
// CHECK-NEXT: [[vinserted_slice_2:%.*]] = tensor.insert_slice [[vcst_0]] into [[vinserted_slice]][1, 0] [1, 1] [1, 1] : tensor<1xi16> into tensor<2x2xi16>
|
||||
// CHECK-NEXT: [[v2:%.*]] = tensor.empty() : tensor<0x0xi64>
|
||||
// CHECK-NEXT: [[vcast:%.*]] = tensor.cast [[vinserted_slice_2]] : tensor<2x2xi16> to tensor<?x?xi16>
|
||||
// CHECK-NEXT: [[vcast_3:%.*]] = tensor.cast [[vcst]] : tensor<2x2xi64> to tensor<?x?xi64>
|
||||
// CHECK-NEXT: [[vcast_4:%.*]] = tensor.cast [[v2]] : tensor<0x0xi64> to tensor<?x?xi64>
|
||||
// CHECK-NEXT: return [[varg0]], [[vcast]], [[vcast_3]], [[vcast_4]] : tensor<6x8xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>
|
||||
// CHECK: [[vcst:%.*]] = arith.constant dense<{{\[\[}}0, 4], [3, 1]]> : tensor<2x2xi64>
|
||||
// CHECK: [[vcst_0:%.*]] = arith.constant dense<2> : tensor<1xi16>
|
||||
// CHECK: [[vcst_1:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16>
|
||||
// CHECK: [[vcm1_i16:%.*]] = arith.constant -1 : i16
|
||||
// CHECK: [[v0:%.*]] = tensor.empty() : tensor<2x2xi16>
|
||||
// CHECK: [[v1:%.*]] = linalg.fill ins([[vcm1_i16]] : i16) outs([[v0]] : tensor<2x2xi16>) -> tensor<2x2xi16>
|
||||
// CHECK: [[vinserted_slice:%.*]] = tensor.insert_slice [[vcst_1]] into [[v1]][0, 0] [1, 2] [1, 1] : tensor<2xi16> into tensor<2x2xi16>
|
||||
// CHECK: [[vinserted_slice_2:%.*]] = tensor.insert_slice [[vcst_0]] into [[vinserted_slice]][1, 0] [1, 1] [1, 1] : tensor<1xi16> into tensor<2x2xi16>
|
||||
// CHECK: [[v2:%.*]] = tensor.empty() : tensor<0x0xi64>
|
||||
// CHECK: [[vcast:%.*]] = tensor.cast [[vinserted_slice_2]] : tensor<2x2xi16> to tensor<?x?xi16>
|
||||
// CHECK: [[vcast_3:%.*]] = tensor.cast [[vcst]] : tensor<2x2xi64> to tensor<?x?xi64>
|
||||
// CHECK: [[vcast_4:%.*]] = tensor.cast [[v2]] : tensor<0x0xi64> to tensor<?x?xi64>
|
||||
// CHECK: return [[varg0]], [[vcast]], [[vcast_3]], [[vcast_4]] : tensor<6x8xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>
|
||||
return %arg0, %sharding : tensor<6x8xf32>, !mesh.sharding
|
||||
}
|
||||
|
||||
@@ -286,24 +343,24 @@ func.func @return_sharding_halos(%arg0: tensor<6x8xf32>) -> (tensor<6x8xf32>, !m
|
||||
// CHECK-SAME: [[varg0:%.*]]: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>) {
|
||||
func.func @return_sharding_offs(%arg0: tensor<?x?xf32>) -> (tensor<?x?xf32>, !mesh.sharding) {
|
||||
%sharding = mesh.sharding @mesh0 split_axes = [[0, 1], [2]] sharded_dims_offsets = [0, 3, 5, 7, 8, 0, 0, 5, 10, 16] : !mesh.sharding
|
||||
// CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<[0, 0, 5, 10, 16]> : tensor<5xi64>
|
||||
// CHECK-NEXT: [[vcst_0:%.*]] = arith.constant dense<[0, 3, 5, 7, 8]> : tensor<5xi64>
|
||||
// CHECK-NEXT: [[vcm9223372036854775808_i64:%.*]] = arith.constant -9223372036854775808 : i64
|
||||
// CHECK-NEXT: [[vcst_1:%.*]] = arith.constant dense<2> : tensor<1xi16>
|
||||
// CHECK-NEXT: [[vcst_2:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16>
|
||||
// CHECK-NEXT: [[vcm1_i16:%.*]] = arith.constant -1 : i16
|
||||
// CHECK-NEXT: [[v0:%.*]] = tensor.empty() : tensor<2x2xi16>
|
||||
// CHECK-NEXT: [[v1:%.*]] = linalg.fill ins([[vcm1_i16]] : i16) outs([[v0]] : tensor<2x2xi16>) -> tensor<2x2xi16>
|
||||
// CHECK-NEXT: [[vinserted_slice:%.*]] = tensor.insert_slice [[vcst_2]] into [[v1]][0, 0] [1, 2] [1, 1] : tensor<2xi16> into tensor<2x2xi16>
|
||||
// CHECK-NEXT: [[vinserted_slice_3:%.*]] = tensor.insert_slice [[vcst_1]] into [[vinserted_slice]][1, 0] [1, 1] [1, 1] : tensor<1xi16> into tensor<2x2xi16>
|
||||
// CHECK-NEXT: [[v2:%.*]] = tensor.empty() : tensor<0x0xi64>
|
||||
// CHECK-NEXT: [[v3:%.*]] = tensor.empty() : tensor<2x5xi64>
|
||||
// CHECK-NEXT: [[v4:%.*]] = linalg.fill ins([[vcm9223372036854775808_i64]] : i64) outs([[v3]] : tensor<2x5xi64>) -> tensor<2x5xi64>
|
||||
// CHECK-NEXT: [[vinserted_slice_4:%.*]] = tensor.insert_slice [[vcst_0]] into [[v4]][0, 0] [1, 5] [1, 1] : tensor<5xi64> into tensor<2x5xi64>
|
||||
// CHECK-NEXT: [[vinserted_slice_5:%.*]] = tensor.insert_slice [[vcst]] into [[vinserted_slice_4]][1, 0] [1, 5] [1, 1] : tensor<5xi64> into tensor<2x5xi64>
|
||||
// CHECK-NEXT: [[vcast:%.*]] = tensor.cast [[vinserted_slice_3]] : tensor<2x2xi16> to tensor<?x?xi16>
|
||||
// CHECK-NEXT: [[vcast_6:%.*]] = tensor.cast [[v2]] : tensor<0x0xi64> to tensor<?x?xi64>
|
||||
// CHECK-NEXT: [[vcast_7:%.*]] = tensor.cast [[vinserted_slice_5]] : tensor<2x5xi64> to tensor<?x?xi64>
|
||||
// CHECK-NEXT: return [[varg0]], [[vcast]], [[vcast_6]], [[vcast_7]] : tensor<?x?xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>
|
||||
// CHECK: [[vcst:%.*]] = arith.constant dense<[0, 0, 5, 10, 16]> : tensor<5xi64>
|
||||
// CHECK: [[vcst_0:%.*]] = arith.constant dense<[0, 3, 5, 7, 8]> : tensor<5xi64>
|
||||
// CHECK: [[vcm9223372036854775808_i64:%.*]] = arith.constant -9223372036854775808 : i64
|
||||
// CHECK: [[vcst_1:%.*]] = arith.constant dense<2> : tensor<1xi16>
|
||||
// CHECK: [[vcst_2:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16>
|
||||
// CHECK: [[vcm1_i16:%.*]] = arith.constant -1 : i16
|
||||
// CHECK: [[v0:%.*]] = tensor.empty() : tensor<2x2xi16>
|
||||
// CHECK: [[v1:%.*]] = linalg.fill ins([[vcm1_i16]] : i16) outs([[v0]] : tensor<2x2xi16>) -> tensor<2x2xi16>
|
||||
// CHECK: [[vinserted_slice:%.*]] = tensor.insert_slice [[vcst_2]] into [[v1]][0, 0] [1, 2] [1, 1] : tensor<2xi16> into tensor<2x2xi16>
|
||||
// CHECK: [[vinserted_slice_3:%.*]] = tensor.insert_slice [[vcst_1]] into [[vinserted_slice]][1, 0] [1, 1] [1, 1] : tensor<1xi16> into tensor<2x2xi16>
|
||||
// CHECK: [[v2:%.*]] = tensor.empty() : tensor<0x0xi64>
|
||||
// CHECK: [[v3:%.*]] = tensor.empty() : tensor<2x5xi64>
|
||||
// CHECK: [[v4:%.*]] = linalg.fill ins([[vcm9223372036854775808_i64]] : i64) outs([[v3]] : tensor<2x5xi64>) -> tensor<2x5xi64>
|
||||
// CHECK: [[vinserted_slice_4:%.*]] = tensor.insert_slice [[vcst_0]] into [[v4]][0, 0] [1, 5] [1, 1] : tensor<5xi64> into tensor<2x5xi64>
|
||||
// CHECK: [[vinserted_slice_5:%.*]] = tensor.insert_slice [[vcst]] into [[vinserted_slice_4]][1, 0] [1, 5] [1, 1] : tensor<5xi64> into tensor<2x5xi64>
|
||||
// CHECK: [[vcast:%.*]] = tensor.cast [[vinserted_slice_3]] : tensor<2x2xi16> to tensor<?x?xi16>
|
||||
// CHECK: [[vcast_6:%.*]] = tensor.cast [[v2]] : tensor<0x0xi64> to tensor<?x?xi64>
|
||||
// CHECK: [[vcast_7:%.*]] = tensor.cast [[vinserted_slice_5]] : tensor<2x5xi64> to tensor<?x?xi64>
|
||||
// CHECK: return [[varg0]], [[vcast]], [[vcast_6]], [[vcast_7]] : tensor<?x?xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>
|
||||
return %arg0, %sharding : tensor<?x?xf32>, !mesh.sharding
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user