Files
clang-p2996/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
Andrzej Warzyński 22f96ab6fb [mlir][vector] Refine vector.transfer_read hoisting/forwarding (#65770)
Make sure that when analysing a `vector.transfer_read` that's a
candidate for either hoisting or store-to-load forwarding,
`memref.collapse_shape` Ops are correctly included in the alias
analysis. This is done by either
* making sure that relevant users are taken into account, or
* source Ops are correctly identified.
2023-09-12 10:33:58 +01:00

239 lines
9.8 KiB
C++

//===- Hoisting.cpp - Linalg hoisting transformations ---------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements functions concerned with hoisting invariant operations
// in the context of Linalg transformations.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dominance.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
using llvm::dbgs;
#define DEBUG_TYPE "linalg-hoisting"
#define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ")
using namespace mlir;
using namespace mlir::linalg;
void mlir::linalg::hoistRedundantVectorTransfersOnTensor(func::FuncOp func) {
IRRewriter rewriter(func->getContext());
// TODO: walking in some reverse / inside-out order would be more efficient
// and would capture more cases.
func.walk([&](scf::ForOp forOp) {
hoistRedundantSubsetExtractInsert(rewriter, forOp);
});
}
static bool noAliasingUseInLoop(vector::TransferReadOp transferRead,
LoopLikeOpInterface loop) {
Value source = transferRead.getSource();
// Skip subview and collapse_shape Ops
while (auto subView = source.getDefiningOp<memref::SubViewOp>())
source = subView.getSource();
while (auto collapsed = source.getDefiningOp<memref::CollapseShapeOp>())
source = collapsed->getOperand(0);
llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
source.getUsers().end());
llvm::SmallDenseSet<Operation *, 32> processed;
while (!users.empty()) {
Operation *user = users.pop_back_val();
// If the user has already been processed skip.
if (!processed.insert(user).second)
continue;
if (auto subView = dyn_cast<memref::SubViewOp>(user)) {
users.append(subView->getUsers().begin(), subView->getUsers().end());
continue;
}
if (auto collapsed = dyn_cast<memref::CollapseShapeOp>(user)) {
users.append(collapsed->getUsers().begin(), collapsed->getUsers().end());
continue;
}
if (isMemoryEffectFree(user) || isa<vector::TransferReadOp>(user))
continue;
if (!loop->isAncestor(user))
continue;
return false;
}
return true;
}
void mlir::linalg::hoistRedundantVectorTransfers(func::FuncOp func) {
bool changed = true;
while (changed) {
changed = false;
// First move loop invariant ops outside of their loop. This needs to be
// done before as we cannot move ops without interrupting the function walk.
func.walk(
[&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); });
func.walk([&](vector::TransferReadOp transferRead) {
if (!isa<MemRefType>(transferRead.getShapedType()))
return WalkResult::advance();
LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
<< *transferRead.getOperation() << "\n");
auto loop = dyn_cast<LoopLikeOpInterface>(transferRead->getParentOp());
LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead->getParentOp()
<< "\n");
if (!isa_and_nonnull<scf::ForOp, affine::AffineForOp>(loop))
return WalkResult::advance();
LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation()
<< "\n");
SetVector<Operation *> forwardSlice;
getForwardSlice(transferRead.getOperation(), &forwardSlice);
// Look for the last TransferWriteOp in the forwardSlice of
// `transferRead` that operates on the same memref.
vector::TransferWriteOp transferWrite;
for (auto *sliceOp : llvm::reverse(forwardSlice)) {
auto candidateWrite = dyn_cast<vector::TransferWriteOp>(sliceOp);
if (!candidateWrite ||
candidateWrite.getSource() != transferRead.getSource())
continue;
transferWrite = candidateWrite;
}
// All operands of the TransferRead must be defined outside of the loop.
for (auto operand : transferRead.getOperands())
if (!loop.isDefinedOutsideOfLoop(operand))
return WalkResult::advance();
// Only hoist transfer_read / transfer_write pairs and singleton
// transfer_reads for now.
if (!transferWrite) {
// Make sure there are no other accesses to the memref before
// hoisting transfer_read.
if (noAliasingUseInLoop(transferRead, loop))
loop.moveOutOfLoop(transferRead);
return WalkResult::advance();
}
LLVM_DEBUG(DBGS() << "Candidate: " << *transferWrite.getOperation()
<< "\n");
// Approximate aliasing by checking that:
// 1. indices, vector type and permutation map are the same (i.e., the
// transfer_read/transfer_write ops are matching),
// 2. no other operations in the loop access the same memref except
// for transfer_read/transfer_write accessing statically disjoint
// slices.
if (transferRead.getIndices() != transferWrite.getIndices() ||
transferRead.getVectorType() != transferWrite.getVectorType() ||
transferRead.getPermutationMap() != transferWrite.getPermutationMap())
return WalkResult::advance();
// TODO: may want to memoize this information for performance but it
// likely gets invalidated often.
DominanceInfo dom(loop);
if (!dom.properlyDominates(transferRead.getOperation(), transferWrite))
return WalkResult::advance();
for (auto &use : transferRead.getSource().getUses()) {
if (!loop->isAncestor(use.getOwner()))
continue;
if (use.getOwner() == transferRead.getOperation() ||
use.getOwner() == transferWrite.getOperation())
continue;
if (auto transferWriteUse =
dyn_cast<vector::TransferWriteOp>(use.getOwner())) {
if (!vector::isDisjointTransferSet(
cast<VectorTransferOpInterface>(transferWrite.getOperation()),
cast<VectorTransferOpInterface>(
transferWriteUse.getOperation())))
return WalkResult::advance();
} else if (auto transferReadUse =
dyn_cast<vector::TransferReadOp>(use.getOwner())) {
if (!vector::isDisjointTransferSet(
cast<VectorTransferOpInterface>(transferWrite.getOperation()),
cast<VectorTransferOpInterface>(
transferReadUse.getOperation())))
return WalkResult::advance();
} else {
// Unknown use, we cannot prove that it doesn't alias with the
// transferRead/transferWrite operations.
return WalkResult::advance();
}
}
// Hoist read before.
loop.moveOutOfLoop(transferRead);
// Hoist write after.
transferWrite->moveAfter(loop);
// Rewrite `loop` with new yields by cloning and erase the original loop.
OpBuilder b(transferRead);
NewYieldValueFn yieldFn = [&](OpBuilder &b, Location loc,
ArrayRef<BlockArgument> newBBArgs) {
return SmallVector<Value>{transferWrite.getVector()};
};
// Transfer write has been hoisted, need to update the written vector by
// the value yielded by the newForOp.
return TypeSwitch<Operation *, WalkResult>(loop)
.Case<scf::ForOp>([&](scf::ForOp scfForOp) {
auto newForOp = replaceLoopWithNewYields(
b, scfForOp, transferRead.getVector(), yieldFn);
transferWrite.getVectorMutable().assign(
newForOp.getResults().back());
changed = true;
loop.erase();
// Need to interrupt and restart because erasing the loop messes up
// the walk.
return WalkResult::interrupt();
})
.Case<affine::AffineForOp>([&](affine::AffineForOp affineForOp) {
auto newForOp = replaceForOpWithNewYields(
b, affineForOp, transferRead.getVector(),
SmallVector<Value>{transferWrite.getVector()},
transferWrite.getVector());
// Replace all uses of the `transferRead` with the corresponding
// basic block argument.
transferRead.getVector().replaceUsesWithIf(
newForOp.getLoopBody().getArguments().back(),
[&](OpOperand &use) {
Operation *user = use.getOwner();
return newForOp->isProperAncestor(user);
});
transferWrite.getVectorMutable().assign(
newForOp.getResults().back());
changed = true;
loop.erase();
// Need to interrupt and restart because erasing the loop messes up
// the walk.
return WalkResult::interrupt();
})
.Default([](Operation *) { return WalkResult::interrupt(); });
});
}
}