Files
clang-p2996/mlir/lib/Transforms/ViewOpGraph.cpp
Mehdi Amini 308571074c Mass update the MLIR license header to mention "Part of the LLVM project"
This is an artifact from merging MLIR into LLVM, the file headers are
now aligned with the rest of the project.
2020-01-26 03:58:30 +00:00

167 lines
5.5 KiB
C++

//===- ViewOpGraph.cpp - View/write op graphviz graphs --------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/ViewOpGraph.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/STLExtras.h"
#include "llvm/Support/CommandLine.h"
static llvm::cl::opt<int> elideIfLarger(
"print-op-graph-elide-if-larger",
llvm::cl::desc("Upper limit to emit elements attribute rather than elide"),
llvm::cl::init(16));
using namespace mlir;
namespace llvm {
// Specialize GraphTraits to treat Block as a graph of Operations as nodes and
// uses as edges.
template <> struct GraphTraits<Block *> {
using GraphType = Block *;
using NodeRef = Operation *;
using ChildIteratorType = Operation::user_iterator;
static ChildIteratorType child_begin(NodeRef n) { return n->user_begin(); }
static ChildIteratorType child_end(NodeRef n) { return n->user_end(); }
// Operation's destructor is private so use Operation* instead and use
// mapped iterator.
static Operation *AddressOf(Operation &op) { return &op; }
using nodes_iterator = mapped_iterator<Block::iterator, decltype(&AddressOf)>;
static nodes_iterator nodes_begin(Block *b) {
return nodes_iterator(b->begin(), &AddressOf);
}
static nodes_iterator nodes_end(Block *b) {
return nodes_iterator(b->end(), &AddressOf);
}
};
// Specialize DOTGraphTraits to produce more readable output.
template <> struct DOTGraphTraits<Block *> : public DefaultDOTGraphTraits {
using DefaultDOTGraphTraits::DefaultDOTGraphTraits;
static std::string getNodeLabel(Operation *op, Block *);
};
std::string DOTGraphTraits<Block *>::getNodeLabel(Operation *op, Block *b) {
// Reuse the print output for the node labels.
std::string ostr;
raw_string_ostream os(ostr);
os << op->getName() << "\n";
if (!op->getLoc().isa<UnknownLoc>()) {
os << op->getLoc() << "\n";
}
// Print resultant types
interleaveComma(op->getResultTypes(), os);
os << "\n";
for (auto attr : op->getAttrs()) {
os << '\n' << attr.first << ": ";
// Always emit splat attributes.
if (attr.second.isa<SplatElementsAttr>()) {
attr.second.print(os);
continue;
}
// Elide "big" elements attributes.
auto elements = attr.second.dyn_cast<ElementsAttr>();
if (elements && elements.getNumElements() > elideIfLarger) {
os << std::string(elements.getType().getRank(), '[') << "..."
<< std::string(elements.getType().getRank(), ']') << " : "
<< elements.getType();
continue;
}
auto array = attr.second.dyn_cast<ArrayAttr>();
if (array && static_cast<int64_t>(array.size()) > elideIfLarger) {
os << "[...]";
continue;
}
// Print all other attributes.
attr.second.print(os);
}
return os.str();
}
} // end namespace llvm
namespace {
// PrintOpPass is simple pass to write graph per function.
// Note: this is a module pass only to avoid interleaving on the same ostream
// due to multi-threading over functions.
struct PrintOpPass : public ModulePass<PrintOpPass> {
explicit PrintOpPass(raw_ostream &os = llvm::errs(), bool short_names = false,
const Twine &title = "")
: os(os), title(title.str()), short_names(short_names) {}
std::string getOpName(Operation &op) {
auto symbolAttr =
op.getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
if (symbolAttr)
return symbolAttr.getValue();
++unnamedOpCtr;
return (op.getName().getStringRef() + llvm::utostr(unnamedOpCtr)).str();
}
// Print all the ops in a module.
void processModule(ModuleOp module) {
for (Operation &op : module) {
// Modules may actually be nested, recurse on nesting.
if (auto nestedModule = dyn_cast<ModuleOp>(op)) {
processModule(nestedModule);
continue;
}
auto opName = getOpName(op);
for (Region &region : op.getRegions()) {
for (auto indexed_block : llvm::enumerate(region)) {
// Suffix block number if there are more than 1 block.
auto blockName = region.getBlocks().size() == 1
? ""
: ("__" + llvm::utostr(indexed_block.index()));
llvm::WriteGraph(os, &indexed_block.value(), short_names,
Twine(title) + opName + blockName);
}
}
}
}
void runOnModule() override { processModule(getModule()); }
private:
raw_ostream &os;
std::string title;
int unnamedOpCtr = 0;
bool short_names;
};
} // namespace
void mlir::viewGraph(Block &block, const Twine &name, bool shortNames,
const Twine &title, llvm::GraphProgram::Name program) {
llvm::ViewGraph(&block, name, shortNames, title, program);
}
raw_ostream &mlir::writeGraph(raw_ostream &os, Block &block, bool shortNames,
const Twine &title) {
return llvm::WriteGraph(os, &block, shortNames, title);
}
std::unique_ptr<OpPassBase<ModuleOp>>
mlir::createPrintOpGraphPass(raw_ostream &os, bool shortNames,
const Twine &title) {
return std::make_unique<PrintOpPass>(os, shortNames, title);
}
static PassRegistration<PrintOpPass> pass("print-op-graph",
"Print op graph per region");