* Split out `MeshDialect.h` form `MeshOps.h` that defines the dialect class. Reduces include clutter if you care only about the dialect and not the ops. * Expose functions `getMesh` and `collectiveProcessGroupSize`. There functions are useful for outside users of the dialect. * Remove unused code. * Remove examples and tests of mesh.shard attribute in tensor encoding. Per the decision that Spmdization would be performed on sharding annotations and there will be no tensors with sharding specified in the type. For more info see this RFC comment: https://discourse.llvm.org/t/rfc-sharding-framework-design-for-device-mesh/73533/81
3.2 KiB
'mesh' Dialect
The mesh dialect contains a set of attributes, operations and interfaces that
are useful for representing sharding and communication on a device mesh
cluster.
[TOC]
Collective Communication Operations
There are a number of operations in the Mesh dialect to facilitate communication between devices in a mesh. It is assumed that the user is familiar with collective operations. Wikipedia has a good explanation. The main addition is that the collectives in this dialect have mesh semantics.
Device groups
The operation attributes mesh and mesh_axes specifies a list of device mesh
axes that partition the devices into disjoint groups.
The collective operation is performed between devices in the same group.
Devices that have the same coordinates outside of axes mesh_axes are in the
same group.
A group is described by its multi-index along the axes outside of mesh_axes.
For example if we have a device mesh of size 2x3x4x5 and the partition mesh
axes list is [0, 1] then devices are partitioned into the groups
{ { (i, j, k, m) | 0<=i<2, 0<=j<3 } | 0<=k<4, 0<=m<5 }.
The device groups would be { (k, m) | 0<=k<4, 0<=m<5 }.
Devices (1, 0, 2, 3) and (1, 1, 2, 3) will be in the same group.
Device (1, 0, 2, 4) will be in another group.
Some collective operations like all-to-all and all-gather care about the
order of devices.
The order of device in a device group is induced by the order of axes in
mesh_axes.
The axes are ordered from outer to inner.
If we have an axis list [3, 1] then device (i, 1, k, 0) will precede
both devices (i, 0, k, 1) and (i, 2, k, 0).
In-group Device
Some operations like broadcast, scatter and send specify devices in each
device-group.
These devices are represented with their multi-index over the mesh axes that
are not constant within a device group.
These are the axes specified by mesh_axes attribute.
For Example on a 3D mesh an operation with mesh_axes = [0, 2] would specify
an in-group device with (i, j). Then for each group with index g on the
second axis, the in-group device would be (i, g, j).
Purity
Collectives that involve the whole device group to perform a single operation
are pure. The exceptions are send and recv.
There is an assumption that the execution is SPMD. Not only that each process runs the same program, but that at the point of execution of a collective operation, all processes are in a coherent state. All compiler transformations must be consistent. Collective operations in the IR that may correspond to the same runtime collective operation must be transformed in a consistent manner. For example if a collective operation is optimized out, than it must also not appear in any path of execution on any process.
Having the operations as Pure implies that if an interpreter is to execute
the IR containing the mesh collectives, all processes would execute the same
line when they reach a pure collective operation.
This requirement stems from the need to be compatible with general optimization
passes like dead code and common sub-expression elimination.
Operations
[include "Dialects/MeshOps.md"]
Attributes
[include "Dialects/MeshAttrs.md"]