This is part of the RFC for a better fold API: https://discourse.llvm.org/t/rfc-a-better-fold-api-using-more-generic-adaptors/67374 This patch implements the generation of generic adaptors through TableGen. These are essentially a generalization of Adaptors, as implemented previously, but instead of indexing into a `mlir::ValueRange`, they may index into any container, regardless of the element type. This allows the use of the convenient getter methods of Adaptors to be reused on ranges that are the result of some kind of mapping functions of an ops operands. In the case of the fold API in the RFC, this would be `ArrayRef<Attribute>`, which is a mapping of the operands to their possibly-constant values. Implementation wise, some special care was taken to not cause a compile time regression, nor to break any kind of source compatibility. For that purpose, the current adaptor class was split into three: * A generic adaptor base class, within the detail namespace as it is an implementation detail, which implements all APIs independent of the range type used for the operands. This is all the attribute and region related code. Since it is not templated, its implementation does not have to be inline and can be put into the cpp source file * The actual generic adaptor, which has a template parameter for the range that should be indexed into for retrieving operands. It implements all the getters for operands, as they are dependent on the range type. It publicly inherits from the generic adaptor base class * A class named as adaptors have been named so far, inheriting from the generic adaptor class with `mlir::ValueRange` as range to index into. It implements the rest of the API, specific to `mlir::ValueRange` adaptors, which have previously been part of the adaptor. This boils down to a constructor from the Op type as well as the verify function. The last class having the exact same API surface and name as Adaptors did previously leads to full source compatibility. Differential Revision: https://reviews.llvm.org/D140660
384 lines
13 KiB
C++
384 lines
13 KiB
C++
//===- Class.cpp - Helper classes for Op C++ code emission --------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/TableGen/Class.h"
|
|
#include "mlir/TableGen/Format.h"
|
|
#include "llvm/ADT/Sequence.h"
|
|
#include "llvm/ADT/Twine.h"
|
|
#include "llvm/Support/Debug.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::tblgen;
|
|
|
|
/// Returns space to be emitted after the given C++ `type`. return "" if the
|
|
/// ends with '&' or '*', or is empty, else returns " ".
|
|
static StringRef getSpaceAfterType(StringRef type) {
|
|
return (type.empty() || type.endswith("&") || type.endswith("*")) ? "" : " ";
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MethodParameter definitions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void MethodParameter::writeDeclTo(raw_indented_ostream &os) const {
|
|
if (optional)
|
|
os << "/*optional*/";
|
|
os << type << getSpaceAfterType(type) << name;
|
|
if (hasDefaultValue())
|
|
os << " = " << defaultValue;
|
|
}
|
|
|
|
void MethodParameter::writeDefTo(raw_indented_ostream &os) const {
|
|
if (optional)
|
|
os << "/*optional*/";
|
|
os << type << getSpaceAfterType(type) << name;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MethodParameters definitions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void MethodParameters::writeDeclTo(raw_indented_ostream &os) const {
|
|
llvm::interleaveComma(parameters, os,
|
|
[&os](auto ¶m) { param.writeDeclTo(os); });
|
|
}
|
|
void MethodParameters::writeDefTo(raw_indented_ostream &os) const {
|
|
llvm::interleaveComma(parameters, os,
|
|
[&os](auto ¶m) { param.writeDefTo(os); });
|
|
}
|
|
|
|
bool MethodParameters::subsumes(const MethodParameters &other) const {
|
|
// These parameters do not subsume the others if there are fewer parameters
|
|
// or their types do not match.
|
|
if (parameters.size() < other.parameters.size())
|
|
return false;
|
|
if (!std::equal(
|
|
other.parameters.begin(), other.parameters.end(), parameters.begin(),
|
|
[](auto &lhs, auto &rhs) { return lhs.getType() == rhs.getType(); }))
|
|
return false;
|
|
|
|
// If all the common parameters have the same type, we can elide the other
|
|
// method if this method has the same number of parameters as other or if the
|
|
// first paramater after the common parameters has a default value (and, as
|
|
// required by C++, subsequent parameters will have default values too).
|
|
return parameters.size() == other.parameters.size() ||
|
|
parameters[other.parameters.size()].hasDefaultValue();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MethodSignature definitions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
bool MethodSignature::makesRedundant(const MethodSignature &other) const {
|
|
return methodName == other.methodName &&
|
|
parameters.subsumes(other.parameters);
|
|
}
|
|
|
|
void MethodSignature::writeDeclTo(raw_indented_ostream &os) const {
|
|
os << returnType << getSpaceAfterType(returnType) << methodName << "(";
|
|
parameters.writeDeclTo(os);
|
|
os << ")";
|
|
}
|
|
|
|
void MethodSignature::writeDefTo(raw_indented_ostream &os,
|
|
StringRef namePrefix) const {
|
|
os << returnType << getSpaceAfterType(returnType) << namePrefix
|
|
<< (namePrefix.empty() ? "" : "::") << methodName << "(";
|
|
parameters.writeDefTo(os);
|
|
os << ")";
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MethodBody definitions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
MethodBody::MethodBody(bool declOnly)
|
|
: declOnly(declOnly), stringOs(body), os(stringOs) {}
|
|
|
|
void MethodBody::writeTo(raw_indented_ostream &os) const {
|
|
auto bodyRef = StringRef(body).drop_while([](char c) { return c == '\n'; });
|
|
os << bodyRef;
|
|
if (bodyRef.empty())
|
|
return;
|
|
if (bodyRef.back() != '\n')
|
|
os << "\n";
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Method definitions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void Method::writeDeclTo(raw_indented_ostream &os) const {
|
|
if (isStatic())
|
|
os << "static ";
|
|
if (properties & ConstexprValue)
|
|
os << "constexpr ";
|
|
methodSignature.writeDeclTo(os);
|
|
if (isConst())
|
|
os << " const";
|
|
if (!isInline()) {
|
|
os << ";\n";
|
|
return;
|
|
}
|
|
os << " {\n";
|
|
methodBody.writeTo(os);
|
|
os << "}\n\n";
|
|
}
|
|
|
|
void Method::writeDefTo(raw_indented_ostream &os, StringRef namePrefix) const {
|
|
// The method has no definition to write if it is declaration only or inline.
|
|
if (properties & Declaration || isInline())
|
|
return;
|
|
|
|
methodSignature.writeDefTo(os, namePrefix);
|
|
if (isConst())
|
|
os << " const";
|
|
os << " {\n";
|
|
methodBody.writeTo(os);
|
|
os << "}\n\n";
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Constructor definitions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void Constructor::writeDeclTo(raw_indented_ostream &os) const {
|
|
if (properties & ConstexprValue)
|
|
os << "constexpr ";
|
|
methodSignature.writeDeclTo(os);
|
|
if (!isInline()) {
|
|
os << ";\n\n";
|
|
return;
|
|
}
|
|
os << ' ';
|
|
if (!initializers.empty())
|
|
os << ": ";
|
|
llvm::interleaveComma(initializers, os,
|
|
[&](auto &initializer) { initializer.writeTo(os); });
|
|
if (!initializers.empty())
|
|
os << ' ';
|
|
os << "{";
|
|
methodBody.writeTo(os);
|
|
os << "}\n\n";
|
|
}
|
|
|
|
void Constructor::writeDefTo(raw_indented_ostream &os,
|
|
StringRef namePrefix) const {
|
|
// The method has no definition to write if it is declaration only or inline.
|
|
if (properties & Declaration || isInline())
|
|
return;
|
|
|
|
methodSignature.writeDefTo(os, namePrefix);
|
|
os << ' ';
|
|
if (!initializers.empty())
|
|
os << ": ";
|
|
llvm::interleaveComma(initializers, os,
|
|
[&](auto &initializer) { initializer.writeTo(os); });
|
|
if (!initializers.empty())
|
|
os << ' ';
|
|
os << "{";
|
|
methodBody.writeTo(os);
|
|
os << "}\n\n";
|
|
}
|
|
|
|
void Constructor::MemberInitializer::writeTo(raw_indented_ostream &os) const {
|
|
os << name << '(' << value << ')';
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Visibility definitions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace mlir {
|
|
namespace tblgen {
|
|
raw_ostream &operator<<(raw_ostream &os, Visibility visibility) {
|
|
switch (visibility) {
|
|
case Visibility::Public:
|
|
return os << "public";
|
|
case Visibility::Protected:
|
|
return os << "protected";
|
|
case Visibility::Private:
|
|
return os << "private";
|
|
}
|
|
return os;
|
|
}
|
|
} // namespace tblgen
|
|
} // namespace mlir
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ParentClass definitions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void ParentClass::writeTo(raw_indented_ostream &os) const {
|
|
os << visibility << ' ' << name;
|
|
if (!templateParams.empty()) {
|
|
auto scope = os.scope("<", ">", /*indent=*/false);
|
|
llvm::interleaveComma(templateParams, os,
|
|
[&](auto ¶m) { os << param; });
|
|
}
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// UsingDeclaration definitions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void UsingDeclaration::writeDeclTo(raw_indented_ostream &os) const {
|
|
if (!templateParams.empty()) {
|
|
os << "template <";
|
|
llvm::interleaveComma(templateParams, os, [&](StringRef paramName) {
|
|
os << "typename " << paramName;
|
|
});
|
|
os << ">\n";
|
|
}
|
|
os << "using " << name;
|
|
if (!value.empty())
|
|
os << " = " << value;
|
|
os << ";\n";
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Field definitions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void Field::writeDeclTo(raw_indented_ostream &os) const {
|
|
os << type << ' ' << name << ";\n";
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// VisibilityDeclaration definitions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void VisibilityDeclaration::writeDeclTo(raw_indented_ostream &os) const {
|
|
os.unindent();
|
|
os << visibility << ":\n";
|
|
os.indent();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ExtraClassDeclaration definitions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void ExtraClassDeclaration::writeDeclTo(raw_indented_ostream &os) const {
|
|
os.printReindented(extraClassDeclaration);
|
|
}
|
|
|
|
void ExtraClassDeclaration::writeDefTo(raw_indented_ostream &os,
|
|
StringRef namePrefix) const {
|
|
os.printReindented(extraClassDefinition);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Class definitions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
ParentClass &Class::addParent(ParentClass parent) {
|
|
parents.push_back(std::move(parent));
|
|
return parents.back();
|
|
}
|
|
|
|
void Class::writeDeclTo(raw_indented_ostream &os) const {
|
|
if (!templateParams.empty()) {
|
|
os << "template <";
|
|
llvm::interleaveComma(templateParams, os,
|
|
[&](StringRef param) { os << "typename " << param; });
|
|
os << ">\n";
|
|
}
|
|
|
|
// Declare the class.
|
|
os << (isStruct ? "struct" : "class") << ' ' << className << ' ';
|
|
|
|
// Declare the parent classes, if any.
|
|
if (!parents.empty()) {
|
|
os << ": ";
|
|
llvm::interleaveComma(parents, os,
|
|
[&](auto &parent) { parent.writeTo(os); });
|
|
os << ' ';
|
|
}
|
|
auto classScope = os.scope("{\n", "};\n", /*indent=*/true);
|
|
|
|
// Print all the class declarations.
|
|
for (auto &decl : declarations)
|
|
decl->writeDeclTo(os);
|
|
}
|
|
|
|
void Class::writeDefTo(raw_indented_ostream &os) const {
|
|
// Print all the definitions.
|
|
for (auto &decl : declarations)
|
|
decl->writeDefTo(os, className);
|
|
}
|
|
|
|
void Class::finalize() {
|
|
// Sort the methods by public and private. Remove them from the pending list
|
|
// of methods.
|
|
SmallVector<std::unique_ptr<Method>> publicMethods, privateMethods;
|
|
for (auto &method : methods) {
|
|
if (method->isPrivate())
|
|
privateMethods.push_back(std::move(method));
|
|
else
|
|
publicMethods.push_back(std::move(method));
|
|
}
|
|
methods.clear();
|
|
|
|
// If the last visibility declaration wasn't `public`, add one that is. Then,
|
|
// declare the public methods.
|
|
if (!publicMethods.empty() && getLastVisibilityDecl() != Visibility::Public)
|
|
declare<VisibilityDeclaration>(Visibility::Public);
|
|
for (auto &method : publicMethods)
|
|
declarations.push_back(std::move(method));
|
|
|
|
// If the last visibility declaration wasn't `private`, add one that is. Then,
|
|
// declare the private methods.
|
|
if (!privateMethods.empty() && getLastVisibilityDecl() != Visibility::Private)
|
|
declare<VisibilityDeclaration>(Visibility::Private);
|
|
for (auto &method : privateMethods)
|
|
declarations.push_back(std::move(method));
|
|
|
|
// All fields added to the pending list are private and declared at the bottom
|
|
// of the class. If the last visibility declaration wasn't `private`, add one
|
|
// that is, then declare the fields.
|
|
if (!fields.empty() && getLastVisibilityDecl() != Visibility::Private)
|
|
declare<VisibilityDeclaration>(Visibility::Private);
|
|
for (auto &field : fields)
|
|
declare<Field>(std::move(field));
|
|
fields.clear();
|
|
}
|
|
|
|
Visibility Class::getLastVisibilityDecl() const {
|
|
auto reverseDecls = llvm::reverse(declarations);
|
|
auto it = llvm::find_if(reverseDecls, [](auto &decl) {
|
|
return isa<VisibilityDeclaration>(decl);
|
|
});
|
|
return it == reverseDecls.end()
|
|
? (isStruct ? Visibility::Public : Visibility::Private)
|
|
: cast<VisibilityDeclaration>(**it).getVisibility();
|
|
}
|
|
|
|
Method *insertAndPruneMethods(std::vector<std::unique_ptr<Method>> &methods,
|
|
std::unique_ptr<Method> newMethod) {
|
|
if (llvm::any_of(methods, [&](auto &method) {
|
|
return method->makesRedundant(*newMethod);
|
|
}))
|
|
return nullptr;
|
|
|
|
llvm::erase_if(methods, [&](auto &method) {
|
|
return newMethod->makesRedundant(*method);
|
|
});
|
|
methods.push_back(std::move(newMethod));
|
|
return methods.back().get();
|
|
}
|
|
|
|
Method *Class::addMethodAndPrune(Method &&newMethod) {
|
|
return insertAndPruneMethods(methods,
|
|
std::make_unique<Method>(std::move(newMethod)));
|
|
}
|
|
|
|
Constructor *Class::addConstructorAndPrune(Constructor &&newCtor) {
|
|
return dyn_cast_or_null<Constructor>(insertAndPruneMethods(
|
|
methods, std::make_unique<Constructor>(std::move(newCtor))));
|
|
}
|