Files
clang-p2996/mlir/lib/TableGen/Operator.cpp
Lei Zhang 020f9eb68c [DRR] Allow interleaved operands and attributes
Previously DRR assumes attributes to appear after operands. This was the
previous requirements on ODS, but that has changed some time ago. Fix
DRR to also support interleaved operands and attributes.

PiperOrigin-RevId: 275983485
2019-10-21 20:48:17 -07:00

320 lines
10 KiB
C++

//===- Operator.cpp - Operator class --------------------------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// Operator wrapper to simplify using TableGen Record defining a MLIR Op.
//
//===----------------------------------------------------------------------===//
#include "mlir/TableGen/Operator.h"
#include "mlir/TableGen/OpTrait.h"
#include "mlir/TableGen/Predicate.h"
#include "mlir/TableGen/Type.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
using namespace mlir;
using llvm::DagInit;
using llvm::DefInit;
using llvm::Record;
tblgen::Operator::Operator(const llvm::Record &def)
: dialect(def.getValueAsDef("opDialect")), def(def) {
// The first `_` in the op's TableGen def name is treated as separating the
// dialect prefix and the op class name. The dialect prefix will be ignored if
// not empty. Otherwise, if def name starts with a `_`, the `_` is considered
// as part of the class name.
StringRef prefix;
std::tie(prefix, cppClassName) = def.getName().split('_');
if (prefix.empty()) {
// Class name with a leading underscore and without dialect prefix
cppClassName = def.getName();
} else if (cppClassName.empty()) {
// Class name without dialect prefix
cppClassName = prefix;
}
populateOpStructure();
}
std::string tblgen::Operator::getOperationName() const {
auto prefix = dialect.getName();
auto opName = def.getValueAsString("opName");
if (prefix.empty())
return opName;
return llvm::formatv("{0}.{1}", prefix, opName);
}
StringRef tblgen::Operator::getDialectName() const { return dialect.getName(); }
StringRef tblgen::Operator::getCppClassName() const { return cppClassName; }
std::string tblgen::Operator::getQualCppClassName() const {
auto prefix = dialect.getCppNamespace();
if (prefix.empty())
return cppClassName;
return llvm::formatv("{0}::{1}", prefix, cppClassName);
}
int tblgen::Operator::getNumResults() const {
DagInit *results = def.getValueAsDag("results");
return results->getNumArgs();
}
StringRef tblgen::Operator::getExtraClassDeclaration() const {
constexpr auto attr = "extraClassDeclaration";
if (def.isValueUnset(attr))
return {};
return def.getValueAsString(attr);
}
const llvm::Record &tblgen::Operator::getDef() const { return def; }
bool tblgen::Operator::isVariadic() const {
return getNumVariadicOperands() != 0 || getNumVariadicResults() != 0;
}
bool tblgen::Operator::skipDefaultBuilders() const {
return def.getValueAsBit("skipDefaultBuilders");
}
auto tblgen::Operator::result_begin() -> value_iterator {
return results.begin();
}
auto tblgen::Operator::result_end() -> value_iterator { return results.end(); }
auto tblgen::Operator::getResults() -> value_range {
return {result_begin(), result_end()};
}
tblgen::TypeConstraint
tblgen::Operator::getResultTypeConstraint(int index) const {
DagInit *results = def.getValueAsDag("results");
return TypeConstraint(cast<DefInit>(results->getArg(index)));
}
StringRef tblgen::Operator::getResultName(int index) const {
DagInit *results = def.getValueAsDag("results");
return results->getArgNameStr(index);
}
unsigned tblgen::Operator::getNumVariadicResults() const {
return std::count_if(
results.begin(), results.end(),
[](const NamedTypeConstraint &c) { return c.constraint.isVariadic(); });
}
unsigned tblgen::Operator::getNumVariadicOperands() const {
return std::count_if(
operands.begin(), operands.end(),
[](const NamedTypeConstraint &c) { return c.constraint.isVariadic(); });
}
tblgen::Operator::arg_iterator tblgen::Operator::arg_begin() const {
return arguments.begin();
}
tblgen::Operator::arg_iterator tblgen::Operator::arg_end() const {
return arguments.end();
}
tblgen::Operator::arg_range tblgen::Operator::getArgs() const {
return {arg_begin(), arg_end()};
}
StringRef tblgen::Operator::getArgName(int index) const {
DagInit *argumentValues = def.getValueAsDag("arguments");
return argumentValues->getArgName(index)->getValue();
}
bool tblgen::Operator::hasTrait(StringRef trait) const {
for (auto t : getTraits()) {
if (auto opTrait = dyn_cast<tblgen::NativeOpTrait>(&t)) {
if (opTrait->getTrait() == trait)
return true;
} else if (auto opTrait = dyn_cast<tblgen::InternalOpTrait>(&t)) {
if (opTrait->getTrait() == trait)
return true;
}
}
return false;
}
unsigned tblgen::Operator::getNumRegions() const { return regions.size(); }
const tblgen::NamedRegion &tblgen::Operator::getRegion(unsigned index) const {
return regions[index];
}
auto tblgen::Operator::trait_begin() const -> const_trait_iterator {
return traits.begin();
}
auto tblgen::Operator::trait_end() const -> const_trait_iterator {
return traits.end();
}
auto tblgen::Operator::getTraits() const
-> llvm::iterator_range<const_trait_iterator> {
return {trait_begin(), trait_end()};
}
auto tblgen::Operator::attribute_begin() const -> attribute_iterator {
return attributes.begin();
}
auto tblgen::Operator::attribute_end() const -> attribute_iterator {
return attributes.end();
}
auto tblgen::Operator::getAttributes() const
-> llvm::iterator_range<attribute_iterator> {
return {attribute_begin(), attribute_end()};
}
auto tblgen::Operator::operand_begin() -> value_iterator {
return operands.begin();
}
auto tblgen::Operator::operand_end() -> value_iterator {
return operands.end();
}
auto tblgen::Operator::getOperands() -> value_range {
return {operand_begin(), operand_end()};
}
auto tblgen::Operator::getArg(int index) const -> Argument {
return arguments[index];
}
void tblgen::Operator::populateOpStructure() {
auto &recordKeeper = def.getRecords();
auto typeConstraintClass = recordKeeper.getClass("TypeConstraint");
auto attrClass = recordKeeper.getClass("Attr");
auto derivedAttrClass = recordKeeper.getClass("DerivedAttr");
numNativeAttributes = 0;
// The argument ordering is operands, native attributes, derived
// attributes.
DagInit *argumentValues = def.getValueAsDag("arguments");
unsigned i = 0;
// Handle operands and native attributes.
for (unsigned e = argumentValues->getNumArgs(); i != e; ++i) {
auto arg = argumentValues->getArg(i);
auto givenName = argumentValues->getArgNameStr(i);
auto argDefInit = dyn_cast<DefInit>(arg);
if (!argDefInit)
PrintFatalError(def.getLoc(),
Twine("undefined type for argument #") + Twine(i));
Record *argDef = argDefInit->getDef();
if (argDef->isSubClassOf(typeConstraintClass)) {
operands.push_back(
NamedTypeConstraint{givenName, TypeConstraint(argDefInit)});
arguments.emplace_back(&operands.back());
} else if (argDef->isSubClassOf(attrClass)) {
if (givenName.empty())
PrintFatalError(argDef->getLoc(), "attributes must be named");
if (argDef->isSubClassOf(derivedAttrClass))
PrintFatalError(argDef->getLoc(),
"derived attributes not allowed in argument list");
attributes.push_back({givenName, Attribute(argDef)});
arguments.emplace_back(&attributes.back());
++numNativeAttributes;
} else {
PrintFatalError(def.getLoc(), "unexpected def type; only defs deriving "
"from TypeConstraint or Attr are allowed");
}
}
// Handle derived attributes.
for (const auto &val : def.getValues()) {
if (auto *record = dyn_cast<llvm::RecordRecTy>(val.getType())) {
if (!record->isSubClassOf(attrClass))
continue;
if (!record->isSubClassOf(derivedAttrClass))
PrintFatalError(def.getLoc(),
"unexpected Attr where only DerivedAttr is allowed");
if (record->getClasses().size() != 1) {
PrintFatalError(
def.getLoc(),
"unsupported attribute modelling, only single class expected");
}
attributes.push_back(
{cast<llvm::StringInit>(val.getNameInit())->getValue(),
Attribute(cast<DefInit>(val.getValue()))});
}
}
auto *resultsDag = def.getValueAsDag("results");
auto *outsOp = dyn_cast<DefInit>(resultsDag->getOperator());
if (!outsOp || outsOp->getDef()->getName() != "outs") {
PrintFatalError(def.getLoc(), "'results' must have 'outs' directive");
}
// Handle results.
for (unsigned i = 0, e = resultsDag->getNumArgs(); i < e; ++i) {
auto name = resultsDag->getArgNameStr(i);
auto *resultDef = dyn_cast<DefInit>(resultsDag->getArg(i));
if (!resultDef) {
PrintFatalError(def.getLoc(),
Twine("undefined type for result #") + Twine(i));
}
results.push_back({name, TypeConstraint(resultDef)});
}
auto traitListInit = def.getValueAsListInit("traits");
if (!traitListInit)
return;
traits.reserve(traitListInit->size());
for (auto traitInit : *traitListInit)
traits.push_back(OpTrait::create(traitInit));
// Handle regions
auto *regionsDag = def.getValueAsDag("regions");
auto *regionsOp = dyn_cast<DefInit>(regionsDag->getOperator());
if (!regionsOp || regionsOp->getDef()->getName() != "region") {
PrintFatalError(def.getLoc(), "'regions' must have 'region' directive");
}
for (unsigned i = 0, e = regionsDag->getNumArgs(); i < e; ++i) {
auto name = regionsDag->getArgNameStr(i);
auto *regionInit = dyn_cast<DefInit>(regionsDag->getArg(i));
if (!regionInit) {
PrintFatalError(def.getLoc(),
Twine("undefined kind for region #") + Twine(i));
}
regions.push_back({name, Region(regionInit->getDef())});
}
}
ArrayRef<llvm::SMLoc> tblgen::Operator::getLoc() const { return def.getLoc(); }
bool tblgen::Operator::hasDescription() const {
return def.getValue("description") != nullptr;
}
StringRef tblgen::Operator::getDescription() const {
return def.getValueAsString("description");
}
bool tblgen::Operator::hasSummary() const {
return def.getValue("summary") != nullptr;
}
StringRef tblgen::Operator::getSummary() const {
return def.getValueAsString("summary");
}