289 lines
11 KiB
C++
289 lines
11 KiB
C++
//===- ParallelLoopFusion.cpp - Code to perform loop fusion ---------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements loop fusion on parallel loops.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/SCF/Transforms/Passes.h"
|
|
|
|
#include "mlir/Analysis/AliasAnalysis.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/IRMapping.h"
|
|
#include "mlir/IR/OpDefinition.h"
|
|
#include "mlir/IR/OperationSupport.h"
|
|
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
|
|
|
namespace mlir {
|
|
#define GEN_PASS_DEF_SCFPARALLELLOOPFUSION
|
|
#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
|
|
} // namespace mlir
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::scf;
|
|
|
|
/// Verify there are no nested ParallelOps.
|
|
static bool hasNestedParallelOp(ParallelOp ploop) {
|
|
auto walkResult =
|
|
ploop.getBody()->walk([](ParallelOp) { return WalkResult::interrupt(); });
|
|
return walkResult.wasInterrupted();
|
|
}
|
|
|
|
/// Verify equal iteration spaces.
|
|
static bool equalIterationSpaces(ParallelOp firstPloop,
|
|
ParallelOp secondPloop) {
|
|
if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
|
|
return false;
|
|
|
|
auto matchOperands = [&](const OperandRange &lhs,
|
|
const OperandRange &rhs) -> bool {
|
|
// TODO: Extend this to support aliases and equal constants.
|
|
return std::equal(lhs.begin(), lhs.end(), rhs.begin());
|
|
};
|
|
return matchOperands(firstPloop.getLowerBound(),
|
|
secondPloop.getLowerBound()) &&
|
|
matchOperands(firstPloop.getUpperBound(),
|
|
secondPloop.getUpperBound()) &&
|
|
matchOperands(firstPloop.getStep(), secondPloop.getStep());
|
|
}
|
|
|
|
/// Checks if the parallel loops have mixed access to the same buffers. Returns
|
|
/// `true` if the first parallel loop writes to the same indices that the second
|
|
/// loop reads.
|
|
static bool haveNoReadsAfterWriteExceptSameIndex(
|
|
ParallelOp firstPloop, ParallelOp secondPloop,
|
|
const IRMapping &firstToSecondPloopIndices,
|
|
llvm::function_ref<bool(Value, Value)> mayAlias) {
|
|
DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores;
|
|
SmallVector<Value> bufferStoresVec;
|
|
firstPloop.getBody()->walk([&](memref::StoreOp store) {
|
|
bufferStores[store.getMemRef()].push_back(store.getIndices());
|
|
bufferStoresVec.emplace_back(store.getMemRef());
|
|
});
|
|
auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) {
|
|
Value loadMem = load.getMemRef();
|
|
// Stop if the memref is defined in secondPloop body. Careful alias analysis
|
|
// is needed.
|
|
auto *memrefDef = loadMem.getDefiningOp();
|
|
if (memrefDef && memrefDef->getBlock() == load->getBlock())
|
|
return WalkResult::interrupt();
|
|
|
|
for (Value store : bufferStoresVec)
|
|
if (store != loadMem && mayAlias(store, loadMem))
|
|
return WalkResult::interrupt();
|
|
|
|
auto write = bufferStores.find(loadMem);
|
|
if (write == bufferStores.end())
|
|
return WalkResult::advance();
|
|
|
|
// Check that at last one store was retrieved
|
|
if (write->second.empty())
|
|
return WalkResult::interrupt();
|
|
|
|
auto storeIndices = write->second.front();
|
|
|
|
// Multiple writes to the same memref are allowed only on the same indices
|
|
for (const auto &othStoreIndices : write->second) {
|
|
if (othStoreIndices != storeIndices)
|
|
return WalkResult::interrupt();
|
|
}
|
|
|
|
// Check that the load indices of secondPloop coincide with store indices of
|
|
// firstPloop for the same memrefs.
|
|
auto loadIndices = load.getIndices();
|
|
if (storeIndices.size() != loadIndices.size())
|
|
return WalkResult::interrupt();
|
|
for (int i = 0, e = storeIndices.size(); i < e; ++i) {
|
|
if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) !=
|
|
loadIndices[i]) {
|
|
auto *storeIndexDefOp = storeIndices[i].getDefiningOp();
|
|
auto *loadIndexDefOp = loadIndices[i].getDefiningOp();
|
|
if (storeIndexDefOp && loadIndexDefOp) {
|
|
if (!isMemoryEffectFree(storeIndexDefOp))
|
|
return WalkResult::interrupt();
|
|
if (!isMemoryEffectFree(loadIndexDefOp))
|
|
return WalkResult::interrupt();
|
|
if (!OperationEquivalence::isEquivalentTo(
|
|
storeIndexDefOp, loadIndexDefOp,
|
|
[&](Value storeIndex, Value loadIndex) {
|
|
if (firstToSecondPloopIndices.lookupOrDefault(storeIndex) !=
|
|
firstToSecondPloopIndices.lookupOrDefault(loadIndex))
|
|
return failure();
|
|
else
|
|
return success();
|
|
},
|
|
/*markEquivalent=*/nullptr,
|
|
OperationEquivalence::Flags::IgnoreLocations)) {
|
|
return WalkResult::interrupt();
|
|
}
|
|
} else {
|
|
return WalkResult::interrupt();
|
|
}
|
|
}
|
|
}
|
|
return WalkResult::advance();
|
|
});
|
|
return !walkResult.wasInterrupted();
|
|
}
|
|
|
|
/// Analyzes dependencies in the most primitive way by checking simple read and
|
|
/// write patterns.
|
|
static LogicalResult
|
|
verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
|
|
const IRMapping &firstToSecondPloopIndices,
|
|
llvm::function_ref<bool(Value, Value)> mayAlias) {
|
|
if (!haveNoReadsAfterWriteExceptSameIndex(
|
|
firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias))
|
|
return failure();
|
|
|
|
IRMapping secondToFirstPloopIndices;
|
|
secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(),
|
|
firstPloop.getBody()->getArguments());
|
|
return success(haveNoReadsAfterWriteExceptSameIndex(
|
|
secondPloop, firstPloop, secondToFirstPloopIndices, mayAlias));
|
|
}
|
|
|
|
static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
|
|
const IRMapping &firstToSecondPloopIndices,
|
|
llvm::function_ref<bool(Value, Value)> mayAlias) {
|
|
return !hasNestedParallelOp(firstPloop) &&
|
|
!hasNestedParallelOp(secondPloop) &&
|
|
equalIterationSpaces(firstPloop, secondPloop) &&
|
|
succeeded(verifyDependencies(firstPloop, secondPloop,
|
|
firstToSecondPloopIndices, mayAlias));
|
|
}
|
|
|
|
/// Prepends operations of firstPloop's body into secondPloop's body.
|
|
/// Updates secondPloop with new loop.
|
|
static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
|
|
OpBuilder builder,
|
|
llvm::function_ref<bool(Value, Value)> mayAlias) {
|
|
Block *block1 = firstPloop.getBody();
|
|
Block *block2 = secondPloop.getBody();
|
|
IRMapping firstToSecondPloopIndices;
|
|
firstToSecondPloopIndices.map(block1->getArguments(), block2->getArguments());
|
|
|
|
if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices,
|
|
mayAlias))
|
|
return;
|
|
|
|
DominanceInfo dom;
|
|
// We are fusing first loop into second, make sure there are no users of the
|
|
// first loop results between loops.
|
|
for (Operation *user : firstPloop->getUsers())
|
|
if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false))
|
|
return;
|
|
|
|
ValueRange inits1 = firstPloop.getInitVals();
|
|
ValueRange inits2 = secondPloop.getInitVals();
|
|
|
|
SmallVector<Value> newInitVars(inits1.begin(), inits1.end());
|
|
newInitVars.append(inits2.begin(), inits2.end());
|
|
|
|
IRRewriter b(builder);
|
|
b.setInsertionPoint(secondPloop);
|
|
auto newSecondPloop = b.create<ParallelOp>(
|
|
secondPloop.getLoc(), secondPloop.getLowerBound(),
|
|
secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars);
|
|
|
|
Block *newBlock = newSecondPloop.getBody();
|
|
auto term1 = cast<ReduceOp>(block1->getTerminator());
|
|
auto term2 = cast<ReduceOp>(block2->getTerminator());
|
|
|
|
b.inlineBlockBefore(block2, newBlock, newBlock->begin(),
|
|
newBlock->getArguments());
|
|
b.inlineBlockBefore(block1, newBlock, newBlock->begin(),
|
|
newBlock->getArguments());
|
|
|
|
ValueRange results = newSecondPloop.getResults();
|
|
if (!results.empty()) {
|
|
b.setInsertionPointToEnd(newBlock);
|
|
|
|
ValueRange reduceArgs1 = term1.getOperands();
|
|
ValueRange reduceArgs2 = term2.getOperands();
|
|
SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end());
|
|
newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end());
|
|
|
|
auto newReduceOp = b.create<scf::ReduceOp>(term2.getLoc(), newReduceArgs);
|
|
|
|
for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>(
|
|
term1.getReductions(), term2.getReductions()))) {
|
|
Block &oldRedBlock = reg.front();
|
|
Block &newRedBlock = newReduceOp.getReductions()[i].front();
|
|
b.inlineBlockBefore(&oldRedBlock, &newRedBlock, newRedBlock.begin(),
|
|
newRedBlock.getArguments());
|
|
}
|
|
|
|
firstPloop.replaceAllUsesWith(results.take_front(inits1.size()));
|
|
secondPloop.replaceAllUsesWith(results.take_back(inits2.size()));
|
|
}
|
|
term1->erase();
|
|
term2->erase();
|
|
firstPloop.erase();
|
|
secondPloop.erase();
|
|
secondPloop = newSecondPloop;
|
|
}
|
|
|
|
void mlir::scf::naivelyFuseParallelOps(
|
|
Region ®ion, llvm::function_ref<bool(Value, Value)> mayAlias) {
|
|
OpBuilder b(region);
|
|
// Consider every single block and attempt to fuse adjacent loops.
|
|
SmallVector<SmallVector<ParallelOp>, 1> ploopChains;
|
|
for (auto &block : region) {
|
|
ploopChains.clear();
|
|
ploopChains.push_back({});
|
|
|
|
// Not using `walk()` to traverse only top-level parallel loops and also
|
|
// make sure that there are no side-effecting ops between the parallel
|
|
// loops.
|
|
bool noSideEffects = true;
|
|
for (auto &op : block) {
|
|
if (auto ploop = dyn_cast<ParallelOp>(op)) {
|
|
if (noSideEffects) {
|
|
ploopChains.back().push_back(ploop);
|
|
} else {
|
|
ploopChains.push_back({ploop});
|
|
noSideEffects = true;
|
|
}
|
|
continue;
|
|
}
|
|
// TODO: Handle region side effects properly.
|
|
noSideEffects &= isMemoryEffectFree(&op) && op.getNumRegions() == 0;
|
|
}
|
|
for (MutableArrayRef<ParallelOp> ploops : ploopChains) {
|
|
for (int i = 0, e = ploops.size(); i + 1 < e; ++i)
|
|
fuseIfLegal(ploops[i], ploops[i + 1], b, mayAlias);
|
|
}
|
|
}
|
|
}
|
|
|
|
namespace {
|
|
struct ParallelLoopFusion
|
|
: public impl::SCFParallelLoopFusionBase<ParallelLoopFusion> {
|
|
void runOnOperation() override {
|
|
auto &AA = getAnalysis<AliasAnalysis>();
|
|
|
|
auto mayAlias = [&](Value val1, Value val2) -> bool {
|
|
return !AA.alias(val1, val2).isNo();
|
|
};
|
|
|
|
getOperation()->walk([&](Operation *child) {
|
|
for (Region ®ion : child->getRegions())
|
|
naivelyFuseParallelOps(region, mayAlias);
|
|
});
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
std::unique_ptr<Pass> mlir::createParallelLoopFusionPass() {
|
|
return std::make_unique<ParallelLoopFusion>();
|
|
}
|