Expose callbacks for encoding of types/attributes

[mlir] Expose a mechanism to provide a callback for encoding types and attributes in MLIR bytecode.

Two callbacks are exposed, respectively, to the BytecodeWriterConfig and to the ParserConfig. At bytecode parsing/printing, clients have the ability to specify a callback to be used to optionally read/write the encoding. On failure, fallback path will execute the default parsers and printers for the dialect.

Testing shows how to leverage this functionality to support back-deployment and backward-compatibility usecases when roundtripping to bytecode a client dialect with type/attributes dependencies on upstream.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D153383
This commit is contained in:
Mehdi Amini
2023-07-28 10:43:51 -07:00
parent bb65caf90a
commit b299ec1666
20 changed files with 954 additions and 156 deletions

View File

@@ -451,7 +451,7 @@ struct BytecodeDialect {
/// Returns failure if the dialect couldn't be loaded *and* the provided
/// context does not allow unregistered dialects. The provided reader is used
/// for error emission if necessary.
LogicalResult load(DialectReader &reader, MLIRContext *ctx);
LogicalResult load(const DialectReader &reader, MLIRContext *ctx);
/// Return the loaded dialect, or nullptr if the dialect is unknown. This can
/// only be called after `load`.
@@ -505,10 +505,11 @@ struct BytecodeOperationName {
/// Parse a single dialect group encoded in the byte stream.
static LogicalResult parseDialectGrouping(
EncodingReader &reader, MutableArrayRef<BytecodeDialect> dialects,
EncodingReader &reader,
MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
function_ref<LogicalResult(BytecodeDialect *)> entryCallback) {
// Parse the dialect and the number of entries in the group.
BytecodeDialect *dialect;
std::unique_ptr<BytecodeDialect> *dialect;
if (failed(parseEntry(reader, dialects, dialect, "dialect")))
return failure();
uint64_t numEntries;
@@ -516,7 +517,7 @@ static LogicalResult parseDialectGrouping(
return failure();
for (uint64_t i = 0; i < numEntries; ++i)
if (failed(entryCallback(dialect)))
if (failed(entryCallback(dialect->get())))
return failure();
return success();
}
@@ -532,7 +533,7 @@ public:
/// Initialize the resource section reader with the given section data.
LogicalResult
initialize(Location fileLoc, const ParserConfig &config,
MutableArrayRef<BytecodeDialect> dialects,
MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData,
ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader,
const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef);
@@ -682,7 +683,7 @@ parseResourceGroup(Location fileLoc, bool allowEmpty,
LogicalResult ResourceSectionReader::initialize(
Location fileLoc, const ParserConfig &config,
MutableArrayRef<BytecodeDialect> dialects,
MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData,
ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader,
const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) {
@@ -731,19 +732,19 @@ LogicalResult ResourceSectionReader::initialize(
// Read the dialect resources from the bytecode.
MLIRContext *ctx = fileLoc->getContext();
while (!offsetReader.empty()) {
BytecodeDialect *dialect;
std::unique_ptr<BytecodeDialect> *dialect;
if (failed(parseEntry(offsetReader, dialects, dialect, "dialect")) ||
failed(dialect->load(dialectReader, ctx)))
failed((*dialect)->load(dialectReader, ctx)))
return failure();
Dialect *loadedDialect = dialect->getLoadedDialect();
Dialect *loadedDialect = (*dialect)->getLoadedDialect();
if (!loadedDialect) {
return resourceReader.emitError()
<< "dialect '" << dialect->name << "' is unknown";
<< "dialect '" << (*dialect)->name << "' is unknown";
}
const auto *handler = dyn_cast<OpAsmDialectInterface>(loadedDialect);
if (!handler) {
return resourceReader.emitError()
<< "unexpected resources for dialect '" << dialect->name << "'";
<< "unexpected resources for dialect '" << (*dialect)->name << "'";
}
// Ensure that each resource is declared before being processed.
@@ -753,7 +754,7 @@ LogicalResult ResourceSectionReader::initialize(
if (failed(handle)) {
return resourceReader.emitError()
<< "unknown 'resource' key '" << key << "' for dialect '"
<< dialect->name << "'";
<< (*dialect)->name << "'";
}
dialectResourceHandleRenamingMap[key] = handler->getResourceKey(*handle);
dialectResources.push_back(*handle);
@@ -796,15 +797,19 @@ class AttrTypeReader {
public:
AttrTypeReader(StringSectionReader &stringReader,
ResourceSectionReader &resourceReader, Location fileLoc,
uint64_t &bytecodeVersion)
ResourceSectionReader &resourceReader,
const llvm::StringMap<BytecodeDialect *> &dialectsMap,
uint64_t &bytecodeVersion, Location fileLoc,
const ParserConfig &config)
: stringReader(stringReader), resourceReader(resourceReader),
fileLoc(fileLoc), bytecodeVersion(bytecodeVersion) {}
dialectsMap(dialectsMap), fileLoc(fileLoc),
bytecodeVersion(bytecodeVersion), parserConfig(config) {}
/// Initialize the attribute and type information within the reader.
LogicalResult initialize(MutableArrayRef<BytecodeDialect> dialects,
ArrayRef<uint8_t> sectionData,
ArrayRef<uint8_t> offsetSectionData);
LogicalResult
initialize(MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
ArrayRef<uint8_t> sectionData,
ArrayRef<uint8_t> offsetSectionData);
/// Resolve the attribute or type at the given index. Returns nullptr on
/// failure.
@@ -878,6 +883,10 @@ private:
/// parsing custom encoded attribute/type entries.
ResourceSectionReader &resourceReader;
/// The map of the loaded dialects used to retrieve dialect information, such
/// as the dialect version.
const llvm::StringMap<BytecodeDialect *> &dialectsMap;
/// The set of attribute and type entries.
SmallVector<AttrEntry> attributes;
SmallVector<TypeEntry> types;
@@ -887,27 +896,48 @@ private:
/// Current bytecode version being used.
uint64_t &bytecodeVersion;
/// Reference to the parser configuration.
const ParserConfig &parserConfig;
};
class DialectReader : public DialectBytecodeReader {
public:
DialectReader(AttrTypeReader &attrTypeReader,
StringSectionReader &stringReader,
ResourceSectionReader &resourceReader, EncodingReader &reader,
uint64_t &bytecodeVersion)
ResourceSectionReader &resourceReader,
const llvm::StringMap<BytecodeDialect *> &dialectsMap,
EncodingReader &reader, uint64_t &bytecodeVersion)
: attrTypeReader(attrTypeReader), stringReader(stringReader),
resourceReader(resourceReader), reader(reader),
bytecodeVersion(bytecodeVersion) {}
resourceReader(resourceReader), dialectsMap(dialectsMap),
reader(reader), bytecodeVersion(bytecodeVersion) {}
InFlightDiagnostic emitError(const Twine &msg) override {
InFlightDiagnostic emitError(const Twine &msg) const override {
return reader.emitError(msg);
}
FailureOr<const DialectVersion *>
getDialectVersion(StringRef dialectName) const override {
// First check if the dialect is available in the map.
auto dialectEntry = dialectsMap.find(dialectName);
if (dialectEntry == dialectsMap.end())
return failure();
// If the dialect was found, try to load it. This will trigger reading the
// bytecode version from the version buffer if it wasn't already processed.
// Return failure if either of those two actions could not be completed.
if (failed(dialectEntry->getValue()->load(*this, getLoc().getContext())) ||
dialectEntry->getValue()->loadedVersion.get() == nullptr)
return failure();
return dialectEntry->getValue()->loadedVersion.get();
}
MLIRContext *getContext() const override { return getLoc().getContext(); }
uint64_t getBytecodeVersion() const override { return bytecodeVersion; }
DialectReader withEncodingReader(EncodingReader &encReader) {
DialectReader withEncodingReader(EncodingReader &encReader) const {
return DialectReader(attrTypeReader, stringReader, resourceReader,
encReader, bytecodeVersion);
dialectsMap, encReader, bytecodeVersion);
}
Location getLoc() const { return reader.getLoc(); }
@@ -1010,6 +1040,7 @@ private:
AttrTypeReader &attrTypeReader;
StringSectionReader &stringReader;
ResourceSectionReader &resourceReader;
const llvm::StringMap<BytecodeDialect *> &dialectsMap;
EncodingReader &reader;
uint64_t &bytecodeVersion;
};
@@ -1096,10 +1127,9 @@ private:
};
} // namespace
LogicalResult
AttrTypeReader::initialize(MutableArrayRef<BytecodeDialect> dialects,
ArrayRef<uint8_t> sectionData,
ArrayRef<uint8_t> offsetSectionData) {
LogicalResult AttrTypeReader::initialize(
MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
ArrayRef<uint8_t> sectionData, ArrayRef<uint8_t> offsetSectionData) {
EncodingReader offsetReader(offsetSectionData, fileLoc);
// Parse the number of attribute and type entries.
@@ -1151,6 +1181,7 @@ AttrTypeReader::initialize(MutableArrayRef<BytecodeDialect> dialects,
return offsetReader.emitError(
"unexpected trailing data in the Attribute/Type offset section");
}
return success();
}
@@ -1216,32 +1247,54 @@ template <typename T>
LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry,
EncodingReader &reader,
StringRef entryType) {
DialectReader dialectReader(*this, stringReader, resourceReader, reader,
bytecodeVersion);
DialectReader dialectReader(*this, stringReader, resourceReader, dialectsMap,
reader, bytecodeVersion);
if (failed(entry.dialect->load(dialectReader, fileLoc.getContext())))
return failure();
if constexpr (std::is_same_v<T, Type>) {
// Try parsing with callbacks first if available.
for (const auto &callback :
parserConfig.getBytecodeReaderConfig().getTypeCallbacks()) {
if (failed(
callback->read(dialectReader, entry.dialect->name, entry.entry)))
return failure();
// Early return if parsing was successful.
if (!!entry.entry)
return success();
// Reset the reader if we failed to parse, so we can fall through the
// other parsing functions.
reader = EncodingReader(entry.data, reader.getLoc());
}
} else {
// Try parsing with callbacks first if available.
for (const auto &callback :
parserConfig.getBytecodeReaderConfig().getAttributeCallbacks()) {
if (failed(
callback->read(dialectReader, entry.dialect->name, entry.entry)))
return failure();
// Early return if parsing was successful.
if (!!entry.entry)
return success();
// Reset the reader if we failed to parse, so we can fall through the
// other parsing functions.
reader = EncodingReader(entry.data, reader.getLoc());
}
}
// Ensure that the dialect implements the bytecode interface.
if (!entry.dialect->interface) {
return reader.emitError("dialect '", entry.dialect->name,
"' does not implement the bytecode interface");
}
// Ask the dialect to parse the entry. If the dialect is versioned, parse
// using the versioned encoding readers.
if (entry.dialect->loadedVersion.get()) {
if constexpr (std::is_same_v<T, Type>)
entry.entry = entry.dialect->interface->readType(
dialectReader, *entry.dialect->loadedVersion);
else
entry.entry = entry.dialect->interface->readAttribute(
dialectReader, *entry.dialect->loadedVersion);
if constexpr (std::is_same_v<T, Type>)
entry.entry = entry.dialect->interface->readType(dialectReader);
else
entry.entry = entry.dialect->interface->readAttribute(dialectReader);
} else {
if constexpr (std::is_same_v<T, Type>)
entry.entry = entry.dialect->interface->readType(dialectReader);
else
entry.entry = entry.dialect->interface->readAttribute(dialectReader);
}
return success(!!entry.entry);
}
@@ -1262,7 +1315,8 @@ public:
llvm::MemoryBufferRef buffer,
const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef)
: config(config), fileLoc(fileLoc), lazyLoading(lazyLoading),
attrTypeReader(stringReader, resourceReader, fileLoc, version),
attrTypeReader(stringReader, resourceReader, dialectsMap, version,
fileLoc, config),
// Use the builtin unrealized conversion cast operation to represent
// forward references to values that aren't yet defined.
forwardRefOpState(UnknownLoc::get(config.getContext()),
@@ -1528,7 +1582,8 @@ private:
StringRef producer;
/// The table of IR units referenced within the bytecode file.
SmallVector<BytecodeDialect> dialects;
SmallVector<std::unique_ptr<BytecodeDialect>> dialects;
llvm::StringMap<BytecodeDialect *> dialectsMap;
SmallVector<BytecodeOperationName> opNames;
/// The reader used to process resources within the bytecode.
@@ -1675,7 +1730,8 @@ LogicalResult BytecodeReader::Impl::parseVersion(EncodingReader &reader) {
//===----------------------------------------------------------------------===//
// Dialect Section
LogicalResult BytecodeDialect::load(DialectReader &reader, MLIRContext *ctx) {
LogicalResult BytecodeDialect::load(const DialectReader &reader,
MLIRContext *ctx) {
if (dialect)
return success();
Dialect *loadedDialect = ctx->getOrLoadDialect(name);
@@ -1719,13 +1775,15 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) {
// Parse each of the dialects.
for (uint64_t i = 0; i < numDialects; ++i) {
dialects[i] = std::make_unique<BytecodeDialect>();
/// Before version kDialectVersioning, there wasn't any versioning available
/// for dialects, and the entryIdx represent the string itself.
if (version < bytecode::kDialectVersioning) {
if (failed(stringReader.parseString(sectionReader, dialects[i].name)))
if (failed(stringReader.parseString(sectionReader, dialects[i]->name)))
return failure();
continue;
}
// Parse ID representing dialect and version.
uint64_t dialectNameIdx;
bool versionAvailable;
@@ -1733,18 +1791,19 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) {
versionAvailable)))
return failure();
if (failed(stringReader.parseStringAtIndex(sectionReader, dialectNameIdx,
dialects[i].name)))
dialects[i]->name)))
return failure();
if (versionAvailable) {
bytecode::Section::ID sectionID;
if (failed(
sectionReader.parseSection(sectionID, dialects[i].versionBuffer)))
if (failed(sectionReader.parseSection(sectionID,
dialects[i]->versionBuffer)))
return failure();
if (sectionID != bytecode::Section::kDialectVersions) {
emitError(fileLoc, "expected dialect version section");
return failure();
}
}
dialectsMap[dialects[i]->name] = dialects[i].get();
}
// Parse the operation names, which are grouped by dialect.
@@ -1792,7 +1851,7 @@ BytecodeReader::Impl::parseOpName(EncodingReader &reader,
if (!opName->opName) {
// Load the dialect and its version.
DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
reader, version);
dialectsMap, reader, version);
if (failed(opName->dialect->load(dialectReader, getContext())))
return failure();
// If the opName is empty, this is because we use to accept names such as
@@ -1835,7 +1894,7 @@ LogicalResult BytecodeReader::Impl::parseResourceSection(
// Initialize the resource reader with the resource sections.
DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
reader, version);
dialectsMap, reader, version);
return resourceReader.initialize(fileLoc, config, dialects, stringReader,
*resourceData, *resourceOffsetData,
dialectReader, bufferOwnerRef);
@@ -2036,14 +2095,14 @@ BytecodeReader::Impl::parseIRSection(ArrayRef<uint8_t> sectionData,
"parsed use-list orders were invalid and could not be applied");
// Resolve dialect version.
for (const BytecodeDialect &byteCodeDialect : dialects) {
for (const std::unique_ptr<BytecodeDialect> &byteCodeDialect : dialects) {
// Parsing is complete, give an opportunity to each dialect to visit the
// IR and perform upgrades.
if (!byteCodeDialect.loadedVersion)
if (!byteCodeDialect->loadedVersion)
continue;
if (byteCodeDialect.interface &&
failed(byteCodeDialect.interface->upgradeFromVersion(
*moduleOp, *byteCodeDialect.loadedVersion)))
if (byteCodeDialect->interface &&
failed(byteCodeDialect->interface->upgradeFromVersion(
*moduleOp, *byteCodeDialect->loadedVersion)))
return failure();
}
@@ -2196,7 +2255,7 @@ BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader,
// interface and control the serialization.
if (wasRegistered) {
DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
reader, version);
dialectsMap, reader, version);
if (failed(
propertiesReader.read(fileLoc, dialectReader, &*opName, opState)))
return failure();