llvm-project/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp:96:8:
error: unused variable 'resultElementType' [-Werror,-Wunused-variable]
Type resultElementType =
^
llvm-project/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp:122:1:
error: non-void function does not return a value in all control paths [-Werror,-Wreturn-type]
}
^
2 errors generated.
354 lines
15 KiB
C++
354 lines
15 KiB
C++
//===- MeshShardingInterfaceImpl.cpp --------------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h"
|
|
|
|
#include "mlir/Analysis/SliceAnalysis.h"
|
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
|
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
|
|
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
|
|
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
|
|
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
|
|
#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
|
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
|
|
#include "mlir/IR/AffineExpr.h"
|
|
#include "mlir/IR/DialectRegistry.h"
|
|
#include "mlir/IR/IRMapping.h"
|
|
#include "mlir/IR/ImplicitLocOpBuilder.h"
|
|
#include "mlir/IR/MLIRContext.h"
|
|
#include "mlir/IR/OpDefinition.h"
|
|
#include "mlir/IR/Operation.h"
|
|
#include "mlir/IR/SymbolTable.h"
|
|
#include "mlir/IR/Value.h"
|
|
#include "mlir/Interfaces/TilingInterface.h"
|
|
#include "mlir/Support/LogicalResult.h"
|
|
#include "llvm/ADT/ArrayRef.h"
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/ADT/SmallVector.h"
|
|
#include "llvm/ADT/TypeSwitch.h"
|
|
#include <iterator>
|
|
#include <optional>
|
|
#include <utility>
|
|
|
|
namespace mlir::linalg {
|
|
|
|
using MeshAxis = mesh::MeshAxis;
|
|
using ReductionKind = mesh::ReductionKind;
|
|
using MeshShardingAttr = mesh::MeshShardingAttr;
|
|
using ShardingArray = mesh::ShardingArray;
|
|
using MeshOp = mesh::MeshOp;
|
|
|
|
// Returns the corresponding mesh reduction kind for the given arith op.
|
|
static ReductionKind getReductionKind(Operation *op) {
|
|
return llvm::TypeSwitch<Operation *, ReductionKind>(op)
|
|
// Floating-point operations.
|
|
.Case([](arith::AddFOp op) { return ReductionKind::Sum; })
|
|
.Case([](arith::MulFOp op) { return ReductionKind::Product; })
|
|
// TODO: handle maxnumf and minnumf.
|
|
.Case([](arith::MaximumFOp op) { return ReductionKind::Max; })
|
|
.Case([](arith::MinimumFOp op) { return ReductionKind::Min; })
|
|
// Integer operations.
|
|
.Case([](arith::AddIOp op) { return ReductionKind::Sum; })
|
|
.Case([](arith::OrIOp op) { return ReductionKind::BitwiseOr; })
|
|
.Case([](arith::XOrIOp op) { return ReductionKind::BitwiseXor; })
|
|
.Case([](arith::AndIOp op) { return ReductionKind::Sum; })
|
|
// TODO: handle signless, signed and unsigned types properly.
|
|
// It is assumed that the element type of the collective operands and
|
|
// result drive the meaning of the reduction kind, whether it is signed
|
|
// or unsigned.
|
|
// The reduction op inside the linalg op may have different result type
|
|
// from the element type of the linalg op's result.
|
|
// Also signed and unsigned Arith dialect ops may accept signed, unsigned
|
|
// or signless operands.
|
|
// Maybe expand the reduction kinds.
|
|
.Case([](arith::MaxUIOp op) { return ReductionKind::Max; })
|
|
.Case([](arith::MinUIOp op) { return ReductionKind::Min; })
|
|
.Case([](arith::MaxSIOp op) { return ReductionKind::Max; })
|
|
.Case([](arith::MinSIOp op) { return ReductionKind::Min; })
|
|
.Case([](arith::MulIOp op) { return ReductionKind::Product; })
|
|
.Default([](Operation *op) { return ReductionKind::Generic; });
|
|
}
|
|
|
|
static std::optional<Operation *> getCombinerOp(LinalgOp op) {
|
|
SmallVector<Operation *> combinerOps;
|
|
Value reducedValue = matchReduction(op.getRegionOutputArgs(), 0, combinerOps);
|
|
if (!reducedValue || combinerOps.size() != 1) {
|
|
return std::nullopt;
|
|
}
|
|
|
|
return combinerOps[0];
|
|
}
|
|
|
|
static ReductionKind getReductionKindOfLinalgOp(LinalgOp op) {
|
|
std::optional<Operation *> reductionOp = getCombinerOp(op);
|
|
if (!reductionOp) {
|
|
return ReductionKind::Generic;
|
|
}
|
|
[[maybe_unused]] Type resultElementType =
|
|
llvm::cast<RankedTensorType>(op->getResult(0).getType()).getElementType();
|
|
// TODO: handle case when result type of the reduction op does not match the
|
|
// element type of the result tensor.
|
|
// Would it makes sense at all?
|
|
assert(resultElementType == reductionOp.value()->getResult(0).getType());
|
|
return getReductionKind(reductionOp.value());
|
|
}
|
|
|
|
static MeshOp getMesh(Operation *op,
|
|
ArrayRef<MeshShardingAttr> operandShardings,
|
|
ArrayRef<MeshShardingAttr> resultShardings,
|
|
SymbolTableCollection &symbolTable) {
|
|
for (MeshShardingAttr sharding : operandShardings) {
|
|
if (sharding) {
|
|
return mesh::getMesh(op, sharding.getMesh(), symbolTable);
|
|
}
|
|
}
|
|
|
|
for (MeshShardingAttr sharding : resultShardings) {
|
|
if (sharding) {
|
|
return mesh::getMesh(op, sharding.getMesh(), symbolTable);
|
|
}
|
|
}
|
|
|
|
assert(false);
|
|
return nullptr;
|
|
}
|
|
|
|
// Choose the operand based on the current process index along the reduction
|
|
// mesh axes.
|
|
// We need to use the initial value only once to avoid including it in the
|
|
// reduction multiple times.
|
|
// In each process group only the leading process with linear index 0 would use
|
|
// the original operand.
|
|
// The other processes would use the reduction operation neutral tensor.
|
|
static Value createDestinationPassingStyleInitOperand(
|
|
LinalgOp op, Value spmdizedOperand, ArrayRef<MeshAxis> reductionMeshAxes,
|
|
MeshOp meshOp, ImplicitLocOpBuilder &builder) {
|
|
Value processLinearIndexInReductionGroup = mesh::createProcessLinearIndex(
|
|
meshOp.getSymName(), reductionMeshAxes, builder);
|
|
Value zero = builder.create<arith::ConstantIndexOp>(0);
|
|
Value isLeadProcess = builder.create<arith::CmpIOp>(
|
|
builder.getI1Type(), arith::CmpIPredicate::eq,
|
|
processLinearIndexInReductionGroup, zero);
|
|
scf::IfOp ifOp = builder.create<scf::IfOp>(spmdizedOperand.getType(),
|
|
isLeadProcess, true, true);
|
|
// Then block.
|
|
{
|
|
OpBuilder::InsertionGuard insertionGuard(builder);
|
|
builder.setInsertionPointToEnd(&ifOp.getThenRegion().front());
|
|
builder.create<scf::YieldOp>(spmdizedOperand);
|
|
}
|
|
|
|
// Else block.
|
|
{
|
|
OpBuilder::InsertionGuard insertionGuard(builder);
|
|
builder.setInsertionPointToEnd(&ifOp.getElseRegion().front());
|
|
SmallVector<OpFoldResult> shape =
|
|
tensor::getMixedSizes(builder, builder.getLoc(), spmdizedOperand);
|
|
PartialReductionOpInterface partialReductionIface =
|
|
llvm::cast<PartialReductionOpInterface>(op.getOperation());
|
|
FailureOr<Operation *> reductionNeutralTensorOp =
|
|
partialReductionIface.generateInitialTensorForPartialReduction(
|
|
builder, builder.getLoc(), shape, {});
|
|
assert(succeeded(reductionNeutralTensorOp));
|
|
builder.create<scf::YieldOp>(
|
|
reductionNeutralTensorOp.value()->getResult(0));
|
|
}
|
|
return ifOp.getResult(0);
|
|
}
|
|
|
|
// Create the DPS init operands for the spmdized Linalg op.
|
|
// Return all the new spmdized operands.
|
|
static SmallVector<Value> createDestinationPassingStyleInitOperands(
|
|
LinalgOp op, MeshOp meshOp, ArrayRef<Value> spmdizedOperands,
|
|
ArrayRef<MeshAxis> reductionMeshAxes, IRMapping &spmdizationMap,
|
|
ImplicitLocOpBuilder &builder) {
|
|
// TODO: add support for multiple destination passing style initial value
|
|
// operands.
|
|
// PartialReductionOpInterface::generateInitialTensorForPartialReduction
|
|
// needs to also support multiple DPS initial operands.
|
|
SmallVector<Value> newOperands = llvm::to_vector(spmdizedOperands);
|
|
auto operandIdx = op.getDpsInitOperand(0)->getOperandNumber();
|
|
Value spmdizedInitOperand =
|
|
spmdizationMap.lookup(op->getOperands()[operandIdx]);
|
|
newOperands[operandIdx] = createDestinationPassingStyleInitOperand(
|
|
op, spmdizedInitOperand, reductionMeshAxes, meshOp, builder);
|
|
return newOperands;
|
|
}
|
|
|
|
static void createAllReduceForResultWithoutPartialSharding(
|
|
Value unshardedLinalgOpResult, ArrayRef<MeshAxis> opReductionMeshAxes,
|
|
MeshShardingAttr resultSharding, ReductionKind reductionKind,
|
|
IRMapping &spmdizationMap, ImplicitLocOpBuilder &builder) {
|
|
SmallVector<MeshAxis> allReduceMeshAxes;
|
|
llvm::copy_if(opReductionMeshAxes, std::back_inserter(allReduceMeshAxes),
|
|
[&resultSharding](MeshAxis axis) {
|
|
return !llvm::is_contained(resultSharding.getPartialAxes(),
|
|
axis);
|
|
});
|
|
if (allReduceMeshAxes.empty()) {
|
|
return;
|
|
}
|
|
|
|
Value spmdizedLinalgOpResult = spmdizationMap.lookup(unshardedLinalgOpResult);
|
|
Value reducedValue = builder.create<mesh::AllReduceOp>(
|
|
spmdizedLinalgOpResult, resultSharding.getMesh().getValue(),
|
|
allReduceMeshAxes, reductionKind);
|
|
spmdizationMap.map(unshardedLinalgOpResult, reducedValue);
|
|
}
|
|
|
|
static void createAllReduceForResultsWithoutPartialShardings(
|
|
LinalgOp unshardedOp, ArrayRef<MeshAxis> opReductionMeshAxes,
|
|
ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
|
|
ImplicitLocOpBuilder &builder) {
|
|
ReductionKind reductionKind = getReductionKindOfLinalgOp(unshardedOp);
|
|
for (auto [unshardedLinalgOpResult, resultSharding] :
|
|
llvm::zip_equal(unshardedOp->getResults(), resultShardings)) {
|
|
createAllReduceForResultWithoutPartialSharding(
|
|
unshardedLinalgOpResult, opReductionMeshAxes, resultSharding,
|
|
reductionKind, spmdizationMap, builder);
|
|
}
|
|
}
|
|
|
|
static void spmdizeLinalgOpWithShardedReduction(
|
|
LinalgOp op, ArrayRef<Value> spmdizedOperands,
|
|
ArrayRef<MeshShardingAttr> operandShardings,
|
|
ArrayRef<MeshShardingAttr> resultShardings,
|
|
ArrayRef<utils::IteratorType> loopIteratorTypes,
|
|
ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators,
|
|
IRMapping &spmdizationMap, SymbolTableCollection &symbolTable,
|
|
ImplicitLocOpBuilder &builder) {
|
|
MeshOp mesh = getMesh(op, operandShardings, resultShardings, symbolTable);
|
|
SmallVector<MeshAxis> reductionMeshAxes = mesh::getReductionMeshAxes(
|
|
loopIteratorTypes, meshAxisAssignmentForLoopIterators);
|
|
SmallVector<Value> spmdizedLinalgOpOperands =
|
|
createDestinationPassingStyleInitOperands(op, mesh, spmdizedOperands,
|
|
reductionMeshAxes,
|
|
spmdizationMap, builder);
|
|
// We must not change the operand mappings of the original spmdizationMap as
|
|
// they are the mappings for the whole spmdization blob and may be used by
|
|
// others.
|
|
IRMapping internalSpmdizationMap;
|
|
for (auto [unshardedOperand, spmdizedOperand] :
|
|
llvm::zip_equal(op->getOperands(), spmdizedLinalgOpOperands)) {
|
|
internalSpmdizationMap.map(unshardedOperand, spmdizedOperand);
|
|
}
|
|
spmdizeTriviallyShardableOperation(
|
|
*op, spmdizedLinalgOpOperands, operandShardings, resultShardings,
|
|
internalSpmdizationMap, symbolTable, builder);
|
|
for (Value result : op->getResults()) {
|
|
spmdizationMap.map(result, internalSpmdizationMap.lookup(result));
|
|
}
|
|
|
|
// Handle partial shardings.
|
|
createAllReduceForResultsWithoutPartialShardings(
|
|
op, reductionMeshAxes, resultShardings, spmdizationMap, builder);
|
|
}
|
|
|
|
namespace {
|
|
|
|
// ShardingInterface for ops that implement LinalgStructuredInterface.
|
|
// The supported ops are only those where the indexing maps are projected
|
|
// permutations.
|
|
template <typename Op>
|
|
struct StructuredOpShardingInterface
|
|
: public mesh::ShardingInterface::ExternalModel<
|
|
StructuredOpShardingInterface<Op>, Op> {
|
|
SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
|
|
return llvm::cast<LinalgOp>(op).getIteratorTypesArray();
|
|
}
|
|
|
|
SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
|
|
LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
|
|
SmallVector<AffineMap> res = linalgOp.getIndexingMapsArray();
|
|
|
|
// Results must have the same indexing as destination passing style initial
|
|
// operands.
|
|
for (int64_t i = 0; i < linalgOp.getNumDpsInits(); ++i) {
|
|
res.push_back(res[linalgOp.getDpsInitOperand(i)->getOperandNumber()]);
|
|
}
|
|
|
|
return res;
|
|
}
|
|
|
|
LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
|
|
ArrayRef<MeshShardingAttr> operandShardings,
|
|
ArrayRef<MeshShardingAttr> resultShardings,
|
|
IRMapping &spmdizationMap,
|
|
SymbolTableCollection &symbolTable,
|
|
OpBuilder &builder) const {
|
|
LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
|
|
|
|
SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
|
|
bool allIndexingMapsAreProjectedPermutation =
|
|
llvm::all_of(indexingMaps, [](AffineMap map) {
|
|
return map.isProjectedPermutation();
|
|
});
|
|
if (!allIndexingMapsAreProjectedPermutation) {
|
|
// TODO: handle non-projected permutations.
|
|
return op->emitOpError()
|
|
<< "supports indexing maps that are only projected permutation.";
|
|
}
|
|
|
|
SmallVector<utils::IteratorType> loopIteratorTypes =
|
|
linalgOp.getIteratorTypesArray();
|
|
ShardingArray meshAxisAssignmentForLoopIterators =
|
|
getMeshAxisAssignmentForLoopIterators(operandShardings, resultShardings,
|
|
loopIteratorTypes, indexingMaps);
|
|
if (mesh::isAtLeastOneReductionIteratorSharded(
|
|
loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
|
|
ImplicitLocOpBuilder implicitLocBuilder(op->getLoc(), builder);
|
|
spmdizeLinalgOpWithShardedReduction(
|
|
linalgOp, spmdizedOperands, operandShardings, resultShardings,
|
|
loopIteratorTypes, meshAxisAssignmentForLoopIterators, spmdizationMap,
|
|
symbolTable, implicitLocBuilder);
|
|
} else {
|
|
spmdizeTriviallyShardableOperation(*op, spmdizedOperands,
|
|
operandShardings, resultShardings,
|
|
spmdizationMap, symbolTable, builder);
|
|
}
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
template <typename OpType>
|
|
static void registerOne(MLIRContext *ctx) {
|
|
OpType::template attachInterface<StructuredOpShardingInterface<OpType>>(*ctx);
|
|
}
|
|
|
|
/// Variadic helper function.
|
|
template <typename... OpTypes>
|
|
static void registerAll(MLIRContext *ctx) {
|
|
(registerOne<OpTypes>(ctx), ...);
|
|
}
|
|
|
|
void registerMeshShardingInterfaceExternalModels(DialectRegistry ®istry) {
|
|
registry.addExtension(+[](MLIRContext *ctx, LinalgDialect *dialect) {
|
|
DialectRegistry registry;
|
|
registry.insert<affine::AffineDialect, arith::ArithDialect, scf::SCFDialect,
|
|
tensor::TensorDialect>();
|
|
ctx->appendDialectRegistry(registry);
|
|
for (StringRef name : registry.getDialectNames())
|
|
ctx->getOrLoadDialect(name);
|
|
|
|
registerOne<linalg::GenericOp>(ctx);
|
|
registerAll<
|
|
#define GET_OP_LIST
|
|
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
|
|
>(ctx);
|
|
});
|
|
}
|
|
|
|
} // namespace mlir::linalg
|