The current translation uses the old "ugly"/"raw" form which used PDLValue for the arguments and results. This commit updates the C++ generation to use the recently added sugar that allows for directly using the desired types for the arguments and result of PDL functions. In addition, this commit also properly imports the C++ class for ODS operations, constraints, and interfaces. This allows for a much more convienent C++ API than previously granted with the raw/low-level types. Differential Revision: https://reviews.llvm.org/D124817
178 lines
6.3 KiB
C++
178 lines
6.3 KiB
C++
//===- Context.cpp --------------------------------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Tools/PDLL/ODS/Context.h"
|
|
#include "mlir/Tools/PDLL/ODS/Constraint.h"
|
|
#include "mlir/Tools/PDLL/ODS/Dialect.h"
|
|
#include "mlir/Tools/PDLL/ODS/Operation.h"
|
|
#include "llvm/Support/ScopedPrinter.h"
|
|
#include "llvm/Support/raw_ostream.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::pdll::ods;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Context
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
Context::Context() = default;
|
|
Context::~Context() = default;
|
|
|
|
const AttributeConstraint &
|
|
Context::insertAttributeConstraint(StringRef name, StringRef summary,
|
|
StringRef cppClass) {
|
|
std::unique_ptr<AttributeConstraint> &constraint = attributeConstraints[name];
|
|
if (!constraint) {
|
|
constraint.reset(new AttributeConstraint(name, summary, cppClass));
|
|
} else {
|
|
assert(constraint->getCppClass() == cppClass &&
|
|
constraint->getSummary() == summary &&
|
|
"constraint with the same name was already registered with a "
|
|
"different class");
|
|
}
|
|
return *constraint;
|
|
}
|
|
|
|
const TypeConstraint &Context::insertTypeConstraint(StringRef name,
|
|
StringRef summary,
|
|
StringRef cppClass) {
|
|
std::unique_ptr<TypeConstraint> &constraint = typeConstraints[name];
|
|
if (!constraint)
|
|
constraint.reset(new TypeConstraint(name, summary, cppClass));
|
|
return *constraint;
|
|
}
|
|
|
|
Dialect &Context::insertDialect(StringRef name) {
|
|
std::unique_ptr<Dialect> &dialect = dialects[name];
|
|
if (!dialect)
|
|
dialect.reset(new Dialect(name));
|
|
return *dialect;
|
|
}
|
|
|
|
const Dialect *Context::lookupDialect(StringRef name) const {
|
|
auto it = dialects.find(name);
|
|
return it == dialects.end() ? nullptr : &*it->second;
|
|
}
|
|
|
|
std::pair<Operation *, bool>
|
|
Context::insertOperation(StringRef name, StringRef summary, StringRef desc,
|
|
StringRef nativeClassName,
|
|
bool supportsResultTypeInferrence, SMLoc loc) {
|
|
std::pair<StringRef, StringRef> dialectAndName = name.split('.');
|
|
return insertDialect(dialectAndName.first)
|
|
.insertOperation(name, summary, desc, nativeClassName,
|
|
supportsResultTypeInferrence, loc);
|
|
}
|
|
|
|
const Operation *Context::lookupOperation(StringRef name) const {
|
|
std::pair<StringRef, StringRef> dialectAndName = name.split('.');
|
|
if (const Dialect *dialect = lookupDialect(dialectAndName.first))
|
|
return dialect->lookupOperation(name);
|
|
return nullptr;
|
|
}
|
|
|
|
template <typename T>
|
|
SmallVector<T *> sortMapByName(const llvm::StringMap<std::unique_ptr<T>> &map) {
|
|
SmallVector<T *> storage;
|
|
for (auto &entry : map)
|
|
storage.push_back(entry.second.get());
|
|
llvm::sort(storage, [](const auto &lhs, const auto &rhs) {
|
|
return lhs->getName() < rhs->getName();
|
|
});
|
|
return storage;
|
|
}
|
|
|
|
void Context::print(raw_ostream &os) const {
|
|
auto printVariableLengthCst = [&](StringRef cst, VariableLengthKind kind) {
|
|
switch (kind) {
|
|
case VariableLengthKind::Optional:
|
|
os << "Optional<" << cst << ">";
|
|
break;
|
|
case VariableLengthKind::Single:
|
|
os << cst;
|
|
break;
|
|
case VariableLengthKind::Variadic:
|
|
os << "Variadic<" << cst << ">";
|
|
break;
|
|
}
|
|
};
|
|
|
|
llvm::ScopedPrinter printer(os);
|
|
llvm::DictScope odsScope(printer, "ODSContext");
|
|
for (const Dialect *dialect : sortMapByName(dialects)) {
|
|
printer.startLine() << "Dialect `" << dialect->getName() << "` {\n";
|
|
printer.indent();
|
|
|
|
for (const Operation *op : sortMapByName(dialect->getOperations())) {
|
|
printer.startLine() << "Operation `" << op->getName() << "` {\n";
|
|
printer.indent();
|
|
|
|
// Attributes.
|
|
ArrayRef<Attribute> attributes = op->getAttributes();
|
|
if (!attributes.empty()) {
|
|
printer.startLine() << "Attributes { ";
|
|
llvm::interleaveComma(attributes, os, [&](const Attribute &attr) {
|
|
os << attr.getName() << " : ";
|
|
|
|
auto kind = attr.isOptional() ? VariableLengthKind::Optional
|
|
: VariableLengthKind::Single;
|
|
printVariableLengthCst(attr.getConstraint().getDemangledName(), kind);
|
|
});
|
|
os << " }\n";
|
|
}
|
|
|
|
// Operands.
|
|
ArrayRef<OperandOrResult> operands = op->getOperands();
|
|
if (!operands.empty()) {
|
|
printer.startLine() << "Operands { ";
|
|
llvm::interleaveComma(
|
|
operands, os, [&](const OperandOrResult &operand) {
|
|
os << operand.getName() << " : ";
|
|
printVariableLengthCst(operand.getConstraint().getDemangledName(),
|
|
operand.getVariableLengthKind());
|
|
});
|
|
os << " }\n";
|
|
}
|
|
|
|
// Results.
|
|
ArrayRef<OperandOrResult> results = op->getResults();
|
|
if (!results.empty()) {
|
|
printer.startLine() << "Results { ";
|
|
llvm::interleaveComma(results, os, [&](const OperandOrResult &result) {
|
|
os << result.getName() << " : ";
|
|
printVariableLengthCst(result.getConstraint().getDemangledName(),
|
|
result.getVariableLengthKind());
|
|
});
|
|
os << " }\n";
|
|
}
|
|
|
|
printer.objectEnd();
|
|
}
|
|
printer.objectEnd();
|
|
}
|
|
for (const AttributeConstraint *cst : sortMapByName(attributeConstraints)) {
|
|
printer.startLine() << "AttributeConstraint `" << cst->getDemangledName()
|
|
<< "` {\n";
|
|
printer.indent();
|
|
|
|
printer.startLine() << "Summary: " << cst->getSummary() << "\n";
|
|
printer.startLine() << "CppClass: " << cst->getCppClass() << "\n";
|
|
printer.objectEnd();
|
|
}
|
|
for (const TypeConstraint *cst : sortMapByName(typeConstraints)) {
|
|
printer.startLine() << "TypeConstraint `" << cst->getDemangledName()
|
|
<< "` {\n";
|
|
printer.indent();
|
|
|
|
printer.startLine() << "Summary: " << cst->getSummary() << "\n";
|
|
printer.startLine() << "CppClass: " << cst->getCppClass() << "\n";
|
|
printer.objectEnd();
|
|
}
|
|
printer.objectEnd();
|
|
}
|