Files
clang-p2996/mlir/lib/TableGen/OpClass.cpp
River Riddle e4e31e19bb [mlir][OpGen] Cache Identifiers for known attribute names in AbstractOperation.
Operations currently rely on the string name of attributes during attribute lookup/removal/replacement, in build methods, and more. This unfortunately means that some of the most used APIs in MLIR require string comparisons, additional hashing(+mutex locking) to construct Identifiers, and more. This revision remedies this by caching identifiers for all of the attributes of the operation in its corresponding AbstractOperation. Just updating the autogenerated usages brings up to a 15% reduction in compile time, greatly reducing the cost of interacting with the attributes of an operation. This number can grow even higher as we use these methods in handwritten C++ code.

Methods for accessing these cached identifiers are exposed via `<attr-name>AttrName` methods on the derived operation class. Moving forward, users should generally use these methods over raw strings when an attribute name is necessary.

Differential Revision: https://reviews.llvm.org/D104167
2021-06-22 19:56:05 +00:00

348 lines
11 KiB
C++

//===- OpClass.cpp - Helper classes for Op C++ code emission --------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/TableGen/OpClass.h"
#include "mlir/TableGen/Format.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <unordered_set>
#define DEBUG_TYPE "mlir-tblgen-opclass"
using namespace mlir;
using namespace mlir::tblgen;
namespace {
// Returns space to be emitted after the given C++ `type`. return "" if the
// ends with '&' or '*', or is empty, else returns " ".
StringRef getSpaceAfterType(StringRef type) {
return (type.empty() || type.endswith("&") || type.endswith("*")) ? "" : " ";
}
} // namespace
//===----------------------------------------------------------------------===//
// OpMethodParameter definitions
//===----------------------------------------------------------------------===//
void OpMethodParameter::writeTo(raw_ostream &os, bool emitDefault) const {
if (properties & PP_Optional)
os << "/*optional*/";
os << type << getSpaceAfterType(type) << name;
if (emitDefault && !defaultValue.empty())
os << " = " << defaultValue;
}
//===----------------------------------------------------------------------===//
// OpMethodParameters definitions
//===----------------------------------------------------------------------===//
// Factory methods to construct the correct type of `OpMethodParameters`
// object based on the arguments.
std::unique_ptr<OpMethodParameters> OpMethodParameters::create() {
return std::make_unique<OpMethodResolvedParameters>();
}
std::unique_ptr<OpMethodParameters>
OpMethodParameters::create(StringRef params) {
return std::make_unique<OpMethodUnresolvedParameters>(params);
}
std::unique_ptr<OpMethodParameters>
OpMethodParameters::create(llvm::SmallVectorImpl<OpMethodParameter> &&params) {
return std::make_unique<OpMethodResolvedParameters>(std::move(params));
}
std::unique_ptr<OpMethodParameters>
OpMethodParameters::create(StringRef type, StringRef name,
StringRef defaultValue) {
return std::make_unique<OpMethodResolvedParameters>(type, name, defaultValue);
}
//===----------------------------------------------------------------------===//
// OpMethodUnresolvedParameters definitions
//===----------------------------------------------------------------------===//
void OpMethodUnresolvedParameters::writeDeclTo(raw_ostream &os) const {
os << parameters;
}
void OpMethodUnresolvedParameters::writeDefTo(raw_ostream &os) const {
// We need to remove the default values for parameters in method definition.
// TODO: We are using '=' and ',' as delimiters for parameter
// initializers. This is incorrect for initializer list with more than one
// element. Change to a more robust approach.
llvm::SmallVector<StringRef, 4> tokens;
StringRef params = parameters;
while (!params.empty()) {
std::pair<StringRef, StringRef> parts = params.split("=");
tokens.push_back(parts.first);
params = parts.second.split(',').second;
}
llvm::interleaveComma(tokens, os, [&](StringRef token) { os << token; });
}
//===----------------------------------------------------------------------===//
// OpMethodResolvedParameters definitions
//===----------------------------------------------------------------------===//
// Returns true if a method with these parameters makes a method with parameters
// `other` redundant. This should return true only if all possible calls to the
// other method can be replaced by calls to this method.
bool OpMethodResolvedParameters::makesRedundant(
const OpMethodResolvedParameters &other) const {
const size_t otherNumParams = other.getNumParameters();
const size_t thisNumParams = getNumParameters();
// All calls to the other method can be replaced this method only if this
// method has the same or more arguments number of arguments as the other, and
// the common arguments have the same type.
if (thisNumParams < otherNumParams)
return false;
for (int idx : llvm::seq<int>(0, otherNumParams))
if (parameters[idx].getType() != other.parameters[idx].getType())
return false;
// If all the common arguments have the same type, we can elide the other
// method if this method has the same number of arguments as other or the
// first argument after the common ones has a default value (and by C++
// requirement, all the later ones will also have a default value).
return thisNumParams == otherNumParams ||
parameters[otherNumParams].hasDefaultValue();
}
void OpMethodResolvedParameters::writeDeclTo(raw_ostream &os) const {
llvm::interleaveComma(parameters, os, [&](const OpMethodParameter &param) {
param.writeDeclTo(os);
});
}
void OpMethodResolvedParameters::writeDefTo(raw_ostream &os) const {
llvm::interleaveComma(parameters, os, [&](const OpMethodParameter &param) {
param.writeDefTo(os);
});
}
//===----------------------------------------------------------------------===//
// OpMethodSignature definitions
//===----------------------------------------------------------------------===//
// Returns if a method with this signature makes a method with `other` signature
// redundant. Only supports resolved parameters.
bool OpMethodSignature::makesRedundant(const OpMethodSignature &other) const {
if (methodName != other.methodName)
return false;
auto *resolvedThis = dyn_cast<OpMethodResolvedParameters>(parameters.get());
auto *resolvedOther =
dyn_cast<OpMethodResolvedParameters>(other.parameters.get());
if (resolvedThis && resolvedOther)
return resolvedThis->makesRedundant(*resolvedOther);
return false;
}
void OpMethodSignature::writeDeclTo(raw_ostream &os) const {
os << returnType << getSpaceAfterType(returnType) << methodName << "(";
parameters->writeDeclTo(os);
os << ")";
}
void OpMethodSignature::writeDefTo(raw_ostream &os,
StringRef namePrefix) const {
os << returnType << getSpaceAfterType(returnType) << namePrefix
<< (namePrefix.empty() ? "" : "::") << methodName << "(";
parameters->writeDefTo(os);
os << ")";
}
//===----------------------------------------------------------------------===//
// OpMethodBody definitions
//===----------------------------------------------------------------------===//
OpMethodBody::OpMethodBody(bool declOnly) : isEffective(!declOnly) {}
OpMethodBody &OpMethodBody::operator<<(Twine content) {
if (isEffective)
body.append(content.str());
return *this;
}
OpMethodBody &OpMethodBody::operator<<(int content) {
if (isEffective)
body.append(std::to_string(content));
return *this;
}
OpMethodBody &OpMethodBody::operator<<(const FmtObjectBase &content) {
if (isEffective)
body.append(content.str());
return *this;
}
void OpMethodBody::writeTo(raw_ostream &os) const {
auto bodyRef = StringRef(body).drop_while([](char c) { return c == '\n'; });
os << bodyRef;
if (bodyRef.empty() || bodyRef.back() != '\n')
os << "\n";
}
//===----------------------------------------------------------------------===//
// OpMethod definitions
//===----------------------------------------------------------------------===//
void OpMethod::writeDeclTo(raw_ostream &os) const {
os.indent(2);
if (isStatic())
os << "static ";
if ((properties & MP_Constexpr) == MP_Constexpr)
os << "constexpr ";
methodSignature.writeDeclTo(os);
if (!isInline()) {
os << ";";
} else {
os << " {\n";
methodBody.writeTo(os.indent(2));
os.indent(2) << "}";
}
}
void OpMethod::writeDefTo(raw_ostream &os, StringRef namePrefix) const {
// Do not write definition if the method is decl only.
if (properties & MP_Declaration)
return;
// Do not generate separate definition for inline method
if (isInline())
return;
methodSignature.writeDefTo(os, namePrefix);
os << " {\n";
methodBody.writeTo(os);
os << "}";
}
//===----------------------------------------------------------------------===//
// OpConstructor definitions
//===----------------------------------------------------------------------===//
void OpConstructor::addMemberInitializer(StringRef name, StringRef value) {
memberInitializers.append(std::string(llvm::formatv(
"{0}{1}({2})", memberInitializers.empty() ? " : " : ", ", name, value)));
}
void OpConstructor::writeDefTo(raw_ostream &os, StringRef namePrefix) const {
// Do not write definition if the method is decl only.
if (properties & MP_Declaration)
return;
methodSignature.writeDefTo(os, namePrefix);
os << " " << memberInitializers << " {\n";
methodBody.writeTo(os);
os << "}";
}
//===----------------------------------------------------------------------===//
// Class definitions
//===----------------------------------------------------------------------===//
Class::Class(StringRef name) : className(name) {}
void Class::newField(StringRef type, StringRef name, StringRef defaultValue) {
std::string varName = formatv("{0} {1}", type, name).str();
std::string field = defaultValue.empty()
? varName
: formatv("{0} = {1}", varName, defaultValue).str();
fields.push_back(std::move(field));
}
void Class::writeDeclTo(raw_ostream &os) const {
bool hasPrivateMethod = false;
os << "class " << className << " {\n";
os << "public:\n";
forAllMethods([&](const OpMethod &method) {
if (!method.isPrivate()) {
method.writeDeclTo(os);
os << '\n';
} else {
hasPrivateMethod = true;
}
});
os << '\n';
os << "private:\n";
if (hasPrivateMethod) {
forAllMethods([&](const OpMethod &method) {
if (method.isPrivate()) {
method.writeDeclTo(os);
os << '\n';
}
});
os << '\n';
}
for (const auto &field : fields)
os.indent(2) << field << ";\n";
os << "};\n";
}
void Class::writeDefTo(raw_ostream &os) const {
forAllMethods([&](const OpMethod &method) {
method.writeDefTo(os, className);
os << "\n\n";
});
}
//===----------------------------------------------------------------------===//
// OpClass definitions
//===----------------------------------------------------------------------===//
OpClass::OpClass(StringRef name, StringRef extraClassDeclaration)
: Class(name), extraClassDeclaration(extraClassDeclaration) {}
void OpClass::addTrait(Twine trait) {
auto traitStr = trait.str();
if (traitsSet.insert(traitStr).second)
traitsVec.push_back(std::move(traitStr));
}
void OpClass::writeDeclTo(raw_ostream &os) const {
os << "class " << className << " : public ::mlir::Op<" << className;
for (const auto &trait : traitsVec)
os << ", " << trait;
os << "> {\npublic:\n"
<< " using Op::Op;\n"
<< " using Op::print;\n"
<< " using Adaptor = " << className << "Adaptor;\n";
bool hasPrivateMethod = false;
forAllMethods([&](const OpMethod &method) {
if (!method.isPrivate()) {
method.writeDeclTo(os);
os << "\n";
} else {
hasPrivateMethod = true;
}
});
// TODO: Add line control markers to make errors easier to debug.
if (!extraClassDeclaration.empty())
os << extraClassDeclaration << "\n";
if (hasPrivateMethod) {
os << "\nprivate:\n";
forAllMethods([&](const OpMethod &method) {
if (method.isPrivate()) {
method.writeDeclTo(os);
os << "\n";
}
});
}
os << "};\n";
}