//===- MeshToMPI.cpp - Mesh to MPI dialect conversion -----------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements a translation of Mesh communication ops tp MPI ops. // //===----------------------------------------------------------------------===// #include "mlir/Conversion/MeshToMPI/MeshToMPI.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/MPI/IR/MPI.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Mesh/IR/MeshOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/SymbolTable.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #define DEBUG_TYPE "mesh-to-mpi" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") namespace mlir { #define GEN_PASS_DEF_CONVERTMESHTOMPIPASS #include "mlir/Conversion/Passes.h.inc" } // namespace mlir using namespace mlir; using namespace mlir::mesh; namespace { // Create operations converting a linear index to a multi-dimensional index static SmallVector linearToMultiIndex(Location loc, OpBuilder b, Value linearIndex, ValueRange dimensions) { int n = dimensions.size(); SmallVector multiIndex(n); for (int i = n - 1; i >= 0; --i) { multiIndex[i] = b.create(loc, linearIndex, dimensions[i]); if (i > 0) { linearIndex = b.create(loc, linearIndex, dimensions[i]); } } return multiIndex; } // Create operations converting a multi-dimensional index to a linear index Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex, ValueRange dimensions) { auto linearIndex = b.create(loc, 0).getResult(); auto stride = b.create(loc, 1).getResult(); for (int i = multiIndex.size() - 1; i >= 0; --i) { auto off = b.create(loc, multiIndex[i], stride); linearIndex = b.create(loc, linearIndex, off); stride = b.create(loc, stride, dimensions[i]); } return linearIndex; } struct ConvertProcessMultiIndexOp : public mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; mlir::LogicalResult matchAndRewrite(mlir::mesh::ProcessMultiIndexOp op, mlir::PatternRewriter &rewriter) const override { // Currently converts its linear index to a multi-dimensional index. SymbolTableCollection symbolTableCollection; auto loc = op.getLoc(); auto meshOp = getMesh(op, symbolTableCollection); // For now we only support static mesh shapes if (ShapedType::isDynamicShape(meshOp.getShape())) { return mlir::failure(); } SmallVector dims; llvm::transform( meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) { return rewriter.create(loc, i).getResult(); }); auto rank = rewriter.create(op.getLoc(), meshOp).getResult(); auto mIdx = linearToMultiIndex(loc, rewriter, rank, dims); // optionally extract subset of mesh axes auto axes = op.getAxes(); if (!axes.empty()) { SmallVector subIndex; for (auto axis : axes) { subIndex.push_back(mIdx[axis]); } mIdx = subIndex; } rewriter.replaceOp(op, mIdx); return mlir::success(); } }; struct ConvertProcessLinearIndexOp : public mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; mlir::LogicalResult matchAndRewrite(mlir::mesh::ProcessLinearIndexOp op, mlir::PatternRewriter &rewriter) const override { // Finds a global named "static_mpi_rank" it will use that splat value. // Otherwise it defaults to mpi.comm_rank. auto loc = op.getLoc(); auto rankOpName = StringAttr::get(op->getContext(), "static_mpi_rank"); if (auto globalOp = SymbolTable::lookupNearestSymbolFrom( op, rankOpName)) { if (auto initTnsr = globalOp.getInitialValueAttr()) { auto val = cast(initTnsr).getSplatValue(); rewriter.replaceOp(op, rewriter.create(loc, val)); return mlir::success(); } } auto rank = rewriter .create( op.getLoc(), TypeRange{mpi::RetvalType::get(op->getContext()), rewriter.getI32Type()}) .getRank(); rewriter.replaceOpWithNewOp(op, rewriter.getIndexType(), rank); return mlir::success(); } }; struct ConvertNeighborsLinearIndicesOp : public mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; mlir::LogicalResult matchAndRewrite(mlir::mesh::NeighborsLinearIndicesOp op, mlir::PatternRewriter &rewriter) const override { // Computes the neighbors indices along a split axis by simply // adding/subtracting 1 to the current index in that dimension. // Assigns -1 if neighbor is out of bounds. auto axes = op.getSplitAxes(); // For now only single axis sharding is supported if (axes.size() != 1) { return mlir::failure(); } auto loc = op.getLoc(); SymbolTableCollection symbolTableCollection; auto meshOp = getMesh(op, symbolTableCollection); auto mIdx = op.getDevice(); auto orgIdx = mIdx[axes[0]]; SmallVector dims; llvm::transform( meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) { return rewriter.create(loc, i).getResult(); }); auto dimSz = dims[axes[0]]; auto one = rewriter.create(loc, 1).getResult(); auto minus1 = rewriter.create(loc, -1).getResult(); auto atBorder = rewriter.create( loc, arith::CmpIPredicate::sle, orgIdx, rewriter.create(loc, 0).getResult()); auto down = rewriter.create( loc, atBorder, [&](OpBuilder &builder, Location loc) { builder.create(loc, minus1); }, [&](OpBuilder &builder, Location loc) { SmallVector tmp = mIdx; tmp[axes[0]] = rewriter.create(op.getLoc(), orgIdx, one) .getResult(); builder.create( loc, multiToLinearIndex(loc, rewriter, tmp, dims)); }); atBorder = rewriter.create( loc, arith::CmpIPredicate::sge, orgIdx, rewriter.create(loc, dimSz, one).getResult()); auto up = rewriter.create( loc, atBorder, [&](OpBuilder &builder, Location loc) { builder.create(loc, minus1); }, [&](OpBuilder &builder, Location loc) { SmallVector tmp = mIdx; tmp[axes[0]] = rewriter.create(op.getLoc(), orgIdx, one) .getResult(); builder.create( loc, multiToLinearIndex(loc, rewriter, tmp, dims)); }); rewriter.replaceOp(op, ValueRange{down.getResult(0), up.getResult(0)}); return mlir::success(); } }; struct ConvertUpdateHaloOp : public mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; mlir::LogicalResult matchAndRewrite(mlir::mesh::UpdateHaloOp op, mlir::PatternRewriter &rewriter) const override { // The input/output memref is assumed to be in C memory order. // Halos are exchanged as 2 blocks per dimension (one for each side: down // and up). For each haloed dimension `d`, the exchanged blocks are // expressed as multi-dimensional subviews. The subviews include potential // halos of higher dimensions `dh > d`, no halos for the lower dimensions // `dl < d` and for dimension `d` the currently exchanged halo only. // By iterating form higher to lower dimensions this also updates the halos // in the 'corners'. // memref.subview is used to read and write the halo data from and to the // local data. Because subviews and halos can have mixed dynamic and static // shapes, OpFoldResults are used whenever possible. SymbolTableCollection symbolTableCollection; auto loc = op.getLoc(); // convert a OpFoldResult into a Value auto toValue = [&rewriter, &loc](OpFoldResult &v) -> Value { if (auto value = dyn_cast(v)) return value; return rewriter.create<::mlir::arith::ConstantOp>( loc, rewriter.getIndexAttr( cast(cast(v)).getInt())); }; auto dest = op.getDestination(); auto dstShape = cast(dest.getType()).getShape(); Value array = dest; if (isa(array.getType())) { // If the destination is a memref, we need to cast it to a tensor auto tensorType = MemRefType::get( dstShape, cast(array.getType()).getElementType()); array = rewriter.create(loc, tensorType, array) .getResult(); } auto rank = cast(array.getType()).getRank(); auto opSplitAxes = op.getSplitAxes().getAxes(); auto mesh = op.getMesh(); auto meshOp = getMesh(op, symbolTableCollection); auto haloSizes = getMixedValues(op.getStaticHaloSizes(), op.getHaloSizes(), rewriter); // subviews need Index values for (auto &sz : haloSizes) { if (auto value = dyn_cast(sz)) { sz = rewriter .create(loc, rewriter.getIndexType(), value) .getResult(); } } // most of the offset/size/stride data is the same for all dims SmallVector offsets(rank, rewriter.getIndexAttr(0)); SmallVector strides(rank, rewriter.getIndexAttr(1)); SmallVector shape(rank), dimSizes(rank); auto currHaloDim = -1; // halo sizes are provided for split dimensions only // we need the actual shape to compute offsets and sizes for (auto i = 0; i < rank; ++i) { auto s = dstShape[i]; if (ShapedType::isDynamic(s)) { shape[i] = rewriter.create(loc, array, s).getResult(); } else { shape[i] = rewriter.getIndexAttr(s); } if ((size_t)i < opSplitAxes.size() && !opSplitAxes[i].empty()) { ++currHaloDim; // the offsets for lower dim sstarts after their down halo offsets[i] = haloSizes[currHaloDim * 2]; // prepare shape and offsets of highest dim's halo exchange auto _haloSz = rewriter .create(loc, toValue(haloSizes[currHaloDim * 2]), toValue(haloSizes[currHaloDim * 2 + 1])) .getResult(); // the halo shape of lower dims exlude the halos dimSizes[i] = rewriter.create(loc, toValue(shape[i]), _haloSz) .getResult(); } else { dimSizes[i] = shape[i]; } } auto tagAttr = rewriter.getI32IntegerAttr(91); // we just pick something auto tag = rewriter.create<::mlir::arith::ConstantOp>(loc, tagAttr); auto zeroAttr = rewriter.getI32IntegerAttr(0); // for detecting v<0 auto zero = rewriter.create<::mlir::arith::ConstantOp>(loc, zeroAttr); SmallVector indexResultTypes(meshOp.getShape().size(), rewriter.getIndexType()); auto myMultiIndex = rewriter.create(loc, indexResultTypes, mesh) .getResult(); // traverse all split axes from high to low dim for (ssize_t dim = opSplitAxes.size() - 1; dim >= 0; --dim) { auto splitAxes = opSplitAxes[dim]; if (splitAxes.empty()) { continue; } assert(currHaloDim >= 0 && (size_t)currHaloDim < haloSizes.size() / 2); // Get the linearized ids of the neighbors (down and up) for the // given split auto tmp = rewriter .create(loc, mesh, myMultiIndex, splitAxes) .getResults(); // MPI operates on i32... Value neighbourIDs[2] = {rewriter.create( loc, rewriter.getI32Type(), tmp[0]), rewriter.create( loc, rewriter.getI32Type(), tmp[1])}; auto lowerRecvOffset = rewriter.getIndexAttr(0); auto lowerSendOffset = toValue(haloSizes[currHaloDim * 2]); auto upperRecvOffset = rewriter.create( loc, toValue(shape[dim]), toValue(haloSizes[currHaloDim * 2 + 1])); auto upperSendOffset = rewriter.create( loc, upperRecvOffset, toValue(haloSizes[currHaloDim * 2])); // Make sure we send/recv in a way that does not lead to a dead-lock. // The current approach is by far not optimal, this should be at least // be a red-black pattern or using MPI_sendrecv. // Also, buffers should be re-used. // Still using temporary contiguous buffers for MPI communication... // Still yielding a "serialized" communication pattern... auto genSendRecv = [&](bool upperHalo) { auto orgOffset = offsets[dim]; dimSizes[dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1] : haloSizes[currHaloDim * 2]; // Check if we need to send and/or receive // Processes on the mesh borders have only one neighbor auto to = upperHalo ? neighbourIDs[1] : neighbourIDs[0]; auto from = upperHalo ? neighbourIDs[0] : neighbourIDs[1]; auto hasFrom = rewriter.create( loc, arith::CmpIPredicate::sge, from, zero); auto hasTo = rewriter.create( loc, arith::CmpIPredicate::sge, to, zero); auto buffer = rewriter.create( loc, dimSizes, cast(array.getType()).getElementType()); // if has neighbor: copy halo data from array to buffer and send rewriter.create( loc, hasTo, [&](OpBuilder &builder, Location loc) { offsets[dim] = upperHalo ? OpFoldResult(lowerSendOffset) : OpFoldResult(upperSendOffset); auto subview = builder.create( loc, array, offsets, dimSizes, strides); builder.create(loc, subview, buffer); builder.create(loc, TypeRange{}, buffer, tag, to); builder.create(loc); }); // if has neighbor: receive halo data into buffer and copy to array rewriter.create( loc, hasFrom, [&](OpBuilder &builder, Location loc) { offsets[dim] = upperHalo ? OpFoldResult(upperRecvOffset) : OpFoldResult(lowerRecvOffset); builder.create(loc, TypeRange{}, buffer, tag, from); auto subview = builder.create( loc, array, offsets, dimSizes, strides); builder.create(loc, buffer, subview); builder.create(loc); }); rewriter.create(loc, buffer); offsets[dim] = orgOffset; }; genSendRecv(false); genSendRecv(true); // the shape for lower dims include higher dims' halos dimSizes[dim] = shape[dim]; // -> the offset for higher dims is always 0 offsets[dim] = rewriter.getIndexAttr(0); // on to next halo --currHaloDim; } if (isa(op.getResult().getType())) { rewriter.replaceOp(op, array); } else { assert(isa(op.getResult().getType())); rewriter.replaceOp(op, rewriter.create( loc, op.getResult().getType(), array, /*restrict=*/true, /*writable=*/true)); } return mlir::success(); } }; struct ConvertMeshToMPIPass : public impl::ConvertMeshToMPIPassBase { using Base::Base; /// Run the dialect converter on the module. void runOnOperation() override { auto *ctx = &getContext(); mlir::RewritePatternSet patterns(ctx); patterns.insert( ctx); (void)mlir::applyPatternsGreedily(getOperation(), std::move(patterns)); } }; } // namespace // Create a pass that convert Mesh to MPI std::unique_ptr<::mlir::Pass> mlir::createConvertMeshToMPIPass() { return std::make_unique(); }