Files
clang-p2996/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
Simone Pellegrini 4b59b7b946 [mlir][Linalg] Fix fusing of indexed linalg consumer with different axes (#140892)
When fusing two `linalg.genericOp`, where the producer has index
semantics, invalid `affine.apply` ops can be generated where the number
of indices do not match the number of loops in the fused genericOp.

This patch fixes the issue by directly using the number of loops from
the generated fused op.
2025-06-13 10:03:09 +01:00

2319 lines
96 KiB
C++

//===- ElementwiseOpFusion.cpp - Implementation of linalg Fusion ---------===///
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements the linalg dialect Fusion on tensors operations pass.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/RegionUtils.h"
#include <optional>
#include <utility>
namespace mlir {
#define GEN_PASS_DEF_LINALGELEMENTWISEOPFUSIONPASS
#include "mlir/Dialect/Linalg/Passes.h.inc"
} // namespace mlir
using namespace mlir;
using namespace mlir::linalg;
//===---------------------------------------------------------------------===//
// Methods and patterns that fuse elementwise `linalg.generic` operations.
//===---------------------------------------------------------------------===//
/// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
/// the `producer` to use in the fused operation given the indexing map of the
/// result of the producer in the consumer.
static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
OpOperand *producerOpOperand, AffineMap producerResultIndexMap,
AffineMap fusedConsumerArgIndexMap) {
// The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
// from consumer loop -> consumer arg tensor index/producer result tensor
// index. The fused loop is same as the consumer loop. For each producer arg
// the indexing map to be computed is a map from consumer loop -> producer
// arg tensor index.
// producerResultIndexMap is a map from producer loop -> tensor index.
// Compute the inverse to get map from tensor index -> producer loop.
// The inverse is a map from producer result tensor index -> producer loop.
AffineMap invProducerResultIndexMap =
inversePermutation(producerResultIndexMap);
assert(invProducerResultIndexMap &&
"expected producer result indexing map to be invertible");
LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner());
// argMap is a map from producer loop -> producer arg tensor index.
AffineMap argMap = producer.getMatchingIndexingMap(producerOpOperand);
// Compose argMap with invProducerResultIndexMap to get a map from
// producer result tensor index -> producer arg tensor index.
AffineMap t1 = argMap.compose(invProducerResultIndexMap);
// Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
// consumer loop/ fused loop -> producer arg tensor index.
return t1.compose(fusedConsumerArgIndexMap);
}
// Checks if the given operand can be dropped, and the remaining operands
// of the fused producer & consumer after the fusion can still compute the
// bounds of the op.
static bool isOpOperandCanBeDroppedAfterFusedLinalgs(
GenericOp producer, GenericOp consumer,
ArrayRef<OpOperand *> opOperandsToIgnore) {
SmallVector<AffineMap> indexingMaps;
SmallVector<GenericOp> ops = {producer, consumer};
for (auto &op : ops) {
for (auto &opOperand : op->getOpOperands()) {
if (llvm::is_contained(opOperandsToIgnore, &opOperand)) {
continue;
}
indexingMaps.push_back(op.getMatchingIndexingMap(&opOperand));
}
}
if (indexingMaps.empty()) {
// If there are no indexing maps, the operand can only be dropped
// if neither op has loops.
return producer.getNumLoops() == 0 && consumer.getNumLoops() == 0;
}
// The concatanation of the remained indexing maps must be invertible, so
// the bounds of the op can be still computed after dropping the selected
// operand. inversePermutation returns an empty AffineMap in case the
// concatanated indexing maps are not invertible.
return inversePermutation(concatAffineMaps(
indexingMaps, producer.getContext())) != AffineMap();
}
/// Returns a set of indices of the producer's results which would
/// be preserved after the fusion.
/// * There is a chance that the implementation of the transformation does not
/// agree with the result of this method. This function gives a prediction based
/// on an optimized fusion.
llvm::SmallDenseSet<int> mlir::linalg::getPreservedProducerResults(
GenericOp producer, GenericOp consumer, OpOperand *fusedOperand) {
llvm::SmallDenseSet<int> preservedProducerResults;
llvm::SmallVector<OpOperand *> opOperandsToIgnore;
// The fusedOperand will be removed during the fusion
opOperandsToIgnore.emplace_back(fusedOperand);
for (const auto &producerResult : llvm::enumerate(producer->getResults())) {
auto *outputOperand = producer.getDpsInitOperand(producerResult.index());
opOperandsToIgnore.emplace_back(outputOperand);
if (producer.payloadUsesValueFromOperand(outputOperand) ||
!isOpOperandCanBeDroppedAfterFusedLinalgs(producer, consumer,
opOperandsToIgnore) ||
llvm::any_of(producerResult.value().getUsers(), [&](Operation *user) {
return user != consumer.getOperation();
})) {
preservedProducerResults.insert(producerResult.index());
// In case the operand can't be dropped
(void)opOperandsToIgnore.pop_back_val();
}
}
return preservedProducerResults;
}
/// Conditions for elementwise fusion of generic operations.
bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
if (!fusedOperand)
return false;
auto producer = fusedOperand->get().getDefiningOp<GenericOp>();
auto consumer = dyn_cast<GenericOp>(fusedOperand->getOwner());
// Check producer and consumer are generic ops.
if (!producer || !consumer)
return false;
// Consumer can have mixed semantics, just check operand itself has tensor
// type. Producer must have full tensor semantics to avoid potential
// aliasing between producer and consumer memrefs.
if (!producer.hasPureTensorSemantics() ||
!isa<RankedTensorType>(fusedOperand->get().getType()))
return false;
// Verify that
// - the producer has all "parallel" iterator type.
if (producer.getNumParallelLoops() != producer.getNumLoops())
return false;
// Only allow fusing the producer of an input operand for now.
// TODO: allow fusing the producer of an output operand.
if (!consumer.isDpsInput(fusedOperand))
return false;
// Get the consumer index map. The number of results of the consumer index
// map must match the number of loops of the producer.
AffineMap consumerIndexMap = consumer.getMatchingIndexingMap(fusedOperand);
if (consumerIndexMap.getNumResults() != producer.getNumLoops())
return false;
// Finally the index_map for the result must be invertible. For now just
// verify it is a permutation.
AffineMap producerResultIndexMap =
producer.getMatchingIndexingMap(producer.getDpsInitOperand(0));
if (!producerResultIndexMap.isPermutation())
return false;
// Ensure that the fusion does not remove size information required to
// get the loop bounds. For non-reduction generics, this is trivially the
// case due to the output operand. For reductions, we need to check that after
// the fusion, each loop dimension has at least one input that defines it.
if ((consumer.getNumReductionLoops())) {
BitVector coveredDims(consumer.getNumLoops(), false);
auto addToCoveredDims = [&](AffineMap map) {
for (auto result : map.getResults())
if (auto dimExpr = dyn_cast<AffineDimExpr>(result))
coveredDims[dimExpr.getPosition()] = true;
};
for (auto pair :
llvm::zip(consumer->getOperands(), consumer.getIndexingMapsArray())) {
Value operand = std::get<0>(pair);
if (operand == fusedOperand->get())
continue;
AffineMap operandMap = std::get<1>(pair);
addToCoveredDims(operandMap);
}
for (OpOperand *operand : producer.getDpsInputOperands()) {
AffineMap newIndexingMap =
getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
operand, producerResultIndexMap, consumerIndexMap);
addToCoveredDims(newIndexingMap);
}
if (!coveredDims.all())
return false;
}
return true;
}
/// Generate the region of the fused tensor operation. The region of the fused
/// op must be empty.
static void generateFusedElementwiseOpRegion(
RewriterBase &rewriter, GenericOp fusedOp,
AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand,
unsigned nloops, llvm::SmallDenseSet<int> &preservedProducerResults) {
auto producer = cast<GenericOp>(fusedOperand->get().getDefiningOp());
auto consumer = cast<GenericOp>(fusedOperand->getOwner());
// Build the region of the fused op.
Block &producerBlock = producer->getRegion(0).front();
Block &consumerBlock = consumer->getRegion(0).front();
OpBuilder::InsertionGuard guard(rewriter);
Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion());
IRMapping mapper;
// 2. Add an index operation for every fused loop dimension and use the
// `consumerToProducerLoopsMap` to map the producer indices.
if (producer.hasIndexSemantics()) {
// Add an index operation for every fused loop dimension.
unsigned numFusedOpLoops = fusedOp.getNumLoops();
SmallVector<Value> fusedIndices;
fusedIndices.reserve(numFusedOpLoops);
llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),
std::back_inserter(fusedIndices), [&](uint64_t dim) {
return rewriter.create<IndexOp>(producer.getLoc(), dim);
});
for (IndexOp indexOp :
llvm::make_early_inc_range(producerBlock.getOps<IndexOp>())) {
Value newIndex = rewriter.create<affine::AffineApplyOp>(
producer.getLoc(),
consumerToProducerLoopsMap.getSubMap(indexOp.getDim()), fusedIndices);
mapper.map(indexOp.getResult(), newIndex);
}
}
// TODO: allow fusing the producer of an output operand.
assert(consumer.isDpsInput(fusedOperand) &&
"expected producer of input operand");
// 3. Consumer input operands up to consumerIdx (exclusive).
for (BlockArgument bbArg : consumerBlock.getArguments().take_front(
fusedOperand->getOperandNumber())) // input assumption.
mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
// Replacing consumerIdx requires getting the cloned, yielded, value from
// the (cloned) producer block. This happens in step 9.
// 4. Splice in producer's input operands.
for (BlockArgument bbArg :
producerBlock.getArguments().take_front(producer.getNumDpsInputs()))
mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
// 5. Remaining consumer's input operands (drop past index `consumerIdx`).
for (BlockArgument bbArg :
consumerBlock.getArguments()
.take_front(consumer.getNumDpsInputs())
.drop_front(fusedOperand->getOperandNumber() + 1))
mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
// 6. All of the producer's output operands
for (const auto &bbArg : llvm::enumerate(
producerBlock.getArguments().take_back(producer.getNumDpsInits()))) {
if (!preservedProducerResults.count(bbArg.index()))
continue;
mapper.map(bbArg.value(), fusedBlock->addArgument(bbArg.value().getType(),
bbArg.value().getLoc()));
}
// 7. All of consumer's output operands.
for (BlockArgument bbArg :
consumerBlock.getArguments().take_back(consumer.getNumDpsInits()))
mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
// 8. Clone all producer operations except for the yield and index operations
// to the fused operation.
for (auto &op : producerBlock.without_terminator()) {
if (!isa<IndexOp>(op))
rewriter.clone(op, mapper);
}
// 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just
// forward the yield operand.
auto producerYieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator());
unsigned producerResultNumber =
cast<OpResult>(fusedOperand->get()).getResultNumber();
Value replacement =
mapper.lookupOrDefault(producerYieldOp.getOperand(producerResultNumber));
// Sanity checks, if replacement is not already in the mapper then it must be
// produced outside.
if (replacement == producerYieldOp.getOperand(producerResultNumber)) {
if (auto bb = dyn_cast<BlockArgument>(replacement))
assert(bb.getOwner() != &producerBlock &&
"yielded block argument must have been mapped");
else
assert(!producer->isAncestor(replacement.getDefiningOp()) &&
"yielded value must have been mapped");
}
mapper.map(consumerBlock.getArgument(fusedOperand->getOperandNumber()),
replacement);
// 10. Clone operations from the consumer to the fused op.
for (auto &op : consumerBlock.without_terminator())
rewriter.clone(op, mapper);
// 11. Include the final yield (which is the remapped values for all the
// yield)
auto consumerYieldOp = cast<linalg::YieldOp>(consumerBlock.getTerminator());
SmallVector<Value> fusedYieldValues;
fusedYieldValues.reserve(producerYieldOp.getNumOperands() +
consumerYieldOp.getNumOperands());
for (const auto &producerYieldVal :
llvm::enumerate(producerYieldOp.getOperands())) {
if (preservedProducerResults.count(producerYieldVal.index()))
fusedYieldValues.push_back(
mapper.lookupOrDefault(producerYieldVal.value()));
}
for (auto consumerYieldVal : consumerYieldOp.getOperands())
fusedYieldValues.push_back(mapper.lookupOrDefault(consumerYieldVal));
rewriter.create<YieldOp>(fusedOp.getLoc(), fusedYieldValues);
// Sanity checks.
assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() &&
"Ill-formed GenericOp region");
}
FailureOr<mlir::linalg::ElementwiseOpFusionResult>
mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
OpOperand *fusedOperand) {
assert(areElementwiseOpsFusable(fusedOperand) &&
"expected elementwise operation pre-conditions to pass");
auto producerResult = cast<OpResult>(fusedOperand->get());
auto producer = cast<GenericOp>(producerResult.getOwner());
auto consumer = cast<GenericOp>(fusedOperand->getOwner());
// TODO: allow fusing the producer of an output operand.
assert(consumer.isDpsInput(fusedOperand) &&
"expected producer of input operand");
/// Find the results of the producer that have uses outside of the consumer,
/// after the fusion.
llvm::SmallDenseSet<int> preservedProducerResults =
mlir::linalg::getPreservedProducerResults(producer, consumer,
fusedOperand);
// Compute the fused operands list and indexing maps.
SmallVector<Value> fusedInputOperands, fusedOutputOperands;
SmallVector<Type> fusedResultTypes;
SmallVector<AffineMap> fusedIndexMaps;
fusedInputOperands.reserve(producer.getNumDpsInputs() +
consumer.getNumDpsInputs());
fusedOutputOperands.reserve(preservedProducerResults.size() +
consumer.getNumDpsInits());
fusedResultTypes.reserve(preservedProducerResults.size() +
consumer.getNumDpsInits());
fusedIndexMaps.reserve(producer->getNumOperands() +
consumer->getNumOperands());
// In the following, numbering matches that of `generateFusedTensorOpRegion`.
// 3. Consumer input operands/maps up to consumerIdx (exclusive).
auto consumerInputs = consumer.getDpsInputOperands();
auto *it = llvm::find_if(consumerInputs, [&](OpOperand *operand) {
return operand == fusedOperand;
});
assert(it != consumerInputs.end() && "expected to find the consumer operand");
for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
fusedInputOperands.push_back(opOperand->get());
fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
}
// 4. Splice in producer's input operands/maps.
AffineMap producerResultIndexMap =
producer.getIndexingMapMatchingResult(producerResult);
for (OpOperand *opOperand : producer.getDpsInputOperands()) {
fusedInputOperands.push_back(opOperand->get());
// Compute indexing maps for the producer args in the fused operation.
AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
opOperand, producerResultIndexMap,
consumer.getMatchingIndexingMap(fusedOperand));
fusedIndexMaps.push_back(map);
}
// 5. Remaining consumer's input operands/maps (drop past index
// `consumerIdx`).
for (OpOperand *opOperand :
llvm::make_range(std::next(it), consumerInputs.end())) {
fusedInputOperands.push_back(opOperand->get());
fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
}
// 6. Collect all of the producer outputs.
for (const auto &opOperand : llvm::enumerate(producer.getDpsInitsMutable())) {
if (!preservedProducerResults.count(opOperand.index()))
continue;
fusedOutputOperands.push_back(opOperand.value().get());
AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
&opOperand.value(), producerResultIndexMap,
consumer.getMatchingIndexingMap(fusedOperand));
fusedIndexMaps.push_back(map);
fusedResultTypes.push_back(opOperand.value().get().getType());
}
// 7. All of consumer's output operands (skip operands: added by the builder).
for (OpOperand &opOperand : consumer.getDpsInitsMutable()) {
fusedOutputOperands.push_back(opOperand.get());
fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(&opOperand));
Type resultType = opOperand.get().getType();
if (!isa<MemRefType>(resultType))
fusedResultTypes.push_back(resultType);
}
// Generate the fused op.
auto fusedOp = rewriter.create<GenericOp>(
consumer.getLoc(), fusedResultTypes, fusedInputOperands,
fusedOutputOperands, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
consumer.getIteratorTypes(),
/*doc=*/nullptr,
/*library_call=*/nullptr);
if (!fusedOp.getShapesToLoopsMap()) {
// Fused op has invalid indexing maps. Typically this means something is off
// in the input, but going ahead here would result in verification errors.
// So cleanup and abort.
rewriter.eraseOp(fusedOp);
return rewriter.notifyMatchFailure(
fusedOp, "fused op failed loop bound computation check");
}
// Construct an AffineMap from consumer loops to producer loops.
// consumer loop -> tensor index
AffineMap consumerResultIndexMap =
consumer.getMatchingIndexingMap(fusedOperand);
// tensor index -> producer loop
AffineMap invProducerResultIndexMap =
inversePermutation(producerResultIndexMap);
assert(invProducerResultIndexMap &&
"expected producer result indexig map to be invertible");
// consumer loop -> producer loop
AffineMap consumerToProducerLoopsMap =
invProducerResultIndexMap.compose(consumerResultIndexMap);
generateFusedElementwiseOpRegion(
rewriter, fusedOp, consumerToProducerLoopsMap, fusedOperand,
consumer.getNumLoops(), preservedProducerResults);
ElementwiseOpFusionResult result;
result.fusedOp = fusedOp;
int resultNum = 0;
for (auto [index, producerResult] : llvm::enumerate(producer->getResults()))
if (preservedProducerResults.count(index))
result.replacements[producerResult] = fusedOp->getResult(resultNum++);
for (auto consumerResult : consumer->getResults())
result.replacements[consumerResult] = fusedOp->getResult(resultNum++);
return result;
}
namespace {
/// Patterns to fuse a generic op, with the producer of its operands.
class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
public:
FuseElementwiseOps(MLIRContext *context, ControlFusionFn fun,
PatternBenefit benefit = 1)
: OpRewritePattern<GenericOp>(context, benefit),
controlFn(std::move(fun)) {}
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
// Find the first operand that is defined by another generic op on tensors.
for (OpOperand &opOperand : genericOp->getOpOperands()) {
if (!areElementwiseOpsFusable(&opOperand))
continue;
if (!controlFn(&opOperand))
continue;
Operation *producer = opOperand.get().getDefiningOp();
// Find the producer of the operand.
FailureOr<ElementwiseOpFusionResult> fusionResult =
fuseElementwiseOps(rewriter, &opOperand);
if (failed(fusionResult))
return rewriter.notifyMatchFailure(genericOp, "fusion failed");
// Perform the fusion.
for (auto [origVal, replacement] : fusionResult->replacements) {
rewriter.replaceUsesWithIf(origVal, replacement, [&](OpOperand &use) {
// Only replace consumer uses.
return use.get().getDefiningOp() != producer;
});
}
rewriter.eraseOp(genericOp);
return success();
}
return failure();
}
private:
ControlFusionFn controlFn;
};
} // namespace
//===---------------------------------------------------------------------===//
// Methods and patterns that fuse reshape ops with elementwise operations by
// expanding the dimensionality of the elementwise operations.
//===---------------------------------------------------------------------===//
/// Conditions for folding a structured linalg operation with a reshape op by
/// expanding the iteration space dimensionality for tensor operations. These
/// are preconditions assumed by `foldReshapeByDimExpansion` which implements
/// the following fusion pattern.
///
/// Consider
///
/// %c = linalg.generic ins(%a, %b : memref<?x?x?xf32>, memref<?x?xf32>)
/// indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
/// affine_map<(d0, d1, d2) -> (d1, d2)>,
/// affine_map<(d0, d1, d2) -> (d0, d2, d1)>]
/// %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]]
/// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
///
/// The reshape can be folded into the `linalgOp` if its loop dimensionality
/// is increased to match the result (operand) of the tensor.expand_shape.
/// The indexing_map of the fused tensor in the `linalgOp` and the
/// reassociation map helps compute the indexing maps of the modified op.
/// For the above example, based on the reassociation map it
/// can be concluded that
///
/// - The loop used to access the first dimension of the fused tensor is split
/// into two.
/// - The loop used to access the second dimension of the fused tensor is kept
/// as is.
/// - The loop used to access the third dimension of the fused tensor is split
/// into three.
///
/// i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified
/// op, then
///
/// d0 -> e0, e1
/// d1 -> e2, e3, e4
/// d2 -> e5
///
/// substituting this, the structured op can be rewritten as
///
/// %d = linalg.generic ins(%0, %1 : )
/// indexing_maps =
/// [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>,
/// affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>,
/// affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>]
///
/// Since operands to the linalg generic are now 5D, reshapes can be introduced
/// to make it consistent
///
/// %0 = tensor.expand_shape %a [[0, 1, 2], [3, 4], [5]]
/// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
/// %1 = tensor.expand_shape %b [[0, 1, 2], [3]]
/// : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
///
/// The added reshapes are again expanding patterns, so they will get fused
/// with its producers if possible.
static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp,
OpOperand *fusableOpOperand) {
// Is fusable only if:
// - All the indexing maps for operands and results are projected
// permutations.
// - The fused tensor is not a scalar.
SmallVector<utils::IteratorType> iteratorTypes =
linalgOp.getIteratorTypesArray();
AffineMap operandMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
return linalgOp.hasPureTensorSemantics() &&
llvm::all_of(linalgOp.getIndexingMaps().getValue(),
[](Attribute attr) {
return cast<AffineMapAttr>(attr)
.getValue()
.isProjectedPermutation();
}) &&
operandMap.getNumResults() > 0;
}
namespace {
/// Information needed to expand a generic operation to fold the reshape with
/// it.
class ExpansionInfo {
public:
// Computes the mapping from original dimensions of the op to the dimensions
// of the expanded op given the `indexingMap` of the fused operand/result of
// the generic op, the `reassocationMaps` of the reshape op and the shape of
// the expanded op.
LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
ArrayRef<AffineMap> reassociationMaps,
ArrayRef<OpFoldResult> expandedShape,
PatternRewriter &rewriter);
unsigned getOrigOpNumDims() const { return reassociation.size(); }
unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
ReassociationIndicesRef getExpandedDims(unsigned i) const {
return reassociation[i];
}
ArrayRef<OpFoldResult> getExpandedShapeOfDim(unsigned i) const {
return expandedShapeMap[i];
}
ArrayRef<OpFoldResult> getOriginalShape() const { return originalLoopExtent; }
private:
/// Reassociation from the dimensions in the original operation to the
/// dimension of the expanded operation.
SmallVector<ReassociationIndices> reassociation;
/// Mapping from extent of loops in the original operation, to the extent of
/// loops in the expanded operation.
SmallVector<SmallVector<OpFoldResult>> expandedShapeMap;
/// Extent of the loop in the original operation.
SmallVector<OpFoldResult> originalLoopExtent;
unsigned expandedOpNumDims;
};
} // namespace
LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
OpOperand *fusableOpOperand,
ArrayRef<AffineMap> reassociationMaps,
ArrayRef<OpFoldResult> expandedShape,
PatternRewriter &rewriter) {
if (reassociationMaps.empty())
return failure();
AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(linalgOp);
originalLoopExtent = llvm::map_to_vector(
linalgOp.createLoopRanges(rewriter, linalgOp->getLoc()),
[](Range r) { return r.size; });
reassociation.clear();
expandedShapeMap.clear();
// Compute the number of dimension in the expanded op that correspond to each
// dimension of the original op.
SmallVector<unsigned> numExpandedDims(fusedIndexMap.getNumDims(), 1);
expandedShapeMap.resize(fusedIndexMap.getNumDims());
for (const auto &resultExpr : llvm::enumerate(fusedIndexMap.getResults())) {
unsigned pos = cast<AffineDimExpr>(resultExpr.value()).getPosition();
AffineMap foldedDims = reassociationMaps[resultExpr.index()];
numExpandedDims[pos] = foldedDims.getNumResults();
ArrayRef<OpFoldResult> shape =
expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]);
expandedShapeMap[pos].assign(shape.begin(), shape.end());
}
// The remaining dimensions remain the same.
for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims()))
if (expandedShapeMap[i].empty())
expandedShapeMap[i] = {originalLoopExtent[i]};
// Compute reassociation map from the original op to the expanded op.
unsigned sum = 0;
reassociation.reserve(fusedIndexMap.getNumDims());
for (const auto &numFoldedDim : llvm::enumerate(numExpandedDims)) {
auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value());
reassociation.emplace_back(seq.begin(), seq.end());
sum += numFoldedDim.value();
}
expandedOpNumDims = sum;
return success();
}
/// Return the indexing map to use in the expanded op for a given the
/// `indexingMap` of the original operation.
static AffineMap
getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap,
const ExpansionInfo &expansionInfo) {
SmallVector<AffineExpr> newExprs;
for (AffineExpr expr : indexingMap.getResults()) {
unsigned pos = cast<AffineDimExpr>(expr).getPosition();
SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>(
llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) {
return builder.getAffineDimExpr(static_cast<unsigned>(v));
}));
newExprs.append(expandedExprs.begin(), expandedExprs.end());
}
return AffineMap::get(expansionInfo.getExpandedOpNumDims(),
indexingMap.getNumSymbols(), newExprs,
builder.getContext());
}
/// Return the shape and type of the operand/result to use in the expanded op
/// given the type in the original op.
static std::tuple<SmallVector<OpFoldResult>, RankedTensorType>
getExpandedShapeAndType(RankedTensorType originalType, AffineMap indexingMap,
const ExpansionInfo &expansionInfo) {
SmallVector<OpFoldResult> expandedShape;
for (AffineExpr expr : indexingMap.getResults()) {
unsigned dim = cast<AffineDimExpr>(expr).getPosition();
ArrayRef<OpFoldResult> dimExpansion =
expansionInfo.getExpandedShapeOfDim(dim);
expandedShape.append(dimExpansion.begin(), dimExpansion.end());
}
SmallVector<int64_t> expandedStaticShape;
std::tie(expandedStaticShape, std::ignore) =
decomposeMixedValues(expandedShape);
return {expandedShape, RankedTensorType::get(expandedStaticShape,
originalType.getElementType())};
}
/// Returns the reassociation maps to use in the `tensor.expand_shape`
/// operation to convert the operands of the original operation to operands of
/// the expanded operation. The same method is used to compute the
/// `tensor.collapse_shape` used to collapse the result of the expanded
/// op to get the value that can replace all uses of the results of the original
/// op.
static SmallVector<ReassociationIndices>
getReassociationForExpansion(AffineMap indexingMap,
const ExpansionInfo &expansionInfo) {
SmallVector<ReassociationIndices> reassociation;
unsigned numReshapeDims = 0;
for (AffineExpr expr : indexingMap.getResults()) {
unsigned dim = cast<AffineDimExpr>(expr).getPosition();
auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
SmallVector<int64_t, 2> indices = llvm::to_vector<2>(
llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
reassociation.emplace_back(std::move(indices));
numReshapeDims += numExpandedDims;
}
return reassociation;
}
/// Update the body of an expanded linalg operation having index semantics. The
/// indices of the original operation need to be recovered by linearizing the
/// indices of the correspoding dimensions of the expanded operation. For now it
/// is assumed that the shapes of the expanded operation needed for
/// linearization are static.
static void updateExpandedGenericOpRegion(PatternRewriter &rewriter,
Location loc, Region &fusedRegion,
const ExpansionInfo &expansionInfo) {
// Replace the original indices by the linearization of the expanded indices.
for (IndexOp indexOp :
llvm::make_early_inc_range(fusedRegion.front().getOps<IndexOp>())) {
ArrayRef<int64_t> expandedDims =
expansionInfo.getExpandedDims(indexOp.getDim());
assert(!expandedDims.empty() && "expected valid expansion info");
// Skip index operations that are not affected by the expansion.
if (expandedDims.size() == 1 &&
expandedDims.front() == (int64_t)indexOp.getDim())
continue;
// Linearize the expanded indices of the original index dimension.
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointAfter(indexOp);
ArrayRef<OpFoldResult> expandedDimsShape =
expansionInfo.getExpandedShapeOfDim(indexOp.getDim()).drop_front();
SmallVector<Value> expandedIndices;
expandedIndices.reserve(expandedDims.size() - 1);
llvm::transform(
expandedDims.drop_front(), std::back_inserter(expandedIndices),
[&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); });
OpFoldResult newIndex =
rewriter.create<IndexOp>(loc, expandedDims.front()).getResult();
for (auto [expandedShape, expandedIndex] :
llvm::zip(expandedDimsShape, expandedIndices)) {
AffineExpr idx, acc, shape;
bindDims(rewriter.getContext(), idx, acc);
bindSymbols(rewriter.getContext(), shape);
newIndex = affine::makeComposedFoldedAffineApply(
rewriter, indexOp.getLoc(), idx + acc * shape,
ArrayRef<OpFoldResult>{expandedIndex, newIndex, expandedShape});
}
Value newIndexVal =
getValueOrCreateConstantIndexOp(rewriter, indexOp.getLoc(), newIndex);
rewriter.replaceOp(indexOp, newIndexVal);
}
}
// Create an expanded transpose op.
// the reassociation map is already permuted hence we inverse permute and then
// flatten it. Then we inverse permute it again to get the final expanded
// transpose permutation. For example,
//
// permutation = [2, 0, 1]
// reassociation_map for expansion = [[0, 1], [2], [3, 4, 5]]
//
// inverse permutation = [1, 2, 0]
// applied to reassocation_map and then flattened becomes
// flatened permutation = [2, 3, 4, 5, 0, 1]
// final permuation is the inverse of the flattened permutation.
//
// Becomes
//
// permutation=[4, 5, 0, 1, 2, 3]
static Operation *createExpandedTransposeOp(PatternRewriter &rewriter,
TransposeOp transposeOp,
Value expandedInput, Value output,
ExpansionInfo &expansionInfo) {
SmallVector<int64_t> newPerm;
for (int64_t perm : invertPermutationVector(transposeOp.getPermutation())) {
auto reassoc = expansionInfo.getExpandedDims(perm);
for (int64_t dim : reassoc) {
newPerm.push_back(dim);
}
}
return rewriter.create<TransposeOp>(transposeOp.getLoc(), expandedInput,
output, invertPermutationVector(newPerm));
}
// Create an expanded generic op.
static Operation *createExpandedGenericOp(
PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes,
ArrayRef<Value> &expandedOpOperands, ArrayRef<Value> outputs,
ExpansionInfo &expansionInfo, ArrayRef<AffineMap> expandedOpIndexingMaps) {
// The iterator types of the expanded op are all parallel.
SmallVector<utils::IteratorType> iteratorTypes(
expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
for (auto [i, type] : llvm::enumerate(linalgOp.getIteratorTypesArray()))
for (auto j : expansionInfo.getExpandedDims(i))
iteratorTypes[j] = type;
Operation *fused = rewriter.create<GenericOp>(
linalgOp.getLoc(), resultTypes, expandedOpOperands, outputs,
expandedOpIndexingMaps, iteratorTypes);
Region &fusedRegion = fused->getRegion(0);
Region &originalRegion = linalgOp->getRegion(0);
rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin());
// Update the index accesses after the expansion.
updateExpandedGenericOpRegion(rewriter, linalgOp.getLoc(), fusedRegion,
expansionInfo);
return fused;
}
// Create an expanded fused op that retains the name for certain ops
// such as fill, copy and transpose and produce a generic op for
// rest of linalg ops.
static Operation *createExpandedOp(PatternRewriter &rewriter, LinalgOp linalgOp,
TypeRange resultTypes,
ArrayRef<Value> expandedOpOperands,
ArrayRef<Value> outputs,
ArrayRef<AffineMap> expandedOpIndexingMaps,
ExpansionInfo &expansionInfo) {
return TypeSwitch<Operation *, Operation *>(linalgOp.getOperation())
.Case<TransposeOp>([&](TransposeOp transposeOp) {
return createExpandedTransposeOp(rewriter, transposeOp,
expandedOpOperands[0], outputs[0],
expansionInfo);
})
.Case<FillOp, CopyOp>([&](Operation *op) {
return clone(rewriter, linalgOp, resultTypes,
llvm::to_vector(llvm::concat<Value>(
llvm::to_vector(expandedOpOperands),
llvm::to_vector(outputs))));
})
.Default([&](Operation *op) {
return createExpandedGenericOp(rewriter, linalgOp, resultTypes,
expandedOpOperands, outputs,
expansionInfo, expandedOpIndexingMaps);
});
}
/// Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op
/// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes
/// that those conditions have been satisfied.
static std::optional<SmallVector<Value>>
fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
OpOperand *fusableOpOperand,
PatternRewriter &rewriter) {
assert(isFusableWithReshapeByDimExpansion(linalgOp, fusableOpOperand) &&
"preconditions for fuse operation failed");
Location loc = linalgOp.getLoc();
SmallVector<OpFoldResult> expandedShape;
SmallVector<AffineMap, 4> reassociationIndices;
Value src;
if (auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(reshapeOp)) {
// Try to move the dynamic dimensions in output shape before the `linalgOp`
// to maintain SSA validity
if (failed(moveValueDefinitions(
rewriter, expandingReshapeOp.getOutputShape(), linalgOp)))
return std::nullopt;
expandedShape = expandingReshapeOp.getMixedOutputShape();
reassociationIndices = expandingReshapeOp.getReassociationMaps();
src = expandingReshapeOp.getSrc();
} else {
auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(reshapeOp);
if (!collapsingReshapeOp)
return std::nullopt;
expandedShape = tensor::getMixedSizes(
rewriter, collapsingReshapeOp->getLoc(), collapsingReshapeOp.getSrc());
reassociationIndices = collapsingReshapeOp.getReassociationMaps();
src = collapsingReshapeOp.getSrc();
}
ExpansionInfo expansionInfo;
if (failed(expansionInfo.compute(linalgOp, fusableOpOperand,
reassociationIndices, expandedShape,
rewriter)))
return std::nullopt;
SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
llvm::map_range(linalgOp.getIndexingMapsArray(), [&](AffineMap m) {
return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
}));
// Set insertion point to the generic op.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(linalgOp);
SmallVector<Value> expandedOpOperands;
expandedOpOperands.reserve(linalgOp.getNumDpsInputs());
for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
if (opOperand == fusableOpOperand) {
expandedOpOperands.push_back(src);
continue;
}
if (auto opOperandType =
dyn_cast<RankedTensorType>(opOperand->get().getType())) {
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
SmallVector<OpFoldResult> expandedOperandShape;
RankedTensorType expandedOperandType;
std::tie(expandedOperandShape, expandedOperandType) =
getExpandedShapeAndType(opOperandType, indexingMap, expansionInfo);
if (expandedOperandType != opOperand->get().getType()) {
// Reshape the operand to get the right type.
SmallVector<ReassociationIndices> reassociation =
getReassociationForExpansion(indexingMap, expansionInfo);
if (failed(reshapeLikeShapesAreCompatible(
[&](const Twine &msg) {
return rewriter.notifyMatchFailure(linalgOp, msg);
},
opOperandType.getShape(), expandedOperandType.getShape(),
reassociation,
/*isExpandingReshape=*/true)))
return std::nullopt;
expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>(
loc, expandedOperandType, opOperand->get(), reassociation,
expandedOperandShape));
continue;
}
}
expandedOpOperands.push_back(opOperand->get());
}
SmallVector<Value> outputs;
for (OpOperand &opOperand : linalgOp.getDpsInitsMutable()) {
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
auto opOperandType = cast<RankedTensorType>(opOperand.get().getType());
SmallVector<OpFoldResult> expandedOutputShape;
RankedTensorType expandedOutputType;
std::tie(expandedOutputShape, expandedOutputType) =
getExpandedShapeAndType(opOperandType, indexingMap, expansionInfo);
if (expandedOutputType != opOperand.get().getType()) {
SmallVector<ReassociationIndices> reassociation =
getReassociationForExpansion(indexingMap, expansionInfo);
if (failed(reshapeLikeShapesAreCompatible(
[&](const Twine &msg) {
return rewriter.notifyMatchFailure(linalgOp, msg);
},
opOperandType.getShape(), expandedOutputType.getShape(),
reassociation,
/*isExpandingReshape=*/true)))
return std::nullopt;
outputs.push_back(rewriter.create<tensor::ExpandShapeOp>(
loc, expandedOutputType, opOperand.get(), reassociation,
expandedOutputShape));
} else {
outputs.push_back(opOperand.get());
}
}
TypeRange resultTypes = ValueRange(outputs).getTypes();
Operation *fusedOp =
createExpandedOp(rewriter, linalgOp, resultTypes, expandedOpOperands,
outputs, expandedOpIndexingMaps, expansionInfo);
// Reshape the result values to their original shape if this is a collapsing
// reshape folded into its consumer.
SmallVector<Value> resultVals;
for (OpResult opResult : linalgOp->getOpResults()) {
int64_t resultNumber = opResult.getResultNumber();
if (resultTypes[resultNumber] != opResult.getType()) {
SmallVector<ReassociationIndices> reassociation =
getReassociationForExpansion(
linalgOp.getMatchingIndexingMap(
linalgOp.getDpsInitOperand(resultNumber)),
expansionInfo);
resultVals.push_back(rewriter.create<tensor::CollapseShapeOp>(
linalgOp.getLoc(), opResult.getType(),
fusedOp->getResult(resultNumber), reassociation));
} else {
resultVals.push_back(fusedOp->getResult(resultNumber));
}
}
// Assuming a single result.
return resultVals;
}
namespace {
/// Pattern to fuse a tensor.collapse_shape op with its consumer structured op,
/// when the reshape op is collapsing dimensions. The dimensionality of the loop
/// in the consumer is expanded.
class FoldWithProducerReshapeOpByExpansion
: public OpInterfaceRewritePattern<LinalgOp> {
public:
FoldWithProducerReshapeOpByExpansion(MLIRContext *context,
ControlFusionFn foldReshapes,
PatternBenefit benefit = 1)
: OpInterfaceRewritePattern<LinalgOp>(context, benefit),
controlFoldingReshapes(std::move(foldReshapes)) {}
LogicalResult matchAndRewrite(LinalgOp linalgOp,
PatternRewriter &rewriter) const override {
for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
tensor::CollapseShapeOp reshapeOp =
opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
if (!reshapeOp)
continue;
// Fold only if
// - The tensor reshape op is folding.
// - All constraints of fusing with reshape by expansion are met.
if (!isFusableWithReshapeByDimExpansion(linalgOp, opOperand) ||
(!controlFoldingReshapes(opOperand)))
continue;
std::optional<SmallVector<Value>> replacementValues =
fuseWithReshapeByExpansion(linalgOp, reshapeOp, opOperand, rewriter);
if (!replacementValues)
return failure();
rewriter.replaceOp(linalgOp, *replacementValues);
return success();
}
return failure();
}
private:
ControlFusionFn controlFoldingReshapes;
};
class FoldPadWithProducerReshapeOpByExpansion
: public OpRewritePattern<tensor::PadOp> {
public:
FoldPadWithProducerReshapeOpByExpansion(MLIRContext *context,
ControlFusionFn foldReshapes,
PatternBenefit benefit = 1)
: OpRewritePattern<tensor::PadOp>(context, benefit),
controlFoldingReshapes(std::move(foldReshapes)) {}
LogicalResult matchAndRewrite(tensor::PadOp padOp,
PatternRewriter &rewriter) const override {
tensor::CollapseShapeOp reshapeOp =
padOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
if (!reshapeOp)
return failure();
if (!reshapeOp->hasOneUse())
return failure();
if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
return rewriter.notifyMatchFailure(padOp,
"fusion blocked by control function");
}
ArrayRef<int64_t> low = padOp.getStaticLow();
ArrayRef<int64_t> high = padOp.getStaticHigh();
SmallVector<ReassociationIndices> reassociations =
reshapeOp.getReassociationIndices();
for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
if (reInd.size() != 1 && (l != 0 || h != 0))
return failure();
}
SmallVector<OpFoldResult> newLow, newHigh;
RankedTensorType expandedType = reshapeOp.getSrcType();
RankedTensorType paddedType = padOp.getResultType();
SmallVector<int64_t> expandedPaddedShape(expandedType.getShape());
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
if (reInd.size() == 1) {
expandedPaddedShape[reInd[0]] = paddedType.getShape()[idx];
}
for (size_t i = 0; i < reInd.size(); ++i) {
newLow.push_back(padOp.getMixedLowPad()[idx]);
newHigh.push_back(padOp.getMixedHighPad()[idx]);
}
}
Location loc = padOp->getLoc();
RankedTensorType expandedPaddedType = paddedType.clone(expandedPaddedShape);
auto newPadOp = rewriter.create<tensor::PadOp>(
loc, expandedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
padOp.getConstantPaddingValue(), padOp.getNofold());
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
return success();
}
private:
ControlFusionFn controlFoldingReshapes;
};
/// Pattern to fold a tensor.expand_shape op with its producer generic op
/// by expanding the dimensionality of the loop in the producer op.
struct FoldReshapeWithGenericOpByExpansion
: public OpRewritePattern<tensor::ExpandShapeOp> {
FoldReshapeWithGenericOpByExpansion(MLIRContext *context,
ControlFusionFn foldReshapes,
PatternBenefit benefit = 1)
: OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
controlFoldingReshapes(std::move(foldReshapes)) {}
LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
PatternRewriter &rewriter) const override {
// Fold only if all constraints of fusing with reshape by expansion are met.
auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
if (!producerResult) {
return rewriter.notifyMatchFailure(reshapeOp,
"source not produced by an operation");
}
auto producer = dyn_cast<LinalgOp>(producerResult.getOwner());
if (!producer) {
return rewriter.notifyMatchFailure(reshapeOp,
"producer not a generic op");
}
if (!isFusableWithReshapeByDimExpansion(
producer,
producer.getDpsInitOperand(producerResult.getResultNumber()))) {
return rewriter.notifyMatchFailure(
reshapeOp, "failed preconditions of fusion with producer generic op");
}
if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
return rewriter.notifyMatchFailure(reshapeOp,
"fusion blocked by control function");
}
std::optional<SmallVector<Value>> replacementValues =
fuseWithReshapeByExpansion(
producer, reshapeOp,
producer.getDpsInitOperand(producerResult.getResultNumber()),
rewriter);
if (!replacementValues) {
return rewriter.notifyMatchFailure(reshapeOp,
"fusion by expansion failed");
}
// Find the replacement for the reshape op. Since the replacements have the
// same type as the returns of the original generic op, the consumer reshape
// op can be replaced by the source of the collapse_shape op that defines
// the replacement.
Value reshapeReplacement =
(*replacementValues)[cast<OpResult>(reshapeOp.getSrc())
.getResultNumber()];
if (auto collapseOp =
reshapeReplacement.getDefiningOp<tensor::CollapseShapeOp>()) {
reshapeReplacement = collapseOp.getSrc();
}
rewriter.replaceOp(reshapeOp, reshapeReplacement);
rewriter.replaceOp(producer, *replacementValues);
return success();
}
private:
ControlFusionFn controlFoldingReshapes;
};
} // namespace
//===---------------------------------------------------------------------===//
// Methods and patterns to fuse reshape with linalg.generic operations by
// contraction of dimensions.
//===---------------------------------------------------------------------===//
/// For a given list of indices in the range of the `indexingMap` that are
/// folded, return the indices of the corresponding domain. Return
/// `std::nullopt` on failure. Ensures that all the elements of the returned
/// reassociation are distinct.
static ReassociationIndices
getDomainReassociation(AffineMap indexingMap,
ReassociationIndicesRef rangeReassociation) {
assert(indexingMap.isProjectedPermutation() &&
"expected projected permutation");
ReassociationIndices domainReassociation = llvm::to_vector<4>(
llvm::map_range(rangeReassociation, [&](int64_t pos) -> int64_t {
return cast<AffineDimExpr>(indexingMap.getResults()[pos]).getPosition();
}));
// The projected permutation semantics ensures that there is no repetition of
// the domain indices.
return domainReassociation;
}
/// For a given `dimSequence`, check if the sequence is conserved in the
/// `indexingMap`. `indexingMap` is expected to be a projected permutation.
/// Non-existence of the sequence returns true as well.
bool mlir::linalg::isDimSequencePreserved(AffineMap indexingMap,
ReassociationIndicesRef dimSequence) {
assert(!dimSequence.empty() &&
"expected non-empty list for dimension sequence");
assert(indexingMap.isProjectedPermutation() &&
"expected indexing map to be projected permutation");
llvm::SmallDenseSet<unsigned, 4> sequenceElements;
sequenceElements.insert_range(dimSequence);
unsigned dimSequenceStart = dimSequence[0];
for (const auto &expr : enumerate(indexingMap.getResults())) {
unsigned dimInMapStart = cast<AffineDimExpr>(expr.value()).getPosition();
// 1. Check if this start of the sequence.
if (dimInMapStart == dimSequenceStart) {
if (expr.index() + dimSequence.size() > indexingMap.getNumResults())
return false;
// 1a. Check if sequence is preserved.
for (const auto &dimInSequence : enumerate(dimSequence)) {
unsigned dimInMap =
cast<AffineDimExpr>(
indexingMap.getResult(expr.index() + dimInSequence.index()))
.getPosition();
if (dimInMap != dimInSequence.value())
return false;
}
// Found the sequence. Projected permutation
// enforces that all AffineDimExprs in the result are unique, so no
// further checks are needed.
return true;
}
// 2. If position in the expr (which is of type AffineDimExpr) is part
// of sequence, return false here. This implies the entire sequence does not
// exist in the indexing map.
if (sequenceElements.count(dimInMapStart))
return false;
}
// 3. No element of sequence found. Return true.
return true;
}
bool mlir::linalg::areDimSequencesPreserved(
ArrayRef<AffineMap> maps, ArrayRef<ReassociationIndices> dimSequences) {
return llvm::all_of(maps, [&](AffineMap map) {
return llvm::all_of(dimSequences, [&](ReassociationIndicesRef dimSequence) {
return isDimSequencePreserved(map, dimSequence);
});
});
}
// Return the list of dimensions of the iteration domain that can be
// collapsed to allow for fusion with the a producer that is an expand_shape
// operation. If all dimensions created by expansion can be collapsed in the
// iteration space then the reshape is defunct.
//
// Example:
//
// ```mlir
// #map = affine_map<(d0, d1) -> (d0, d1)>
// %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
// %2 = tensor.empty [..] : tensor<?x4xf32>
// %3 = linalg.generic {
// indexing_maps = [#map, #map],
// iterator_types = ["parallel" ,"parallel"]}
// ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<?x4xf32>) {.. }
// ```
//
// can be fused by collapsing the dimensions of the iteration space.
//
// ```mlir
// #map = affine_map<(d0) -> (d0)>
// %2 = tensor.empty [..] : tensor<?xf32>
// %3 = linalg.generic {
// indexing_maps = [#map, #map],
// iterator_types = ["parallel"]}
// ins(%1 : tensor<?xf32>) outs(%2 : tensor<?xf32>) {.. }
// %4 = tensor.expand_shape %3 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
// ```
//
// In the following example,
//
// ```mlir
// #map0 = affine_map<(d0, d1) -> (d0, d1)>
// #map1 = affine_map<(d0, d1) -> (d1, d0)>
// %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
// %2 = tensor.empty [..] : tensor<4x?xf32>
// %2 = linalg.generic {
// indexing_maps = [#map0, #map1],
// iterator_types = ["parallel" ,"parallel"]}
// ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<4x?xf32>) {.. }
// ```
//
// the reshape cannot be fused with the generic op by collapsing the op
// dimensions since the indexing maps will have to contain mods and divs
// to preserve the accesses pattern. When no dimensions of the iteration
// space are collapsable and empty vector is returned.
static SmallVector<ReassociationIndices>
getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand,
ArrayRef<ReassociationIndices> reassociation) {
// Some basic checks for this fusion to be valid.
if (!genericOp.hasPureTensorSemantics())
return {};
if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](AffineMap map) {
return map.isProjectedPermutation();
})) {
return {};
}
// Compute all the loops with the reduction iterator types.
SmallVector<unsigned> reductionDims;
genericOp.getReductionDims(reductionDims);
llvm::SmallDenseSet<unsigned, 4> processedIterationDims;
AffineMap indexingMap = genericOp.getMatchingIndexingMap(fusableOperand);
auto iteratorTypes = genericOp.getIteratorTypesArray();
SmallVector<ReassociationIndices> iterationSpaceReassociation;
for (ReassociationIndicesRef foldedRangeDims : reassociation) {
assert(!foldedRangeDims.empty() && "unexpected empty reassociation");
// Ignore dims that are not folded.
if (foldedRangeDims.size() == 1)
continue;
ReassociationIndices foldedIterationSpaceDims =
getDomainReassociation(indexingMap, foldedRangeDims);
// Check that the folded iteration dims do not contain already processed
// dims.
if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
return processedIterationDims.count(dim);
}))
continue;
// Check that all folded iterator types are all parallel or all reductions.
utils::IteratorType startIteratorType =
iteratorTypes[foldedIterationSpaceDims[0]];
if (!isParallelIterator(startIteratorType) &&
!isReductionIterator(startIteratorType))
continue;
if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
return iteratorTypes[dim] != startIteratorType;
}))
continue;
// If the folded dimensions correspond to a "reduction" iterator type,
// the folded dimensions need to be "in-order". Strictly speaking this is
// not necessary, for reductions that are associative and commutative, but
// using a more strict definition of reduction for now.
if (isReductionIterator(startIteratorType)) {
bool isContiguous = false;
for (const auto &startDim : llvm::enumerate(reductionDims)) {
// Move window in `reductionDims` to start of the folded iteration dims.
if (startDim.value() != foldedIterationSpaceDims[0])
continue;
// If sizes doesnt match, trivial not contiguous. This condition should
// not be hit.
if (startDim.index() + foldedIterationSpaceDims.size() >
reductionDims.size())
break;
// Check that the contiguity is maintained.
isContiguous = true;
for (const auto &foldedDim :
llvm::enumerate(foldedIterationSpaceDims)) {
if (reductionDims[foldedDim.index() + startDim.index()] !=
foldedDim.value()) {
isContiguous = false;
break;
}
}
break;
}
if (!isContiguous)
continue;
}
// Check that the sequence is preserved in all indexing maps.
if (llvm::any_of(genericOp.getIndexingMapsArray(),
[&](AffineMap indexingMap) {
return !isDimSequencePreserved(indexingMap,
foldedIterationSpaceDims);
}))
continue;
processedIterationDims.insert_range(foldedIterationSpaceDims);
iterationSpaceReassociation.emplace_back(
std::move(foldedIterationSpaceDims));
}
return iterationSpaceReassociation;
}
/// Helper class to carry state while collapsing the `linalg.generic` op.
namespace {
class CollapsingInfo {
public:
LogicalResult initialize(unsigned origNumLoops,
ArrayRef<ReassociationIndices> foldedIterationDims) {
llvm::SmallDenseSet<int64_t, 4> processedDims;
// Find all the dims that are folded.
for (ReassociationIndicesRef foldedIterationDim : foldedIterationDims) {
if (foldedIterationDim.empty())
continue;
// If the folded dims contain dims already folded, that's illegal
// specification. Repetition within a list is also illegal.
for (auto dim : foldedIterationDim) {
if (dim >= origNumLoops)
return failure();
if (processedDims.count(dim))
return failure();
processedDims.insert(dim);
}
collapsedOpToOrigOpIterationDim.emplace_back(foldedIterationDim.begin(),
foldedIterationDim.end());
}
if (processedDims.size() > origNumLoops)
return failure();
// Add all the preserved dims of the original op as single
// elements to `collapsedOpToOrigOpIterationDim`.
for (auto dim : llvm::seq<int64_t>(0, origNumLoops)) {
if (processedDims.count(dim))
continue;
collapsedOpToOrigOpIterationDim.emplace_back(ReassociationIndices{dim});
}
llvm::sort(collapsedOpToOrigOpIterationDim,
[&](ReassociationIndicesRef lhs, ReassociationIndicesRef rhs) {
return lhs[0] < rhs[0];
});
origOpToCollapsedOpIterationDim.resize(origNumLoops);
for (const auto &foldedDims :
llvm::enumerate(collapsedOpToOrigOpIterationDim)) {
for (const auto &dim : enumerate(foldedDims.value()))
origOpToCollapsedOpIterationDim[dim.value()] =
std::make_pair<int64_t, unsigned>(foldedDims.index(), dim.index());
}
return success();
}
/// Return mapping from collapsed loop domain to original loop domain.
ArrayRef<ReassociationIndices> getCollapsedOpToOrigOpMapping() const {
return collapsedOpToOrigOpIterationDim;
}
/// Return mapping from original loop domain to collapsed loop domain. The
/// mapping is a pair. First value is the dimension in the collapsed loop that
/// the original loop is mapped to. Second is the relative position in folded
/// list of this domain. For example if the original loop domain is 3D, and
/// the collapsed loop domain is folding all of it, i.e.
///
/// ```
/// collapsedOpToOrigOpMapping = [[0, 1, 2] [3, 4]]`
/// ```
///
/// then
///
/// ```
/// origOpToCollapsedOpMapping[0] = {0, 0};
/// origOpToCollapsedOpMapping[1] = {0, 1};
/// origOpToCollapsedOpMapping[2] = {0, 2};
/// origOpToCollapsedOpMapping[3] = {1, 0};
/// origOpToCollapsedOpMapping[4] = {1, 1};
/// ```
///
ArrayRef<std::pair<int64_t, unsigned>> getOrigOpToCollapsedOpMapping() const {
return origOpToCollapsedOpIterationDim;
}
/// Return the collapsed op iteration domain rank.
unsigned getCollapsedOpIterationRank() const {
return collapsedOpToOrigOpIterationDim.size();
}
private:
/// Map from the iteration domain index in collapsed op to the iteration
/// domain indices in the original op.
SmallVector<ReassociationIndices> collapsedOpToOrigOpIterationDim;
/// Map from iteration domain index in the original op to the iteration domain
/// index in the collapsed op.
SmallVector<std::pair<int64_t, unsigned>> origOpToCollapsedOpIterationDim;
};
} // namespace
/// Get the iterator types for the collapsed operation given the original
/// iterator types and collapsed dimensions.
static SmallVector<utils::IteratorType>
getCollapsedOpIteratorTypes(ArrayRef<utils::IteratorType> iteratorTypes,
const CollapsingInfo &collapsingInfo) {
SmallVector<utils::IteratorType> collapsedIteratorTypes;
for (ReassociationIndicesRef foldedIterDims :
collapsingInfo.getCollapsedOpToOrigOpMapping()) {
assert(!foldedIterDims.empty() &&
"reassociation indices expected to have non-empty sets");
// Just pick the iterator type of the first folded dim. Pre-condition checks
// expected to have checked that iterator types of all folded dimensions are
// the same.
collapsedIteratorTypes.push_back(iteratorTypes[foldedIterDims[0]]);
}
return collapsedIteratorTypes;
}
/// Compute the indexing map in the collapsed op that corresponds to the given
/// `indexingMap` of the original operation.
static AffineMap
getCollapsedOpIndexingMap(AffineMap indexingMap,
const CollapsingInfo &collapsingInfo) {
MLIRContext *context = indexingMap.getContext();
assert(indexingMap.isProjectedPermutation() &&
"expected indexing map to be projected permutation");
SmallVector<AffineExpr> resultExprs;
auto origOpToCollapsedOpMapping =
collapsingInfo.getOrigOpToCollapsedOpMapping();
for (auto expr : indexingMap.getResults()) {
unsigned dim = cast<AffineDimExpr>(expr).getPosition();
// If the dim is not the first of the collapsed dim, do nothing.
if (origOpToCollapsedOpMapping[dim].second != 0)
continue;
// The next n-dims are guaranteed to be collapsed. So just use the
// iteration dimension of the collapsed op.
resultExprs.push_back(
getAffineDimExpr(origOpToCollapsedOpMapping[dim].first, context));
}
return AffineMap::get(collapsingInfo.getCollapsedOpIterationRank(), 0,
resultExprs, context);
}
/// Return the `reassociation` indices to use to collapse the operand when the
/// iteration space of a generic op is collapsed.
static SmallVector<ReassociationIndices>
getOperandReassociation(AffineMap indexingMap,
const CollapsingInfo &collapsingInfo) {
unsigned counter = 0;
SmallVector<ReassociationIndices> operandReassociation;
auto origOpToCollapsedOpMapping =
collapsingInfo.getOrigOpToCollapsedOpMapping();
auto collapsedOpToOrigOpMapping =
collapsingInfo.getCollapsedOpToOrigOpMapping();
while (counter < indexingMap.getNumResults()) {
unsigned dim =
cast<AffineDimExpr>(indexingMap.getResult(counter)).getPosition();
// This is the start of a collapsed dimensions of the iteration that
// is gauranteed to be preserved in the indexing map. The number of folded
// dims is obtained from the collapsed op to original op mapping.
unsigned numFoldedDims =
collapsedOpToOrigOpMapping[origOpToCollapsedOpMapping[dim].first]
.size();
if (origOpToCollapsedOpMapping[dim].second == 0) {
auto range = llvm::seq<unsigned>(counter, counter + numFoldedDims);
operandReassociation.emplace_back(range.begin(), range.end());
}
counter += numFoldedDims;
}
return operandReassociation;
}
/// Get the new value to use for a given `OpOperand` in the collapsed operation.
static Value getCollapsedOpOperand(Location loc, LinalgOp op,
OpOperand *opOperand,
const CollapsingInfo &collapsingInfo,
OpBuilder &builder) {
AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
SmallVector<ReassociationIndices> operandReassociation =
getOperandReassociation(indexingMap, collapsingInfo);
// If the number of entries in the reassociation for the operand is same as
// the number of results of the indexing map, then nothing to do for this
// operand.
Value operand = opOperand->get();
if (operandReassociation.size() == indexingMap.getNumResults())
return operand;
// Insert a reshape to collapse the dimensions.
if (isa<MemRefType>(operand.getType())) {
return builder
.create<memref::CollapseShapeOp>(loc, operand, operandReassociation)
.getResult();
}
return builder
.create<tensor::CollapseShapeOp>(loc, operand, operandReassociation)
.getResult();
}
/// Modify the `linalg.index` operations in the original generic op, to its
/// value in the collapsed operation.
static void generateCollapsedIndexingRegion(
Location loc, Block *block, const CollapsingInfo &collapsingInfo,
ArrayRef<OpFoldResult> loopRange, RewriterBase &rewriter) {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointToStart(block);
// Collect all the original index ops.
auto indexOps = llvm::to_vector(block->getOps<linalg::IndexOp>());
// For each folded dimension list resolve the original induction variable
// values in terms of the folded dimension induction variable.
// i_{folded} = (i_0 * d1 + i1) * d2 + i2.
// can be inverted to
// i2 = i_{folded} % d2
// i1 = (i_{folded} / d2) % d1
// i0 = i_{folded} / (d1 * d2)
llvm::DenseMap<unsigned, Value> indexReplacementVals;
for (auto foldedDims :
enumerate(collapsingInfo.getCollapsedOpToOrigOpMapping())) {
ReassociationIndicesRef foldedDimsRef(foldedDims.value());
Value newIndexVal =
rewriter.create<linalg::IndexOp>(loc, foldedDims.index());
for (auto dim : llvm::reverse(foldedDimsRef.drop_front())) {
Value loopDim =
getValueOrCreateConstantIndexOp(rewriter, loc, loopRange[dim]);
indexReplacementVals[dim] =
rewriter.createOrFold<arith::RemSIOp>(loc, newIndexVal, loopDim);
newIndexVal =
rewriter.createOrFold<arith::DivSIOp>(loc, newIndexVal, loopDim);
}
indexReplacementVals[foldedDims.value().front()] = newIndexVal;
}
for (auto indexOp : indexOps) {
auto dim = indexOp.getDim();
rewriter.replaceOp(indexOp, indexReplacementVals[dim]);
}
}
void collapseOperandsAndResults(LinalgOp op,
const CollapsingInfo &collapsingInfo,
RewriterBase &rewriter,
SmallVectorImpl<Value> &inputOperands,
SmallVectorImpl<Value> &outputOperands,
SmallVectorImpl<Type> &resultTypes) {
Location loc = op->getLoc();
inputOperands =
llvm::map_to_vector(op.getDpsInputOperands(), [&](OpOperand *opOperand) {
return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo,
rewriter);
});
// Get the output operands and result types.
resultTypes.reserve(op.getNumDpsInits());
outputOperands.reserve(op.getNumDpsInits());
for (OpOperand &output : op.getDpsInitsMutable()) {
Value newOutput =
getCollapsedOpOperand(loc, op, &output, collapsingInfo, rewriter);
outputOperands.push_back(newOutput);
// If the op has "buffer semantics", then the init operands are ranked
// memrefs and the op has no results.
if (!op.hasPureBufferSemantics())
resultTypes.push_back(newOutput.getType());
}
}
/// Clone a `LinalgOp` to a collapsed version of same name
template <typename OpTy>
OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp,
const CollapsingInfo &collapsingInfo) {
return nullptr;
}
/// Collapse any `LinalgOp` that does not require any specialization such as
/// indexing_maps, iterator_types, etc.
template <>
LinalgOp cloneToCollapsedOp<LinalgOp>(RewriterBase &rewriter, LinalgOp origOp,
const CollapsingInfo &collapsingInfo) {
SmallVector<Value> inputOperands, outputOperands;
SmallVector<Type> resultTypes;
collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands,
outputOperands, resultTypes);
return clone(
rewriter, origOp, resultTypes,
llvm::to_vector(llvm::concat<Value>(inputOperands, outputOperands)));
}
/// Collapse a `GenericOp`
template <>
GenericOp cloneToCollapsedOp<GenericOp>(RewriterBase &rewriter,
GenericOp origOp,
const CollapsingInfo &collapsingInfo) {
SmallVector<Value> inputOperands, outputOperands;
SmallVector<Type> resultTypes;
collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands,
outputOperands, resultTypes);
SmallVector<AffineMap> indexingMaps(
llvm::map_range(origOp.getIndexingMapsArray(), [&](AffineMap map) {
return getCollapsedOpIndexingMap(map, collapsingInfo);
}));
SmallVector<utils::IteratorType> iteratorTypes(getCollapsedOpIteratorTypes(
origOp.getIteratorTypesArray(), collapsingInfo));
GenericOp collapsedOp = rewriter.create<linalg::GenericOp>(
origOp.getLoc(), resultTypes, inputOperands, outputOperands, indexingMaps,
iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {});
Block *origOpBlock = &origOp->getRegion(0).front();
Block *collapsedOpBlock = &collapsedOp->getRegion(0).front();
rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
collapsedOpBlock->getArguments());
return collapsedOp;
}
LinalgOp createCollapsedOp(LinalgOp op, const CollapsingInfo &collapsingInfo,
RewriterBase &rewriter) {
if (GenericOp genericOp = dyn_cast<GenericOp>(op.getOperation())) {
return cloneToCollapsedOp(rewriter, genericOp, collapsingInfo);
} else {
return cloneToCollapsedOp(rewriter, op, collapsingInfo);
}
}
/// Implementation of fusion with reshape operation by collapsing dimensions.
FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
LinalgOp op, ArrayRef<ReassociationIndices> foldedIterationDims,
RewriterBase &rewriter) {
// Bail on trivial no-op cases.
if (op.getNumLoops() <= 1 || foldedIterationDims.empty() ||
llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) {
return foldedDims.size() <= 1;
}))
return failure();
bool hasPureBufferSemantics = op.hasPureBufferSemantics();
if (hasPureBufferSemantics &&
!llvm::all_of(op->getOperands(), [&](Value operand) -> bool {
MemRefType memRefToCollapse = dyn_cast<MemRefType>(operand.getType());
if (!memRefToCollapse)
return true;
return memref::CollapseShapeOp::isGuaranteedCollapsible(
memRefToCollapse, foldedIterationDims);
}))
return rewriter.notifyMatchFailure(op,
"memref is not guaranteed collapsible");
CollapsingInfo collapsingInfo;
if (failed(
collapsingInfo.initialize(op.getNumLoops(), foldedIterationDims))) {
return rewriter.notifyMatchFailure(
op, "illegal to collapse specified dimensions");
}
// Bail on non-canonical ranges.
SmallVector<Range> loopRanges = op.createLoopRanges(rewriter, op.getLoc());
auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
return cast<IntegerAttr>(attr).getInt() == value;
llvm::APInt actual;
return matchPattern(cast<Value>(ofr), m_ConstantInt(&actual)) &&
actual.getSExtValue() == value;
};
if (!llvm::all_of(loopRanges, [&](Range range) {
return opFoldIsConstantValue(range.offset, 0) &&
opFoldIsConstantValue(range.stride, 1);
})) {
return rewriter.notifyMatchFailure(
op, "expected all loop ranges to have zero start and unit stride");
}
LinalgOp collapsedOp = createCollapsedOp(op, collapsingInfo, rewriter);
Location loc = op->getLoc();
SmallVector<OpFoldResult> loopBound =
llvm::map_to_vector(loopRanges, [](Range range) { return range.size; });
if (collapsedOp.hasIndexSemantics()) {
// Collect the loop range of the generic op.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(collapsedOp);
generateCollapsedIndexingRegion(loc, &collapsedOp->getRegion(0).front(),
collapsingInfo, loopBound, rewriter);
}
// Insert expanding reshape for the result to get back the original result
// type.
SmallVector<Value> results;
for (const auto &originalResult : llvm::enumerate(op->getResults())) {
Value collapsedOpResult = collapsedOp->getResult(originalResult.index());
auto originalResultType =
cast<ShapedType>(originalResult.value().getType());
auto collapsedOpResultType = cast<ShapedType>(collapsedOpResult.getType());
if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
AffineMap indexingMap =
op.getIndexingMapMatchingResult(originalResult.value());
SmallVector<ReassociationIndices> reassociation =
getOperandReassociation(indexingMap, collapsingInfo);
assert(
indexingMap.isProjectedPermutation() &&
"Expected indexing map to be a projected permutation for collapsing");
SmallVector<OpFoldResult> resultShape =
applyPermutationMap(indexingMap, ArrayRef(loopBound));
Value result;
if (isa<MemRefType>(collapsedOpResult.getType())) {
MemRefType expandShapeResultType = MemRefType::get(
originalResultType.getShape(), originalResultType.getElementType());
result = rewriter.create<memref::ExpandShapeOp>(
loc, expandShapeResultType, collapsedOpResult, reassociation,
resultShape);
} else {
result = rewriter.create<tensor::ExpandShapeOp>(
loc, originalResultType, collapsedOpResult, reassociation,
resultShape);
}
results.push_back(result);
} else {
results.push_back(collapsedOpResult);
}
}
return CollapseResult{results, collapsedOp};
}
namespace {
/// Pattern to fuse a tensor.expand_shape op with its consumer generic op by
/// contracting dimensions of the loop.
class FoldWithProducerReshapeOpByCollapsing
: public OpRewritePattern<GenericOp> {
public:
// TODO : support fusion with all linalg ops, not just generic.
FoldWithProducerReshapeOpByCollapsing(MLIRContext *context,
ControlFusionFn foldReshapes,
PatternBenefit benefit = 1)
: OpRewritePattern<GenericOp>(context, benefit),
controlFoldingReshapes(std::move(foldReshapes)) {}
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
for (OpOperand &opOperand : genericOp->getOpOperands()) {
tensor::ExpandShapeOp reshapeOp =
opOperand.get().getDefiningOp<tensor::ExpandShapeOp>();
if (!reshapeOp)
continue;
SmallVector<ReassociationIndices> collapsableIterationDims =
getCollapsableIterationSpaceDims(genericOp, &opOperand,
reshapeOp.getReassociationIndices());
if (collapsableIterationDims.empty() ||
!controlFoldingReshapes(&opOperand)) {
continue;
}
std::optional<CollapseResult> collapseResult = collapseOpIterationDims(
genericOp, collapsableIterationDims, rewriter);
if (!collapseResult) {
return rewriter.notifyMatchFailure(
genericOp, "failed to do the fusion by collapsing transformation");
}
rewriter.replaceOp(genericOp, collapseResult->results);
return success();
}
return failure();
}
private:
ControlFusionFn controlFoldingReshapes;
};
/// Pattern to fold a tensor.collapse_shape op with its producer generic op
/// by expanding the dimensionality of the loop in the producer op.
struct FoldReshapeWithGenericOpByCollapsing
: public OpRewritePattern<tensor::CollapseShapeOp> {
FoldReshapeWithGenericOpByCollapsing(MLIRContext *context,
ControlFusionFn foldReshapes,
PatternBenefit benefit = 1)
: OpRewritePattern<tensor::CollapseShapeOp>(context, benefit),
controlFoldingReshapes(std::move(foldReshapes)) {}
LogicalResult matchAndRewrite(tensor::CollapseShapeOp reshapeOp,
PatternRewriter &rewriter) const override {
// Fold only if all constraints of fusing with reshape by collapsing are
// met.
auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
if (!producerResult) {
return rewriter.notifyMatchFailure(reshapeOp,
"source not produced by an operation");
}
// TODO : support fusion with all linalg producers, not just generic.
auto producer = dyn_cast<GenericOp>(producerResult.getOwner());
if (!producer) {
return rewriter.notifyMatchFailure(reshapeOp,
"producer not a generic op");
}
SmallVector<ReassociationIndices> collapsableIterationDims =
getCollapsableIterationSpaceDims(
producer,
producer.getDpsInitOperand(producerResult.getResultNumber()),
reshapeOp.getReassociationIndices());
if (collapsableIterationDims.empty()) {
return rewriter.notifyMatchFailure(
reshapeOp, "failed preconditions of fusion with producer generic op");
}
if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
return rewriter.notifyMatchFailure(reshapeOp,
"fusion blocked by control function");
}
// Set the insertion point after `producer` because there could be uses
// of `producer` between it and the `tensor.collapse_shape` op.
rewriter.setInsertionPointAfter(producer);
std::optional<CollapseResult> collapseResult =
collapseOpIterationDims(producer, collapsableIterationDims, rewriter);
if (!collapseResult) {
return rewriter.notifyMatchFailure(
producer, "failed to do the fusion by collapsing transformation");
}
rewriter.replaceOp(producer, collapseResult->results);
return success();
}
private:
ControlFusionFn controlFoldingReshapes;
};
class FoldPadWithProducerReshapeOpByCollapsing
: public OpRewritePattern<tensor::PadOp> {
public:
FoldPadWithProducerReshapeOpByCollapsing(MLIRContext *context,
ControlFusionFn foldReshapes,
PatternBenefit benefit = 1)
: OpRewritePattern<tensor::PadOp>(context, benefit),
controlFoldingReshapes(std::move(foldReshapes)) {}
LogicalResult matchAndRewrite(tensor::PadOp padOp,
PatternRewriter &rewriter) const override {
tensor::ExpandShapeOp reshapeOp =
padOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
if (!reshapeOp)
return failure();
if (!reshapeOp->hasOneUse())
return failure();
if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
return rewriter.notifyMatchFailure(padOp,
"fusion blocked by control function");
}
ArrayRef<int64_t> low = padOp.getStaticLow();
ArrayRef<int64_t> high = padOp.getStaticHigh();
SmallVector<ReassociationIndices> reassociations =
reshapeOp.getReassociationIndices();
for (auto reInd : reassociations) {
if (reInd.size() == 1)
continue;
if (llvm::any_of(reInd, [&](int64_t ind) {
return low[ind] != 0 || high[ind] != 0;
})) {
return failure();
}
}
SmallVector<OpFoldResult> newLow, newHigh;
RankedTensorType collapsedType = reshapeOp.getSrcType();
RankedTensorType paddedType = padOp.getResultType();
SmallVector<int64_t> collapsedPaddedShape(collapsedType.getShape());
SmallVector<OpFoldResult> expandedPaddedSizes(
getMixedValues(reshapeOp.getStaticOutputShape(),
reshapeOp.getOutputShape(), rewriter));
AffineExpr d0, d1, d2;
bindDims(rewriter.getContext(), d0, d1, d2);
auto addMap = AffineMap::get(3, 0, {d0 + d1 + d2});
Location loc = reshapeOp->getLoc();
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
OpFoldResult l = padOp.getMixedLowPad()[reInd[0]];
OpFoldResult h = padOp.getMixedHighPad()[reInd[0]];
if (reInd.size() == 1) {
collapsedPaddedShape[idx] = paddedType.getShape()[reInd[0]];
OpFoldResult paddedSize = affine::makeComposedFoldedAffineApply(
rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0]]});
expandedPaddedSizes[reInd[0]] = paddedSize;
}
newLow.push_back(l);
newHigh.push_back(h);
}
RankedTensorType collapsedPaddedType =
paddedType.clone(collapsedPaddedShape);
auto newPadOp = rewriter.create<tensor::PadOp>(
loc, collapsedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
padOp.getConstantPaddingValue(), padOp.getNofold());
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
padOp, padOp.getResultType(), newPadOp.getResult(), reassociations,
expandedPaddedSizes);
return success();
}
private:
ControlFusionFn controlFoldingReshapes;
};
/// Pattern to collapse dimensions.
template <typename LinalgType>
class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
public:
CollapseLinalgDimensions(MLIRContext *context,
GetCollapsableDimensionsFn collapseDimensions,
PatternBenefit benefit = 1)
: OpRewritePattern<LinalgType>(context, benefit),
controlCollapseDimension(std::move(collapseDimensions)) {}
LogicalResult matchAndRewrite(LinalgType op,
PatternRewriter &rewriter) const override {
SmallVector<ReassociationIndices> collapsableIterationDims =
controlCollapseDimension(op);
if (collapsableIterationDims.empty())
return failure();
// Check if the specified list of dimensions to collapse is a valid list.
if (!areDimSequencesPreserved(op.getIndexingMapsArray(),
collapsableIterationDims)) {
return rewriter.notifyMatchFailure(
op, "specified dimensions cannot be collapsed");
}
std::optional<CollapseResult> collapseResult =
collapseOpIterationDims(op, collapsableIterationDims, rewriter);
if (!collapseResult) {
return rewriter.notifyMatchFailure(op, "failed to collapse dimensions");
}
rewriter.replaceOp(op, collapseResult->results);
return success();
}
private:
GetCollapsableDimensionsFn controlCollapseDimension;
};
} // namespace
//===---------------------------------------------------------------------===//
// Methods and patterns that fuse constants with linalg.generic operations.
//===---------------------------------------------------------------------===//
namespace {
/// Pattern to fold a generic op with a splat constant/scalar constant. Does not
/// handle cases where the constant is not single-valued.
class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
public:
FoldScalarOrSplatConstant(MLIRContext *context, PatternBenefit benefit = 1)
: OpRewritePattern<GenericOp>(context, benefit) {}
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
if (!genericOp.hasPureTensorSemantics())
return failure();
for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
Operation *def = opOperand->get().getDefiningOp();
TypedAttr constantAttr;
auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool {
{
DenseElementsAttr splatAttr;
if (matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) &&
splatAttr.isSplat() &&
splatAttr.getType().getElementType().isIntOrFloat()) {
constantAttr = splatAttr.getSplatValue<TypedAttr>();
return true;
}
}
{
IntegerAttr intAttr;
if (matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) {
constantAttr = intAttr;
return true;
}
}
{
FloatAttr floatAttr;
if (matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) {
constantAttr = floatAttr;
return true;
}
}
return false;
};
auto resultValue = dyn_cast<OpResult>(opOperand->get());
if (!def || !resultValue || !isScalarOrSplatConstantOp(def))
continue;
// The operands and the indexing_maps of the fused operation the same as
// the operands and indexing_maps of the generic operations with the
// values at the constant index dropped.
SmallVector<AffineMap> fusedIndexMaps;
SmallVector<Value> fusedOperands;
SmallVector<Location> fusedLocs{genericOp.getLoc()};
fusedIndexMaps.reserve(genericOp->getNumOperands());
fusedOperands.reserve(genericOp.getNumDpsInputs());
fusedLocs.reserve(fusedLocs.size() + genericOp.getNumDpsInputs());
for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
if (inputOperand == opOperand)
continue;
Value inputValue = inputOperand->get();
fusedIndexMaps.push_back(
genericOp.getMatchingIndexingMap(inputOperand));
fusedOperands.push_back(inputValue);
fusedLocs.push_back(inputValue.getLoc());
}
for (OpOperand &outputOperand : genericOp.getDpsInitsMutable())
fusedIndexMaps.push_back(
genericOp.getMatchingIndexingMap(&outputOperand));
// Check if the operation shapes to loops map is computable.
if (!inversePermutation(
concatAffineMaps(fusedIndexMaps, rewriter.getContext()))) {
return rewriter.notifyMatchFailure(
genericOp, "fused op loop bound computation failed");
}
// Create a constant scalar value from the splat constant.
Value scalarConstant =
rewriter.create<arith::ConstantOp>(def->getLoc(), constantAttr);
SmallVector<Value> outputOperands = genericOp.getOutputs();
auto fusedOp = rewriter.create<GenericOp>(
rewriter.getFusedLoc(fusedLocs), genericOp->getResultTypes(),
/*inputs=*/fusedOperands,
/*outputs=*/outputOperands,
rewriter.getAffineMapArrayAttr(fusedIndexMaps),
genericOp.getIteratorTypes(),
/*doc=*/nullptr,
/*library_call=*/nullptr);
// Map the block argument corresponding to the replaced argument with the
// scalar constant.
Region &region = genericOp->getRegion(0);
Block &entryBlock = *region.begin();
IRMapping mapping;
mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()),
scalarConstant);
Region &fusedRegion = fusedOp->getRegion(0);
rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(),
mapping);
rewriter.replaceOp(genericOp, fusedOp->getResults());
return success();
}
return failure();
}
};
} // namespace
//===---------------------------------------------------------------------===//
// Miscellaneous patterns that help fusion.
//===---------------------------------------------------------------------===//
namespace {
/// Forces `outs` operands of linalg operations to use `tensor.empty` if the
/// value of the `outs` operand is not used within the op. This is only
/// implemented for `linalg.generic` operations for now, but should hold for all
/// linalg structured ops.
struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
using OpRewritePattern<GenericOp>::OpRewritePattern;
LogicalResult matchAndRewrite(GenericOp op,
PatternRewriter &rewriter) const override {
rewriter.startOpModification(op);
bool modifiedOutput = false;
Location loc = op.getLoc();
for (OpOperand &opOperand : op.getDpsInitsMutable()) {
if (!op.payloadUsesValueFromOperand(&opOperand)) {
Value operandVal = opOperand.get();
auto operandType = dyn_cast<RankedTensorType>(operandVal.getType());
if (!operandType)
continue;
// If outs is sparse, leave it to the sparsifier.
if (sparse_tensor::getSparseTensorEncoding(operandVal.getType()))
continue;
// If outs is already an `empty` operation, nothing to do.
auto definingOp = operandVal.getDefiningOp<tensor::EmptyOp>();
if (definingOp)
continue;
modifiedOutput = true;
SmallVector<OpFoldResult> mixedSizes =
tensor::getMixedSizes(rewriter, loc, operandVal);
Value emptyTensor = rewriter.create<tensor::EmptyOp>(
loc, mixedSizes, operandType.getElementType());
op->setOperand(opOperand.getOperandNumber(), emptyTensor);
}
}
if (!modifiedOutput) {
rewriter.cancelOpModification(op);
return failure();
}
rewriter.finalizeOpModification(op);
return success();
}
};
/// Fold linalg.fill into linalg.generic
struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> {
using OpRewritePattern<GenericOp>::OpRewritePattern;
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
if (!genericOp.hasPureTensorSemantics())
return failure();
bool fillFound = false;
Block &payload = genericOp.getRegion().front();
for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
if (!genericOp.payloadUsesValueFromOperand(opOperand))
continue;
FillOp fillOp = opOperand->get().getDefiningOp<FillOp>();
if (!fillOp)
continue;
fillFound = true;
Value fillVal = fillOp.value();
auto resultType =
cast<RankedTensorType>(fillOp.result().getType()).getElementType();
Value convertedVal =
convertScalarToDtype(rewriter, fillOp.getLoc(), fillVal, resultType,
/*isUnsignedCast =*/false);
rewriter.replaceAllUsesWith(
payload.getArgument(opOperand->getOperandNumber()), convertedVal);
}
return success(fillFound);
}
};
} // namespace
void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
RewritePatternSet &patterns,
const ControlFusionFn &controlFoldingReshapes) {
patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(),
controlFoldingReshapes);
patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
controlFoldingReshapes);
patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
controlFoldingReshapes);
}
void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
RewritePatternSet &patterns,
const ControlFusionFn &controlFoldingReshapes) {
patterns.add<FoldWithProducerReshapeOpByCollapsing>(patterns.getContext(),
controlFoldingReshapes);
patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
patterns.getContext(), controlFoldingReshapes);
patterns.add<FoldReshapeWithGenericOpByCollapsing>(patterns.getContext(),
controlFoldingReshapes);
}
void mlir::linalg::populateElementwiseOpsFusionPatterns(
RewritePatternSet &patterns,
const ControlFusionFn &controlElementwiseOpsFusion) {
auto *context = patterns.getContext();
patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion);
patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant,
RemoveOutsDependency>(context);
// Add the patterns that clean up dead operands and results.
populateEraseUnusedOperandsAndResultsPatterns(patterns);
}
void mlir::linalg::populateCollapseDimensions(
RewritePatternSet &patterns,
const GetCollapsableDimensionsFn &controlCollapseDimensions) {
patterns.add<CollapseLinalgDimensions<linalg::GenericOp>,
CollapseLinalgDimensions<linalg::CopyOp>>(
patterns.getContext(), controlCollapseDimensions);
}
//===---------------------------------------------------------------------===//
// Passes
//===---------------------------------------------------------------------===//
namespace {
/// Pass that fuses generic ops on tensors. Used only for testing.
// TODO(ravishankarm): This pass is to be deprecated. The efficacy of the
// patterns added here heavily depends on the cost function used. Having an
// opinionated pass of this form is not recommended. Deprecate this pass in
// favor of test passes that check the functionality of each of the patterns
// added here individually.
struct LinalgElementwiseOpFusionPass
: public impl::LinalgElementwiseOpFusionPassBase<
LinalgElementwiseOpFusionPass> {
using impl::LinalgElementwiseOpFusionPassBase<
LinalgElementwiseOpFusionPass>::LinalgElementwiseOpFusionPassBase;
void runOnOperation() override {
Operation *op = getOperation();
MLIRContext *context = op->getContext();
RewritePatternSet patterns(context);
// Add folding with reshape by expansion patterns.
ControlFusionFn defaultControlFn = [](OpOperand *fusedOperand) {
Operation *producer = fusedOperand->get().getDefiningOp();
return producer && producer->hasOneUse();
};
// Add elementwise op fusion patterns.
populateElementwiseOpsFusionPatterns(patterns, defaultControlFn);
populateFoldReshapeOpsByExpansionPatterns(patterns, defaultControlFn);
tensor::populateBubbleUpExpandShapePatterns(patterns);
// General canonicalization patterns.
affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context);
GenericOp::getCanonicalizationPatterns(patterns, context);
tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(
patterns);
// Add constant folding patterns.
populateConstantFoldLinalgOperations(patterns, defaultControlFn);
// Use TopDownTraversal for compile time reasons.
(void)applyPatternsGreedily(op, std::move(patterns),
GreedyRewriteConfig().setUseTopDownTraversal());
}
};
} // namespace