[MLIR] print/parse resource handle key quoted and escaped (#119746)
resource keys have the problem that you can’t parse them from mlir
assembly if they have special or non-printable characters, but nothing
prevents you from specifying such a key when you create e.g. a
DenseResourceElementsAttr, and it works fine in other ways, including
bytecode emission and parsing
this PR solves the parsing by quoting and escaping keys with special or
non-printable characters in mlir assembly, in the same way as symbols,
e.g.:
```
module attributes {
fst = dense_resource<resource_fst> : tensor<2xf16>,
snd = dense_resource<"resource\09snd"> : tensor<2xf16>
} {}
{-#
dialect_resources: {
builtin: {
resource_fst: "0x0200000001000200",
"resource\09snd": "0x0200000008000900"
}
}
#-}
```
by not quoting keys without special or non-printable characters, the
change is effectively backwards compatible
the change is tested by:
1. adding a test with a dense resource handle key with special
characters to `dense-resource-elements-attr.mlir`
2. adding special and unprintable characters to some resource keys in
the existing lit tests `pretty-resources-print.mlir` and
`mlir/test/Bytecode/resources.mlir`
This commit is contained in:
@@ -202,7 +202,8 @@ public:
|
||||
/// special or non-printable characters in it.
|
||||
virtual void printSymbolName(StringRef symbolRef);
|
||||
|
||||
/// Print a handle to the given dialect resource.
|
||||
/// Print a handle to the given dialect resource. The handle key is quoted and
|
||||
/// escaped if it has any special or non-printable characters in it.
|
||||
virtual void printResourceHandle(const AsmDialectResourceHandle &resource);
|
||||
|
||||
/// Print an optional arrow followed by a type list.
|
||||
|
||||
@@ -248,13 +248,7 @@ public:
|
||||
|
||||
/// Parses a quoted string token if present.
|
||||
ParseResult parseOptionalString(std::string *string) override {
|
||||
if (!parser.getToken().is(Token::string))
|
||||
return failure();
|
||||
|
||||
if (string)
|
||||
*string = parser.getToken().getStringValue();
|
||||
parser.consumeToken();
|
||||
return success();
|
||||
return parser.parseOptionalString(string);
|
||||
}
|
||||
|
||||
/// Parses a Base64 encoded string of bytes.
|
||||
@@ -355,13 +349,7 @@ public:
|
||||
|
||||
/// Parse a keyword, if present, into 'keyword'.
|
||||
ParseResult parseOptionalKeyword(StringRef *keyword) override {
|
||||
// Check that the current token is a keyword.
|
||||
if (!parser.isCurrentTokenAKeyword())
|
||||
return failure();
|
||||
|
||||
*keyword = parser.getTokenSpelling();
|
||||
parser.consumeToken();
|
||||
return success();
|
||||
return parser.parseOptionalKeyword(keyword);
|
||||
}
|
||||
|
||||
/// Parse a keyword if it is one of the 'allowedKeywords'.
|
||||
@@ -387,13 +375,7 @@ public:
|
||||
|
||||
/// Parse an optional keyword or string and set instance into 'result'.`
|
||||
ParseResult parseOptionalKeywordOrString(std::string *result) override {
|
||||
StringRef keyword;
|
||||
if (succeeded(parseOptionalKeyword(&keyword))) {
|
||||
*result = keyword.str();
|
||||
return success();
|
||||
}
|
||||
|
||||
return parseOptionalString(result);
|
||||
return parser.parseOptionalKeywordOrString(result);
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
@@ -514,7 +496,7 @@ public:
|
||||
return parser.emitError() << "dialect '" << dialect->getNamespace()
|
||||
<< "' does not expect resource handles";
|
||||
}
|
||||
StringRef resourceName;
|
||||
std::string resourceName;
|
||||
return parser.parseResourceHandle(interface, resourceName);
|
||||
}
|
||||
|
||||
|
||||
@@ -271,6 +271,17 @@ ParseResult Parser::parseToken(Token::Kind expectedToken,
|
||||
return emitWrongTokenError(message);
|
||||
}
|
||||
|
||||
/// Parses a quoted string token if present.
|
||||
ParseResult Parser::parseOptionalString(std::string *string) {
|
||||
if (!getToken().is(Token::string))
|
||||
return failure();
|
||||
|
||||
if (string)
|
||||
*string = getToken().getStringValue();
|
||||
consumeToken();
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Parse an optional integer value from the stream.
|
||||
OptionalParseResult Parser::parseOptionalInteger(APInt &result) {
|
||||
// Parse `false` and `true` keywords as 0 and 1 respectively.
|
||||
@@ -412,15 +423,25 @@ ParseResult Parser::parseOptionalKeyword(StringRef *keyword) {
|
||||
return success();
|
||||
}
|
||||
|
||||
ParseResult Parser::parseOptionalKeywordOrString(std::string *result) {
|
||||
StringRef keyword;
|
||||
if (succeeded(parseOptionalKeyword(&keyword))) {
|
||||
*result = keyword.str();
|
||||
return success();
|
||||
}
|
||||
|
||||
return parseOptionalString(result);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Resource Parsing
|
||||
|
||||
FailureOr<AsmDialectResourceHandle>
|
||||
Parser::parseResourceHandle(const OpAsmDialectInterface *dialect,
|
||||
StringRef &name) {
|
||||
std::string &name) {
|
||||
assert(dialect && "expected valid dialect interface");
|
||||
SMLoc nameLoc = getToken().getLoc();
|
||||
if (failed(parseOptionalKeyword(&name)))
|
||||
if (failed(parseOptionalKeywordOrString(&name)))
|
||||
return emitError("expected identifier key for 'resource' entry");
|
||||
auto &resources = getState().symbols.dialectResources;
|
||||
|
||||
@@ -451,7 +472,7 @@ Parser::parseResourceHandle(Dialect *dialect) {
|
||||
return emitError() << "dialect '" << dialect->getNamespace()
|
||||
<< "' does not expect resource handles";
|
||||
}
|
||||
StringRef resourceName;
|
||||
std::string resourceName;
|
||||
return parseResourceHandle(interface, resourceName);
|
||||
}
|
||||
|
||||
@@ -2530,8 +2551,8 @@ private:
|
||||
/// textual format.
|
||||
class ParsedResourceEntry : public AsmParsedResourceEntry {
|
||||
public:
|
||||
ParsedResourceEntry(StringRef key, SMLoc keyLoc, Token value, Parser &p)
|
||||
: key(key), keyLoc(keyLoc), value(value), p(p) {}
|
||||
ParsedResourceEntry(std::string key, SMLoc keyLoc, Token value, Parser &p)
|
||||
: key(std::move(key)), keyLoc(keyLoc), value(value), p(p) {}
|
||||
~ParsedResourceEntry() override = default;
|
||||
|
||||
StringRef getKey() const final { return key; }
|
||||
@@ -2607,7 +2628,7 @@ public:
|
||||
}
|
||||
|
||||
private:
|
||||
StringRef key;
|
||||
std::string key;
|
||||
SMLoc keyLoc;
|
||||
Token value;
|
||||
Parser &p;
|
||||
@@ -2736,7 +2757,7 @@ ParseResult TopLevelOperationParser::parseDialectResourceFileMetadata() {
|
||||
return parseCommaSeparatedListUntil(Token::r_brace, [&]() -> ParseResult {
|
||||
// Parse the name of the resource entry.
|
||||
SMLoc keyLoc = getToken().getLoc();
|
||||
StringRef key;
|
||||
std::string key;
|
||||
if (failed(parseResourceHandle(handler, key)) ||
|
||||
parseToken(Token::colon, "expected ':'"))
|
||||
return failure();
|
||||
@@ -2763,8 +2784,8 @@ ParseResult TopLevelOperationParser::parseExternalResourceFileMetadata() {
|
||||
return parseCommaSeparatedListUntil(Token::r_brace, [&]() -> ParseResult {
|
||||
// Parse the name of the resource entry.
|
||||
SMLoc keyLoc = getToken().getLoc();
|
||||
StringRef key;
|
||||
if (failed(parseOptionalKeyword(&key)))
|
||||
std::string key;
|
||||
if (failed(parseOptionalKeywordOrString(&key)))
|
||||
return emitError(
|
||||
"expected identifier key for 'external_resources' entry");
|
||||
if (parseToken(Token::colon, "expected ':'"))
|
||||
|
||||
@@ -146,6 +146,9 @@ public:
|
||||
/// output a diagnostic and return failure.
|
||||
ParseResult parseToken(Token::Kind expectedToken, const Twine &message);
|
||||
|
||||
/// Parses a quoted string token if present.
|
||||
ParseResult parseOptionalString(std::string *string);
|
||||
|
||||
/// Parse an optional integer value from the stream.
|
||||
OptionalParseResult parseOptionalInteger(APInt &result);
|
||||
|
||||
@@ -171,13 +174,16 @@ public:
|
||||
/// Parse a keyword, if present, into 'keyword'.
|
||||
ParseResult parseOptionalKeyword(StringRef *keyword);
|
||||
|
||||
/// Parse an optional keyword or string and set instance into 'result'.`
|
||||
ParseResult parseOptionalKeywordOrString(std::string *result);
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Resource Parsing
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// Parse a handle to a dialect resource within the assembly format.
|
||||
FailureOr<AsmDialectResourceHandle>
|
||||
parseResourceHandle(const OpAsmDialectInterface *dialect, StringRef &name);
|
||||
parseResourceHandle(const OpAsmDialectInterface *dialect, std::string &name);
|
||||
FailureOr<AsmDialectResourceHandle> parseResourceHandle(Dialect *dialect);
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
@@ -2188,13 +2188,6 @@ void AsmPrinter::Impl::printLocation(LocationAttr loc, bool allowAlias) {
|
||||
os << ')';
|
||||
}
|
||||
|
||||
void AsmPrinter::Impl::printResourceHandle(
|
||||
const AsmDialectResourceHandle &resource) {
|
||||
auto *interface = cast<OpAsmDialectInterface>(resource.getDialect());
|
||||
os << interface->getResourceKey(resource);
|
||||
state.getDialectResources()[resource.getDialect()].insert(resource);
|
||||
}
|
||||
|
||||
/// Returns true if the given dialect symbol data is simple enough to print in
|
||||
/// the pretty form. This is essentially when the symbol takes the form:
|
||||
/// identifier (`<` body `>`)?
|
||||
@@ -2279,6 +2272,13 @@ static void printElidedElementsAttr(raw_ostream &os) {
|
||||
os << R"(dense_resource<__elided__>)";
|
||||
}
|
||||
|
||||
void AsmPrinter::Impl::printResourceHandle(
|
||||
const AsmDialectResourceHandle &resource) {
|
||||
auto *interface = cast<OpAsmDialectInterface>(resource.getDialect());
|
||||
::printKeywordOrString(interface->getResourceKey(resource), os);
|
||||
state.getDialectResources()[resource.getDialect()].insert(resource);
|
||||
}
|
||||
|
||||
LogicalResult AsmPrinter::Impl::printAlias(Attribute attr) {
|
||||
return state.getAliasState().getAlias(attr, os);
|
||||
}
|
||||
@@ -3373,41 +3373,41 @@ void OperationPrinter::printResourceFileMetadata(
|
||||
auto printFn = [&](StringRef key, ResourceBuilder::ValueFn valueFn) {
|
||||
checkAddMetadataDict();
|
||||
|
||||
auto printFormatting = [&]() {
|
||||
// Emit the top-level resource entry if we haven't yet.
|
||||
if (!std::exchange(hadResource, true)) {
|
||||
if (needResourceComma)
|
||||
os << "," << newLine;
|
||||
os << " " << dictName << "_resources: {" << newLine;
|
||||
}
|
||||
// Emit the parent resource entry if we haven't yet.
|
||||
if (!std::exchange(hadEntry, true)) {
|
||||
if (needEntryComma)
|
||||
os << "," << newLine;
|
||||
os << " " << name << ": {" << newLine;
|
||||
} else {
|
||||
os << "," << newLine;
|
||||
}
|
||||
};
|
||||
|
||||
std::string resourceStr;
|
||||
auto printResourceStr = [&](raw_ostream &os) { os << resourceStr; };
|
||||
std::optional<uint64_t> charLimit =
|
||||
printerFlags.getLargeResourceStringLimit();
|
||||
if (charLimit.has_value()) {
|
||||
std::string resourceStr;
|
||||
llvm::raw_string_ostream ss(resourceStr);
|
||||
valueFn(ss);
|
||||
|
||||
// Only print entry if it's string is small enough
|
||||
// Only print entry if its string is small enough.
|
||||
if (resourceStr.size() > charLimit.value())
|
||||
return;
|
||||
|
||||
printFormatting();
|
||||
os << " " << key << ": " << resourceStr;
|
||||
} else {
|
||||
printFormatting();
|
||||
os << " " << key << ": ";
|
||||
valueFn(os);
|
||||
// Don't recompute resourceStr when valueFn is called below.
|
||||
valueFn = printResourceStr;
|
||||
}
|
||||
|
||||
// Emit the top-level resource entry if we haven't yet.
|
||||
if (!std::exchange(hadResource, true)) {
|
||||
if (needResourceComma)
|
||||
os << "," << newLine;
|
||||
os << " " << dictName << "_resources: {" << newLine;
|
||||
}
|
||||
// Emit the parent resource entry if we haven't yet.
|
||||
if (!std::exchange(hadEntry, true)) {
|
||||
if (needEntryComma)
|
||||
os << "," << newLine;
|
||||
os << " " << name << ": {" << newLine;
|
||||
} else {
|
||||
os << "," << newLine;
|
||||
}
|
||||
os << " ";
|
||||
::printKeywordOrString(key, os);
|
||||
os << ": ";
|
||||
// Call printResourceStr or original valueFn, depending on charLimit.
|
||||
valueFn(os);
|
||||
};
|
||||
ResourceBuilder entryBuilder(printFn);
|
||||
provider.buildResources(op, providerArgs..., entryBuilder);
|
||||
|
||||
@@ -4,21 +4,21 @@
|
||||
module @TestDialectResources attributes {
|
||||
// CHECK: bytecode.test = dense_resource<decl_resource> : tensor<2xui32>
|
||||
// CHECK: bytecode.test2 = dense_resource<resource> : tensor<4xf64>
|
||||
// CHECK: bytecode.test3 = dense_resource<resource_2> : tensor<4xf64>
|
||||
// CHECK: bytecode.test3 = dense_resource<"resource\09two"> : tensor<4xf64>
|
||||
bytecode.test = dense_resource<decl_resource> : tensor<2xui32>,
|
||||
bytecode.test2 = dense_resource<resource> : tensor<4xf64>,
|
||||
bytecode.test3 = dense_resource<resource_2> : tensor<4xf64>
|
||||
bytecode.test3 = dense_resource<"resource\09two"> : tensor<4xf64>
|
||||
} {}
|
||||
|
||||
// CHECK: builtin: {
|
||||
// CHECK-NEXT: resource: "0x08000000010000000000000002000000000000000300000000000000"
|
||||
// CHECK-NEXT: resource_2: "0x08000000010000000000000002000000000000000300000000000000"
|
||||
// CHECK-NEXT: "resource\09two": "0x08000000010000000000000002000000000000000300000000000000"
|
||||
|
||||
{-#
|
||||
dialect_resources: {
|
||||
builtin: {
|
||||
resource: "0x08000000010000000000000002000000000000000300000000000000",
|
||||
resource_2: "0x08000000010000000000000002000000000000000300000000000000"
|
||||
"resource\09two": "0x08000000010000000000000002000000000000000300000000000000"
|
||||
}
|
||||
}
|
||||
#-}
|
||||
|
||||
@@ -11,3 +11,18 @@
|
||||
}
|
||||
}
|
||||
#-}
|
||||
|
||||
// -----
|
||||
|
||||
// DenseResourceElementsHandle key blob\-"one" is quoted and escaped.
|
||||
// CHECK: attr = dense_resource<"blob\\-\22one\22"> : tensor<2xi16>
|
||||
"test.user_op"() {attr = dense_resource<"blob\\-\22one\22"> : tensor<2xi16>} : () -> ()
|
||||
|
||||
{-#
|
||||
dialect_resources: {
|
||||
builtin: {
|
||||
// CHECK: "blob\\-\22one\22": "0x0200000001000200"
|
||||
"blob\\-\22one\22": "0x0200000001000200"
|
||||
}
|
||||
}
|
||||
#-}
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
// CHECK: {-#
|
||||
// CHECK-NEXT: external_resources: {
|
||||
// CHECK-NEXT: external: {
|
||||
// CHECK-NEXT: bool: true,
|
||||
// CHECK-NEXT: "backslash\\tab\09": true,
|
||||
// CHECK-NEXT: string: "\22string\22"
|
||||
// CHECK-NEXT: },
|
||||
// CHECK-NEXT: other_stuff: {
|
||||
@@ -31,8 +31,8 @@
|
||||
external_resources: {
|
||||
external: {
|
||||
blob: "0x08000000010000000000000002000000000000000300000000000000",
|
||||
bool: true,
|
||||
string: "\"string\"" // with escape characters
|
||||
"backslash\\tab\09": true, // quoted key with escape characters
|
||||
string: "\"string\"" // string with escape characters
|
||||
},
|
||||
other_stuff: {
|
||||
bool: true
|
||||
|
||||
Reference in New Issue
Block a user