diff --git a/mlir/include/mlir/IR/EnumAttr.td b/mlir/include/mlir/IR/EnumAttr.td index 3f7f747ac20d..ff6cec6d4116 100644 --- a/mlir/include/mlir/IR/EnumAttr.td +++ b/mlir/include/mlir/IR/EnumAttr.td @@ -39,8 +39,11 @@ class EnumCase { class IntEnumAttrCaseBase : EnumCase, SignlessIntegerAttrBase { - let predicate = - CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getInt() == " # intVal>; + let predicate = CPred<[{ + ::llvm::cast<::mlir::IntegerAttr>($_self).getValue().eq(::llvm::APInt(}] + # intType.bitwidth # ", " + # intVal # + "))">; } // Cases of integer enums with a specific type. By default, the string diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir index 4f280bde1aec..edb7357e4e04 100644 --- a/mlir/test/IR/attribute.mlir +++ b/mlir/test/IR/attribute.mlir @@ -454,6 +454,10 @@ func.func @allowed_cases_pass() { %0 = "test.i32_enum_attr"() {attr = 5: i32} : () -> i32 // CHECK: test.i32_enum_attr %1 = "test.i32_enum_attr"() {attr = 10: i32} : () -> i32 + // CHECK: test.i32_enum_attr + %2 = "test.i32_enum_attr"() {attr = 2147483648: i32} : () -> i32 + // CHECK: test.i32_enum_attr + %3 = "test.i32_enum_attr"() {attr = 4294967295: i32} : () -> i32 return } diff --git a/mlir/test/lib/Dialect/Test/TestEnumDefs.td b/mlir/test/lib/Dialect/Test/TestEnumDefs.td index 5b785a600aad..10e424a0f252 100644 --- a/mlir/test/lib/Dialect/Test/TestEnumDefs.td +++ b/mlir/test/lib/Dialect/Test/TestEnumDefs.td @@ -17,9 +17,13 @@ include "mlir/IR/EnumAttr.td" def I32Case5: I32EnumAttrCase<"case5", 5>; def I32Case10: I32EnumAttrCase<"case10", 10>; +def I32CaseSignedMaxPlusOne + : I32EnumAttrCase<"caseSignedMaxPlusOne", 2147483648>; +def I32CaseUnsignedMax : I32EnumAttrCase<"caseUnsignedMax", 4294967295>; -def SomeI32Enum: I32EnumAttr< - "SomeI32Enum", "", [I32Case5, I32Case10]>; +def SomeI32Enum : I32EnumAttr<"SomeI32Enum", "", + [I32Case5, I32Case10, I32CaseSignedMaxPlusOne, + I32CaseUnsignedMax]>; def I64Case5: I64EnumAttrCase<"case5", 5>; def I64Case10: I64EnumAttrCase<"case10", 10>; diff --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp index 9941a203bc5c..06dc588f9020 100644 --- a/mlir/tools/mlir-tblgen/EnumsGen.cpp +++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp @@ -648,8 +648,10 @@ static void emitSpecializedAttrDef(const Record &enumDef, raw_ostream &os) { os << formatv("{0} {1}::getValue() const {{\n", enumName, attrClassName); - os << formatv(" return static_cast<{0}>(::mlir::IntegerAttr::getInt());\n", - enumName); + os << formatv( + " return " + "static_cast<{0}>(::mlir::IntegerAttr::getValue().getZExtValue());\n", + enumName); os << "}\n"; }