The callgraph currently contains a special external node that is used both as the quasi caller for any externally callable as well as callees that could not be resolved. This has one negative side effect however, which is the motivation for this patch: It leads to every externally callable which contains a call that could not be resolved (eg. an indirect call), to be put into one giant SCC when iterating over the SCCs of the call graph. This patch fixes that issue by creating a second special callgraph node that acts as the callee for any unresolved callable. This breaks the cycles produced in the callgraph, yielding proper SCCs for all direct calls. Differential Revision: https://reviews.llvm.org/D133585
235 lines
8.1 KiB
C++
235 lines
8.1 KiB
C++
//===- CallGraph.cpp - CallGraph analysis for MLIR ------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file contains interfaces and analyses for defining a nested callgraph.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Analysis/CallGraph.h"
|
|
#include "mlir/IR/Operation.h"
|
|
#include "mlir/IR/SymbolTable.h"
|
|
#include "mlir/Interfaces/CallInterfaces.h"
|
|
#include "llvm/ADT/PointerUnion.h"
|
|
#include "llvm/ADT/SCCIterator.h"
|
|
#include "llvm/Support/raw_ostream.h"
|
|
|
|
using namespace mlir;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CallGraphNode
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Returns true if this node refers to the indirect/external node.
|
|
bool CallGraphNode::isExternal() const { return !callableRegion; }
|
|
|
|
/// Return the callable region this node represents. This can only be called
|
|
/// on non-external nodes.
|
|
Region *CallGraphNode::getCallableRegion() const {
|
|
assert(!isExternal() && "the external node has no callable region");
|
|
return callableRegion;
|
|
}
|
|
|
|
/// Adds an reference edge to the given node. This is only valid on the
|
|
/// external node.
|
|
void CallGraphNode::addAbstractEdge(CallGraphNode *node) {
|
|
assert(isExternal() && "abstract edges are only valid on external nodes");
|
|
addEdge(node, Edge::Kind::Abstract);
|
|
}
|
|
|
|
/// Add an outgoing call edge from this node.
|
|
void CallGraphNode::addCallEdge(CallGraphNode *node) {
|
|
addEdge(node, Edge::Kind::Call);
|
|
}
|
|
|
|
/// Adds a reference edge to the given child node.
|
|
void CallGraphNode::addChildEdge(CallGraphNode *child) {
|
|
addEdge(child, Edge::Kind::Child);
|
|
}
|
|
|
|
/// Returns true if this node has any child edges.
|
|
bool CallGraphNode::hasChildren() const {
|
|
return llvm::any_of(edges, [](const Edge &edge) { return edge.isChild(); });
|
|
}
|
|
|
|
/// Add an edge to 'node' with the given kind.
|
|
void CallGraphNode::addEdge(CallGraphNode *node, Edge::Kind kind) {
|
|
edges.insert({node, kind});
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CallGraph
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Recursively compute the callgraph edges for the given operation. Computed
|
|
/// edges are placed into the given callgraph object.
|
|
static void computeCallGraph(Operation *op, CallGraph &cg,
|
|
SymbolTableCollection &symbolTable,
|
|
CallGraphNode *parentNode, bool resolveCalls) {
|
|
if (CallOpInterface call = dyn_cast<CallOpInterface>(op)) {
|
|
// If there is no parent node, we ignore this operation. Even if this
|
|
// operation was a call, there would be no callgraph node to attribute it
|
|
// to.
|
|
if (resolveCalls && parentNode)
|
|
parentNode->addCallEdge(cg.resolveCallable(call, symbolTable));
|
|
return;
|
|
}
|
|
|
|
// Compute the callgraph nodes and edges for each of the nested operations.
|
|
if (CallableOpInterface callable = dyn_cast<CallableOpInterface>(op)) {
|
|
if (auto *callableRegion = callable.getCallableRegion())
|
|
parentNode = cg.getOrAddNode(callableRegion, parentNode);
|
|
else
|
|
return;
|
|
}
|
|
|
|
for (Region ®ion : op->getRegions())
|
|
for (Operation &nested : region.getOps())
|
|
computeCallGraph(&nested, cg, symbolTable, parentNode, resolveCalls);
|
|
}
|
|
|
|
CallGraph::CallGraph(Operation *op)
|
|
: externalCallerNode(/*callableRegion=*/nullptr),
|
|
unknownCalleeNode(/*callableRegion=*/nullptr) {
|
|
// Make two passes over the graph, one to compute the callables and one to
|
|
// resolve the calls. We split these up as we may have nested callable objects
|
|
// that need to be reserved before the calls.
|
|
SymbolTableCollection symbolTable;
|
|
computeCallGraph(op, *this, symbolTable, /*parentNode=*/nullptr,
|
|
/*resolveCalls=*/false);
|
|
computeCallGraph(op, *this, symbolTable, /*parentNode=*/nullptr,
|
|
/*resolveCalls=*/true);
|
|
}
|
|
|
|
/// Get or add a call graph node for the given region.
|
|
CallGraphNode *CallGraph::getOrAddNode(Region *region,
|
|
CallGraphNode *parentNode) {
|
|
assert(region && isa<CallableOpInterface>(region->getParentOp()) &&
|
|
"expected parent operation to be callable");
|
|
std::unique_ptr<CallGraphNode> &node = nodes[region];
|
|
if (!node) {
|
|
node.reset(new CallGraphNode(region));
|
|
|
|
// Add this node to the given parent node if necessary.
|
|
if (parentNode) {
|
|
parentNode->addChildEdge(node.get());
|
|
} else {
|
|
// Otherwise, connect all callable nodes to the external node, this allows
|
|
// for conservatively including all callable nodes within the graph.
|
|
// FIXME This isn't correct, this is only necessary for callable nodes
|
|
// that *could* be called from external sources. This requires extending
|
|
// the interface for callables to check if they may be referenced
|
|
// externally.
|
|
externalCallerNode.addAbstractEdge(node.get());
|
|
}
|
|
}
|
|
return node.get();
|
|
}
|
|
|
|
/// Lookup a call graph node for the given region, or nullptr if none is
|
|
/// registered.
|
|
CallGraphNode *CallGraph::lookupNode(Region *region) const {
|
|
auto it = nodes.find(region);
|
|
return it == nodes.end() ? nullptr : it->second.get();
|
|
}
|
|
|
|
/// Resolve the callable for given callee to a node in the callgraph, or the
|
|
/// unknown callee node if a valid node was not resolved.
|
|
CallGraphNode *
|
|
CallGraph::resolveCallable(CallOpInterface call,
|
|
SymbolTableCollection &symbolTable) const {
|
|
Operation *callable = call.resolveCallable(&symbolTable);
|
|
if (auto callableOp = dyn_cast_or_null<CallableOpInterface>(callable))
|
|
if (auto *node = lookupNode(callableOp.getCallableRegion()))
|
|
return node;
|
|
|
|
return getUnknownCalleeNode();
|
|
}
|
|
|
|
/// Erase the given node from the callgraph.
|
|
void CallGraph::eraseNode(CallGraphNode *node) {
|
|
// Erase any children of this node first.
|
|
if (node->hasChildren()) {
|
|
for (const CallGraphNode::Edge &edge : llvm::make_early_inc_range(*node))
|
|
if (edge.isChild())
|
|
eraseNode(edge.getTarget());
|
|
}
|
|
// Erase any edges to this node from any other nodes.
|
|
for (auto &it : nodes) {
|
|
it.second->edges.remove_if([node](const CallGraphNode::Edge &edge) {
|
|
return edge.getTarget() == node;
|
|
});
|
|
}
|
|
nodes.erase(node->getCallableRegion());
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Printing
|
|
|
|
/// Dump the graph in a human readable format.
|
|
void CallGraph::dump() const { print(llvm::errs()); }
|
|
void CallGraph::print(raw_ostream &os) const {
|
|
os << "// ---- CallGraph ----\n";
|
|
|
|
// Functor used to output the name for the given node.
|
|
auto emitNodeName = [&](const CallGraphNode *node) {
|
|
if (node == getExternalCallerNode()) {
|
|
os << "<External-Caller-Node>";
|
|
return;
|
|
}
|
|
if (node == getUnknownCalleeNode()) {
|
|
os << "<Unknown-Callee-Node>";
|
|
return;
|
|
}
|
|
|
|
auto *callableRegion = node->getCallableRegion();
|
|
auto *parentOp = callableRegion->getParentOp();
|
|
os << "'" << callableRegion->getParentOp()->getName() << "' - Region #"
|
|
<< callableRegion->getRegionNumber();
|
|
auto attrs = parentOp->getAttrDictionary();
|
|
if (!attrs.empty())
|
|
os << " : " << attrs;
|
|
};
|
|
|
|
for (auto &nodeIt : nodes) {
|
|
const CallGraphNode *node = nodeIt.second.get();
|
|
|
|
// Dump the header for this node.
|
|
os << "// - Node : ";
|
|
emitNodeName(node);
|
|
os << "\n";
|
|
|
|
// Emit each of the edges.
|
|
for (auto &edge : *node) {
|
|
os << "// -- ";
|
|
if (edge.isCall())
|
|
os << "Call";
|
|
else if (edge.isChild())
|
|
os << "Child";
|
|
|
|
os << "-Edge : ";
|
|
emitNodeName(edge.getTarget());
|
|
os << "\n";
|
|
}
|
|
os << "//\n";
|
|
}
|
|
|
|
os << "// -- SCCs --\n";
|
|
|
|
for (auto &scc : make_range(llvm::scc_begin(this), llvm::scc_end(this))) {
|
|
os << "// - SCC : \n";
|
|
for (auto &node : scc) {
|
|
os << "// -- Node :";
|
|
emitNodeName(node);
|
|
os << "\n";
|
|
}
|
|
os << "\n";
|
|
}
|
|
|
|
os << "// -------------------\n";
|
|
}
|