Files
clang-p2996/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
Boian Petkantchin 79aa776267 [mlir][mesh] Add lowering of process multi-index op (#77490)
* Rename mesh.process_index -> mesh.process_multi_index.
* Add mesh.process_linear_index op.
* Add lowering of mesh.process_multi_index into an expression using
mesh.process_linear_index, mesh.cluster_shape and
affine.delinearize_index.

This is useful to lower mesh ops and prepare them for further lowering
where the runtime may have only the linear index of a device/process.
For example in MPI we have a rank (linear index) in a communicator.
2024-01-10 07:01:16 -08:00

85 lines
3.1 KiB
C++

//===- Transforms.cpp ---------------------------------------------- C++ --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include <iterator>
#include <numeric>
namespace mlir::mesh {
namespace {
/// Lower `mesh.process_multi_index` into expression using
/// `mesh.process_linear_index` and `mesh.cluster_shape`.
struct ProcessMultiIndexOpLowering : OpRewritePattern<ProcessMultiIndexOp> {
template <typename... OpRewritePatternArgs>
ProcessMultiIndexOpLowering(SymbolTableCollection &symbolTableCollection,
OpRewritePatternArgs &&...opRewritePatternArgs)
: OpRewritePattern(
std::forward<OpRewritePatternArgs...>(opRewritePatternArgs)...),
symbolTableCollection(symbolTableCollection) {}
LogicalResult matchAndRewrite(ProcessMultiIndexOp op,
PatternRewriter &rewriter) const override {
ClusterOp mesh =
symbolTableCollection.lookupNearestSymbolFrom<mesh::ClusterOp>(
op.getOperation(), op.getMeshAttr());
if (!mesh) {
return failure();
}
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
builder.setInsertionPointAfter(op.getOperation());
Value linearIndex = builder.create<ProcessLinearIndexOp>(mesh);
ValueRange meshShape = builder.create<ClusterShapeOp>(mesh).getResults();
SmallVector<Value> completeMultiIndex =
builder.create<affine::AffineDelinearizeIndexOp>(linearIndex, meshShape)
.getMultiIndex();
SmallVector<Value> multiIndex;
ArrayRef<MeshAxis> opMeshAxes = op.getAxes();
SmallVector<MeshAxis> opAxesIota;
if (opMeshAxes.empty()) {
opAxesIota.resize(mesh.getRank());
std::iota(opAxesIota.begin(), opAxesIota.end(), 0);
opMeshAxes = opAxesIota;
}
llvm::transform(opMeshAxes, std::back_inserter(multiIndex),
[&completeMultiIndex](MeshAxis meshAxis) {
return completeMultiIndex[meshAxis];
});
rewriter.replaceAllUsesWith(op.getResults(), multiIndex);
return success();
}
private:
SymbolTableCollection &symbolTableCollection;
};
} // namespace
void processMultiIndexOpLoweringPopulatePatterns(
RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
patterns.add<ProcessMultiIndexOpLowering>(symbolTableCollection,
patterns.getContext());
}
void processMultiIndexOpLoweringRegisterDialects(DialectRegistry &registry) {
registry.insert<affine::AffineDialect, mesh::MeshDialect>();
}
} // namespace mlir::mesh