From 0f052a972ebe9bdf2c1eb56bddf6abb04eec50d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Markus=20B=C3=B6ck?= Date: Sat, 9 Sep 2023 00:06:38 +0200 Subject: [PATCH] [mlir] Make `StringRefParameter` roundtrippable (#65813) The current printer of `StringRefParameter` simply prints out the content of the string as is without escaping it any way. This leads to it generating invalid syntax, causing parser errors when read in again. This PR fixes that by adding `printString` to `AsmPrinter`, allowing one to print a string that can be parsed with `parseString`, using the same escaping syntax as `StringAttr`. --- mlir/include/mlir/IR/AttrTypeBase.td | 2 +- mlir/include/mlir/IR/OpImplementation.h | 4 ++++ mlir/lib/IR/AsmPrinter.cpp | 9 +++++++++ mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir | 4 +++- 4 files changed, 17 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir/IR/AttrTypeBase.td b/mlir/include/mlir/IR/AttrTypeBase.td index 3e356373cbd7..42a611ee8e42 100644 --- a/mlir/include/mlir/IR/AttrTypeBase.td +++ b/mlir/include/mlir/IR/AttrTypeBase.td @@ -363,7 +363,7 @@ class DefaultValuedParameter : class StringRefParameter : AttrOrTypeParameter<"::llvm::StringRef", desc> { let allocator = [{$_dst = $_allocator.copyInto($_self);}]; - let printer = [{$_printer << '"' << $_self << '"';}]; + let printer = [{$_printer.printString($_self);}]; let cppStorageType = "std::string"; let defaultValue = value; } diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index f894ee64a27b..8864ef02cd3c 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -184,6 +184,10 @@ public: /// has any special or non-printable characters in it. virtual void printKeywordOrString(StringRef keyword); + /// Print the given string as a quoted string, escaping any special or + /// non-printable characters in it. + virtual void printString(StringRef string); + /// Print the given string as a symbol reference, i.e. a form representable by /// a SymbolRefAttr. A symbol reference is represented as a string prefixed /// with '@'. The reference is surrounded with ""'s and escaped if it has any diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index c662edd59203..7b0da30541b1 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -779,6 +779,7 @@ private: os << "%"; } void printKeywordOrString(StringRef) override {} + void printString(StringRef) override {} void printResourceHandle(const AsmDialectResourceHandle &) override {} void printSymbolName(StringRef) override {} void printSuccessor(Block *) override {} @@ -919,6 +920,7 @@ private: /// determining potential aliases. void printFloat(const APFloat &) override {} void printKeywordOrString(StringRef) override {} + void printString(StringRef) override {} void printSymbolName(StringRef) override {} void printResourceHandle(const AsmDialectResourceHandle &) override {} @@ -2767,6 +2769,13 @@ void AsmPrinter::printKeywordOrString(StringRef keyword) { ::printKeywordOrString(keyword, impl->getStream()); } +void AsmPrinter::printString(StringRef keyword) { + assert(impl && "expected AsmPrinter::printString to be overriden"); + *this << '"'; + printEscapedString(keyword, getStream()); + *this << '"'; +} + void AsmPrinter::printSymbolName(StringRef symbolRef) { assert(impl && "expected AsmPrinter::printSymbolName to be overriden"); ::printSymbolReference(symbolRef, impl->getStream()); diff --git a/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir b/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir index 12289b4d7325..160c388cedf7 100644 --- a/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir +++ b/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir @@ -70,6 +70,7 @@ attributes { // CHECK: !test.optional_type_string // CHECK: !test.optional_type_string // CHECK: !test.optional_type_string<"non default"> +// CHECK: !test.optional_type_string<"containing\0A \22escape\22 characters\0F"> func.func private @test_roundtrip_default_parsers_struct( !test.no_parser<255, [1, 2, 3, 4, 5], "foobar", 4> @@ -111,5 +112,6 @@ func.func private @test_roundtrip_default_parsers_struct( !test.custom_type_string<"bar" bar>, !test.optional_type_string, !test.optional_type_string<"default">, - !test.optional_type_string<"non default"> + !test.optional_type_string<"non default">, + !test.optional_type_string<"containing\n \"escape\" characters\0f"> )