[mlir][bytecode] Fix dialect version parsing.

We were querying the wrong EncReader along some paths that resulted in
failures depending on if one encountered an Attribute from an unloaded
dialect before encountering an operation from that dialect.

Also fix error where we were able to emit "custom" form for an attribute
without custom form in TestDialect.

Differential Revision: https://reviews.llvm.org/D150260
This commit is contained in:
Jacques Pienaar
2023-05-11 05:19:06 -07:00
parent c3a0df1903
commit 3449e7a832
3 changed files with 16 additions and 6 deletions

View File

@@ -271,6 +271,8 @@ public:
return parseBytes(static_cast<size_t>(length), sectionData);
}
Location getLoc() const { return fileLoc; }
private:
/// Parse a variable length encoded integer from the byte stream. This method
/// is a fallback when the number of bytes used to encode the value is greater
@@ -835,6 +837,13 @@ public:
return reader.emitError(msg);
}
DialectReader withEncodingReader(EncodingReader &encReader) {
return DialectReader(attrTypeReader, stringReader, resourceReader,
encReader);
}
Location getLoc() const { return reader.getLoc(); }
//===--------------------------------------------------------------------===//
// IR
//===--------------------------------------------------------------------===//
@@ -1054,7 +1063,6 @@ LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry,
DialectReader dialectReader(*this, stringReader, resourceReader, reader);
if (failed(entry.dialect->load(dialectReader, fileLoc.getContext())))
return failure();
// Ensure that the dialect implements the bytecode interface.
if (!entry.dialect->interface) {
return reader.emitError("dialect '", entry.dialect->name,
@@ -1378,7 +1386,9 @@ LogicalResult BytecodeDialect::load(DialectReader &reader, MLIRContext *ctx) {
<< name
<< "' does not implement the bytecode interface, "
"but found a version entry";
loadedVersion = interface->readVersion(reader);
EncodingReader encReader(versionBuffer, reader.getLoc());
DialectReader versionReader = reader.withEncodingReader(encReader);
loadedVersion = interface->readVersion(versionReader);
if (!loadedVersion)
return failure();
}
@@ -1448,9 +1458,8 @@ FailureOr<OperationName> BytecodeReader::parseOpName(EncodingReader &reader) {
// haven't, load the dialect and build the operation name.
if (!opName->opName) {
// Load the dialect and its version.
EncodingReader versionReader(opName->dialect->versionBuffer, fileLoc);
DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
versionReader);
reader);
if (failed(opName->dialect->load(dialectReader, getContext())))
return failure();
opName->opName.emplace((opName->dialect->name + "." + opName->name).str(),