[mlir][mpi] Mandatory Communicator (#133280)

This is replacing #125361
- communicator is mandatory
- new mpi.comm_world
- new mp.comm_split
- lowering and test

---------

Co-authored-by: Sergio Sánchez Ramírez <sergio.sanchez.ramirez+git@bsc.es>
This commit is contained in:
Frank Schlimbach
2025-04-01 08:58:55 +02:00
committed by GitHub
parent aa889ed129
commit 49f080afc4
7 changed files with 388 additions and 174 deletions

View File

@@ -37,26 +37,41 @@ def MPI_InitOp : MPI_Op<"init", []> {
let assemblyFormat = "attr-dict (`:` type($retval)^)?";
}
//===----------------------------------------------------------------------===//
// CommWorldOp
//===----------------------------------------------------------------------===//
def MPI_CommWorldOp : MPI_Op<"comm_world", []> {
let summary = "Get the World communicator, equivalent to `MPI_COMM_WORLD`";
let description = [{
This operation returns the predefined MPI_COMM_WORLD communicator.
}];
let results = (outs MPI_Comm : $comm);
let assemblyFormat = "attr-dict `:` type(results)";
}
//===----------------------------------------------------------------------===//
// CommRankOp
//===----------------------------------------------------------------------===//
def MPI_CommRankOp : MPI_Op<"comm_rank", []> {
let summary = "Get the current rank, equivalent to "
"`MPI_Comm_rank(MPI_COMM_WORLD, &rank)`";
"`MPI_Comm_rank(comm, &rank)`";
let description = [{
Communicators other than `MPI_COMM_WORLD` are not supported for now.
This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];
let arguments = (ins MPI_Comm : $comm);
let results = (
outs Optional<MPI_Retval> : $retval,
I32 : $rank
);
let assemblyFormat = "attr-dict `:` type(results)";
let assemblyFormat = "`(` $comm `)` attr-dict `:` type(results)";
}
//===----------------------------------------------------------------------===//
@@ -65,20 +80,48 @@ def MPI_CommRankOp : MPI_Op<"comm_rank", []> {
def MPI_CommSizeOp : MPI_Op<"comm_size", []> {
let summary = "Get the size of the group associated to the communicator, "
"equivalent to `MPI_Comm_size(MPI_COMM_WORLD, &size)`";
"equivalent to `MPI_Comm_size(comm, &size)`";
let description = [{
Communicators other than `MPI_COMM_WORLD` are not supported for now.
This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];
let arguments = (ins MPI_Comm : $comm);
let results = (
outs Optional<MPI_Retval> : $retval,
I32 : $size
);
let assemblyFormat = "attr-dict `:` type(results)";
let assemblyFormat = "`(` $comm `)` attr-dict `:` type(results)";
}
//===----------------------------------------------------------------------===//
// CommSplitOp
//===----------------------------------------------------------------------===//
def MPI_CommSplitOp : MPI_Op<"comm_split", []> {
let summary = "Partition the group associated with the given communicator into "
"disjoint subgroups";
let description = [{
This operation splits the communicator into multiple sub-communicators.
The color value determines the group of processes that will be part of the
new communicator. The key value determines the rank of the calling process
in the new communicator.
This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];
let arguments = (ins MPI_Comm : $comm, I32 : $color, I32 : $key);
let results = (
outs Optional<MPI_Retval> : $retval,
MPI_Comm : $newcomm
);
let assemblyFormat = "`(` $comm `,` $color `,` $key `)` attr-dict `:` "
"type(results)";
}
//===----------------------------------------------------------------------===//
@@ -87,14 +130,12 @@ def MPI_CommSizeOp : MPI_Op<"comm_size", []> {
def MPI_SendOp : MPI_Op<"send", []> {
let summary =
"Equivalent to `MPI_Send(ptr, size, dtype, dest, tag, MPI_COMM_WORLD)`";
"Equivalent to `MPI_Send(ptr, size, dtype, dest, tag, comm)`";
let description = [{
MPI_Send performs a blocking send of `size` elements of type `dtype` to rank
`dest`. The `tag` value and communicator enables the library to determine
the matching of multiple sends and receives between the same ranks.
Communicators other than `MPI_COMM_WORLD` are not supported for now.
This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];
@@ -102,12 +143,13 @@ def MPI_SendOp : MPI_Op<"send", []> {
let arguments = (
ins AnyMemRef : $ref,
I32 : $tag,
I32 : $dest
I32 : $dest,
MPI_Comm : $comm
);
let results = (outs Optional<MPI_Retval>:$retval);
let assemblyFormat = "`(` $ref `,` $tag `,` $dest `)` attr-dict `:` "
let assemblyFormat = "`(` $ref `,` $tag `,` $dest `,` $comm `)` attr-dict `:` "
"type($ref) `,` type($tag) `,` type($dest)"
"(`->` type($retval)^)?";
let hasCanonicalizer = 1;
@@ -119,15 +161,13 @@ def MPI_SendOp : MPI_Op<"send", []> {
def MPI_ISendOp : MPI_Op<"isend", []> {
let summary =
"Equivalent to `MPI_Isend(ptr, size, dtype, dest, tag, MPI_COMM_WORLD)`";
"Equivalent to `MPI_Isend(ptr, size, dtype, dest, tag, comm)`";
let description = [{
MPI_Isend begins a non-blocking send of `size` elements of type `dtype` to
rank `dest`. The `tag` value and communicator enables the library to
determine the matching of multiple sends and receives between the same
ranks.
Communicators other than `MPI_COMM_WORLD` are not supported for now.
This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];
@@ -135,7 +175,8 @@ def MPI_ISendOp : MPI_Op<"isend", []> {
let arguments = (
ins AnyMemRef : $ref,
I32 : $tag,
I32 : $rank
I32 : $dest,
MPI_Comm : $comm
);
let results = (
@@ -143,8 +184,8 @@ def MPI_ISendOp : MPI_Op<"isend", []> {
MPI_Request : $req
);
let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict "
"`:` type($ref) `,` type($tag) `,` type($rank) "
let assemblyFormat = "`(` $ref `,` $tag `,` $dest `,` $comm`)` attr-dict "
"`:` type($ref) `,` type($tag) `,` type($dest) "
"`->` type(results)";
let hasCanonicalizer = 1;
}
@@ -155,14 +196,13 @@ def MPI_ISendOp : MPI_Op<"isend", []> {
def MPI_RecvOp : MPI_Op<"recv", []> {
let summary = "Equivalent to `MPI_Recv(ptr, size, dtype, source, tag, "
"MPI_COMM_WORLD, MPI_STATUS_IGNORE)`";
"comm, MPI_STATUS_IGNORE)`";
let description = [{
MPI_Recv performs a blocking receive of `size` elements of type `dtype`
from rank `source`. The `tag` value and communicator enables the library to
determine the matching of multiple sends and receives between the same
ranks.
Communicators other than `MPI_COMM_WORLD` are not supported for now.
The MPI_Status is set to `MPI_STATUS_IGNORE`, as the status object
is not yet ported to MLIR.
@@ -172,13 +212,14 @@ def MPI_RecvOp : MPI_Op<"recv", []> {
let arguments = (
ins AnyMemRef : $ref,
I32 : $tag, I32 : $source
I32 : $tag, I32 : $source,
MPI_Comm : $comm
);
let results = (outs Optional<MPI_Retval>:$retval);
let assemblyFormat = "`(` $ref `,` $tag `,` $source `)` attr-dict `:` "
"type($ref) `,` type($tag) `,` type($source)"
let assemblyFormat = "`(` $ref `,` $tag `,` $source `,` $comm `)` attr-dict"
" `:` type($ref) `,` type($tag) `,` type($source) "
"(`->` type($retval)^)?";
let hasCanonicalizer = 1;
}
@@ -188,16 +229,14 @@ def MPI_RecvOp : MPI_Op<"recv", []> {
//===----------------------------------------------------------------------===//
def MPI_IRecvOp : MPI_Op<"irecv", []> {
let summary = "Equivalent to `MPI_Irecv(ptr, size, dtype, dest, tag, "
"MPI_COMM_WORLD, &req)`";
let summary = "Equivalent to `MPI_Irecv(ptr, size, dtype, source, tag, "
"comm, &req)`";
let description = [{
MPI_Irecv begins a non-blocking receive of `size` elements of type `dtype`
from rank `dest`. The `tag` value and communicator enables the library to
from rank `source`. The `tag` value and communicator enables the library to
determine the matching of multiple sends and receives between the same
ranks.
Communicators other than `MPI_COMM_WORLD` are not supported for now.
This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];
@@ -205,7 +244,8 @@ def MPI_IRecvOp : MPI_Op<"irecv", []> {
let arguments = (
ins AnyMemRef : $ref,
I32 : $tag,
I32 : $rank
I32 : $source,
MPI_Comm : $comm
);
let results = (
@@ -213,9 +253,9 @@ def MPI_IRecvOp : MPI_Op<"irecv", []> {
MPI_Request : $req
);
let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:`"
"type($ref) `,` type($tag) `,` type($rank) `->`"
"type(results)";
let assemblyFormat = "`(` $ref `,` $tag `,` $source `,` $comm`)` attr-dict "
"`:` type($ref) `,` type($tag) `,` type($source)"
"`->` type(results)";
let hasCanonicalizer = 1;
}
@@ -224,8 +264,7 @@ def MPI_IRecvOp : MPI_Op<"irecv", []> {
//===----------------------------------------------------------------------===//
def MPI_AllReduceOp : MPI_Op<"allreduce", []> {
let summary = "Equivalent to `MPI_Allreduce(sendbuf, recvbuf, op, "
"MPI_COMM_WORLD)`";
let summary = "Equivalent to `MPI_Allreduce(sendbuf, recvbuf, op, comm)`";
let description = [{
MPI_Allreduce performs a reduction operation on the values in the sendbuf
array and stores the result in the recvbuf array. The operation is
@@ -235,8 +274,6 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> {
Currently only the `MPI_Op` predefined in the standard (e.g. `MPI_SUM`) are
supported.
Communicators other than `MPI_COMM_WORLD` are not supported for now.
This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];
@@ -244,13 +281,14 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> {
let arguments = (
ins AnyMemRef : $sendbuf,
AnyMemRef : $recvbuf,
MPI_OpClassEnum : $op
MPI_OpClassEnum : $op,
MPI_Comm : $comm
);
let results = (outs Optional<MPI_Retval>:$retval);
let assemblyFormat = "`(` $sendbuf `,` $recvbuf `,` $op `)` attr-dict `:`"
"type($sendbuf) `,` type($recvbuf)"
let assemblyFormat = "`(` $sendbuf `,` $recvbuf `,` $op `,` $comm `)` "
"attr-dict `:` type($sendbuf) `,` type($recvbuf) "
"(`->` type($retval)^)?";
}
@@ -259,20 +297,23 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> {
//===----------------------------------------------------------------------===//
def MPI_Barrier : MPI_Op<"barrier", []> {
let summary = "Equivalent to `MPI_Barrier(MPI_COMM_WORLD)`";
let summary = "Equivalent to `MPI_Barrier(comm)`";
let description = [{
MPI_Barrier blocks execution until all processes in the communicator have
reached this routine.
Communicators other than `MPI_COMM_WORLD` are not supported for now.
This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];
let arguments = (ins MPI_Comm : $comm);
let results = (outs Optional<MPI_Retval>:$retval);
let assemblyFormat = "attr-dict (`:` type($retval) ^)?";
let assemblyFormat = [{
`(` $comm `)` attr-dict
(`->` type($retval)^)?
}];
}
//===----------------------------------------------------------------------===//
@@ -295,8 +336,7 @@ def MPI_Wait : MPI_Op<"wait", []> {
let results = (outs Optional<MPI_Retval>:$retval);
let assemblyFormat = "`(` $req `)` attr-dict `:` type($req) "
"(`->` type($retval) ^)?";
let assemblyFormat = "`(` $req `)` attr-dict `:` type($req) (`->` type($retval) ^)?";
}
//===----------------------------------------------------------------------===//

View File

@@ -40,6 +40,17 @@ def MPI_Retval : MPI_Type<"Retval", "retval"> {
}];
}
//===----------------------------------------------------------------------===//
// mpi::CommType
//===----------------------------------------------------------------------===//
def MPI_Comm : MPI_Type<"Comm", "comm"> {
let summary = "MPI communicator handler";
let description = [{
This type represents a handler for the MPI communicator.
}];
}
//===----------------------------------------------------------------------===//
// mpi::RequestType
//===----------------------------------------------------------------------===//

View File

@@ -83,9 +83,17 @@ public:
ModuleOp &getModuleOp() { return moduleOp; }
/// Gets or creates MPI_COMM_WORLD as a Value.
/// Different MPI implementations have different communicator types.
/// Using i64 as a portable, intermediate type.
/// Appropriate cast needs to take place before calling MPI functions.
virtual Value getCommWorld(const Location loc,
ConversionPatternRewriter &rewriter) = 0;
/// Type converter provides i64 type for communicator type.
/// Converts to native type, which might be ptr or int or whatever.
virtual Value castComm(const Location loc,
ConversionPatternRewriter &rewriter, Value comm) = 0;
/// Get the MPI_STATUS_IGNORE value (typically a pointer type).
virtual intptr_t getStatusIgnore() = 0;
@@ -139,10 +147,15 @@ public:
Value getCommWorld(const Location loc,
ConversionPatternRewriter &rewriter) override {
static constexpr int MPI_COMM_WORLD = 0x44000000;
return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(),
return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(),
MPI_COMM_WORLD);
}
Value castComm(const Location loc, ConversionPatternRewriter &rewriter,
Value comm) override {
return rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), comm);
}
intptr_t getStatusIgnore() override { return 1; }
Value getDataType(const Location loc, ConversionPatternRewriter &rewriter,
@@ -256,9 +269,16 @@ public:
getOrDefineExternalStruct(loc, rewriter, name, commStructT);
// get address of symbol
return rewriter.create<LLVM::AddressOfOp>(
auto comm = rewriter.create<LLVM::AddressOfOp>(
loc, LLVM::LLVMPointerType::get(context),
SymbolRefAttr::get(context, name));
return rewriter.create<LLVM::PtrToIntOp>(loc, rewriter.getI64Type(), comm);
}
Value castComm(const Location loc, ConversionPatternRewriter &rewriter,
Value comm) override {
return rewriter.create<LLVM::IntToPtrOp>(
loc, LLVM::LLVMPointerType::get(rewriter.getContext()), comm);
}
intptr_t getStatusIgnore() override { return 0; }
@@ -440,6 +460,78 @@ struct FinalizeOpLowering : public ConvertOpToLLVMPattern<mpi::FinalizeOp> {
}
};
//===----------------------------------------------------------------------===//
// CommWorldOpLowering
//===----------------------------------------------------------------------===//
struct CommWorldOpLowering : public ConvertOpToLLVMPattern<mpi::CommWorldOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(mpi::CommWorldOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// grab a reference to the global module op:
auto moduleOp = op->getParentOfType<ModuleOp>();
auto mpiTraits = MPIImplTraits::get(moduleOp);
// get MPI_COMM_WORLD
rewriter.replaceOp(op, mpiTraits->getCommWorld(op.getLoc(), rewriter));
return success();
}
};
//===----------------------------------------------------------------------===//
// CommSplitOpLowering
//===----------------------------------------------------------------------===//
struct CommSplitOpLowering : public ConvertOpToLLVMPattern<mpi::CommSplitOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(mpi::CommSplitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// grab a reference to the global module op:
auto moduleOp = op->getParentOfType<ModuleOp>();
auto mpiTraits = MPIImplTraits::get(moduleOp);
Type i32 = rewriter.getI32Type();
Type ptrType = LLVM::LLVMPointerType::get(op->getContext());
Location loc = op.getLoc();
// get communicator
Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
auto one = rewriter.create<LLVM::ConstantOp>(loc, i32, 1);
auto outPtr =
rewriter.create<LLVM::AllocaOp>(loc, ptrType, comm.getType(), one);
// int MPI_Comm_split(MPI_Comm comm, int color, int key, MPI_Comm * newcomm)
auto funcType =
LLVM::LLVMFunctionType::get(i32, {comm.getType(), i32, i32, ptrType});
// get or create function declaration:
LLVM::LLVMFuncOp funcDecl = getOrDefineFunction(moduleOp, loc, rewriter,
"MPI_Comm_split", funcType);
auto callOp = rewriter.create<LLVM::CallOp>(
loc, funcDecl,
ValueRange{comm, adaptor.getColor(), adaptor.getKey(),
outPtr.getRes()});
// load the communicator into a register
auto res = rewriter.create<LLVM::LoadOp>(loc, i32, outPtr.getResult());
// if retval is checked, replace uses of retval with the results from the
// call op
SmallVector<Value> replacements;
if (op.getRetval())
replacements.push_back(callOp.getResult());
// replace op
replacements.push_back(res.getRes());
rewriter.replaceOp(op, replacements);
return success();
}
};
//===----------------------------------------------------------------------===//
// CommRankOpLowering
//===----------------------------------------------------------------------===//
@@ -462,21 +554,21 @@ struct CommRankOpLowering : public ConvertOpToLLVMPattern<mpi::CommRankOp> {
auto moduleOp = op->getParentOfType<ModuleOp>();
auto mpiTraits = MPIImplTraits::get(moduleOp);
// get MPI_COMM_WORLD
Value commWorld = mpiTraits->getCommWorld(loc, rewriter);
// get communicator
Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
// LLVM Function type representing `i32 MPI_Comm_rank(ptr, ptr)`
auto rankFuncType =
LLVM::LLVMFunctionType::get(i32, {commWorld.getType(), ptrType});
LLVM::LLVMFunctionType::get(i32, {comm.getType(), ptrType});
// get or create function declaration:
LLVM::LLVMFuncOp initDecl = getOrDefineFunction(
moduleOp, loc, rewriter, "MPI_Comm_rank", rankFuncType);
// replace init with function call
// replace with function call
auto one = rewriter.create<LLVM::ConstantOp>(loc, i32, 1);
auto rankptr = rewriter.create<LLVM::AllocaOp>(loc, ptrType, i32, one);
auto callOp = rewriter.create<LLVM::CallOp>(
loc, initDecl, ValueRange{commWorld, rankptr.getRes()});
loc, initDecl, ValueRange{comm, rankptr.getRes()});
// load the rank into a register
auto loadedRank =
@@ -523,12 +615,12 @@ struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType);
auto mpiTraits = MPIImplTraits::get(moduleOp);
Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
Value commWorld = mpiTraits->getCommWorld(loc, rewriter);
Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
// LLVM Function type representing `i32 MPI_send(data, count, datatype, dst,
// tag, comm)`
auto funcType = LLVM::LLVMFunctionType::get(
i32, {ptrType, i32, dataType.getType(), i32, i32, commWorld.getType()});
i32, {ptrType, i32, dataType.getType(), i32, i32, comm.getType()});
// get or create function declaration:
LLVM::LLVMFuncOp funcDecl =
getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Send", funcType);
@@ -537,7 +629,7 @@ struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
auto funcCall = rewriter.create<LLVM::CallOp>(
loc, funcDecl,
ValueRange{dataPtr, size, dataType, adaptor.getDest(), adaptor.getTag(),
commWorld});
comm});
if (op.getRetval())
rewriter.replaceOp(op, funcCall.getResult());
else
@@ -575,7 +667,7 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType);
auto mpiTraits = MPIImplTraits::get(moduleOp);
Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
Value commWorld = mpiTraits->getCommWorld(loc, rewriter);
Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
Value statusIgnore = rewriter.create<LLVM::ConstantOp>(
loc, i64, mpiTraits->getStatusIgnore());
statusIgnore =
@@ -585,7 +677,7 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
// tag, comm)`
auto funcType =
LLVM::LLVMFunctionType::get(i32, {ptrType, i32, dataType.getType(), i32,
i32, commWorld.getType(), ptrType});
i32, comm.getType(), ptrType});
// get or create function declaration:
LLVM::LLVMFuncOp funcDecl =
getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Recv", funcType);
@@ -594,7 +686,7 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
auto funcCall = rewriter.create<LLVM::CallOp>(
loc, funcDecl,
ValueRange{dataPtr, size, dataType, adaptor.getSource(),
adaptor.getTag(), commWorld, statusIgnore});
adaptor.getTag(), comm, statusIgnore});
if (op.getRetval())
rewriter.replaceOp(op, funcCall.getResult());
else
@@ -629,7 +721,8 @@ struct AllReduceOpLowering : public ConvertOpToLLVMPattern<mpi::AllReduceOp> {
getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), elemType);
Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
Value mpiOp = mpiTraits->getMPIOp(loc, rewriter, op.getOp());
Value commWorld = mpiTraits->getCommWorld(loc, rewriter);
Value commWorld = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
// 'int MPI_Allreduce(const void *sendbuf, void *recvbuf, int count,
// MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)'
auto funcType = LLVM::LLVMFunctionType::get(
@@ -676,8 +769,15 @@ struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
void mpi::populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns) {
patterns.add<CommRankOpLowering, FinalizeOpLowering, InitOpLowering,
SendOpLowering, RecvOpLowering, AllReduceOpLowering>(converter);
// Using i64 as a portable, intermediate type for !mpi.comm.
// It would be nicer to somehow get the right type directly, but TLDI is not
// available here.
converter.addConversion([](mpi::CommType type) {
return IntegerType::get(type.getContext(), 64);
});
patterns.add<CommRankOpLowering, CommSplitOpLowering, CommWorldOpLowering,
FinalizeOpLowering, InitOpLowering, SendOpLowering,
RecvOpLowering, AllReduceOpLowering>(converter);
}
void mpi::registerConvertMPIToLLVMInterface(DialectRegistry &registry) {

View File

@@ -310,11 +310,16 @@ public:
}
// Otherwise call create mpi::CommRankOp
auto rank = rewriter
.create<mpi::CommRankOp>(
loc, TypeRange{mpi::RetvalType::get(op->getContext()),
rewriter.getI32Type()})
.getRank();
auto ctx = op.getContext();
Value commWorld =
rewriter.create<mpi::CommWorldOp>(loc, mpi::CommType::get(ctx));
auto rank =
rewriter
.create<mpi::CommRankOp>(
loc,
TypeRange{mpi::RetvalType::get(ctx), rewriter.getI32Type()},
commWorld)
.getRank();
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, rewriter.getIndexType(),
rank);
return success();
@@ -652,6 +657,9 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
auto upperSendOffset = rewriter.create<arith::SubIOp>(
loc, upperRecvOffset, toValue(haloSizes[currHaloDim * 2]));
Value commWorld = rewriter.create<mpi::CommWorldOp>(
loc, mpi::CommType::get(op->getContext()));
// Make sure we send/recv in a way that does not lead to a dead-lock.
// The current approach is by far not optimal, this should be at least
// be a red-black pattern or using MPI_sendrecv.
@@ -680,7 +688,8 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
auto subview = builder.create<memref::SubViewOp>(
loc, array, offsets, dimSizes, strides);
builder.create<memref::CopyOp>(loc, subview, buffer);
builder.create<mpi::SendOp>(loc, TypeRange{}, buffer, tag, to);
builder.create<mpi::SendOp>(loc, TypeRange{}, buffer, tag, to,
commWorld);
builder.create<scf::YieldOp>(loc);
});
// if has neighbor: receive halo data into buffer and copy to array
@@ -688,7 +697,8 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
loc, hasFrom, [&](OpBuilder &builder, Location loc) {
offsets[dim] = upperHalo ? OpFoldResult(upperRecvOffset)
: OpFoldResult(lowerRecvOffset);
builder.create<mpi::RecvOp>(loc, TypeRange{}, buffer, tag, from);
builder.create<mpi::RecvOp>(loc, TypeRange{}, buffer, tag, from,
commWorld);
auto subview = builder.create<memref::SubViewOp>(
loc, array, offsets, dimSizes, strides);
builder.create<memref::CopyOp>(loc, buffer, subview);

View File

@@ -3,6 +3,7 @@
// COM: Test MPICH ABI
// CHECK: module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: llvm.func @MPI_Finalize() -> i32
// CHECK: llvm.func @MPI_Comm_split(i32, i32, i32, !llvm.ptr) -> i32
// CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
// CHECK: llvm.func @MPI_Send(!llvm.ptr, i32, i32, i32, i32, i32) -> i32
// CHECK: llvm.func @MPI_Comm_rank(i32, !llvm.ptr) -> i32
@@ -22,11 +23,14 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: [[v7:%.*]] = llvm.call @MPI_Init([[v6]], [[v6]]) : (!llvm.ptr, !llvm.ptr) -> i32
%0 = mpi.init : !mpi.retval
// CHECK: [[v8:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
// CHECK: [[comm:%.*]] = llvm.mlir.constant(1140850688 : i64) : i64
%comm = mpi.comm_world : !mpi.comm
// CHECK: [[v8:%.*]] = llvm.trunc [[comm]] : i64 to i32
// CHECK: [[v9:%.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: [[v10:%.*]] = llvm.alloca [[v9]] x i32 : (i32) -> !llvm.ptr
// CHECK: [[v11:%.*]] = llvm.call @MPI_Comm_rank([[v8]], [[v10]]) : (i32, !llvm.ptr) -> i32
%retval, %rank = mpi.comm_rank : !mpi.retval, i32
%retval, %rank = mpi.comm_rank(%comm) : !mpi.retval, i32
// CHECK: [[v12:%.*]] = llvm.load [[v10]] : !llvm.ptr -> i32
// CHECK: [[v13:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
@@ -35,9 +39,9 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: [[v16:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v17:%.*]] = llvm.trunc [[v16]] : i64 to i32
// CHECK: [[v18:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
// CHECK: [[v19:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
// CHECK: [[v20:%.*]] = llvm.call @MPI_Send([[v15]], [[v17]], [[v18]], [[v12]], [[v12]], [[v19]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32
mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
// CHECK: [[comm_1:%.*]] = llvm.trunc [[comm]] : i64 to i32
// CHECK: [[v20:%.*]] = llvm.call @MPI_Send([[v15]], [[v17]], [[v18]], [[v12]], [[v12]], [[comm_1]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32
mpi.send(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32
// CHECK: [[v21:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v22:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
@@ -45,9 +49,9 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: [[v24:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v25:%.*]] = llvm.trunc [[v24]] : i64 to i32
// CHECK: [[v26:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
// CHECK: [[v27:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
// CHECK: [[v28:%.*]] = llvm.call @MPI_Send([[v23]], [[v25]], [[v26]], [[v12]], [[v12]], [[v27]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32
%1 = mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
// CHECK: [[comm_2:%.*]] = llvm.trunc [[comm]] : i64 to i32
// CHECK: [[v28:%.*]] = llvm.call @MPI_Send([[v23]], [[v25]], [[v26]], [[v12]], [[v12]], [[comm_2]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32
%1 = mpi.send(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval
// CHECK: [[v29:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v30:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
@@ -55,11 +59,11 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: [[v32:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v33:%.*]] = llvm.trunc [[v32]] : i64 to i32
// CHECK: [[v34:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
// CHECK: [[v35:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
// CHECK: [[comm_3:%.*]] = llvm.trunc [[comm]] : i64 to i32
// CHECK: [[v36:%.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: [[v37:%.*]] = llvm.inttoptr [[v36]] : i64 to !llvm.ptr
// CHECK: [[v38:%.*]] = llvm.call @MPI_Recv([[v31]], [[v33]], [[v34]], [[v12]], [[v12]], [[v35]], [[v37]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
// CHECK: [[v38:%.*]] = llvm.call @MPI_Recv([[v31]], [[v33]], [[v34]], [[v12]], [[v12]], [[comm_3]], [[v37]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32
// CHECK: [[v39:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v40:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
@@ -67,27 +71,38 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: [[v42:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v43:%.*]] = llvm.trunc [[v42]] : i64 to i32
// CHECK: [[v44:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
// CHECK: [[v45:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
// CHECK: [[comm_4:%.*]] = llvm.trunc [[comm]] : i64 to i32
// CHECK: [[v46:%.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: [[v47:%.*]] = llvm.inttoptr [[v46]] : i64 to !llvm.ptr
// CHECK: [[v48:%.*]] = llvm.call @MPI_Recv([[v41]], [[v43]], [[v44]], [[v12]], [[v12]], [[v45]], [[v47]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
%2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
// CHECK: [[v48:%.*]] = llvm.call @MPI_Recv([[v41]], [[v43]], [[v44]], [[v12]], [[v12]], [[comm_4]], [[v47]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
%2 = mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval
// CHECK: [[v51:%.*]] = llvm.mlir.constant(10 : i32) : i32
%color = arith.constant 10 : i32
// CHECK: [[v52:%.*]] = llvm.mlir.constant(22 : i32) : i32
%key = arith.constant 22 : i32
// CHECK: [[v53:%.*]] = llvm.trunc [[comm]] : i64 to i32
// CHECK: [[v54:%.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: [[v55:%.*]] = llvm.alloca [[v54]] x i32 : (i32) -> !llvm.ptr
// CHECK: [[v56:%.*]] = llvm.call @MPI_Comm_split([[v53]], [[v51]], [[v52]], [[v55]]) : (i32, i32, i32, !llvm.ptr) -> i32
// CHECK: [[v57:%.*]] = llvm.load [[v55]] : !llvm.ptr -> i32
%split = mpi.comm_split(%comm, %color, %key) : !mpi.comm
// CHECK: [[v49:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v50:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v51:%.*]] = llvm.getelementptr [[v49]][[[v50]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[v52:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v53:%.*]] = llvm.trunc [[v52]] : i64 to i32
// CHECK: [[v54:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v55:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v56:%.*]] = llvm.getelementptr [[v54]][[[v55]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[v57:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v58:%.*]] = llvm.trunc [[v57]] : i64 to i32
// CHECK: [[v59:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
// CHECK: [[v60:%.*]] = llvm.mlir.constant(1476395011 : i32) : i32
// CHECK: [[v61:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
// CHECK: [[v62:%.*]] = llvm.call @MPI_Allreduce([[v51]], [[v56]], [[v53]], [[v59]], [[v60]], [[v61]]) : (!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
mpi.allreduce(%arg0, %arg0, MPI_SUM) : memref<100xf32>, memref<100xf32>
// CHECK: [[v59:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v60:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v61:%.*]] = llvm.getelementptr [[v59]][[[v60]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[v62:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v63:%.*]] = llvm.trunc [[v62]] : i64 to i32
// CHECK: [[v64:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v65:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v66:%.*]] = llvm.getelementptr [[v64]][[[v65]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[v67:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v68:%.*]] = llvm.trunc [[v67]] : i64 to i32
// CHECK: [[v69:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
// CHECK: [[v70:%.*]] = llvm.mlir.constant(1476395011 : i32) : i32
// CHECK: [[v71:%.*]] = llvm.trunc [[comm]] : i64 to i32
// CHECK: [[v72:%.*]] = llvm.call @MPI_Allreduce([[v61]], [[v66]], [[v63]], [[v69]], [[v70]], [[v71]]) : (!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
mpi.allreduce(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
// CHECK: llvm.call @MPI_Finalize() : () -> i32
%3 = mpi.finalize : !mpi.retval
@@ -101,6 +116,7 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// COM: Test OpenMPI ABI
// CHECK: module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI">} {
// CHECK: llvm.func @MPI_Finalize() -> i32
// CHECK: llvm.func @MPI_Comm_split(!llvm.ptr, i32, i32, !llvm.ptr) -> i32
// CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
// CHECK: llvm.func @MPI_Send(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32
// CHECK: llvm.mlir.global external @ompi_mpi_float() {addr_space = 0 : i32} : !llvm.struct<"ompi_predefined_datatype_t", opaque>
@@ -122,11 +138,14 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
// CHECK: [[v7:%.*]] = llvm.call @MPI_Init([[v6]], [[v6]]) : (!llvm.ptr, !llvm.ptr) -> i32
%0 = mpi.init : !mpi.retval
%comm = mpi.comm_world : !mpi.comm
// CHECK: [[v8:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
// CHECK: [[comm:%.*]] = llvm.ptrtoint [[v8]] : !llvm.ptr to i64
// CHECK: [[comm_1:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
// CHECK: [[v9:%.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: [[v10:%.*]] = llvm.alloca [[v9]] x i32 : (i32) -> !llvm.ptr
// CHECK: [[v11:%.*]] = llvm.call @MPI_Comm_rank([[v8]], [[v10]]) : (!llvm.ptr, !llvm.ptr) -> i32
%retval, %rank = mpi.comm_rank : !mpi.retval, i32
// CHECK: [[v11:%.*]] = llvm.call @MPI_Comm_rank([[comm_1]], [[v10]]) : (!llvm.ptr, !llvm.ptr) -> i32
%retval, %rank = mpi.comm_rank(%comm) : !mpi.retval, i32
// CHECK: [[v12:%.*]] = llvm.load [[v10]] : !llvm.ptr -> i32
// CHECK: [[v13:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
@@ -135,9 +154,9 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
// CHECK: [[v16:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v17:%.*]] = llvm.trunc [[v16]] : i64 to i32
// CHECK: [[v18:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
// CHECK: [[v19:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
// CHECK: [[v19:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
// CHECK: [[v20:%.*]] = llvm.call @MPI_Send([[v15]], [[v17]], [[v18]], [[v12]], [[v12]], [[v19]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32
mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
mpi.send(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32
// CHECK: [[v21:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v22:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
@@ -145,9 +164,9 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
// CHECK: [[v24:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v25:%.*]] = llvm.trunc [[v24]] : i64 to i32
// CHECK: [[v26:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
// CHECK: [[v27:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
// CHECK: [[v27:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
// CHECK: [[v28:%.*]] = llvm.call @MPI_Send([[v23]], [[v25]], [[v26]], [[v12]], [[v12]], [[v27]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32
%1 = mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
%1 = mpi.send(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval
// CHECK: [[v29:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v30:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
@@ -155,11 +174,11 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
// CHECK: [[v32:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v33:%.*]] = llvm.trunc [[v32]] : i64 to i32
// CHECK: [[v34:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
// CHECK: [[v35:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
// CHECK: [[v35:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
// CHECK: [[v36:%.*]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: [[v37:%.*]] = llvm.inttoptr [[v36]] : i64 to !llvm.ptr
// CHECK: [[v38:%.*]] = llvm.call @MPI_Recv([[v31]], [[v33]], [[v34]], [[v12]], [[v12]], [[v35]], [[v37]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32
// CHECK: [[v39:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v40:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
@@ -167,11 +186,11 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
// CHECK: [[v42:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v43:%.*]] = llvm.trunc [[v42]] : i64 to i32
// CHECK: [[v44:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
// CHECK: [[v45:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
// CHECK: [[v45:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
// CHECK: [[v46:%.*]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: [[v47:%.*]] = llvm.inttoptr [[v46]] : i64 to !llvm.ptr
// CHECK: [[v48:%.*]] = llvm.call @MPI_Recv([[v41]], [[v43]], [[v44]], [[v12]], [[v12]], [[v45]], [[v47]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
%2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
%2 = mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval
// CHECK: [[v49:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v50:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
@@ -185,11 +204,22 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
// CHECK: [[v58:%.*]] = llvm.trunc [[v57]] : i64 to i32
// CHECK: [[v59:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
// CHECK: [[v60:%.*]] = llvm.mlir.addressof @ompi_mpi_sum : !llvm.ptr
// CHECK: [[v61:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
// CHECK: [[v61:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
// CHECK: [[v62:%.*]] = llvm.call @MPI_Allreduce([[v51]], [[v56]], [[v53]], [[v59]], [[v60]], [[v61]]) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
mpi.allreduce(%arg0, %arg0, MPI_SUM) : memref<100xf32>, memref<100xf32>
mpi.allreduce(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
// CHECK: [[v49:%.*]] = llvm.call @MPI_Finalize() : () -> i32
// CHECK: [[v71:%.*]] = llvm.mlir.constant(10 : i32) : i32
%color = arith.constant 10 : i32
// CHECK: [[v72:%.*]] = llvm.mlir.constant(22 : i32) : i32
%key = arith.constant 22 : i32
// CHECK: [[v73:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
// CHECK: [[v74:%.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: [[v75:%.*]] = llvm.alloca [[v74]] x !llvm.ptr : (i32) -> !llvm.ptr
// CHECK: [[v76:%.*]] = llvm.call @MPI_Comm_split([[v73]], [[v71]], [[v72]], [[v75]]) : (!llvm.ptr, i32, i32, !llvm.ptr) -> i32
// CHECK: [[v77:%.*]] = llvm.load [[v75]] : !llvm.ptr -> i32
%split = mpi.comm_split(%comm, %color, %key) : !mpi.comm
// CHECK: llvm.call @MPI_Finalize() : () -> i32
%3 = mpi.finalize : !mpi.retval
return

View File

@@ -4,7 +4,7 @@
// CHECK: mesh.mesh @mesh0
mesh.mesh @mesh0(shape = 3x4x5)
func.func @process_multi_index() -> (index, index, index) {
// CHECK: mpi.comm_rank : !mpi.retval, i32
// CHECK: mpi.comm_rank
// CHECK-DAG: %[[v4:.*]] = arith.remsi
// CHECK-DAG: %[[v0:.*]] = arith.remsi
// CHECK-DAG: %[[v1:.*]] = arith.remsi
@@ -15,7 +15,7 @@ func.func @process_multi_index() -> (index, index, index) {
// CHECK-LABEL: func @process_linear_index
func.func @process_linear_index() -> index {
// CHECK: %[[RES:.*]], %[[rank:.*]] = mpi.comm_rank : !mpi.retval, i32
// CHECK: %[[RES:.*]], %[[rank:.*]] = mpi.comm_rank
// CHECK: %[[cast:.*]] = arith.index_cast %[[rank]] : i32 to index
%0 = mesh.process_linear_index on @mesh0 : index
// CHECK: return %[[cast]] : index
@@ -113,17 +113,17 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 1> } {
// CHECK: [[vc91_i32:%.*]] = arith.constant 91 : i32
// CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32
// CHECK-NEXT: [[vc2_i32:%.*]] = arith.constant 2 : i32
// CHECK-NEXT: [[v0:%.*]] = mpi.comm_world : !mpi.comm
// CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<2x120x120xi8>
// CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][118, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8
// CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<2x120x120xi8
// CHECK-SAME: to memref<2x120x120xi8>
// CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc2_i32]]) : memref<2x120x120xi8>, i32, i32
// CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc0_i32]]) : memref<2x120x120xi8>, i32, i32
// CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[varg0]][0, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8
// CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<2x120x120xi8> to memref<2x120x120xi8
// CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][118, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>>
// CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>> to memref<2x120x120xi8>
// CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc2_i32]], [[v0]]) : memref<2x120x120xi8>, i32, i32
// CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc0_i32]], [[v0]]) : memref<2x120x120xi8>, i32, i32
// CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[varg0]][0, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1]>>
// CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<2x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1]>>
// CHECK-NEXT: memref.dealloc [[valloc]] : memref<2x120x120xi8>
%res = mesh.update_halo %arg0 on @mesh0 split_axes = [[0]] halo_sizes = [2, 0] : memref<120x120x120xi8>
// CHECK: return [[res:%.*]] : memref<120x120x120xi8>
// CHECK: return [[varg0]] : memref<120x120x120xi8>
return %res : memref<120x120x120xi8>
}
}
@@ -140,41 +140,44 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
// CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
// CHECK-NEXT: [[vc4_i32:%.*]] = arith.constant 4 : i32
// CHECK-NEXT: [[vc44_i32:%.*]] = arith.constant 44 : i32
// CHECK-NEXT: [[v0:%.*]] = mpi.comm_world : !mpi.comm
// CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8>
// CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>>
// CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8>
// CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x5xi8>, i32, i32
// CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x5xi8>, i32, i32
// CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc44_i32]], [[v0]]) : memref<117x113x5xi8>, i32, i32
// CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc4_i32]], [[v0]]) : memref<117x113x5xi8>, i32, i32
// CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[varg0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
// CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
// CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8>
// CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8>
// CHECK-NEXT: [[vsubview_2:%.*]] = memref.subview [[varg0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>>
// CHECK-NEXT: memref.copy [[vsubview_2]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8>
// CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x6xi8>, i32, i32
// CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x6xi8>, i32, i32
// CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc4_i32]], [[v0]]) : memref<117x113x6xi8>, i32, i32
// CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc44_i32]], [[v0]]) : memref<117x113x6xi8>, i32, i32
// CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[varg0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
// CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_3]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
// CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8>
// CHECK-NEXT: [[v1:%.*]] = mpi.comm_world : !mpi.comm
// CHECK-NEXT: [[valloc_4:%.*]] = memref.alloc() : memref<117x3x120xi8>
// CHECK-NEXT: [[vsubview_5:%.*]] = memref.subview [[varg0]][1, 113, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>>
// CHECK-NEXT: memref.copy [[vsubview_5]], [[valloc_4]] : memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>> to memref<117x3x120xi8>
// CHECK-NEXT: mpi.send([[valloc_4]], [[vc91_i32]], [[vc29_i32]]) : memref<117x3x120xi8>, i32, i32
// CHECK-NEXT: mpi.send([[valloc_4]], [[vc91_i32]], [[vc29_i32]], [[v1]]) : memref<117x3x120xi8>, i32, i32
// CHECK-NEXT: memref.dealloc [[valloc_4]] : memref<117x3x120xi8>
// CHECK-NEXT: [[valloc_6:%.*]] = memref.alloc() : memref<117x4x120xi8>
// CHECK-NEXT: mpi.recv([[valloc_6]], [[vc91_i32]], [[vc29_i32]]) : memref<117x4x120xi8>, i32, i32
// CHECK-NEXT: mpi.recv([[valloc_6]], [[vc91_i32]], [[vc29_i32]], [[v1]]) : memref<117x4x120xi8>, i32, i32
// CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[varg0]][1, 116, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
// CHECK-NEXT: memref.copy [[valloc_6]], [[vsubview_7]] : memref<117x4x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
// CHECK-NEXT: memref.dealloc [[valloc_6]] : memref<117x4x120xi8>
// CHECK-NEXT: [[v2:%.*]] = mpi.comm_world : !mpi.comm
// CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<1x120x120xi8>
// CHECK-NEXT: mpi.recv([[valloc_8]], [[vc91_i32]], [[vc23_i32]]) : memref<1x120x120xi8>, i32, i32
// CHECK-NEXT: mpi.recv([[valloc_8]], [[vc91_i32]], [[vc23_i32]], [[v2]]) : memref<1x120x120xi8>, i32, i32
// CHECK-NEXT: [[vsubview_9:%.*]] = memref.subview [[varg0]][0, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
// CHECK-NEXT: memref.copy [[valloc_8]], [[vsubview_9]] : memref<1x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
// CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<1x120x120xi8>
// CHECK-NEXT: [[valloc_10:%.*]] = memref.alloc() : memref<2x120x120xi8>
// CHECK-NEXT: [[vsubview_11:%.*]] = memref.subview [[varg0]][1, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>>
// CHECK-NEXT: memref.copy [[vsubview_11]], [[valloc_10]] : memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>> to memref<2x120x120xi8>
// CHECK-NEXT: mpi.send([[valloc_10]], [[vc91_i32]], [[vc23_i32]]) : memref<2x120x120xi8>, i32, i32
// CHECK-NEXT: mpi.send([[valloc_10]], [[vc91_i32]], [[vc23_i32]], [[v2]]) : memref<2x120x120xi8>, i32, i32
// CHECK-NEXT: memref.dealloc [[valloc_10]] : memref<2x120x120xi8>
%res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : memref<120x120x120xi8>
// CHECK: return [[varg0]] : memref<120x120x120xi8>
@@ -191,45 +194,48 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
// CHECK-NEXT: [[vc4_i32:%.*]] = arith.constant 4 : i32
// CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
// CHECK-NEXT: [[v0:%.*]] = bufferization.to_memref [[varg0]] : tensor<120x120x120xi8> to memref<120x120x120xi8>
// CHECK-NEXT: [[v1:%.*]] = mpi.comm_world : !mpi.comm
// CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8>
// CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[v0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>>
// CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8>
// CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x5xi8>, i32, i32
// CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x5xi8>, i32, i32
// CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc44_i32]], [[v1]]) : memref<117x113x5xi8>, i32, i32
// CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc4_i32]], [[v1]]) : memref<117x113x5xi8>, i32, i32
// CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[v0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
// CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
// CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8>
// CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8>
// CHECK-NEXT: [[vsubview_2:%.*]] = memref.subview [[v0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>>
// CHECK-NEXT: memref.copy [[vsubview_2]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8>
// CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x6xi8>, i32, i32
// CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x6xi8>, i32, i32
// CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc4_i32]], [[v1]]) : memref<117x113x6xi8>, i32, i32
// CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc44_i32]], [[v1]]) : memref<117x113x6xi8>, i32, i32
// CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[v0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
// CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_3]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
// CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8>
// CHECK-NEXT: [[v2:%.*]] = mpi.comm_world : !mpi.comm
// CHECK-NEXT: [[valloc_4:%.*]] = memref.alloc() : memref<117x3x120xi8>
// CHECK-NEXT: [[vsubview_5:%.*]] = memref.subview [[v0]][1, 113, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>>
// CHECK-NEXT: memref.copy [[vsubview_5]], [[valloc_4]] : memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>> to memref<117x3x120xi8>
// CHECK-NEXT: mpi.send([[valloc_4]], [[vc91_i32]], [[vc29_i32]]) : memref<117x3x120xi8>, i32, i32
// CHECK-NEXT: mpi.send([[valloc_4]], [[vc91_i32]], [[vc29_i32]], [[v2]]) : memref<117x3x120xi8>, i32, i32
// CHECK-NEXT: memref.dealloc [[valloc_4]] : memref<117x3x120xi8>
// CHECK-NEXT: [[valloc_6:%.*]] = memref.alloc() : memref<117x4x120xi8>
// CHECK-NEXT: mpi.recv([[valloc_6]], [[vc91_i32]], [[vc29_i32]]) : memref<117x4x120xi8>, i32, i32
// CHECK-NEXT: mpi.recv([[valloc_6]], [[vc91_i32]], [[vc29_i32]], [[v2]]) : memref<117x4x120xi8>, i32, i32
// CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[v0]][1, 116, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
// CHECK-NEXT: memref.copy [[valloc_6]], [[vsubview_7]] : memref<117x4x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
// CHECK-NEXT: memref.dealloc [[valloc_6]] : memref<117x4x120xi8>
// CHECK-NEXT: [[v3:%.*]] = mpi.comm_world : !mpi.comm
// CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<1x120x120xi8>
// CHECK-NEXT: mpi.recv([[valloc_8]], [[vc91_i32]], [[vc23_i32]]) : memref<1x120x120xi8>, i32, i32
// CHECK-NEXT: mpi.recv([[valloc_8]], [[vc91_i32]], [[vc23_i32]], [[v3]]) : memref<1x120x120xi8>, i32, i32
// CHECK-NEXT: [[vsubview_9:%.*]] = memref.subview [[v0]][0, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
// CHECK-NEXT: memref.copy [[valloc_8]], [[vsubview_9]] : memref<1x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
// CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<1x120x120xi8>
// CHECK-NEXT: [[valloc_10:%.*]] = memref.alloc() : memref<2x120x120xi8>
// CHECK-NEXT: [[vsubview_11:%.*]] = memref.subview [[v0]][1, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>>
// CHECK-NEXT: memref.copy [[vsubview_11]], [[valloc_10]] : memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>> to memref<2x120x120xi8>
// CHECK-NEXT: mpi.send([[valloc_10]], [[vc91_i32]], [[vc23_i32]]) : memref<2x120x120xi8>, i32, i32
// CHECK-NEXT: mpi.send([[valloc_10]], [[vc91_i32]], [[vc23_i32]], [[v3]]) : memref<2x120x120xi8>, i32, i32
// CHECK-NEXT: memref.dealloc [[valloc_10]] : memref<2x120x120xi8>
// CHECK-NEXT: [[v1:%.*]] = bufferization.to_tensor [[v0]] restrict writable : memref<120x120x120xi8> to tensor<120x120x120xi8>
// CHECK-NEXT: [[v4:%.*]] = bufferization.to_tensor [[v0]] restrict writable : memref<120x120x120xi8> to tensor<120x120x120xi8>
%res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : tensor<120x120x120xi8>
// CHECK: return [[v1]] : tensor<120x120x120xi8>
// CHECK-NEXT: return [[v4]] : tensor<120x120x120xi8>
return %res : tensor<120x120x120xi8>
}
}

View File

@@ -1,66 +1,83 @@
// RUN: mlir-opt %s | mlir-opt | FileCheck %s
// CHECK-LABEL: func.func @mpi_test(
// CHECK-SAME: [[varg0:%.*]]: memref<100xf32>) {
func.func @mpi_test(%ref : memref<100xf32>) -> () {
// Note: the !mpi.retval result is optional on all operations except mpi.error_class
// CHECK: %0 = mpi.init : !mpi.retval
// CHECK-NEXT: [[v0:%.*]] = mpi.init : !mpi.retval
%err = mpi.init : !mpi.retval
// CHECK-NEXT: %retval, %rank = mpi.comm_rank : !mpi.retval, i32
%retval, %rank = mpi.comm_rank : !mpi.retval, i32
// CHECK-NEXT: [[v1:%.*]] = mpi.comm_world : !mpi.comm
%comm = mpi.comm_world : !mpi.comm
// CHECK-NEXT: %retval_0, %size = mpi.comm_size : !mpi.retval, i32
%retval_0, %size = mpi.comm_size : !mpi.retval, i32
// CHECK-NEXT: [[vrank:%.*]] = mpi.comm_rank([[v1]]) : i32
%rank = mpi.comm_rank(%comm) : i32
// CHECK-NEXT: mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
mpi.send(%ref, %rank, %rank) : memref<100xf32>, i32, i32
// CHECK-NEXT: [[vretval:%.*]], [[vrank_0:%.*]] = mpi.comm_rank([[v1]]) : !mpi.retval, i32
%retval, %rank_1 = mpi.comm_rank(%comm) : !mpi.retval, i32
// CHECK-NEXT: %1 = mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
%err2 = mpi.send(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
// CHECK-NEXT: [[vsize:%.*]] = mpi.comm_size([[v1]]) : i32
%size = mpi.comm_size(%comm) : i32
// CHECK-NEXT: mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
mpi.recv(%ref, %rank, %rank) : memref<100xf32>, i32, i32
// CHECK-NEXT: [[vretval_1:%.*]], [[vsize_2:%.*]] = mpi.comm_size([[v1]]) : !mpi.retval, i32
%retval_0, %size_1 = mpi.comm_size(%comm) : !mpi.retval, i32
// CHECK-NEXT: %2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
%err3 = mpi.recv(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
// CHECK-NEXT: [[vnewcomm:%.*]] = mpi.comm_split([[v1]], [[vrank]], [[vrank]]) : !mpi.comm
%new_comm = mpi.comm_split(%comm, %rank, %rank) : !mpi.comm
// CHECK-NEXT: %req = mpi.isend(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.request
%req = mpi.isend(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.request
// CHECK-NEXT: [[vretval_3:%.*]], [[vnewcomm_4:%.*]] = mpi.comm_split([[v1]], [[vrank]], [[vrank]]) : !mpi.retval, !mpi.comm
%retval_1, %new_comm_1 = mpi.comm_split(%comm, %rank, %rank) : !mpi.retval, !mpi.comm
// CHECK-NEXT: %retval_1, %req_2 = mpi.isend(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval, !mpi.request
%err4, %req2 = mpi.isend(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval, !mpi.request
// CHECK-NEXT: mpi.send([[varg0]], [[vrank]], [[vrank]], [[v1]]) : memref<100xf32>, i32, i32
mpi.send(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32
// CHECK-NEXT: %req_3 = mpi.irecv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.request
%req3 = mpi.irecv(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.request
// CHECK-NEXT: [[v2:%.*]] = mpi.send([[varg0]], [[vrank]], [[vrank]], [[v1]]) : memref<100xf32>, i32, i32 -> !mpi.retval
%retval_2 = mpi.send(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval
// CHECK-NEXT: %retval_4, %req_5 = mpi.irecv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval, !mpi.request
%err5, %req4 = mpi.irecv(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval, !mpi.request
// CHECK-NEXT: mpi.recv([[varg0]], [[vrank]], [[vrank]], [[v1]]) : memref<100xf32>, i32, i32
mpi.recv(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32
// CHECK-NEXT: mpi.wait(%req) : !mpi.request
mpi.wait(%req) : !mpi.request
// CHECK-NEXT: [[v3:%.*]] = mpi.recv([[varg0]], [[vrank]], [[vrank]], [[v1]]) : memref<100xf32>, i32, i32 -> !mpi.retval
%retval_3 = mpi.recv(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval
// CHECK-NEXT: %3 = mpi.wait(%req_2) : !mpi.request -> !mpi.retval
// CHECK-NEXT: [[vretval_5:%.*]], [[vreq:%.*]] = mpi.isend([[varg0]], [[vrank]], [[vrank]], [[v1]]) : memref<100xf32>, i32, i32 -> !mpi.retval, !mpi.request
%err4, %req2 = mpi.isend(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval, !mpi.request
// CHECK-NEXT: [[vreq_6:%.*]] = mpi.isend([[varg0]], [[vrank]], [[vrank]], [[v1]]) : memref<100xf32>, i32, i32 -> !mpi.request
%req1 = mpi.isend(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.request
// CHECK-NEXT: [[vreq_7:%.*]] = mpi.irecv([[varg0]], [[vrank]], [[vrank]], [[v1]]) : memref<100xf32>, i32, i32 -> !mpi.request
%req3 = mpi.irecv(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.request
// CHECK-NEXT: [[vretval_8:%.*]], [[vreq_9:%.*]] = mpi.irecv([[varg0]], [[vrank]], [[vrank]], [[v1]]) : memref<100xf32>, i32, i32 -> !mpi.retval, !mpi.request
%err5, %req4 = mpi.irecv(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval, !mpi.request
// CHECK-NEXT: mpi.wait([[vreq_9]]) : !mpi.request
mpi.wait(%req4) : !mpi.request
// CHECK-NEXT: [[v4:%.*]] = mpi.wait([[vreq]]) : !mpi.request -> !mpi.retval
%err6 = mpi.wait(%req2) : !mpi.request -> !mpi.retval
// CHECK-NEXT: mpi.barrier : !mpi.retval
mpi.barrier : !mpi.retval
// CHECK-NEXT: mpi.barrier([[v1]])
mpi.barrier(%comm)
// CHECK-NEXT: %5 = mpi.barrier : !mpi.retval
%err7 = mpi.barrier : !mpi.retval
// CHECK-NEXT: [[v5:%.*]] = mpi.barrier([[v1]]) -> !mpi.retval
%err7 = mpi.barrier(%comm) -> !mpi.retval
// CHECK-NEXT: mpi.allreduce(%arg0, %arg0, MPI_SUM) : memref<100xf32>, memref<100xf32>
mpi.allreduce(%ref, %ref, MPI_SUM) : memref<100xf32>, memref<100xf32>
// CHECK-NEXT: [[v6:%.*]] = mpi.allreduce([[varg0]], [[varg0]], MPI_SUM, [[v1]]) : memref<100xf32>, memref<100xf32> -> !mpi.retval
%err8 = mpi.allreduce(%ref, %ref, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32> -> !mpi.retval
// CHECK-NEXT: mpi.allreduce(%arg0, %arg0, MPI_SUM) : memref<100xf32>, memref<100xf32> -> !mpi.retval
%err8 = mpi.allreduce(%ref, %ref, MPI_SUM) : memref<100xf32>, memref<100xf32> -> !mpi.retval
// CHECK-NEXT: mpi.allreduce([[varg0]], [[varg0]], MPI_SUM, [[v1]]) : memref<100xf32>, memref<100xf32>
mpi.allreduce(%ref, %ref, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
// CHECK-NEXT: %7 = mpi.finalize : !mpi.retval
// CHECK-NEXT: [[v7:%.*]] = mpi.finalize : !mpi.retval
%rval = mpi.finalize : !mpi.retval
// CHECK-NEXT: %8 = mpi.retval_check %retval = <MPI_SUCCESS> : i1
// CHECK-NEXT: [[v8:%.*]] = mpi.retval_check [[vretval:%.*]] = <MPI_SUCCESS> : i1
%res = mpi.retval_check %retval = <MPI_SUCCESS> : i1
// CHECK-NEXT: %9 = mpi.error_class %0 : !mpi.retval
// CHECK-NEXT: [[v9:%.*]] = mpi.error_class [[v0]] : !mpi.retval
%errclass = mpi.error_class %err : !mpi.retval
// CHECK-NEXT: return