Files
clang-p2996/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
Boian Petkantchin 4b3446771f [mlir][mesh] Add endomorphism simplification for all-reduce (#73150)
Does transformations like
all_reduce(x) + all_reduce(y) -> all_reduce(x + y)

max(all_reduce(x), all_reduce(y)) -> all_reduce(max(x, y))
when the all_reduce element-wise op is max.

Added general rewrite pattern HomomorphismSimplification and
EndomorphismSimplification that encapsulate the general algorithm.
Made specialization for all-reduce with respect to
addf, addi, minsi, maxsi, minimumf and maximumf
in the Arithmetic dialect.
2023-12-12 10:21:52 -08:00

40 lines
1.5 KiB
C++

//===- Patterns.cpp - Mesh Patterns -----------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
namespace mlir {
namespace mesh {
void populateSimplificationPatterns(RewritePatternSet &patterns) {
populateAllReduceEndomorphismSimplificationPatterns<arith::AddFOp>(
patterns, Partial::Sum);
populateAllReduceEndomorphismSimplificationPatterns<arith::AddIOp>(
patterns, Partial::Sum);
populateAllReduceEndomorphismSimplificationPatterns<arith::MinimumFOp>(
patterns, Partial::Min);
populateAllReduceEndomorphismSimplificationPatterns<arith::MinSIOp>(
patterns, Partial::Min);
populateAllReduceEndomorphismSimplificationPatterns<arith::MinUIOp>(
patterns, Partial::Min);
populateAllReduceEndomorphismSimplificationPatterns<arith::MaximumFOp>(
patterns, Partial::Max);
populateAllReduceEndomorphismSimplificationPatterns<arith::MaxSIOp>(
patterns, Partial::Max);
populateAllReduceEndomorphismSimplificationPatterns<arith::MaxUIOp>(
patterns, Partial::Max);
// TODO: add simplifications for all-gather and other collectives.
}
} // namespace mesh
} // namespace mlir