From 263ec7221ef6ef96876d75b69ab48bf37dcbb25e Mon Sep 17 00:00:00 2001 From: Krzysztof Drewniak Date: Wed, 26 Mar 2025 20:26:14 -0500 Subject: [PATCH] [mlir][NFC] Move and rename EnumAttrCase, EnumAttr C++ classes (#132650) This moves the EnumAttrCase and EnumAttr classes from Attribute.h/.cpp to a new EnumInfo.h/cpp and renames them to EnumCase and EnumInfo, respectively. This doesn't change any of the tablegen files or any user-facing aspects of the enum attribute generation system, just reorganizes code in order to make main PR (#132148) shorter. --- mlir/include/mlir/TableGen/Attribute.h | 74 -------- mlir/include/mlir/TableGen/EnumInfo.h | 133 ++++++++++++++ mlir/include/mlir/TableGen/Pattern.h | 11 +- mlir/lib/TableGen/Attribute.cpp | 94 ---------- mlir/lib/TableGen/CMakeLists.txt | 1 + mlir/lib/TableGen/EnumInfo.cpp | 130 ++++++++++++++ mlir/lib/TableGen/Pattern.cpp | 12 +- .../mlir-tblgen/EnumPythonBindingGen.cpp | 47 ++--- mlir/tools/mlir-tblgen/EnumsGen.cpp | 166 +++++++++--------- .../tools/mlir-tblgen/LLVMIRConversionGen.cpp | 75 ++++---- mlir/tools/mlir-tblgen/OpDocGen.cpp | 17 +- mlir/tools/mlir-tblgen/OpFormatGen.cpp | 35 ++-- mlir/tools/mlir-tblgen/RewriterGen.cpp | 8 +- mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp | 91 +++++----- mlir/tools/mlir-tblgen/TosaUtilsGen.cpp | 5 +- 15 files changed, 504 insertions(+), 395 deletions(-) create mode 100644 mlir/include/mlir/TableGen/EnumInfo.h create mode 100644 mlir/lib/TableGen/EnumInfo.cpp diff --git a/mlir/include/mlir/TableGen/Attribute.h b/mlir/include/mlir/TableGen/Attribute.h index 62720e74849f..dee81880baca 100644 --- a/mlir/include/mlir/TableGen/Attribute.h +++ b/mlir/include/mlir/TableGen/Attribute.h @@ -16,7 +16,6 @@ #include "mlir/Support/LLVM.h" #include "mlir/TableGen/Constraint.h" -#include "llvm/ADT/StringRef.h" namespace llvm { class DefInit; @@ -136,79 +135,6 @@ private: const llvm::Record *def; }; -// Wrapper class providing helper methods for accessing enum attribute cases -// defined in TableGen. This is used for enum attribute case backed by both -// StringAttr and IntegerAttr. -class EnumAttrCase : public Attribute { -public: - explicit EnumAttrCase(const llvm::Record *record); - explicit EnumAttrCase(const llvm::DefInit *init); - - // Returns the symbol of this enum attribute case. - StringRef getSymbol() const; - - // Returns the textual representation of this enum attribute case. - StringRef getStr() const; - - // Returns the value of this enum attribute case. - int64_t getValue() const; - - // Returns the TableGen definition this EnumAttrCase was constructed from. - const llvm::Record &getDef() const; -}; - -// Wrapper class providing helper methods for accessing enum attributes defined -// in TableGen.This is used for enum attribute case backed by both StringAttr -// and IntegerAttr. -class EnumAttr : public Attribute { -public: - explicit EnumAttr(const llvm::Record *record); - explicit EnumAttr(const llvm::Record &record); - explicit EnumAttr(const llvm::DefInit *init); - - static bool classof(const Attribute *attr); - - // Returns true if this is a bit enum attribute. - bool isBitEnum() const; - - // Returns the enum class name. - StringRef getEnumClassName() const; - - // Returns the C++ namespaces this enum class should be placed in. - StringRef getCppNamespace() const; - - // Returns the underlying type. - StringRef getUnderlyingType() const; - - // Returns the name of the utility function that converts a value of the - // underlying type to the corresponding symbol. - StringRef getUnderlyingToSymbolFnName() const; - - // Returns the name of the utility function that converts a string to the - // corresponding symbol. - StringRef getStringToSymbolFnName() const; - - // Returns the name of the utility function that converts a symbol to the - // corresponding string. - StringRef getSymbolToStringFnName() const; - - // Returns the return type of the utility function that converts a symbol to - // the corresponding string. - StringRef getSymbolToStringFnRetType() const; - - // Returns the name of the utilit function that returns the max enum value - // used within the enum class. - StringRef getMaxEnumValFnName() const; - - // Returns all allowed cases for this enum attribute. - std::vector getAllCases() const; - - bool genSpecializedAttr() const; - const llvm::Record *getBaseAttrClass() const; - StringRef getSpecializedAttrClassName() const; - bool printBitEnumPrimaryGroups() const; -}; - // Name of infer type op interface. extern const char *inferTypeOpInterface; diff --git a/mlir/include/mlir/TableGen/EnumInfo.h b/mlir/include/mlir/TableGen/EnumInfo.h new file mode 100644 index 000000000000..5bc7ffb6a8a3 --- /dev/null +++ b/mlir/include/mlir/TableGen/EnumInfo.h @@ -0,0 +1,133 @@ +//===- EnumInfo.h - EnumInfo wrapper class --------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// EnumInfo wrapper to simplify using a TableGen Record defining an Enum +// via EnumInfo and its `EnumCase`s. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TABLEGEN_ENUMINFO_H_ +#define MLIR_TABLEGEN_ENUMINFO_H_ + +#include "mlir/Support/LLVM.h" +#include "mlir/TableGen/Attribute.h" +#include "llvm/ADT/StringRef.h" + +namespace llvm { +class DefInit; +class Record; +} // namespace llvm + +namespace mlir::tblgen { + +// Wrapper class providing around enum cases defined in TableGen. +class EnumCase { +public: + explicit EnumCase(const llvm::Record *record); + explicit EnumCase(const llvm::DefInit *init); + + // Returns the symbol of this enum attribute case. + StringRef getSymbol() const; + + // Returns the textual representation of this enum attribute case. + StringRef getStr() const; + + // Returns the value of this enum attribute case. + int64_t getValue() const; + + // Returns the TableGen definition this EnumAttrCase was constructed from. + const llvm::Record &getDef() const; + +protected: + // The TableGen definition of this constraint. + const llvm::Record *def; +}; + +// Wrapper class providing helper methods for accessing enums defined +// in TableGen using EnumInfo. Some methods are only applicable when +// the enum is also an attribute, or only when it is a bit enum. +class EnumInfo { +public: + explicit EnumInfo(const llvm::Record *record); + explicit EnumInfo(const llvm::Record &record); + explicit EnumInfo(const llvm::DefInit *init); + + // Returns true if the given EnumInfo is a subclass of the named TableGen + // class. + bool isSubClassOf(StringRef className) const; + + // Returns true if this enum is an EnumAttrInfo, thus making it define an + // attribute. + bool isEnumAttr() const; + + // Create the `Attribute` wrapper around this EnumInfo if it is defining an + // attribute. + std::optional asEnumAttr() const; + + // Returns true if this is a bit enum. + bool isBitEnum() const; + + // Returns the enum class name. + StringRef getEnumClassName() const; + + // Returns the C++ namespaces this enum class should be placed in. + StringRef getCppNamespace() const; + + // Returns the summary of the enum. + StringRef getSummary() const; + + // Returns the description of the enum. + StringRef getDescription() const; + + // Returns the underlying type. + StringRef getUnderlyingType() const; + + // Returns the name of the utility function that converts a value of the + // underlying type to the corresponding symbol. + StringRef getUnderlyingToSymbolFnName() const; + + // Returns the name of the utility function that converts a string to the + // corresponding symbol. + StringRef getStringToSymbolFnName() const; + + // Returns the name of the utility function that converts a symbol to the + // corresponding string. + StringRef getSymbolToStringFnName() const; + + // Returns the return type of the utility function that converts a symbol to + // the corresponding string. + StringRef getSymbolToStringFnRetType() const; + + // Returns the name of the utilit function that returns the max enum value + // used within the enum class. + StringRef getMaxEnumValFnName() const; + + // Returns all allowed cases for this enum attribute. + std::vector getAllCases() const; + + // Only applicable for enum attributes. + + bool genSpecializedAttr() const; + const llvm::Record *getBaseAttrClass() const; + StringRef getSpecializedAttrClassName() const; + + // Only applicable for bit enums. + + bool printBitEnumPrimaryGroups() const; + + // Returns the TableGen definition this EnumAttrCase was constructed from. + const llvm::Record &getDef() const; + +protected: + // The TableGen definition of this constraint. + const llvm::Record *def; +}; + +} // namespace mlir::tblgen + +#endif diff --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h index 80f38fdeffee..1c9e128f0a0f 100644 --- a/mlir/include/mlir/TableGen/Pattern.h +++ b/mlir/include/mlir/TableGen/Pattern.h @@ -16,6 +16,7 @@ #include "mlir/Support/LLVM.h" #include "mlir/TableGen/Argument.h" +#include "mlir/TableGen/EnumInfo.h" #include "mlir/TableGen/Operator.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/Hashing.h" @@ -78,8 +79,8 @@ public: // Returns true if this DAG leaf is specifying a constant attribute. bool isConstantAttr() const; - // Returns true if this DAG leaf is specifying an enum attribute case. - bool isEnumAttrCase() const; + // Returns true if this DAG leaf is specifying an enum case. + bool isEnumCase() const; // Returns true if this DAG leaf is specifying a string attribute. bool isStringAttr() const; @@ -90,9 +91,9 @@ public: // Returns this DAG leaf as an constant attribute. Asserts if fails. ConstantAttr getAsConstantAttr() const; - // Returns this DAG leaf as an enum attribute case. - // Precondition: isEnumAttrCase() - EnumAttrCase getAsEnumAttrCase() const; + // Returns this DAG leaf as an enum case. + // Precondition: isEnumCase() + EnumCase getAsEnumCase() const; // Returns the matching condition template inside this DAG leaf. Assumes the // leaf is an operand/attribute matcher and asserts otherwise. diff --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp index f9fc58a40f33..142d19426094 100644 --- a/mlir/lib/TableGen/Attribute.cpp +++ b/mlir/lib/TableGen/Attribute.cpp @@ -146,98 +146,4 @@ StringRef ConstantAttr::getConstantValue() const { return def->getValueAsString("value"); } -EnumAttrCase::EnumAttrCase(const Record *record) : Attribute(record) { - assert(isSubClassOf("EnumAttrCaseInfo") && - "must be subclass of TableGen 'EnumAttrInfo' class"); -} - -EnumAttrCase::EnumAttrCase(const DefInit *init) - : EnumAttrCase(init->getDef()) {} - -StringRef EnumAttrCase::getSymbol() const { - return def->getValueAsString("symbol"); -} - -StringRef EnumAttrCase::getStr() const { return def->getValueAsString("str"); } - -int64_t EnumAttrCase::getValue() const { return def->getValueAsInt("value"); } - -const Record &EnumAttrCase::getDef() const { return *def; } - -EnumAttr::EnumAttr(const Record *record) : Attribute(record) { - assert(isSubClassOf("EnumAttrInfo") && - "must be subclass of TableGen 'EnumAttr' class"); -} - -EnumAttr::EnumAttr(const Record &record) : Attribute(&record) {} - -EnumAttr::EnumAttr(const DefInit *init) : EnumAttr(init->getDef()) {} - -bool EnumAttr::classof(const Attribute *attr) { - return attr->isSubClassOf("EnumAttrInfo"); -} - -bool EnumAttr::isBitEnum() const { return isSubClassOf("BitEnumAttr"); } - -StringRef EnumAttr::getEnumClassName() const { - return def->getValueAsString("className"); -} - -StringRef EnumAttr::getCppNamespace() const { - return def->getValueAsString("cppNamespace"); -} - -StringRef EnumAttr::getUnderlyingType() const { - return def->getValueAsString("underlyingType"); -} - -StringRef EnumAttr::getUnderlyingToSymbolFnName() const { - return def->getValueAsString("underlyingToSymbolFnName"); -} - -StringRef EnumAttr::getStringToSymbolFnName() const { - return def->getValueAsString("stringToSymbolFnName"); -} - -StringRef EnumAttr::getSymbolToStringFnName() const { - return def->getValueAsString("symbolToStringFnName"); -} - -StringRef EnumAttr::getSymbolToStringFnRetType() const { - return def->getValueAsString("symbolToStringFnRetType"); -} - -StringRef EnumAttr::getMaxEnumValFnName() const { - return def->getValueAsString("maxEnumValFnName"); -} - -std::vector EnumAttr::getAllCases() const { - const auto *inits = def->getValueAsListInit("enumerants"); - - std::vector cases; - cases.reserve(inits->size()); - - for (const Init *init : *inits) { - cases.emplace_back(cast(init)); - } - - return cases; -} - -bool EnumAttr::genSpecializedAttr() const { - return def->getValueAsBit("genSpecializedAttr"); -} - -const Record *EnumAttr::getBaseAttrClass() const { - return def->getValueAsDef("baseAttrClass"); -} - -StringRef EnumAttr::getSpecializedAttrClassName() const { - return def->getValueAsString("specializedAttrClassName"); -} - -bool EnumAttr::printBitEnumPrimaryGroups() const { - return def->getValueAsBit("printBitEnumPrimaryGroups"); -} - const char * ::mlir::tblgen::inferTypeOpInterface = "InferTypeOpInterface"; diff --git a/mlir/lib/TableGen/CMakeLists.txt b/mlir/lib/TableGen/CMakeLists.txt index c4104e644147..a90c55847718 100644 --- a/mlir/lib/TableGen/CMakeLists.txt +++ b/mlir/lib/TableGen/CMakeLists.txt @@ -20,6 +20,7 @@ llvm_add_library(MLIRTableGen STATIC CodeGenHelpers.cpp Constraint.cpp Dialect.cpp + EnumInfo.cpp Format.cpp GenInfo.cpp Interfaces.cpp diff --git a/mlir/lib/TableGen/EnumInfo.cpp b/mlir/lib/TableGen/EnumInfo.cpp new file mode 100644 index 000000000000..9f491d30f0e7 --- /dev/null +++ b/mlir/lib/TableGen/EnumInfo.cpp @@ -0,0 +1,130 @@ +//===- EnumInfo.cpp - EnumInfo wrapper class ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/TableGen/EnumInfo.h" +#include "mlir/TableGen/Attribute.h" +#include "llvm/TableGen/Record.h" + +using namespace mlir; +using namespace mlir::tblgen; + +using llvm::DefInit; +using llvm::Init; +using llvm::Record; + +EnumCase::EnumCase(const Record *record) : def(record) { + assert(def->isSubClassOf("EnumAttrCaseInfo") && + "must be subclass of TableGen 'EnumAttrCaseInfo' class"); +} + +EnumCase::EnumCase(const DefInit *init) : EnumCase(init->getDef()) {} + +StringRef EnumCase::getSymbol() const { + return def->getValueAsString("symbol"); +} + +StringRef EnumCase::getStr() const { return def->getValueAsString("str"); } + +int64_t EnumCase::getValue() const { return def->getValueAsInt("value"); } + +const Record &EnumCase::getDef() const { return *def; } + +EnumInfo::EnumInfo(const Record *record) : def(record) { + assert(isSubClassOf("EnumAttrInfo") && + "must be subclass of TableGen 'EnumAttrInfo' class"); +} + +EnumInfo::EnumInfo(const Record &record) : EnumInfo(&record) {} + +EnumInfo::EnumInfo(const DefInit *init) : EnumInfo(init->getDef()) {} + +bool EnumInfo::isSubClassOf(StringRef className) const { + return def->isSubClassOf(className); +} + +bool EnumInfo::isEnumAttr() const { return isSubClassOf("EnumAttrInfo"); } + +std::optional EnumInfo::asEnumAttr() const { + if (isEnumAttr()) + return Attribute(def); + return std::nullopt; +} + +bool EnumInfo::isBitEnum() const { return isSubClassOf("BitEnumAttr"); } + +StringRef EnumInfo::getEnumClassName() const { + return def->getValueAsString("className"); +} + +StringRef EnumInfo::getSummary() const { + return def->getValueAsString("summary"); +} + +StringRef EnumInfo::getDescription() const { + return def->getValueAsString("description"); +} + +StringRef EnumInfo::getCppNamespace() const { + return def->getValueAsString("cppNamespace"); +} + +StringRef EnumInfo::getUnderlyingType() const { + return def->getValueAsString("underlyingType"); +} + +StringRef EnumInfo::getUnderlyingToSymbolFnName() const { + return def->getValueAsString("underlyingToSymbolFnName"); +} + +StringRef EnumInfo::getStringToSymbolFnName() const { + return def->getValueAsString("stringToSymbolFnName"); +} + +StringRef EnumInfo::getSymbolToStringFnName() const { + return def->getValueAsString("symbolToStringFnName"); +} + +StringRef EnumInfo::getSymbolToStringFnRetType() const { + return def->getValueAsString("symbolToStringFnRetType"); +} + +StringRef EnumInfo::getMaxEnumValFnName() const { + return def->getValueAsString("maxEnumValFnName"); +} + +std::vector EnumInfo::getAllCases() const { + const auto *inits = def->getValueAsListInit("enumerants"); + + std::vector cases; + cases.reserve(inits->size()); + + for (const Init *init : *inits) { + cases.emplace_back(cast(init)); + } + + return cases; +} + +bool EnumInfo::genSpecializedAttr() const { + return isSubClassOf("EnumAttrInfo") && + def->getValueAsBit("genSpecializedAttr"); +} + +const Record *EnumInfo::getBaseAttrClass() const { + return def->getValueAsDef("baseAttrClass"); +} + +StringRef EnumInfo::getSpecializedAttrClassName() const { + return def->getValueAsString("specializedAttrClassName"); +} + +bool EnumInfo::printBitEnumPrimaryGroups() const { + return def->getValueAsBit("printBitEnumPrimaryGroups"); +} + +const Record &EnumInfo::getDef() const { return *def; } diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp index ac8c49c72d38..73e2803c21da 100644 --- a/mlir/lib/TableGen/Pattern.cpp +++ b/mlir/lib/TableGen/Pattern.cpp @@ -57,9 +57,7 @@ bool DagLeaf::isNativeCodeCall() const { bool DagLeaf::isConstantAttr() const { return isSubClassOf("ConstantAttr"); } -bool DagLeaf::isEnumAttrCase() const { - return isSubClassOf("EnumAttrCaseInfo"); -} +bool DagLeaf::isEnumCase() const { return isSubClassOf("EnumAttrCaseInfo"); } bool DagLeaf::isStringAttr() const { return isa(def); } @@ -74,9 +72,9 @@ ConstantAttr DagLeaf::getAsConstantAttr() const { return ConstantAttr(cast(def)); } -EnumAttrCase DagLeaf::getAsEnumAttrCase() const { - assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case"); - return EnumAttrCase(cast(def)); +EnumCase DagLeaf::getAsEnumCase() const { + assert(isEnumCase() && "the DAG leaf must be an enum attribute case"); + return EnumCase(cast(def)); } std::string DagLeaf::getConditionTemplate() const { @@ -776,7 +774,7 @@ void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, verifyBind(infoMap.bindValue(treeArgName), treeArgName); } else { auto constraint = leaf.getAsConstraint(); - bool isAttr = leaf.isAttrMatcher() || leaf.isEnumAttrCase() || + bool isAttr = leaf.isAttrMatcher() || leaf.isEnumCase() || leaf.isConstantAttr() || constraint.getKind() == Constraint::Kind::CK_Attr; diff --git a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp index 3f660ae151c7..5d4d9e90fff6 100644 --- a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp @@ -15,6 +15,7 @@ #include "mlir/TableGen/AttrOrTypeDef.h" #include "mlir/TableGen/Attribute.h" #include "mlir/TableGen/Dialect.h" +#include "mlir/TableGen/EnumInfo.h" #include "mlir/TableGen/GenInfo.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/TableGen/Record.h" @@ -44,14 +45,14 @@ static std::string makePythonEnumCaseName(StringRef name) { } /// Emits the Python class for the given enum. -static void emitEnumClass(EnumAttr enumAttr, raw_ostream &os) { - os << formatv("class {0}({1}):\n", enumAttr.getEnumClassName(), - enumAttr.isBitEnum() ? "IntFlag" : "IntEnum"); - if (!enumAttr.getSummary().empty()) - os << formatv(" \"\"\"{0}\"\"\"\n", enumAttr.getSummary()); +static void emitEnumClass(EnumInfo enumInfo, raw_ostream &os) { + os << formatv("class {0}({1}):\n", enumInfo.getEnumClassName(), + enumInfo.isBitEnum() ? "IntFlag" : "IntEnum"); + if (!enumInfo.getSummary().empty()) + os << formatv(" \"\"\"{0}\"\"\"\n", enumInfo.getSummary()); os << "\n"; - for (const EnumAttrCase &enumCase : enumAttr.getAllCases()) { + for (const EnumCase &enumCase : enumInfo.getAllCases()) { os << formatv(" {0} = {1}\n", makePythonEnumCaseName(enumCase.getSymbol()), enumCase.getValue() >= 0 ? std::to_string(enumCase.getValue()) @@ -60,7 +61,7 @@ static void emitEnumClass(EnumAttr enumAttr, raw_ostream &os) { os << "\n"; - if (enumAttr.isBitEnum()) { + if (enumInfo.isBitEnum()) { os << formatv(" def __iter__(self):\n" " return iter([case for case in type(self) if " "(self & case) is case])\n"); @@ -70,17 +71,17 @@ static void emitEnumClass(EnumAttr enumAttr, raw_ostream &os) { } os << formatv(" def __str__(self):\n"); - if (enumAttr.isBitEnum()) + if (enumInfo.isBitEnum()) os << formatv(" if len(self) > 1:\n" " return \"{0}\".join(map(str, self))\n", - enumAttr.getDef().getValueAsString("separator")); - for (const EnumAttrCase &enumCase : enumAttr.getAllCases()) { - os << formatv(" if self is {0}.{1}:\n", enumAttr.getEnumClassName(), + enumInfo.getDef().getValueAsString("separator")); + for (const EnumCase &enumCase : enumInfo.getAllCases()) { + os << formatv(" if self is {0}.{1}:\n", enumInfo.getEnumClassName(), makePythonEnumCaseName(enumCase.getSymbol())); os << formatv(" return \"{0}\"\n", enumCase.getStr()); } os << formatv(" raise ValueError(\"Unknown {0} enum entry.\")\n\n\n", - enumAttr.getEnumClassName()); + enumInfo.getEnumClassName()); os << "\n"; } @@ -98,17 +99,21 @@ static bool extractUIntBitwidth(StringRef uintType, int64_t &bitwidth) { /// Emits an attribute builder for the given enum attribute to support automatic /// conversion between enum values and attributes in Python. Returns /// `false` on success, `true` on failure. -static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) { +static bool emitAttributeBuilder(const EnumInfo &enumInfo, raw_ostream &os) { + std::optional enumAttrInfo = enumInfo.asEnumAttr(); + if (!enumAttrInfo) + return false; + int64_t bitwidth; - if (extractUIntBitwidth(enumAttr.getUnderlyingType(), bitwidth)) { + if (extractUIntBitwidth(enumInfo.getUnderlyingType(), bitwidth)) { llvm::errs() << "failed to identify bitwidth of " - << enumAttr.getUnderlyingType(); + << enumInfo.getUnderlyingType(); return true; } - os << formatv("@register_attribute_builder(\"{0}\")\n", - enumAttr.getAttrDefName()); - os << formatv("def _{0}(x, context):\n", enumAttr.getAttrDefName().lower()); + enumAttrInfo->getAttrDefName()); + os << formatv("def _{0}(x, context):\n", + enumAttrInfo->getAttrDefName().lower()); os << formatv(" return " "_ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless({0}, " "context=context), int(x))\n\n", @@ -136,9 +141,9 @@ static bool emitPythonEnums(const RecordKeeper &records, raw_ostream &os) { os << fileHeader; for (const Record *it : records.getAllDerivedDefinitionsIfDefined("EnumAttrInfo")) { - EnumAttr enumAttr(*it); - emitEnumClass(enumAttr, os); - emitAttributeBuilder(enumAttr, os); + EnumInfo enumInfo(*it); + emitEnumClass(enumInfo, os); + emitAttributeBuilder(enumInfo, os); } for (const Record *it : records.getAllDerivedDefinitionsIfDefined("EnumAttr")) { diff --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp index d11aa9b27c2d..fa6fad156b74 100644 --- a/mlir/tools/mlir-tblgen/EnumsGen.cpp +++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp @@ -12,6 +12,7 @@ #include "FormatGen.h" #include "mlir/TableGen/Attribute.h" +#include "mlir/TableGen/EnumInfo.h" #include "mlir/TableGen/Format.h" #include "mlir/TableGen/GenInfo.h" #include "llvm/ADT/BitVector.h" @@ -30,8 +31,8 @@ using llvm::Record; using llvm::RecordKeeper; using namespace mlir; using mlir::tblgen::Attribute; -using mlir::tblgen::EnumAttr; -using mlir::tblgen::EnumAttrCase; +using mlir::tblgen::EnumCase; +using mlir::tblgen::EnumInfo; using mlir::tblgen::FmtContext; using mlir::tblgen::tgfmt; @@ -45,7 +46,7 @@ static std::string makeIdentifier(StringRef str) { static void emitEnumClass(const Record &enumDef, StringRef enumName, StringRef underlyingType, StringRef description, - const std::vector &enumerants, + const std::vector &enumerants, raw_ostream &os) { os << "// " << description << "\n"; os << "enum class " << enumName; @@ -66,12 +67,13 @@ static void emitEnumClass(const Record &enumDef, StringRef enumName, os << "};\n\n"; } -static void emitParserPrinter(const EnumAttr &enumAttr, StringRef qualName, +static void emitParserPrinter(const EnumInfo &enumInfo, StringRef qualName, StringRef cppNamespace, raw_ostream &os) { - if (enumAttr.getUnderlyingType().empty() || - enumAttr.getConstBuilderTemplate().empty()) + std::optional enumAttrInfo = enumInfo.asEnumAttr(); + if (enumInfo.getUnderlyingType().empty() || + (enumAttrInfo && enumAttrInfo->getConstBuilderTemplate().empty())) return; - auto cases = enumAttr.getAllCases(); + auto cases = enumInfo.getAllCases(); // Check which cases shouldn't be printed using a keyword. llvm::BitVector nonKeywordCases(cases.size()); @@ -128,8 +130,9 @@ namespace llvm { inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, {0} value) {{ auto valueStr = stringifyEnum(value); )"; + os << formatv(parsedAndPrinterStart, qualName, cppNamespace, - enumAttr.getSummary()); + enumInfo.getSummary()); // If all cases require a string, always wrap. if (nonKeywordCases.all()) { @@ -157,9 +160,9 @@ inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, {0} value) {{ // If this is a bit enum, conservatively print the string form if the value // is not a power of two (i.e. not a single bit case) and not a known case. - } else if (enumAttr.isBitEnum()) { + } else if (enumInfo.isBitEnum()) { // Process the known multi-bit cases that use valid keywords. - SmallVector validMultiBitCases; + SmallVector validMultiBitCases; for (auto [index, caseVal] : llvm::enumerate(cases)) { uint64_t value = caseVal.getValue(); if (value && !llvm::has_single_bit(value) && !nonKeywordCases.test(index)) @@ -167,7 +170,7 @@ inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, {0} value) {{ } if (!validMultiBitCases.empty()) { os << " switch (value) {\n"; - for (EnumAttrCase *caseVal : validMultiBitCases) { + for (EnumCase *caseVal : validMultiBitCases) { StringRef symbol = caseVal->getSymbol(); os << llvm::formatv(" case {0}::{1}:\n", qualName, llvm::isDigit(symbol.front()) ? ("_" + symbol) @@ -224,9 +227,9 @@ template<> struct DenseMapInfo<{0}> {{ } static void emitMaxValueFn(const Record &enumDef, raw_ostream &os) { - EnumAttr enumAttr(enumDef); - StringRef maxEnumValFnName = enumAttr.getMaxEnumValFnName(); - auto enumerants = enumAttr.getAllCases(); + EnumInfo enumInfo(enumDef); + StringRef maxEnumValFnName = enumInfo.getMaxEnumValFnName(); + auto enumerants = enumInfo.getAllCases(); unsigned maxEnumVal = 0; for (const auto &enumerant : enumerants) { @@ -245,10 +248,10 @@ static void emitMaxValueFn(const Record &enumDef, raw_ostream &os) { os << "}\n\n"; } -// Returns the EnumAttrCase whose value is zero if exists; returns std::nullopt +// Returns the EnumCase whose value is zero if exists; returns std::nullopt // otherwise. -static std::optional -getAllBitsUnsetCase(llvm::ArrayRef cases) { +static std::optional +getAllBitsUnsetCase(llvm::ArrayRef cases) { for (auto attrCase : cases) { if (attrCase.getValue() == 0) return attrCase; @@ -268,9 +271,9 @@ getAllBitsUnsetCase(llvm::ArrayRef cases) { // inline constexpr bitEnumSet( bits, bit, // bool value=true); static void emitOperators(const Record &enumDef, raw_ostream &os) { - EnumAttr enumAttr(enumDef); - StringRef enumName = enumAttr.getEnumClassName(); - std::string underlyingType = std::string(enumAttr.getUnderlyingType()); + EnumInfo enumInfo(enumDef); + StringRef enumName = enumInfo.getEnumClassName(); + std::string underlyingType = std::string(enumInfo.getUnderlyingType()); int64_t validBits = enumDef.getValueAsInt("validBits"); const char *const operators = R"( inline constexpr {0} operator|({0} a, {0} b) {{ @@ -303,11 +306,11 @@ inline constexpr {0} bitEnumSet({0} bits, {0} bit, /*optional*/bool value=true) } static void emitSymToStrFnForIntEnum(const Record &enumDef, raw_ostream &os) { - EnumAttr enumAttr(enumDef); - StringRef enumName = enumAttr.getEnumClassName(); - StringRef symToStrFnName = enumAttr.getSymbolToStringFnName(); - StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType(); - auto enumerants = enumAttr.getAllCases(); + EnumInfo enumInfo(enumDef); + StringRef enumName = enumInfo.getEnumClassName(); + StringRef symToStrFnName = enumInfo.getSymbolToStringFnName(); + StringRef symToStrFnRetType = enumInfo.getSymbolToStringFnRetType(); + auto enumerants = enumInfo.getAllCases(); os << formatv("{2} {1}({0} val) {{\n", enumName, symToStrFnName, symToStrFnRetType); @@ -324,19 +327,19 @@ static void emitSymToStrFnForIntEnum(const Record &enumDef, raw_ostream &os) { } static void emitSymToStrFnForBitEnum(const Record &enumDef, raw_ostream &os) { - EnumAttr enumAttr(enumDef); - StringRef enumName = enumAttr.getEnumClassName(); - StringRef symToStrFnName = enumAttr.getSymbolToStringFnName(); - StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType(); + EnumInfo enumInfo(enumDef); + StringRef enumName = enumInfo.getEnumClassName(); + StringRef symToStrFnName = enumInfo.getSymbolToStringFnName(); + StringRef symToStrFnRetType = enumInfo.getSymbolToStringFnRetType(); StringRef separator = enumDef.getValueAsString("separator"); - auto enumerants = enumAttr.getAllCases(); + auto enumerants = enumInfo.getAllCases(); auto allBitsUnsetCase = getAllBitsUnsetCase(enumerants); os << formatv("{2} {1}({0} symbol) {{\n", enumName, symToStrFnName, symToStrFnRetType); os << formatv(" auto val = static_cast<{0}>(symbol);\n", - enumAttr.getUnderlyingType()); + enumInfo.getUnderlyingType()); // If we have unknown bit set, return an empty string to signal errors. int64_t validBits = enumDef.getValueAsInt("validBits"); os << formatv(" assert({0}u == ({0}u | val) && \"invalid bits set in bit " @@ -365,21 +368,23 @@ static void emitSymToStrFnForBitEnum(const Record &enumDef, raw_ostream &os) { )"; // Optionally elide bits that are members of groups that will also be printed // for more concise output. - if (enumAttr.printBitEnumPrimaryGroups()) { + if (enumInfo.printBitEnumPrimaryGroups()) { os << " // Print bit enum groups before individual bits\n"; // Emit comparisons for group bit cases in reverse tablegen declaration // order, removing bits for groups with all bits present. for (const auto &enumerant : llvm::reverse(enumerants)) { if ((enumerant.getValue() != 0) && - enumerant.getDef().isSubClassOf("BitEnumAttrCaseGroup")) { + (enumerant.getDef().isSubClassOf("BitEnumCaseGroup") || + enumerant.getDef().isSubClassOf("BitEnumAttrCaseGroup"))) { os << formatv(formatCompareRemove, enumerant.getValue(), - enumerant.getStr(), enumAttr.getUnderlyingType()); + enumerant.getStr(), enumInfo.getUnderlyingType()); } } // Emit comparisons for individual bit cases in tablegen declaration order. for (const auto &enumerant : enumerants) { if ((enumerant.getValue() != 0) && - enumerant.getDef().isSubClassOf("BitEnumAttrCaseBit")) + (enumerant.getDef().isSubClassOf("BitEnumCaseBit") || + enumerant.getDef().isSubClassOf("BitEnumAttrCaseBit"))) os << formatv(formatCompare, enumerant.getValue(), enumerant.getStr()); } } else { @@ -396,10 +401,10 @@ static void emitSymToStrFnForBitEnum(const Record &enumDef, raw_ostream &os) { } static void emitStrToSymFnForIntEnum(const Record &enumDef, raw_ostream &os) { - EnumAttr enumAttr(enumDef); - StringRef enumName = enumAttr.getEnumClassName(); - StringRef strToSymFnName = enumAttr.getStringToSymbolFnName(); - auto enumerants = enumAttr.getAllCases(); + EnumInfo enumInfo(enumDef); + StringRef enumName = enumInfo.getEnumClassName(); + StringRef strToSymFnName = enumInfo.getStringToSymbolFnName(); + auto enumerants = enumInfo.getAllCases(); os << formatv("::std::optional<{0}> {1}(::llvm::StringRef str) {{\n", enumName, strToSymFnName); @@ -416,13 +421,13 @@ static void emitStrToSymFnForIntEnum(const Record &enumDef, raw_ostream &os) { } static void emitStrToSymFnForBitEnum(const Record &enumDef, raw_ostream &os) { - EnumAttr enumAttr(enumDef); - StringRef enumName = enumAttr.getEnumClassName(); - std::string underlyingType = std::string(enumAttr.getUnderlyingType()); - StringRef strToSymFnName = enumAttr.getStringToSymbolFnName(); + EnumInfo enumInfo(enumDef); + StringRef enumName = enumInfo.getEnumClassName(); + std::string underlyingType = std::string(enumInfo.getUnderlyingType()); + StringRef strToSymFnName = enumInfo.getStringToSymbolFnName(); StringRef separator = enumDef.getValueAsString("separator"); StringRef separatorTrimmed = separator.trim(); - auto enumerants = enumAttr.getAllCases(); + auto enumerants = enumInfo.getAllCases(); auto allBitsUnsetCase = getAllBitsUnsetCase(enumerants); os << formatv("::std::optional<{0}> {1}(::llvm::StringRef str) {{\n", @@ -463,17 +468,16 @@ static void emitStrToSymFnForBitEnum(const Record &enumDef, raw_ostream &os) { static void emitUnderlyingToSymFnForIntEnum(const Record &enumDef, raw_ostream &os) { - EnumAttr enumAttr(enumDef); - StringRef enumName = enumAttr.getEnumClassName(); - std::string underlyingType = std::string(enumAttr.getUnderlyingType()); - StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName(); - auto enumerants = enumAttr.getAllCases(); + EnumInfo enumInfo(enumDef); + StringRef enumName = enumInfo.getEnumClassName(); + std::string underlyingType = std::string(enumInfo.getUnderlyingType()); + StringRef underlyingToSymFnName = enumInfo.getUnderlyingToSymbolFnName(); + auto enumerants = enumInfo.getAllCases(); // Avoid generating the underlying value to symbol conversion function if // there is an enumerant without explicit value. - if (llvm::any_of(enumerants, [](EnumAttrCase enumerant) { - return enumerant.getValue() < 0; - })) + if (llvm::any_of(enumerants, + [](EnumCase enumerant) { return enumerant.getValue() < 0; })) return; os << formatv("::std::optional<{0}> {1}({2} value) {{\n", enumName, @@ -493,10 +497,10 @@ static void emitUnderlyingToSymFnForIntEnum(const Record &enumDef, } static void emitSpecializedAttrDef(const Record &enumDef, raw_ostream &os) { - EnumAttr enumAttr(enumDef); - StringRef enumName = enumAttr.getEnumClassName(); - StringRef attrClassName = enumAttr.getSpecializedAttrClassName(); - const Record *baseAttrDef = enumAttr.getBaseAttrClass(); + EnumInfo enumInfo(enumDef); + StringRef enumName = enumInfo.getEnumClassName(); + StringRef attrClassName = enumInfo.getSpecializedAttrClassName(); + const Record *baseAttrDef = enumInfo.getBaseAttrClass(); Attribute baseAttr(baseAttrDef); // Emit classof method @@ -520,7 +524,7 @@ static void emitSpecializedAttrDef(const Record &enumDef, raw_ostream &os) { os << formatv("{0} {0}::get(::mlir::MLIRContext *context, {1} val) {{\n", attrClassName, enumName); - StringRef underlyingType = enumAttr.getUnderlyingType(); + StringRef underlyingType = enumInfo.getUnderlyingType(); // Assuming that it is IntegerAttr constraint int64_t bitwidth = 64; @@ -552,11 +556,11 @@ static void emitSpecializedAttrDef(const Record &enumDef, raw_ostream &os) { static void emitUnderlyingToSymFnForBitEnum(const Record &enumDef, raw_ostream &os) { - EnumAttr enumAttr(enumDef); - StringRef enumName = enumAttr.getEnumClassName(); - std::string underlyingType = std::string(enumAttr.getUnderlyingType()); - StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName(); - auto enumerants = enumAttr.getAllCases(); + EnumInfo enumInfo(enumDef); + StringRef enumName = enumInfo.getEnumClassName(); + std::string underlyingType = std::string(enumInfo.getUnderlyingType()); + StringRef underlyingToSymFnName = enumInfo.getUnderlyingToSymbolFnName(); + auto enumerants = enumInfo.getAllCases(); auto allBitsUnsetCase = getAllBitsUnsetCase(enumerants); os << formatv("::std::optional<{0}> {1}({2} value) {{\n", enumName, @@ -574,16 +578,16 @@ static void emitUnderlyingToSymFnForBitEnum(const Record &enumDef, } static void emitEnumDecl(const Record &enumDef, raw_ostream &os) { - EnumAttr enumAttr(enumDef); - StringRef enumName = enumAttr.getEnumClassName(); - StringRef cppNamespace = enumAttr.getCppNamespace(); - std::string underlyingType = std::string(enumAttr.getUnderlyingType()); - StringRef description = enumAttr.getSummary(); - StringRef strToSymFnName = enumAttr.getStringToSymbolFnName(); - StringRef symToStrFnName = enumAttr.getSymbolToStringFnName(); - StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType(); - StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName(); - auto enumerants = enumAttr.getAllCases(); + EnumInfo enumInfo(enumDef); + StringRef enumName = enumInfo.getEnumClassName(); + StringRef cppNamespace = enumInfo.getCppNamespace(); + std::string underlyingType = std::string(enumInfo.getUnderlyingType()); + StringRef description = enumInfo.getSummary(); + StringRef strToSymFnName = enumInfo.getStringToSymbolFnName(); + StringRef symToStrFnName = enumInfo.getSymbolToStringFnName(); + StringRef symToStrFnRetType = enumInfo.getSymbolToStringFnRetType(); + StringRef underlyingToSymFnName = enumInfo.getUnderlyingToSymbolFnName(); + auto enumerants = enumInfo.getAllCases(); SmallVector namespaces; llvm::SplitString(cppNamespace, namespaces, "::"); @@ -595,7 +599,7 @@ static void emitEnumDecl(const Record &enumDef, raw_ostream &os) { emitEnumClass(enumDef, enumName, underlyingType, description, enumerants, os); // Emit conversion function declarations - if (llvm::all_of(enumerants, [](EnumAttrCase enumerant) { + if (llvm::all_of(enumerants, [](EnumCase enumerant) { return enumerant.getValue() >= 0; })) { os << formatv( @@ -606,7 +610,7 @@ static void emitEnumDecl(const Record &enumDef, raw_ostream &os) { os << formatv("::std::optional<{0}> {1}(::llvm::StringRef);\n", enumName, strToSymFnName); - if (enumAttr.isBitEnum()) { + if (enumInfo.isBitEnum()) { emitOperators(enumDef, os); } else { emitMaxValueFn(enumDef, os); @@ -644,8 +648,8 @@ public: {0} getValue() const; }; )"; - if (enumAttr.genSpecializedAttr()) { - StringRef attrClassName = enumAttr.getSpecializedAttrClassName(); + if (enumInfo.genSpecializedAttr()) { + StringRef attrClassName = enumInfo.getSpecializedAttrClassName(); StringRef baseAttrClassName = "IntegerAttr"; os << formatv(attrClassDecl, enumName, attrClassName, baseAttrClassName); } @@ -656,7 +660,7 @@ public: // Generate a generic parser and printer for the enum. std::string qualName = std::string(formatv("{0}::{1}", cppNamespace, enumName)); - emitParserPrinter(enumAttr, qualName, cppNamespace, os); + emitParserPrinter(enumInfo, qualName, cppNamespace, os); // Emit DenseMapInfo for this enum class emitDenseMapInfo(qualName, underlyingType, cppNamespace, os); @@ -673,8 +677,8 @@ static bool emitEnumDecls(const RecordKeeper &records, raw_ostream &os) { } static void emitEnumDef(const Record &enumDef, raw_ostream &os) { - EnumAttr enumAttr(enumDef); - StringRef cppNamespace = enumAttr.getCppNamespace(); + EnumInfo enumInfo(enumDef); + StringRef cppNamespace = enumInfo.getCppNamespace(); SmallVector namespaces; llvm::SplitString(cppNamespace, namespaces, "::"); @@ -682,7 +686,7 @@ static void emitEnumDef(const Record &enumDef, raw_ostream &os) { for (auto ns : namespaces) os << "namespace " << ns << " {\n"; - if (enumAttr.isBitEnum()) { + if (enumInfo.isBitEnum()) { emitSymToStrFnForBitEnum(enumDef, os); emitStrToSymFnForBitEnum(enumDef, os); emitUnderlyingToSymFnForBitEnum(enumDef, os); @@ -692,7 +696,7 @@ static void emitEnumDef(const Record &enumDef, raw_ostream &os) { emitUnderlyingToSymFnForIntEnum(enumDef, os); } - if (enumAttr.genSpecializedAttr()) + if (enumInfo.genSpecializedAttr()) emitSpecializedAttrDef(enumDef, os); for (auto ns : llvm::reverse(namespaces)) diff --git a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp index 9e19f479d673..96af14d36817 100644 --- a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp +++ b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp @@ -13,6 +13,7 @@ #include "mlir/TableGen/Argument.h" #include "mlir/TableGen/Attribute.h" +#include "mlir/TableGen/EnumInfo.h" #include "mlir/TableGen/GenInfo.h" #include "mlir/TableGen/Operator.h" @@ -335,13 +336,13 @@ static bool emitOpMLIRBuilders(const RecordKeeper &records, raw_ostream &os) { namespace { // Wrapper class around a Tablegen definition of an LLVM enum attribute case. -class LLVMEnumAttrCase : public tblgen::EnumAttrCase { +class LLVMEnumCase : public tblgen::EnumCase { public: - using tblgen::EnumAttrCase::EnumAttrCase; + using tblgen::EnumCase::EnumCase; // Constructs a case from a non LLVM-specific enum attribute case. - explicit LLVMEnumAttrCase(const tblgen::EnumAttrCase &other) - : tblgen::EnumAttrCase(&other.getDef()) {} + explicit LLVMEnumCase(const tblgen::EnumCase &other) + : tblgen::EnumCase(&other.getDef()) {} // Returns the C++ enumerant for the LLVM API. StringRef getLLVMEnumerant() const { @@ -350,9 +351,9 @@ public: }; // Wraper class around a Tablegen definition of an LLVM enum attribute. -class LLVMEnumAttr : public tblgen::EnumAttr { +class LLVMEnumInfo : public tblgen::EnumInfo { public: - using tblgen::EnumAttr::EnumAttr; + using tblgen::EnumInfo::EnumInfo; // Returns the C++ enum name for the LLVM API. StringRef getLLVMClassName() const { @@ -360,19 +361,19 @@ public: } // Returns all associated cases viewed as LLVM-specific enum cases. - std::vector getAllCases() const { - std::vector cases; + std::vector getAllCases() const { + std::vector cases; - for (auto &c : tblgen::EnumAttr::getAllCases()) + for (auto &c : tblgen::EnumInfo::getAllCases()) cases.emplace_back(c); return cases; } - std::vector getAllUnsupportedCases() const { + std::vector getAllUnsupportedCases() const { const auto *inits = def->getValueAsListInit("unsupported"); - std::vector cases; + std::vector cases; cases.reserve(inits->size()); for (const llvm::Init *init : *inits) @@ -383,9 +384,9 @@ public: }; // Wraper class around a Tablegen definition of a C-style LLVM enum attribute. -class LLVMCEnumAttr : public tblgen::EnumAttr { +class LLVMCEnumInfo : public tblgen::EnumInfo { public: - using tblgen::EnumAttr::EnumAttr; + using tblgen::EnumInfo::EnumInfo; // Returns the C++ enum name for the LLVM API. StringRef getLLVMClassName() const { @@ -393,10 +394,10 @@ public: } // Returns all associated cases viewed as LLVM-specific enum cases. - std::vector getAllCases() const { - std::vector cases; + std::vector getAllCases() const { + std::vector cases; - for (auto &c : tblgen::EnumAttr::getAllCases()) + for (auto &c : tblgen::EnumInfo::getAllCases()) cases.emplace_back(c); return cases; @@ -408,10 +409,10 @@ public: // switch-based logic to convert from the MLIR LLVM dialect enum attribute case // (Enum) to the corresponding LLVM API enumerant static void emitOneEnumToConversion(const Record *record, raw_ostream &os) { - LLVMEnumAttr enumAttr(record); - StringRef llvmClass = enumAttr.getLLVMClassName(); - StringRef cppClassName = enumAttr.getEnumClassName(); - StringRef cppNamespace = enumAttr.getCppNamespace(); + LLVMEnumInfo enumInfo(record); + StringRef llvmClass = enumInfo.getLLVMClassName(); + StringRef cppClassName = enumInfo.getEnumClassName(); + StringRef cppNamespace = enumInfo.getCppNamespace(); // Emit the function converting the enum attribute to its LLVM counterpart. os << formatv( @@ -419,7 +420,7 @@ static void emitOneEnumToConversion(const Record *record, raw_ostream &os) { llvmClass, cppClassName, cppNamespace); os << " switch (value) {\n"; - for (const auto &enumerant : enumAttr.getAllCases()) { + for (const auto &enumerant : enumInfo.getAllCases()) { StringRef llvmEnumerant = enumerant.getLLVMEnumerant(); StringRef cppEnumerant = enumerant.getSymbol(); os << formatv(" case {0}::{1}::{2}:\n", cppNamespace, cppClassName, @@ -429,7 +430,7 @@ static void emitOneEnumToConversion(const Record *record, raw_ostream &os) { os << " }\n"; os << formatv(" llvm_unreachable(\"unknown {0} type\");\n", - enumAttr.getEnumClassName()); + enumInfo.getEnumClassName()); os << "}\n\n"; } @@ -437,7 +438,7 @@ static void emitOneEnumToConversion(const Record *record, raw_ostream &os) { // switch-based logic to convert from the MLIR LLVM dialect enum attribute case // (Enum) to the corresponding LLVM API C-style enumerant static void emitOneCEnumToConversion(const Record *record, raw_ostream &os) { - LLVMCEnumAttr enumAttr(record); + LLVMCEnumInfo enumAttr(record); StringRef llvmClass = enumAttr.getLLVMClassName(); StringRef cppClassName = enumAttr.getEnumClassName(); StringRef cppNamespace = enumAttr.getCppNamespace(); @@ -467,10 +468,10 @@ static void emitOneCEnumToConversion(const Record *record, raw_ostream &os) { // containing switch-based logic to convert from the LLVM API enumerant to MLIR // LLVM dialect enum attribute (Enum). static void emitOneEnumFromConversion(const Record *record, raw_ostream &os) { - LLVMEnumAttr enumAttr(record); - StringRef llvmClass = enumAttr.getLLVMClassName(); - StringRef cppClassName = enumAttr.getEnumClassName(); - StringRef cppNamespace = enumAttr.getCppNamespace(); + LLVMEnumInfo enumInfo(record); + StringRef llvmClass = enumInfo.getLLVMClassName(); + StringRef cppClassName = enumInfo.getEnumClassName(); + StringRef cppNamespace = enumInfo.getCppNamespace(); // Emit the function converting the enum attribute from its LLVM counterpart. os << formatv("inline LLVM_ATTRIBUTE_UNUSED {0}::{1} convert{1}FromLLVM({2} " @@ -478,23 +479,23 @@ static void emitOneEnumFromConversion(const Record *record, raw_ostream &os) { cppNamespace, cppClassName, llvmClass); os << " switch (value) {\n"; - for (const auto &enumerant : enumAttr.getAllCases()) { + for (const auto &enumerant : enumInfo.getAllCases()) { StringRef llvmEnumerant = enumerant.getLLVMEnumerant(); StringRef cppEnumerant = enumerant.getSymbol(); os << formatv(" case {0}::{1}:\n", llvmClass, llvmEnumerant); os << formatv(" return {0}::{1}::{2};\n", cppNamespace, cppClassName, cppEnumerant); } - for (const auto &enumerant : enumAttr.getAllUnsupportedCases()) { + for (const auto &enumerant : enumInfo.getAllUnsupportedCases()) { StringRef llvmEnumerant = enumerant.getLLVMEnumerant(); os << formatv(" case {0}::{1}:\n", llvmClass, llvmEnumerant); os << formatv(" llvm_unreachable(\"unsupported case {0}::{1}\");\n", - enumAttr.getLLVMClassName(), llvmEnumerant); + enumInfo.getLLVMClassName(), llvmEnumerant); } os << " }\n"; os << formatv(" llvm_unreachable(\"unknown {0} type\");", - enumAttr.getLLVMClassName()); + enumInfo.getLLVMClassName()); os << "}\n\n"; } @@ -502,10 +503,10 @@ static void emitOneEnumFromConversion(const Record *record, raw_ostream &os) { // containing switch-based logic to convert from the LLVM API C-style enumerant // to MLIR LLVM dialect enum attribute (Enum). static void emitOneCEnumFromConversion(const Record *record, raw_ostream &os) { - LLVMCEnumAttr enumAttr(record); - StringRef llvmClass = enumAttr.getLLVMClassName(); - StringRef cppClassName = enumAttr.getEnumClassName(); - StringRef cppNamespace = enumAttr.getCppNamespace(); + LLVMCEnumInfo enumInfo(record); + StringRef llvmClass = enumInfo.getLLVMClassName(); + StringRef cppClassName = enumInfo.getEnumClassName(); + StringRef cppNamespace = enumInfo.getCppNamespace(); // Emit the function converting the enum attribute from its LLVM counterpart. os << formatv( @@ -514,7 +515,7 @@ static void emitOneCEnumFromConversion(const Record *record, raw_ostream &os) { cppNamespace, cppClassName); os << " switch (value) {\n"; - for (const auto &enumerant : enumAttr.getAllCases()) { + for (const auto &enumerant : enumInfo.getAllCases()) { StringRef llvmEnumerant = enumerant.getLLVMEnumerant(); StringRef cppEnumerant = enumerant.getSymbol(); os << formatv(" case static_cast({0}::{1}):\n", llvmClass, @@ -525,7 +526,7 @@ static void emitOneCEnumFromConversion(const Record *record, raw_ostream &os) { os << " }\n"; os << formatv(" llvm_unreachable(\"unknown {0} type\");", - enumAttr.getLLVMClassName()); + enumInfo.getLLVMClassName()); os << "}\n\n"; } diff --git a/mlir/tools/mlir-tblgen/OpDocGen.cpp b/mlir/tools/mlir-tblgen/OpDocGen.cpp index dbaad84cda5d..f53aebb302dc 100644 --- a/mlir/tools/mlir-tblgen/OpDocGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDocGen.cpp @@ -17,6 +17,7 @@ #include "mlir/Support/IndentedOstream.h" #include "mlir/TableGen/AttrOrTypeDef.h" #include "mlir/TableGen/Attribute.h" +#include "mlir/TableGen/EnumInfo.h" #include "mlir/TableGen/GenInfo.h" #include "mlir/TableGen/Operator.h" #include "llvm/ADT/DenseMap.h" @@ -384,14 +385,14 @@ static void emitAttrOrTypeDefDoc(const RecordKeeper &records, raw_ostream &os, // Enum Documentation //===----------------------------------------------------------------------===// -static void emitEnumDoc(const EnumAttr &def, raw_ostream &os) { +static void emitEnumDoc(const EnumInfo &def, raw_ostream &os) { os << formatv("\n### {0}\n", def.getEnumClassName()); // Emit the summary if present. emitSummary(def.getSummary(), os); // Emit case documentation. - std::vector cases = def.getAllCases(); + std::vector cases = def.getAllCases(); os << "\n#### Cases:\n\n"; os << "| Symbol | Value | String |\n" << "| :----: | :---: | ------ |"; @@ -406,7 +407,7 @@ static void emitEnumDoc(const EnumAttr &def, raw_ostream &os) { static void emitEnumDoc(const RecordKeeper &records, raw_ostream &os) { os << "\n"; for (const Record *def : records.getAllDerivedDefinitions("EnumAttrInfo")) - emitEnumDoc(EnumAttr(def), os); + emitEnumDoc(EnumInfo(def), os); } //===----------------------------------------------------------------------===// @@ -441,7 +442,7 @@ static void maybeNest(bool nest, llvm::function_ref fn, static void emitBlock(ArrayRef attributes, StringRef inputFilename, ArrayRef attrDefs, ArrayRef ops, ArrayRef types, ArrayRef typeDefs, - ArrayRef enums, raw_ostream &os) { + ArrayRef enums, raw_ostream &os) { if (!ops.empty()) { os << "\n## Operations\n"; emitSourceLink(inputFilename, os); @@ -490,7 +491,7 @@ static void emitBlock(ArrayRef attributes, StringRef inputFilename, if (!enums.empty()) { os << "\n## Enums\n"; - for (const EnumAttr &def : enums) + for (const EnumInfo &def : enums) emitEnumDoc(def, os); } } @@ -499,7 +500,7 @@ static void emitDialectDoc(const Dialect &dialect, StringRef inputFilename, ArrayRef attributes, ArrayRef attrDefs, ArrayRef ops, ArrayRef types, ArrayRef typeDefs, - ArrayRef enums, raw_ostream &os) { + ArrayRef enums, raw_ostream &os) { os << "\n# '" << dialect.getName() << "' Dialect\n"; emitSummary(dialect.getSummary(), os); emitDescription(dialect.getDescription(), os); @@ -532,7 +533,7 @@ static bool emitDialectDoc(const RecordKeeper &records, raw_ostream &os) { std::vector dialectOps; std::vector dialectTypes; std::vector dialectTypeDefs; - std::vector dialectEnums; + std::vector dialectEnums; SmallDenseSet seen; auto addIfNotSeen = [&](const Record *record, const auto &def, auto &vec) { @@ -576,7 +577,7 @@ static bool emitDialectDoc(const RecordKeeper &records, raw_ostream &os) { addIfInDialect(def, Type(def), dialectTypes); dialectEnums.reserve(enumDefs.size()); for (const Record *def : enumDefs) - addIfNotSeen(def, EnumAttr(def), dialectEnums); + addIfNotSeen(def, EnumInfo(def), dialectEnums); // Sort alphabetically ignorning dialect for ops and section name for // sections. diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp index fe724e86d670..3a7a7aaf3a5d 100644 --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -11,6 +11,7 @@ #include "OpClass.h" #include "mlir/Support/LLVM.h" #include "mlir/TableGen/Class.h" +#include "mlir/TableGen/EnumInfo.h" #include "mlir/TableGen/Format.h" #include "mlir/TableGen/Operator.h" #include "mlir/TableGen/Trait.h" @@ -424,17 +425,17 @@ struct OperationFormat { //===----------------------------------------------------------------------===// // Parser Gen -/// Returns true if we can format the given attribute as an EnumAttr in the +/// Returns true if we can format the given attribute as an enum in the /// parser format. static bool canFormatEnumAttr(const NamedAttribute *attr) { Attribute baseAttr = attr->attr.getBaseAttr(); - const EnumAttr *enumAttr = dyn_cast(&baseAttr); - if (!enumAttr) + if (!baseAttr.isEnumAttr()) return false; + EnumInfo enumInfo(&baseAttr.getDef()); // The attribute must have a valid underlying type and a constant builder. - return !enumAttr->getUnderlyingType().empty() && - !enumAttr->getConstBuilderTemplate().empty(); + return !enumInfo.getUnderlyingType().empty() && + !baseAttr.getConstBuilderTemplate().empty(); } /// Returns if we should format the given attribute as an SymbolNameAttr. @@ -1150,21 +1151,21 @@ static void genEnumAttrParser(const NamedAttribute *var, MethodBody &body, FmtContext &attrTypeCtx, bool parseAsOptional, bool useProperties, StringRef opCppClassName) { Attribute baseAttr = var->attr.getBaseAttr(); - const EnumAttr &enumAttr = cast(baseAttr); - std::vector cases = enumAttr.getAllCases(); + EnumInfo enumInfo(&baseAttr.getDef()); + std::vector cases = enumInfo.getAllCases(); // Generate the code for building an attribute for this enum. std::string attrBuilderStr; { llvm::raw_string_ostream os(attrBuilderStr); - os << tgfmt(enumAttr.getConstBuilderTemplate(), &attrTypeCtx, + os << tgfmt(baseAttr.getConstBuilderTemplate(), &attrTypeCtx, "*attrOptional"); } // Build a string containing the cases that can be formatted as a keyword. std::string validCaseKeywordsStr = "{"; llvm::raw_string_ostream validCaseKeywordsOS(validCaseKeywordsStr); - for (const EnumAttrCase &attrCase : cases) + for (const EnumCase &attrCase : cases) if (canFormatStringAsKeyword(attrCase.getStr())) validCaseKeywordsOS << '"' << attrCase.getStr() << "\","; validCaseKeywordsOS.str().back() = '}'; @@ -1194,8 +1195,8 @@ static void genEnumAttrParser(const NamedAttribute *var, MethodBody &body, formatv("result.addAttribute(\"{0}\", {0}Attr);", var->name); } - body << formatv(enumAttrParserCode, var->name, enumAttr.getCppNamespace(), - enumAttr.getStringToSymbolFnName(), attrBuilderStr, + body << formatv(enumAttrParserCode, var->name, enumInfo.getCppNamespace(), + enumInfo.getStringToSymbolFnName(), attrBuilderStr, validCaseKeywordsStr, errorMessage, attrAssignment); } @@ -2264,13 +2265,13 @@ static MethodBody &genTypeOperandPrinter(FormatElement *arg, const Operator &op, static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op, MethodBody &body) { Attribute baseAttr = var->attr.getBaseAttr(); - const EnumAttr &enumAttr = cast(baseAttr); - std::vector cases = enumAttr.getAllCases(); + const EnumInfo enumInfo(&baseAttr.getDef()); + std::vector cases = enumInfo.getAllCases(); body << formatv(enumAttrBeginPrinterCode, (var->attr.isOptional() ? "*" : "") + op.getGetterName(var->name), - enumAttr.getSymbolToStringFnName()); + enumInfo.getSymbolToStringFnName()); // Get a string containing all of the cases that can't be represented with a // keyword. @@ -2283,7 +2284,7 @@ static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op, // Otherwise if this is a bit enum attribute, don't allow cases that may // overlap with other cases. For simplicity sake, only allow cases with a // single bit value. - if (enumAttr.isBitEnum()) { + if (enumInfo.isBitEnum()) { for (auto it : llvm::enumerate(cases)) { int64_t value = it.value().getValue(); if (value < 0 || !llvm::isPowerOf2_64(value)) @@ -2295,8 +2296,8 @@ static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op, // case value to determine when to print in the string form. if (nonKeywordCases.any()) { body << " switch (caseValue) {\n"; - StringRef cppNamespace = enumAttr.getCppNamespace(); - StringRef enumName = enumAttr.getEnumClassName(); + StringRef cppNamespace = enumInfo.getCppNamespace(); + StringRef enumName = enumInfo.getEnumClassName(); for (auto it : llvm::enumerate(cases)) { if (nonKeywordCases.test(it.index())) continue; diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index c74cb9943671..c8a12b9e21b9 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -1377,12 +1377,12 @@ std::string PatternEmitter::handleOpArgument(DagLeaf leaf, return handleConstantAttr(constAttr.getAttribute(), constAttr.getConstantValue()); } - if (leaf.isEnumAttrCase()) { - auto enumCase = leaf.getAsEnumAttrCase(); + if (leaf.isEnumCase()) { + auto enumCase = leaf.getAsEnumCase(); // This is an enum case backed by an IntegerAttr. We need to get its value // to build the constant. std::string val = std::to_string(enumCase.getValue()); - return handleConstantAttr(enumCase, val); + return handleConstantAttr(Attribute(&enumCase.getDef()), val); } LLVM_DEBUG(llvm::dbgs() << "handle argument '" << patArgName << "'\n"); @@ -1782,7 +1782,7 @@ void PatternEmitter::supplyValuesForOpArgs( auto leaf = node.getArgAsLeaf(argIndex); // The argument in the result DAG pattern. auto patArgName = node.getArgName(argIndex); - if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) { + if (leaf.isConstantAttr() || leaf.isEnumCase()) { // TODO: Refactor out into map to avoid recomputing these. if (!isa(opArg)) PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex)); diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp index 75b8829be4da..7a6189c09f42 100644 --- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -13,6 +13,7 @@ #include "mlir/TableGen/Attribute.h" #include "mlir/TableGen/CodeGenHelpers.h" +#include "mlir/TableGen/EnumInfo.h" #include "mlir/TableGen/Format.h" #include "mlir/TableGen/GenInfo.h" #include "mlir/TableGen/Operator.h" @@ -45,8 +46,8 @@ using llvm::SMLoc; using llvm::StringMap; using llvm::StringRef; using mlir::tblgen::Attribute; -using mlir::tblgen::EnumAttr; -using mlir::tblgen::EnumAttrCase; +using mlir::tblgen::EnumCase; +using mlir::tblgen::EnumInfo; using mlir::tblgen::NamedAttribute; using mlir::tblgen::NamedTypeConstraint; using mlir::tblgen::NamespaceEmitter; @@ -335,18 +336,18 @@ static mlir::GenRegistration static void emitAvailabilityQueryForIntEnum(const Record &enumDef, raw_ostream &os) { - EnumAttr enumAttr(enumDef); - StringRef enumName = enumAttr.getEnumClassName(); - std::vector enumerants = enumAttr.getAllCases(); + EnumInfo enumInfo(enumDef); + StringRef enumName = enumInfo.getEnumClassName(); + std::vector enumerants = enumInfo.getAllCases(); // Mapping from availability class name to (enumerant, availability // specification) pairs. - llvm::StringMap, 1>> + llvm::StringMap, 1>> classCaseMap; // Place all availability specifications to their corresponding // availability classes. - for (const EnumAttrCase &enumerant : enumerants) + for (const EnumCase &enumerant : enumerants) for (const Availability &avail : getAvailabilities(enumerant.getDef())) classCaseMap[avail.getClass()].push_back({enumerant, avail}); @@ -359,14 +360,14 @@ static void emitAvailabilityQueryForIntEnum(const Record &enumDef, os << " switch (value) {\n"; for (const auto &caseSpecPair : classCasePair.getValue()) { - EnumAttrCase enumerant = caseSpecPair.first; + EnumCase enumerant = caseSpecPair.first; Availability avail = caseSpecPair.second; os << formatv(" case {0}::{1}: { {2} return {3}({4}); }\n", enumName, enumerant.getSymbol(), avail.getMergeInstancePreparation(), avail.getMergeInstanceType(), avail.getMergeInstance()); } // Only emit default if uncovered cases. - if (classCasePair.getValue().size() < enumAttr.getAllCases().size()) + if (classCasePair.getValue().size() < enumInfo.getAllCases().size()) os << " default: break;\n"; os << " }\n" << " return std::nullopt;\n" @@ -376,19 +377,19 @@ static void emitAvailabilityQueryForIntEnum(const Record &enumDef, static void emitAvailabilityQueryForBitEnum(const Record &enumDef, raw_ostream &os) { - EnumAttr enumAttr(enumDef); - StringRef enumName = enumAttr.getEnumClassName(); - std::string underlyingType = std::string(enumAttr.getUnderlyingType()); - std::vector enumerants = enumAttr.getAllCases(); + EnumInfo enumInfo(enumDef); + StringRef enumName = enumInfo.getEnumClassName(); + std::string underlyingType = std::string(enumInfo.getUnderlyingType()); + std::vector enumerants = enumInfo.getAllCases(); // Mapping from availability class name to (enumerant, availability // specification) pairs. - llvm::StringMap, 1>> + llvm::StringMap, 1>> classCaseMap; // Place all availability specifications to their corresponding // availability classes. - for (const EnumAttrCase &enumerant : enumerants) + for (const EnumCase &enumerant : enumerants) for (const Availability &avail : getAvailabilities(enumerant.getDef())) classCaseMap[avail.getClass()].push_back({enumerant, avail}); @@ -406,7 +407,7 @@ static void emitAvailabilityQueryForBitEnum(const Record &enumDef, os << " switch (value) {\n"; for (const auto &caseSpecPair : classCasePair.getValue()) { - EnumAttrCase enumerant = caseSpecPair.first; + EnumCase enumerant = caseSpecPair.first; Availability avail = caseSpecPair.second; os << formatv(" case {0}::{1}: { {2} return {3}({4}); }\n", enumName, enumerant.getSymbol(), avail.getMergeInstancePreparation(), @@ -420,10 +421,10 @@ static void emitAvailabilityQueryForBitEnum(const Record &enumDef, } static void emitEnumDecl(const Record &enumDef, raw_ostream &os) { - EnumAttr enumAttr(enumDef); - StringRef enumName = enumAttr.getEnumClassName(); - StringRef cppNamespace = enumAttr.getCppNamespace(); - auto enumerants = enumAttr.getAllCases(); + EnumInfo enumInfo(enumDef); + StringRef enumName = enumInfo.getEnumClassName(); + StringRef cppNamespace = enumInfo.getCppNamespace(); + auto enumerants = enumInfo.getAllCases(); llvm::SmallVector namespaces; llvm::SplitString(cppNamespace, namespaces, "::"); @@ -435,7 +436,7 @@ static void emitEnumDecl(const Record &enumDef, raw_ostream &os) { // Place all availability specifications to their corresponding // availability classes. - for (const EnumAttrCase &enumerant : enumerants) + for (const EnumCase &enumerant : enumerants) for (const Availability &avail : getAvailabilities(enumerant.getDef())) { StringRef className = avail.getClass(); if (handledClasses.count(className)) @@ -462,8 +463,8 @@ static bool emitEnumDecls(const RecordKeeper &records, raw_ostream &os) { } static void emitEnumDef(const Record &enumDef, raw_ostream &os) { - EnumAttr enumAttr(enumDef); - StringRef cppNamespace = enumAttr.getCppNamespace(); + EnumInfo enumInfo(enumDef); + StringRef cppNamespace = enumInfo.getCppNamespace(); llvm::SmallVector namespaces; llvm::SplitString(cppNamespace, namespaces, "::"); @@ -471,7 +472,7 @@ static void emitEnumDef(const Record &enumDef, raw_ostream &os) { for (auto ns : namespaces) os << "namespace " << ns << " {\n"; - if (enumAttr.isBitEnum()) { + if (enumInfo.isBitEnum()) { emitAvailabilityQueryForBitEnum(enumDef, os); } else { emitAvailabilityQueryForIntEnum(enumDef, os); @@ -535,7 +536,7 @@ static void emitAttributeSerialization(const Attribute &attr, os << tabs << formatv("if (auto attr = {0}->getAttr(\"{1}\")) {{\n", opVar, attrName); if (llvm::is_contained(constantIdEnumAttrs, attr.getAttrDefName())) { - EnumAttr baseEnum(attr.getDef().getValueAsDef("enum")); + EnumInfo baseEnum(attr.getDef().getValueAsDef("enum")); os << tabs << formatv(" {0}.push_back(prepareConstantInt({1}.getLoc(), " "Builder({1}).getI32IntegerAttr(static_cast(" @@ -544,7 +545,7 @@ static void emitAttributeSerialization(const Attribute &attr, baseEnum.getEnumClassName()); } else if (attr.isSubClassOf("SPIRV_BitEnumAttr") || attr.isSubClassOf("SPIRV_I32EnumAttr")) { - EnumAttr baseEnum(attr.getDef().getValueAsDef("enum")); + EnumInfo baseEnum(attr.getDef().getValueAsDef("enum")); os << tabs << formatv(" {0}.push_back(static_cast(" "::llvm::cast<{1}::{2}Attr>(attr).getValue()));\n", @@ -831,7 +832,7 @@ static void emitAttributeDeserialization(const Attribute &attr, StringRef words, StringRef wordIndex, raw_ostream &os) { if (llvm::is_contained(constantIdEnumAttrs, attr.getAttrDefName())) { - EnumAttr baseEnum(attr.getDef().getValueAsDef("enum")); + EnumInfo baseEnum(attr.getDef().getValueAsDef("enum")); os << tabs << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", " "opBuilder.getAttr<{2}::{3}Attr>(static_cast<{2}::{3}>(" @@ -840,7 +841,7 @@ static void emitAttributeDeserialization(const Attribute &attr, baseEnum.getEnumClassName(), words, wordIndex); } else if (attr.isSubClassOf("SPIRV_BitEnumAttr") || attr.isSubClassOf("SPIRV_I32EnumAttr")) { - EnumAttr baseEnum(attr.getDef().getValueAsDef("enum")); + EnumInfo baseEnum(attr.getDef().getValueAsDef("enum")); os << tabs << formatv(" {0}.push_back(opBuilder.getNamedAttr(\"{1}\", " "opBuilder.getAttr<{2}::{3}Attr>(" @@ -1246,9 +1247,9 @@ static void emitEnumGetAttrNameFnDecl(raw_ostream &os) { "attributeName();\n"); } -static void emitEnumGetAttrNameFnDefn(const EnumAttr &enumAttr, +static void emitEnumGetAttrNameFnDefn(const EnumInfo &enumInfo, raw_ostream &os) { - auto enumName = enumAttr.getEnumClassName(); + auto enumName = enumInfo.getEnumClassName(); os << formatv("template <> inline StringRef attributeName<{0}>() {{\n", enumName); os << " " @@ -1266,8 +1267,8 @@ static bool emitAttrUtils(const RecordKeeper &records, raw_ostream &os) { os << "#define MLIR_DIALECT_SPIRV_IR_ATTR_UTILS_H_\n"; emitEnumGetAttrNameFnDecl(os); for (const auto *def : defs) { - EnumAttr enumAttr(*def); - emitEnumGetAttrNameFnDefn(enumAttr, os); + EnumInfo enumInfo(*def); + emitEnumGetAttrNameFnDefn(enumInfo, os); } os << "#endif // MLIR_DIALECT_SPIRV_IR_ATTR_UTILS_H\n"; return false; @@ -1306,9 +1307,9 @@ static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) { if (!namedAttr.attr.isSubClassOf("SPIRV_BitEnumAttr") && !namedAttr.attr.isSubClassOf("SPIRV_I32EnumAttr")) continue; - EnumAttr enumAttr(namedAttr.attr.getDef().getValueAsDef("enum")); + EnumInfo enumInfo(namedAttr.attr.getDef().getValueAsDef("enum")); - for (const EnumAttrCase &enumerant : enumAttr.getAllCases()) + for (const EnumCase &enumerant : enumInfo.getAllCases()) for (const Availability &caseAvail : getAvailabilities(enumerant.getDef())) availClasses.try_emplace(caseAvail.getClass(), caseAvail); @@ -1348,14 +1349,14 @@ static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) { if (!namedAttr.attr.isSubClassOf("SPIRV_BitEnumAttr") && !namedAttr.attr.isSubClassOf("SPIRV_I32EnumAttr")) continue; - EnumAttr enumAttr(namedAttr.attr.getDef().getValueAsDef("enum")); + EnumInfo enumInfo(namedAttr.attr.getDef().getValueAsDef("enum")); // (enumerant, availability specification) pairs for this availability // class. - SmallVector, 1> caseSpecs; + SmallVector, 1> caseSpecs; // Collect all cases' availability specs. - for (const EnumAttrCase &enumerant : enumAttr.getAllCases()) + for (const EnumCase &enumerant : enumInfo.getAllCases()) for (const Availability &caseAvail : getAvailabilities(enumerant.getDef())) if (availClassName == caseAvail.getClass()) @@ -1366,19 +1367,19 @@ static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) { if (caseSpecs.empty()) continue; - if (enumAttr.isBitEnum()) { + if (enumInfo.isBitEnum()) { // For BitEnumAttr, we need to iterate over each bit to query its // availability spec. os << formatv(" for (unsigned i = 0; " "i < std::numeric_limits<{0}>::digits; ++i) {{\n", - enumAttr.getUnderlyingType()); + enumInfo.getUnderlyingType()); os << formatv(" {0}::{1} tblgen_attrVal = this->{2}() & " "static_cast<{0}::{1}>(1 << i);\n", - enumAttr.getCppNamespace(), enumAttr.getEnumClassName(), + enumInfo.getCppNamespace(), enumInfo.getEnumClassName(), srcOp.getGetterName(namedAttr.name)); os << formatv( " if (static_cast<{0}>(tblgen_attrVal) == 0) continue;\n", - enumAttr.getUnderlyingType()); + enumInfo.getUnderlyingType()); } else { // For IntEnumAttr, we just need to query the value as a whole. os << " {\n"; @@ -1386,7 +1387,7 @@ static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) { srcOp.getGetterName(namedAttr.name)); } os << formatv(" auto tblgen_instance = {0}::{1}(tblgen_attrVal);\n", - enumAttr.getCppNamespace(), avail.getQueryFnName()); + enumInfo.getCppNamespace(), avail.getQueryFnName()); os << " if (tblgen_instance) " // TODO` here once ODS supports // dialect-specific contents so that we can use not implementing the @@ -1434,14 +1435,14 @@ static bool emitCapabilityImplication(const RecordKeeper &records, raw_ostream &os) { llvm::emitSourceFileHeader("SPIR-V Capability Implication", os, records); - EnumAttr enumAttr( + EnumInfo enumInfo( records.getDef("SPIRV_CapabilityAttr")->getValueAsDef("enum")); os << "ArrayRef " "spirv::getDirectImpliedCapabilities(spirv::Capability cap) {\n" << " switch (cap) {\n" << " default: return {};\n"; - for (const EnumAttrCase &enumerant : enumAttr.getAllCases()) { + for (const EnumCase &enumerant : enumInfo.getAllCases()) { const Record &def = enumerant.getDef(); if (!def.getValue("implies")) continue; @@ -1452,7 +1453,7 @@ static bool emitCapabilityImplication(const RecordKeeper &records, << ": {static const spirv::Capability implies[" << impliedCapsDefs.size() << "] = {"; llvm::interleaveComma(impliedCapsDefs, os, [&](const Record *capDef) { - os << "spirv::Capability::" << EnumAttrCase(capDef).getSymbol(); + os << "spirv::Capability::" << EnumCase(capDef).getSymbol(); }); os << "}; return ArrayRef(implies, " << impliedCapsDefs.size() << "); }\n"; diff --git a/mlir/tools/mlir-tblgen/TosaUtilsGen.cpp b/mlir/tools/mlir-tblgen/TosaUtilsGen.cpp index 491f9143edb0..ddc149810ebd 100644 --- a/mlir/tools/mlir-tblgen/TosaUtilsGen.cpp +++ b/mlir/tools/mlir-tblgen/TosaUtilsGen.cpp @@ -12,6 +12,7 @@ #include "mlir/TableGen/Attribute.h" #include "mlir/TableGen/CodeGenHelpers.h" +#include "mlir/TableGen/EnumInfo.h" #include "mlir/TableGen/Format.h" #include "mlir/TableGen/GenInfo.h" #include "mlir/TableGen/Operator.h" @@ -42,8 +43,8 @@ using llvm::SMLoc; using llvm::StringMap; using llvm::StringRef; using mlir::tblgen::Attribute; -using mlir::tblgen::EnumAttr; -using mlir::tblgen::EnumAttrCase; +using mlir::tblgen::EnumCase; +using mlir::tblgen::EnumInfo; using mlir::tblgen::NamedAttribute; using mlir::tblgen::NamedTypeConstraint; using mlir::tblgen::NamespaceEmitter;