diff --git a/mlir/include/mlir/IR/ODSSupport.h b/mlir/include/mlir/IR/ODSSupport.h index 25d6f3da6a86..b24a2470470f 100644 --- a/mlir/include/mlir/IR/ODSSupport.h +++ b/mlir/include/mlir/IR/ODSSupport.h @@ -43,6 +43,26 @@ convertFromAttribute(int32_t &storage, Attribute attr, /// Convert the provided int32_t to an IntegerAttr attribute. Attribute convertToAttribute(MLIRContext *ctx, int32_t storage); +/// Convert an IntegerAttr attribute to an int8_t, or return an error if the +/// attribute isn't an IntegerAttr. If the optional diagnostic is provided an +/// error message is also emitted. +LogicalResult +convertFromAttribute(int8_t &storage, Attribute attr, + function_ref emitError); + +/// Convert the provided int8_t to an IntegerAttr attribute. +Attribute convertToAttribute(MLIRContext *ctx, int8_t storage); + +/// Convert an IntegerAttr attribute to an uint8_t, or return an error if the +/// attribute isn't an IntegerAttr. If the optional diagnostic is provided an +/// error message is also emitted. +LogicalResult +convertFromAttribute(uint8_t &storage, Attribute attr, + function_ref emitError); + +/// Convert the provided uint8_t to an IntegerAttr attribute. +Attribute convertToAttribute(MLIRContext *ctx, uint8_t storage); + /// Extract the string from `attr` into `storage`. If `attr` is not a /// `StringAttr`, return failure and emit an error into the diagnostic from /// `emitError`. diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index 8b56d81c8eec..8710b970e8d7 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -135,6 +135,19 @@ public: /// hook on the AsmParser. virtual void printFloat(const APFloat &value); + /// Print the given integer value. This is useful to force a uint8_t/int8_t to + /// be printed as an integer instead of a char. + template + std::enable_if_t, void> printInteger(IntT value) { + // Handle int8_t/uint8_t specially to avoid printing as char + if constexpr (std::is_same_v || + std::is_same_v) { + getStream() << static_cast(value); + } else { + getStream() << value; + } + } + virtual void printType(Type type); virtual void printAttribute(Attribute attr); diff --git a/mlir/include/mlir/IR/Properties.td b/mlir/include/mlir/IR/Properties.td index 25a45489c7b5..1aa19d0ecfa3 100644 --- a/mlir/include/mlir/IR/Properties.td +++ b/mlir/include/mlir/IR/Properties.td @@ -219,6 +219,7 @@ class IntProp : let optionalParser = [{ return $_parser.parseOptionalInteger($_storage); }]; + let printer = "$_printer.printInteger($_storage)"; let writeToMlirBytecode = [{ $_writer.writeVarInt($_storage); }]; diff --git a/mlir/lib/IR/ODSSupport.cpp b/mlir/lib/IR/ODSSupport.cpp index d56c75ede984..5b0a3e22139e 100644 --- a/mlir/lib/IR/ODSSupport.cpp +++ b/mlir/lib/IR/ODSSupport.cpp @@ -48,6 +48,40 @@ Attribute mlir::convertToAttribute(MLIRContext *ctx, int32_t storage) { return IntegerAttr::get(IntegerType::get(ctx, 32), storage); } +LogicalResult +mlir::convertFromAttribute(int8_t &storage, Attribute attr, + function_ref emitError) { + auto valueAttr = dyn_cast(attr); + if (!valueAttr) { + emitError() << "expected IntegerAttr for key `value`"; + return failure(); + } + storage = valueAttr.getValue().getSExtValue(); + return success(); +} + +Attribute mlir::convertToAttribute(MLIRContext *ctx, int8_t storage) { + /// Convert the provided int8_t to an IntegerAttr attribute. + return IntegerAttr::get(IntegerType::get(ctx, 8), storage); +} + +LogicalResult +mlir::convertFromAttribute(uint8_t &storage, Attribute attr, + function_ref emitError) { + auto valueAttr = dyn_cast(attr); + if (!valueAttr) { + emitError() << "expected IntegerAttr for key `value`"; + return failure(); + } + storage = valueAttr.getValue().getZExtValue(); + return success(); +} + +Attribute mlir::convertToAttribute(MLIRContext *ctx, uint8_t storage) { + /// Convert the provided uint8_t to an IntegerAttr attribute. + return IntegerAttr::get(IntegerType::get(ctx, 8), storage); +} + LogicalResult mlir::convertFromAttribute(std::string &storage, Attribute attr, function_ref emitError) { diff --git a/mlir/test/IR/properties.mlir b/mlir/test/IR/properties.mlir index b339a03812ba..dde9100cde14 100644 --- a/mlir/test/IR/properties.mlir +++ b/mlir/test/IR/properties.mlir @@ -59,9 +59,15 @@ test.with_default_valued_properties 1 "foo" 0 unit // CHECK: test.with_optional_properties // CHECK-SAME: simple = 0 // GENERIC: "test.with_optional_properties"() -// GENERIC-SAME: <{hasDefault = [], hasUnit = false, longSyntax = [], maybeUnit = [], nested = [], nonTrivialStorage = [], simple = [0]}> : () -> () +// GENERIC-SAME: <{hasDefault = [], hasUnit = false, longSyntax = [], maybeUnit = [], nested = [], nonTrivialStorage = [], simple = [0], simplei8 = [], simpleui8 = []}> : () -> () test.with_optional_properties simple = 0 +// CHECK: test.with_optional_properties +// CHECK-SAME: simple = 1 simplei8 = -1 simpleui8 = 255 +// GENERIC: "test.with_optional_properties"() +// GENERIC-SAME: <{hasDefault = [], hasUnit = false, longSyntax = [], maybeUnit = [], nested = [], nonTrivialStorage = [], simple = [1], simplei8 = [-1 : i8], simpleui8 = [-1 : i8]}> : () -> () +test.with_optional_properties simple = 1 simplei8 = -1 simpleui8 = 255 + // CHECK: test.with_optional_properties{{$}} // GENERIC: "test.with_optional_properties"() // GENERIC-SAME: simple = [] @@ -70,7 +76,7 @@ test.with_optional_properties // CHECK: test.with_optional_properties // CHECK-SAME: anAttr = 0 simple = 1 nonTrivialStorage = "foo" hasDefault = some<0> nested = some<1> longSyntax = some<"bar"> hasUnit maybeUnit = some // GENERIC: "test.with_optional_properties"() -// GENERIC-SAME: <{anAttr = 0 : i32, hasDefault = [0], hasUnit, longSyntax = ["bar"], maybeUnit = [unit], nested = {{\[}}[1]], nonTrivialStorage = ["foo"], simple = [1]}> : () -> () +// GENERIC-SAME: <{anAttr = 0 : i32, hasDefault = [0], hasUnit, longSyntax = ["bar"], maybeUnit = [unit], nested = {{\[}}[1]], nonTrivialStorage = ["foo"], simple = [1], simplei8 = [], simpleui8 = []}> : () -> () test.with_optional_properties anAttr = 0 simple = 1 diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 8a4981a90831..8c332adb3565 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -3406,6 +3406,8 @@ def TestOpWithOptionalProperties : TEST_Op<"with_optional_properties"> { let assemblyFormat = [{ (`anAttr` `=` $anAttr^)? (`simple` `=` $simple^)? + (`simplei8` `=` $simplei8^)? + (`simpleui8` `=` $simpleui8^)? (`nonTrivialStorage` `=` $nonTrivialStorage^)? (`hasDefault` `=` $hasDefault^)? (`nested` `=` $nested^)? @@ -3417,6 +3419,8 @@ def TestOpWithOptionalProperties : TEST_Op<"with_optional_properties"> { let arguments = (ins OptionalAttr:$anAttr, OptionalProp:$simple, + OptionalProp>:$simplei8, + OptionalProp>:$simpleui8, OptionalProp:$nonTrivialStorage, // Confirm that properties with default values now default to nullopt and have // the long syntax.