Expose callbacks for encoding of types/attributes
[mlir] Expose a mechanism to provide a callback for encoding types and attributes in MLIR bytecode. Two callbacks are exposed, respectively, to the BytecodeWriterConfig and to the ParserConfig. At bytecode parsing/printing, clients have the ability to specify a callback to be used to optionally read/write the encoding. On failure, fallback path will execute the default parsers and printers for the dialect. Testing shows how to leverage this functionality to support back-deployment and backward-compatibility usecases when roundtripping to bytecode a client dialect with type/attributes dependencies on upstream. Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D153383
This commit is contained in:
@@ -451,7 +451,7 @@ struct BytecodeDialect {
|
||||
/// Returns failure if the dialect couldn't be loaded *and* the provided
|
||||
/// context does not allow unregistered dialects. The provided reader is used
|
||||
/// for error emission if necessary.
|
||||
LogicalResult load(DialectReader &reader, MLIRContext *ctx);
|
||||
LogicalResult load(const DialectReader &reader, MLIRContext *ctx);
|
||||
|
||||
/// Return the loaded dialect, or nullptr if the dialect is unknown. This can
|
||||
/// only be called after `load`.
|
||||
@@ -505,10 +505,11 @@ struct BytecodeOperationName {
|
||||
|
||||
/// Parse a single dialect group encoded in the byte stream.
|
||||
static LogicalResult parseDialectGrouping(
|
||||
EncodingReader &reader, MutableArrayRef<BytecodeDialect> dialects,
|
||||
EncodingReader &reader,
|
||||
MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
|
||||
function_ref<LogicalResult(BytecodeDialect *)> entryCallback) {
|
||||
// Parse the dialect and the number of entries in the group.
|
||||
BytecodeDialect *dialect;
|
||||
std::unique_ptr<BytecodeDialect> *dialect;
|
||||
if (failed(parseEntry(reader, dialects, dialect, "dialect")))
|
||||
return failure();
|
||||
uint64_t numEntries;
|
||||
@@ -516,7 +517,7 @@ static LogicalResult parseDialectGrouping(
|
||||
return failure();
|
||||
|
||||
for (uint64_t i = 0; i < numEntries; ++i)
|
||||
if (failed(entryCallback(dialect)))
|
||||
if (failed(entryCallback(dialect->get())))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
@@ -532,7 +533,7 @@ public:
|
||||
/// Initialize the resource section reader with the given section data.
|
||||
LogicalResult
|
||||
initialize(Location fileLoc, const ParserConfig &config,
|
||||
MutableArrayRef<BytecodeDialect> dialects,
|
||||
MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
|
||||
StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData,
|
||||
ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader,
|
||||
const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef);
|
||||
@@ -682,7 +683,7 @@ parseResourceGroup(Location fileLoc, bool allowEmpty,
|
||||
|
||||
LogicalResult ResourceSectionReader::initialize(
|
||||
Location fileLoc, const ParserConfig &config,
|
||||
MutableArrayRef<BytecodeDialect> dialects,
|
||||
MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
|
||||
StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData,
|
||||
ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader,
|
||||
const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) {
|
||||
@@ -731,19 +732,19 @@ LogicalResult ResourceSectionReader::initialize(
|
||||
// Read the dialect resources from the bytecode.
|
||||
MLIRContext *ctx = fileLoc->getContext();
|
||||
while (!offsetReader.empty()) {
|
||||
BytecodeDialect *dialect;
|
||||
std::unique_ptr<BytecodeDialect> *dialect;
|
||||
if (failed(parseEntry(offsetReader, dialects, dialect, "dialect")) ||
|
||||
failed(dialect->load(dialectReader, ctx)))
|
||||
failed((*dialect)->load(dialectReader, ctx)))
|
||||
return failure();
|
||||
Dialect *loadedDialect = dialect->getLoadedDialect();
|
||||
Dialect *loadedDialect = (*dialect)->getLoadedDialect();
|
||||
if (!loadedDialect) {
|
||||
return resourceReader.emitError()
|
||||
<< "dialect '" << dialect->name << "' is unknown";
|
||||
<< "dialect '" << (*dialect)->name << "' is unknown";
|
||||
}
|
||||
const auto *handler = dyn_cast<OpAsmDialectInterface>(loadedDialect);
|
||||
if (!handler) {
|
||||
return resourceReader.emitError()
|
||||
<< "unexpected resources for dialect '" << dialect->name << "'";
|
||||
<< "unexpected resources for dialect '" << (*dialect)->name << "'";
|
||||
}
|
||||
|
||||
// Ensure that each resource is declared before being processed.
|
||||
@@ -753,7 +754,7 @@ LogicalResult ResourceSectionReader::initialize(
|
||||
if (failed(handle)) {
|
||||
return resourceReader.emitError()
|
||||
<< "unknown 'resource' key '" << key << "' for dialect '"
|
||||
<< dialect->name << "'";
|
||||
<< (*dialect)->name << "'";
|
||||
}
|
||||
dialectResourceHandleRenamingMap[key] = handler->getResourceKey(*handle);
|
||||
dialectResources.push_back(*handle);
|
||||
@@ -796,15 +797,19 @@ class AttrTypeReader {
|
||||
|
||||
public:
|
||||
AttrTypeReader(StringSectionReader &stringReader,
|
||||
ResourceSectionReader &resourceReader, Location fileLoc,
|
||||
uint64_t &bytecodeVersion)
|
||||
ResourceSectionReader &resourceReader,
|
||||
const llvm::StringMap<BytecodeDialect *> &dialectsMap,
|
||||
uint64_t &bytecodeVersion, Location fileLoc,
|
||||
const ParserConfig &config)
|
||||
: stringReader(stringReader), resourceReader(resourceReader),
|
||||
fileLoc(fileLoc), bytecodeVersion(bytecodeVersion) {}
|
||||
dialectsMap(dialectsMap), fileLoc(fileLoc),
|
||||
bytecodeVersion(bytecodeVersion), parserConfig(config) {}
|
||||
|
||||
/// Initialize the attribute and type information within the reader.
|
||||
LogicalResult initialize(MutableArrayRef<BytecodeDialect> dialects,
|
||||
ArrayRef<uint8_t> sectionData,
|
||||
ArrayRef<uint8_t> offsetSectionData);
|
||||
LogicalResult
|
||||
initialize(MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
|
||||
ArrayRef<uint8_t> sectionData,
|
||||
ArrayRef<uint8_t> offsetSectionData);
|
||||
|
||||
/// Resolve the attribute or type at the given index. Returns nullptr on
|
||||
/// failure.
|
||||
@@ -878,6 +883,10 @@ private:
|
||||
/// parsing custom encoded attribute/type entries.
|
||||
ResourceSectionReader &resourceReader;
|
||||
|
||||
/// The map of the loaded dialects used to retrieve dialect information, such
|
||||
/// as the dialect version.
|
||||
const llvm::StringMap<BytecodeDialect *> &dialectsMap;
|
||||
|
||||
/// The set of attribute and type entries.
|
||||
SmallVector<AttrEntry> attributes;
|
||||
SmallVector<TypeEntry> types;
|
||||
@@ -887,27 +896,48 @@ private:
|
||||
|
||||
/// Current bytecode version being used.
|
||||
uint64_t &bytecodeVersion;
|
||||
|
||||
/// Reference to the parser configuration.
|
||||
const ParserConfig &parserConfig;
|
||||
};
|
||||
|
||||
class DialectReader : public DialectBytecodeReader {
|
||||
public:
|
||||
DialectReader(AttrTypeReader &attrTypeReader,
|
||||
StringSectionReader &stringReader,
|
||||
ResourceSectionReader &resourceReader, EncodingReader &reader,
|
||||
uint64_t &bytecodeVersion)
|
||||
ResourceSectionReader &resourceReader,
|
||||
const llvm::StringMap<BytecodeDialect *> &dialectsMap,
|
||||
EncodingReader &reader, uint64_t &bytecodeVersion)
|
||||
: attrTypeReader(attrTypeReader), stringReader(stringReader),
|
||||
resourceReader(resourceReader), reader(reader),
|
||||
bytecodeVersion(bytecodeVersion) {}
|
||||
resourceReader(resourceReader), dialectsMap(dialectsMap),
|
||||
reader(reader), bytecodeVersion(bytecodeVersion) {}
|
||||
|
||||
InFlightDiagnostic emitError(const Twine &msg) override {
|
||||
InFlightDiagnostic emitError(const Twine &msg) const override {
|
||||
return reader.emitError(msg);
|
||||
}
|
||||
|
||||
FailureOr<const DialectVersion *>
|
||||
getDialectVersion(StringRef dialectName) const override {
|
||||
// First check if the dialect is available in the map.
|
||||
auto dialectEntry = dialectsMap.find(dialectName);
|
||||
if (dialectEntry == dialectsMap.end())
|
||||
return failure();
|
||||
// If the dialect was found, try to load it. This will trigger reading the
|
||||
// bytecode version from the version buffer if it wasn't already processed.
|
||||
// Return failure if either of those two actions could not be completed.
|
||||
if (failed(dialectEntry->getValue()->load(*this, getLoc().getContext())) ||
|
||||
dialectEntry->getValue()->loadedVersion.get() == nullptr)
|
||||
return failure();
|
||||
return dialectEntry->getValue()->loadedVersion.get();
|
||||
}
|
||||
|
||||
MLIRContext *getContext() const override { return getLoc().getContext(); }
|
||||
|
||||
uint64_t getBytecodeVersion() const override { return bytecodeVersion; }
|
||||
|
||||
DialectReader withEncodingReader(EncodingReader &encReader) {
|
||||
DialectReader withEncodingReader(EncodingReader &encReader) const {
|
||||
return DialectReader(attrTypeReader, stringReader, resourceReader,
|
||||
encReader, bytecodeVersion);
|
||||
dialectsMap, encReader, bytecodeVersion);
|
||||
}
|
||||
|
||||
Location getLoc() const { return reader.getLoc(); }
|
||||
@@ -1010,6 +1040,7 @@ private:
|
||||
AttrTypeReader &attrTypeReader;
|
||||
StringSectionReader &stringReader;
|
||||
ResourceSectionReader &resourceReader;
|
||||
const llvm::StringMap<BytecodeDialect *> &dialectsMap;
|
||||
EncodingReader &reader;
|
||||
uint64_t &bytecodeVersion;
|
||||
};
|
||||
@@ -1096,10 +1127,9 @@ private:
|
||||
};
|
||||
} // namespace
|
||||
|
||||
LogicalResult
|
||||
AttrTypeReader::initialize(MutableArrayRef<BytecodeDialect> dialects,
|
||||
ArrayRef<uint8_t> sectionData,
|
||||
ArrayRef<uint8_t> offsetSectionData) {
|
||||
LogicalResult AttrTypeReader::initialize(
|
||||
MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
|
||||
ArrayRef<uint8_t> sectionData, ArrayRef<uint8_t> offsetSectionData) {
|
||||
EncodingReader offsetReader(offsetSectionData, fileLoc);
|
||||
|
||||
// Parse the number of attribute and type entries.
|
||||
@@ -1151,6 +1181,7 @@ AttrTypeReader::initialize(MutableArrayRef<BytecodeDialect> dialects,
|
||||
return offsetReader.emitError(
|
||||
"unexpected trailing data in the Attribute/Type offset section");
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -1216,32 +1247,54 @@ template <typename T>
|
||||
LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry,
|
||||
EncodingReader &reader,
|
||||
StringRef entryType) {
|
||||
DialectReader dialectReader(*this, stringReader, resourceReader, reader,
|
||||
bytecodeVersion);
|
||||
DialectReader dialectReader(*this, stringReader, resourceReader, dialectsMap,
|
||||
reader, bytecodeVersion);
|
||||
if (failed(entry.dialect->load(dialectReader, fileLoc.getContext())))
|
||||
return failure();
|
||||
|
||||
if constexpr (std::is_same_v<T, Type>) {
|
||||
// Try parsing with callbacks first if available.
|
||||
for (const auto &callback :
|
||||
parserConfig.getBytecodeReaderConfig().getTypeCallbacks()) {
|
||||
if (failed(
|
||||
callback->read(dialectReader, entry.dialect->name, entry.entry)))
|
||||
return failure();
|
||||
// Early return if parsing was successful.
|
||||
if (!!entry.entry)
|
||||
return success();
|
||||
|
||||
// Reset the reader if we failed to parse, so we can fall through the
|
||||
// other parsing functions.
|
||||
reader = EncodingReader(entry.data, reader.getLoc());
|
||||
}
|
||||
} else {
|
||||
// Try parsing with callbacks first if available.
|
||||
for (const auto &callback :
|
||||
parserConfig.getBytecodeReaderConfig().getAttributeCallbacks()) {
|
||||
if (failed(
|
||||
callback->read(dialectReader, entry.dialect->name, entry.entry)))
|
||||
return failure();
|
||||
// Early return if parsing was successful.
|
||||
if (!!entry.entry)
|
||||
return success();
|
||||
|
||||
// Reset the reader if we failed to parse, so we can fall through the
|
||||
// other parsing functions.
|
||||
reader = EncodingReader(entry.data, reader.getLoc());
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that the dialect implements the bytecode interface.
|
||||
if (!entry.dialect->interface) {
|
||||
return reader.emitError("dialect '", entry.dialect->name,
|
||||
"' does not implement the bytecode interface");
|
||||
}
|
||||
|
||||
// Ask the dialect to parse the entry. If the dialect is versioned, parse
|
||||
// using the versioned encoding readers.
|
||||
if (entry.dialect->loadedVersion.get()) {
|
||||
if constexpr (std::is_same_v<T, Type>)
|
||||
entry.entry = entry.dialect->interface->readType(
|
||||
dialectReader, *entry.dialect->loadedVersion);
|
||||
else
|
||||
entry.entry = entry.dialect->interface->readAttribute(
|
||||
dialectReader, *entry.dialect->loadedVersion);
|
||||
if constexpr (std::is_same_v<T, Type>)
|
||||
entry.entry = entry.dialect->interface->readType(dialectReader);
|
||||
else
|
||||
entry.entry = entry.dialect->interface->readAttribute(dialectReader);
|
||||
|
||||
} else {
|
||||
if constexpr (std::is_same_v<T, Type>)
|
||||
entry.entry = entry.dialect->interface->readType(dialectReader);
|
||||
else
|
||||
entry.entry = entry.dialect->interface->readAttribute(dialectReader);
|
||||
}
|
||||
return success(!!entry.entry);
|
||||
}
|
||||
|
||||
@@ -1262,7 +1315,8 @@ public:
|
||||
llvm::MemoryBufferRef buffer,
|
||||
const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef)
|
||||
: config(config), fileLoc(fileLoc), lazyLoading(lazyLoading),
|
||||
attrTypeReader(stringReader, resourceReader, fileLoc, version),
|
||||
attrTypeReader(stringReader, resourceReader, dialectsMap, version,
|
||||
fileLoc, config),
|
||||
// Use the builtin unrealized conversion cast operation to represent
|
||||
// forward references to values that aren't yet defined.
|
||||
forwardRefOpState(UnknownLoc::get(config.getContext()),
|
||||
@@ -1528,7 +1582,8 @@ private:
|
||||
StringRef producer;
|
||||
|
||||
/// The table of IR units referenced within the bytecode file.
|
||||
SmallVector<BytecodeDialect> dialects;
|
||||
SmallVector<std::unique_ptr<BytecodeDialect>> dialects;
|
||||
llvm::StringMap<BytecodeDialect *> dialectsMap;
|
||||
SmallVector<BytecodeOperationName> opNames;
|
||||
|
||||
/// The reader used to process resources within the bytecode.
|
||||
@@ -1675,7 +1730,8 @@ LogicalResult BytecodeReader::Impl::parseVersion(EncodingReader &reader) {
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Dialect Section
|
||||
|
||||
LogicalResult BytecodeDialect::load(DialectReader &reader, MLIRContext *ctx) {
|
||||
LogicalResult BytecodeDialect::load(const DialectReader &reader,
|
||||
MLIRContext *ctx) {
|
||||
if (dialect)
|
||||
return success();
|
||||
Dialect *loadedDialect = ctx->getOrLoadDialect(name);
|
||||
@@ -1719,13 +1775,15 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) {
|
||||
|
||||
// Parse each of the dialects.
|
||||
for (uint64_t i = 0; i < numDialects; ++i) {
|
||||
dialects[i] = std::make_unique<BytecodeDialect>();
|
||||
/// Before version kDialectVersioning, there wasn't any versioning available
|
||||
/// for dialects, and the entryIdx represent the string itself.
|
||||
if (version < bytecode::kDialectVersioning) {
|
||||
if (failed(stringReader.parseString(sectionReader, dialects[i].name)))
|
||||
if (failed(stringReader.parseString(sectionReader, dialects[i]->name)))
|
||||
return failure();
|
||||
continue;
|
||||
}
|
||||
|
||||
// Parse ID representing dialect and version.
|
||||
uint64_t dialectNameIdx;
|
||||
bool versionAvailable;
|
||||
@@ -1733,18 +1791,19 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) {
|
||||
versionAvailable)))
|
||||
return failure();
|
||||
if (failed(stringReader.parseStringAtIndex(sectionReader, dialectNameIdx,
|
||||
dialects[i].name)))
|
||||
dialects[i]->name)))
|
||||
return failure();
|
||||
if (versionAvailable) {
|
||||
bytecode::Section::ID sectionID;
|
||||
if (failed(
|
||||
sectionReader.parseSection(sectionID, dialects[i].versionBuffer)))
|
||||
if (failed(sectionReader.parseSection(sectionID,
|
||||
dialects[i]->versionBuffer)))
|
||||
return failure();
|
||||
if (sectionID != bytecode::Section::kDialectVersions) {
|
||||
emitError(fileLoc, "expected dialect version section");
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
dialectsMap[dialects[i]->name] = dialects[i].get();
|
||||
}
|
||||
|
||||
// Parse the operation names, which are grouped by dialect.
|
||||
@@ -1792,7 +1851,7 @@ BytecodeReader::Impl::parseOpName(EncodingReader &reader,
|
||||
if (!opName->opName) {
|
||||
// Load the dialect and its version.
|
||||
DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
|
||||
reader, version);
|
||||
dialectsMap, reader, version);
|
||||
if (failed(opName->dialect->load(dialectReader, getContext())))
|
||||
return failure();
|
||||
// If the opName is empty, this is because we use to accept names such as
|
||||
@@ -1835,7 +1894,7 @@ LogicalResult BytecodeReader::Impl::parseResourceSection(
|
||||
|
||||
// Initialize the resource reader with the resource sections.
|
||||
DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
|
||||
reader, version);
|
||||
dialectsMap, reader, version);
|
||||
return resourceReader.initialize(fileLoc, config, dialects, stringReader,
|
||||
*resourceData, *resourceOffsetData,
|
||||
dialectReader, bufferOwnerRef);
|
||||
@@ -2036,14 +2095,14 @@ BytecodeReader::Impl::parseIRSection(ArrayRef<uint8_t> sectionData,
|
||||
"parsed use-list orders were invalid and could not be applied");
|
||||
|
||||
// Resolve dialect version.
|
||||
for (const BytecodeDialect &byteCodeDialect : dialects) {
|
||||
for (const std::unique_ptr<BytecodeDialect> &byteCodeDialect : dialects) {
|
||||
// Parsing is complete, give an opportunity to each dialect to visit the
|
||||
// IR and perform upgrades.
|
||||
if (!byteCodeDialect.loadedVersion)
|
||||
if (!byteCodeDialect->loadedVersion)
|
||||
continue;
|
||||
if (byteCodeDialect.interface &&
|
||||
failed(byteCodeDialect.interface->upgradeFromVersion(
|
||||
*moduleOp, *byteCodeDialect.loadedVersion)))
|
||||
if (byteCodeDialect->interface &&
|
||||
failed(byteCodeDialect->interface->upgradeFromVersion(
|
||||
*moduleOp, *byteCodeDialect->loadedVersion)))
|
||||
return failure();
|
||||
}
|
||||
|
||||
@@ -2196,7 +2255,7 @@ BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader,
|
||||
// interface and control the serialization.
|
||||
if (wasRegistered) {
|
||||
DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
|
||||
reader, version);
|
||||
dialectsMap, reader, version);
|
||||
if (failed(
|
||||
propertiesReader.read(fileLoc, dialectReader, &*opName, opState)))
|
||||
return failure();
|
||||
|
||||
Reference in New Issue
Block a user