Files
clang-p2996/mlir/lib/Analysis/SliceAnalysis.cpp
thomasraoux 5aa6038a40 [mlir] Make topologicalSort iterative and consider op regions
When doing topological sort we need to make sure an op is scheduled before any
of the ops within its regions.
Also change the algorithm to not be recursive in order to prevent potential
stack overflow.

Differential Revision: https://reviews.llvm.org/D113423
2021-11-10 10:05:01 -08:00

215 lines
7.5 KiB
C++

//===- UseDefAnalysis.cpp - Analysis for Transitive UseDef chains ---------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements Analysis functions specific to slicing in Function.
//
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Operation.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/SetVector.h"
///
/// Implements Analysis functions specific to slicing in Function.
///
using namespace mlir;
static void getForwardSliceImpl(Operation *op,
SetVector<Operation *> *forwardSlice,
TransitiveFilter filter) {
if (!op)
return;
// Evaluate whether we should keep this use.
// This is useful in particular to implement scoping; i.e. return the
// transitive forwardSlice in the current scope.
if (filter && !filter(op))
return;
for (Region &region : op->getRegions())
for (Block &block : region)
for (Operation &blockOp : block)
if (forwardSlice->count(&blockOp) == 0)
getForwardSliceImpl(&blockOp, forwardSlice, filter);
for (Value result : op->getResults()) {
for (Operation *userOp : result.getUsers())
if (forwardSlice->count(userOp) == 0)
getForwardSliceImpl(userOp, forwardSlice, filter);
}
forwardSlice->insert(op);
}
void mlir::getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice,
TransitiveFilter filter) {
getForwardSliceImpl(op, forwardSlice, filter);
// Don't insert the top level operation, we just queried on it and don't
// want it in the results.
forwardSlice->remove(op);
// Reverse to get back the actual topological order.
// std::reverse does not work out of the box on SetVector and I want an
// in-place swap based thing (the real std::reverse, not the LLVM adapter).
std::vector<Operation *> v(forwardSlice->takeVector());
forwardSlice->insert(v.rbegin(), v.rend());
}
void mlir::getForwardSlice(Value root, SetVector<Operation *> *forwardSlice,
TransitiveFilter filter) {
for (Operation *user : root.getUsers())
getForwardSliceImpl(user, forwardSlice, filter);
// Reverse to get back the actual topological order.
// std::reverse does not work out of the box on SetVector and I want an
// in-place swap based thing (the real std::reverse, not the LLVM adapter).
std::vector<Operation *> v(forwardSlice->takeVector());
forwardSlice->insert(v.rbegin(), v.rend());
}
static void getBackwardSliceImpl(Operation *op,
SetVector<Operation *> *backwardSlice,
TransitiveFilter filter) {
if (!op || op->hasTrait<OpTrait::IsIsolatedFromAbove>())
return;
// Evaluate whether we should keep this def.
// This is useful in particular to implement scoping; i.e. return the
// transitive backwardSlice in the current scope.
if (filter && !filter(op))
return;
for (auto en : llvm::enumerate(op->getOperands())) {
auto operand = en.value();
if (auto *definingOp = operand.getDefiningOp()) {
if (backwardSlice->count(definingOp) == 0)
getBackwardSliceImpl(definingOp, backwardSlice, filter);
} else if (auto blockArg = operand.dyn_cast<BlockArgument>()) {
Block *block = blockArg.getOwner();
Operation *parentOp = block->getParentOp();
// TODO: determine whether we want to recurse backward into the other
// blocks of parentOp, which are not technically backward unless they flow
// into us. For now, just bail.
assert(parentOp->getNumRegions() == 1 &&
parentOp->getRegion(0).getBlocks().size() == 1);
if (backwardSlice->count(parentOp) == 0)
getBackwardSliceImpl(parentOp, backwardSlice, filter);
} else {
llvm_unreachable("No definingOp and not a block argument.");
}
}
backwardSlice->insert(op);
}
void mlir::getBackwardSlice(Operation *op,
SetVector<Operation *> *backwardSlice,
TransitiveFilter filter) {
getBackwardSliceImpl(op, backwardSlice, filter);
// Don't insert the top level operation, we just queried on it and don't
// want it in the results.
backwardSlice->remove(op);
}
void mlir::getBackwardSlice(Value root, SetVector<Operation *> *backwardSlice,
TransitiveFilter filter) {
if (Operation *definingOp = root.getDefiningOp()) {
getBackwardSlice(definingOp, backwardSlice, filter);
return;
}
Operation *bbAargOwner = root.cast<BlockArgument>().getOwner()->getParentOp();
getBackwardSlice(bbAargOwner, backwardSlice, filter);
}
SetVector<Operation *> mlir::getSlice(Operation *op,
TransitiveFilter backwardFilter,
TransitiveFilter forwardFilter) {
SetVector<Operation *> slice;
slice.insert(op);
unsigned currentIndex = 0;
SetVector<Operation *> backwardSlice;
SetVector<Operation *> forwardSlice;
while (currentIndex != slice.size()) {
auto *currentOp = (slice)[currentIndex];
// Compute and insert the backwardSlice starting from currentOp.
backwardSlice.clear();
getBackwardSlice(currentOp, &backwardSlice, backwardFilter);
slice.insert(backwardSlice.begin(), backwardSlice.end());
// Compute and insert the forwardSlice starting from currentOp.
forwardSlice.clear();
getForwardSlice(currentOp, &forwardSlice, forwardFilter);
slice.insert(forwardSlice.begin(), forwardSlice.end());
++currentIndex;
}
return topologicalSort(slice);
}
namespace {
/// DFS post-order implementation that maintains a global count to work across
/// multiple invocations, to help implement topological sort on multi-root DAGs.
/// We traverse all operations but only record the ones that appear in
/// `toSort` for the final result.
struct DFSState {
DFSState(const SetVector<Operation *> &set)
: toSort(set), topologicalCounts(), seen() {}
const SetVector<Operation *> &toSort;
SmallVector<Operation *, 16> topologicalCounts;
DenseSet<Operation *> seen;
};
} // namespace
static void DFSPostorder(Operation *root, DFSState *state) {
SmallVector<Operation *> queue(1, root);
std::vector<Operation *> ops;
while (!queue.empty()) {
Operation *current = queue.pop_back_val();
ops.push_back(current);
for (Value result : current->getResults()) {
for (Operation *op : result.getUsers())
queue.push_back(op);
}
for (Region &region : current->getRegions()) {
for (Operation &op : region.getOps())
queue.push_back(&op);
}
}
for (Operation *op : llvm::reverse(ops)) {
if (state->seen.insert(op).second && state->toSort.count(op) > 0)
state->topologicalCounts.push_back(op);
}
}
SetVector<Operation *>
mlir::topologicalSort(const SetVector<Operation *> &toSort) {
if (toSort.empty()) {
return toSort;
}
// Run from each root with global count and `seen` set.
DFSState state(toSort);
for (auto *s : toSort) {
assert(toSort.count(s) == 1 && "NYI: multi-sets not supported");
DFSPostorder(s, &state);
}
// Reorder and return.
SetVector<Operation *> res;
for (auto it = state.topologicalCounts.rbegin(),
eit = state.topologicalCounts.rend();
it != eit; ++it) {
res.insert(*it);
}
return res;
}