[mlir:Bytecode] Add support for encoding resources

Resources are encoded in two separate sections similarly to
attributes/types, one for the actual data and one for the data
offsets. Unlike other sections, the resource sections are optional
given that in many cases they won't be present. For testing,
bytecode serialization is added for DenseResourceElementsAttr.

Differential Revision: https://reviews.llvm.org/D132729
This commit is contained in:
River Riddle
2022-09-06 20:47:57 -07:00
parent e166d2e00b
commit 6ab2bcffe4
14 changed files with 994 additions and 47 deletions

View File

@@ -20,6 +20,7 @@
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/MemoryBufferRef.h"
#include "llvm/Support/SaveAndRestore.h"
@@ -40,11 +41,32 @@ static std::string toString(bytecode::Section::ID sectionID) {
return "AttrTypeOffset (3)";
case bytecode::Section::kIR:
return "IR (4)";
case bytecode::Section::kResource:
return "Resource (5)";
case bytecode::Section::kResourceOffset:
return "ResourceOffset (6)";
default:
return ("Unknown (" + Twine(static_cast<unsigned>(sectionID)) + ")").str();
}
}
/// Returns true if the given top-level section ID is optional.
static bool isSectionOptional(bytecode::Section::ID sectionID) {
switch (sectionID) {
case bytecode::Section::kString:
case bytecode::Section::kDialect:
case bytecode::Section::kAttrType:
case bytecode::Section::kAttrTypeOffset:
case bytecode::Section::kIR:
return false;
case bytecode::Section::kResource:
case bytecode::Section::kResourceOffset:
return true;
default:
llvm_unreachable("unknown section ID");
}
}
//===----------------------------------------------------------------------===//
// EncodingReader
//===----------------------------------------------------------------------===//
@@ -65,11 +87,34 @@ public:
/// Returns the remaining size of the bytecode.
size_t size() const { return dataEnd - dataIt; }
/// Align the current reader position to the specified alignment.
LogicalResult alignTo(unsigned alignment) {
if (!llvm::isPowerOf2_32(alignment))
return emitError("expected alignment to be a power-of-two");
// Shift the reader position to the next alignment boundary.
while (uintptr_t(dataIt) & (uintptr_t(alignment) - 1)) {
uint8_t padding;
if (failed(parseByte(padding)))
return failure();
if (padding != bytecode::kAlignmentByte) {
return emitError("expected alignment byte (0xCB), but got: '0x" +
llvm::utohexstr(padding) + "'");
}
}
// TODO: Check that the current data pointer is actually at the expected
// alignment.
return success();
}
/// Emit an error using the given arguments.
template <typename... Args>
InFlightDiagnostic emitError(Args &&...args) const {
return ::emitError(fileLoc).append(std::forward<Args>(args)...);
}
InFlightDiagnostic emitError() const { return ::emitError(fileLoc); }
/// Parse a single byte from the stream.
template <typename T>
@@ -101,6 +146,17 @@ public:
return success();
}
/// Parse an aligned blob of data, where the alignment was encoded alongside
/// the data.
LogicalResult parseBlobAndAlignment(ArrayRef<uint8_t> &data,
uint64_t &alignment) {
uint64_t dataSize;
if (failed(parseVarInt(alignment)) || failed(parseVarInt(dataSize)) ||
failed(alignTo(alignment)))
return failure();
return parseBytes(dataSize, data);
}
/// Parse a variable length encoded integer from the byte stream. The first
/// encoded byte contains a prefix in the low bits indicating the encoded
/// length of the value. This length prefix is a bit sequence of '0's followed
@@ -177,13 +233,31 @@ public:
/// contents of the section in `sectionData`.
LogicalResult parseSection(bytecode::Section::ID &sectionID,
ArrayRef<uint8_t> &sectionData) {
uint8_t sectionIDAndHasAlignment;
uint64_t length;
if (failed(parseByte(sectionID)) || failed(parseVarInt(length)))
if (failed(parseByte(sectionIDAndHasAlignment)) ||
failed(parseVarInt(length)))
return failure();
// Extract the section ID and whether the section is aligned. The high bit
// of the ID is the alignment flag.
sectionID = static_cast<bytecode::Section::ID>(sectionIDAndHasAlignment &
0b01111111);
bool hasAlignment = sectionIDAndHasAlignment & 0b10000000;
// Check that the section is actually valid before trying to process its
// data.
if (sectionID >= bytecode::Section::kNumSections)
return emitError("invalid section ID: ", unsigned(sectionID));
// Parse the actua section data now that we have its length.
// Process the section alignment if present.
if (hasAlignment) {
uint64_t alignment;
if (failed(parseVarInt(alignment)) || failed(alignTo(alignment)))
return failure();
}
// Parse the actual section data.
return parseBytes(static_cast<size_t>(length), sectionData);
}
@@ -346,6 +420,14 @@ struct BytecodeDialect {
return success();
}
/// Return the loaded dialect, or nullptr if the dialect is unknown. This can
/// only be called after `load`.
Dialect *getLoadedDialect() const {
assert(dialect &&
"expected `load` to be invoked before `getLoadedDialect`");
return *dialect;
}
/// The loaded dialect entry. This field is None if we haven't attempted to
/// load, nullptr if we failed to load, otherwise the loaded dialect.
Optional<Dialect *> dialect;
@@ -393,6 +475,225 @@ static LogicalResult parseDialectGrouping(
return success();
}
//===----------------------------------------------------------------------===//
// ResourceSectionReader
//===----------------------------------------------------------------------===//
namespace {
/// This class is used to read the resource section from the bytecode.
class ResourceSectionReader {
public:
/// Initialize the resource section reader with the given section data.
LogicalResult initialize(Location fileLoc, const ParserConfig &config,
MutableArrayRef<BytecodeDialect> dialects,
StringSectionReader &stringReader,
ArrayRef<uint8_t> sectionData,
ArrayRef<uint8_t> offsetSectionData);
/// Parse a dialect resource handle from the resource section.
LogicalResult parseResourceHandle(EncodingReader &reader,
AsmDialectResourceHandle &result) {
return parseEntry(reader, dialectResources, result, "resource handle");
}
private:
/// The table of dialect resources within the bytecode file.
SmallVector<AsmDialectResourceHandle> dialectResources;
};
class ParsedResourceEntry : public AsmParsedResourceEntry {
public:
ParsedResourceEntry(StringRef key, AsmResourceEntryKind kind,
EncodingReader &reader, StringSectionReader &stringReader)
: key(key), kind(kind), reader(reader), stringReader(stringReader) {}
~ParsedResourceEntry() override = default;
StringRef getKey() const final { return key; }
InFlightDiagnostic emitError() const final { return reader.emitError(); }
AsmResourceEntryKind getKind() const final { return kind; }
FailureOr<bool> parseAsBool() const final {
if (kind != AsmResourceEntryKind::Bool)
return emitError() << "expected a bool resource entry, but found a "
<< toString(kind) << " entry instead";
bool value;
if (failed(reader.parseByte(value)))
return failure();
return value;
}
FailureOr<std::string> parseAsString() const final {
if (kind != AsmResourceEntryKind::String)
return emitError() << "expected a string resource entry, but found a "
<< toString(kind) << " entry instead";
StringRef string;
if (failed(stringReader.parseString(reader, string)))
return failure();
return string.str();
}
FailureOr<AsmResourceBlob>
parseAsBlob(BlobAllocatorFn allocator) const final {
if (kind != AsmResourceEntryKind::Blob)
return emitError() << "expected a blob resource entry, but found a "
<< toString(kind) << " entry instead";
ArrayRef<uint8_t> data;
uint64_t alignment;
if (failed(reader.parseBlobAndAlignment(data, alignment)))
return failure();
// Allocate memory for the blob using the provided allocator and copy the
// data into it.
// FIXME: If the current holder of the bytecode can ensure its lifetime
// (e.g. when mmap'd), we should not copy the data. We should use the data
// from the bytecode directly.
AsmResourceBlob blob = allocator(data.size(), alignment);
assert(llvm::isAddrAligned(llvm::Align(alignment), blob.getData().data()) &&
blob.isMutable() &&
"blob allocator did not return a properly aligned address");
memcpy(blob.getMutableData().data(), data.data(), data.size());
return blob;
}
private:
StringRef key;
AsmResourceEntryKind kind;
EncodingReader &reader;
StringSectionReader &stringReader;
};
} // namespace
template <typename T>
static LogicalResult
parseResourceGroup(Location fileLoc, bool allowEmpty,
EncodingReader &offsetReader, EncodingReader &resourceReader,
StringSectionReader &stringReader, T *handler,
function_ref<LogicalResult(StringRef)> processKeyFn = {}) {
uint64_t numResources;
if (failed(offsetReader.parseVarInt(numResources)))
return failure();
for (uint64_t i = 0; i < numResources; ++i) {
StringRef key;
AsmResourceEntryKind kind;
uint64_t resourceOffset;
ArrayRef<uint8_t> data;
if (failed(stringReader.parseString(offsetReader, key)) ||
failed(offsetReader.parseVarInt(resourceOffset)) ||
failed(offsetReader.parseByte(kind)) ||
failed(resourceReader.parseBytes(resourceOffset, data)))
return failure();
// Process the resource key.
if ((processKeyFn && failed(processKeyFn(key))))
return failure();
// If the resource data is empty and we allow it, don't error out when
// parsing below, just skip it.
if (allowEmpty && data.empty())
continue;
// Ignore the entry if we don't have a valid handler.
if (!handler)
continue;
// Otherwise, parse the resource value.
EncodingReader entryReader(data, fileLoc);
ParsedResourceEntry entry(key, kind, entryReader, stringReader);
if (failed(handler->parseResource(entry)))
return failure();
if (!entryReader.empty()) {
return entryReader.emitError(
"unexpected trailing bytes in resource entry '", key, "'");
}
}
return success();
}
LogicalResult
ResourceSectionReader::initialize(Location fileLoc, const ParserConfig &config,
MutableArrayRef<BytecodeDialect> dialects,
StringSectionReader &stringReader,
ArrayRef<uint8_t> sectionData,
ArrayRef<uint8_t> offsetSectionData) {
EncodingReader resourceReader(sectionData, fileLoc);
EncodingReader offsetReader(offsetSectionData, fileLoc);
// Read the number of external resource providers.
uint64_t numExternalResourceGroups;
if (failed(offsetReader.parseVarInt(numExternalResourceGroups)))
return failure();
// Utility functor that dispatches to `parseResourceGroup`, but implicitly
// provides most of the arguments.
auto parseGroup = [&](auto *handler, bool allowEmpty = false,
function_ref<LogicalResult(StringRef)> keyFn = {}) {
return parseResourceGroup(fileLoc, allowEmpty, offsetReader, resourceReader,
stringReader, handler, keyFn);
};
// Read the external resources from the bytecode.
for (uint64_t i = 0; i < numExternalResourceGroups; ++i) {
StringRef key;
if (failed(stringReader.parseString(offsetReader, key)))
return failure();
// Get the handler for these resources.
// TODO: Should we require handling external resources in some scenarios?
AsmResourceParser *handler = config.getResourceParser(key);
if (!handler) {
emitWarning(fileLoc) << "ignoring unknown external resources for '" << key
<< "'";
}
if (failed(parseGroup(handler)))
return failure();
}
// Read the dialect resources from the bytecode.
MLIRContext *ctx = fileLoc->getContext();
while (!offsetReader.empty()) {
BytecodeDialect *dialect;
if (failed(parseEntry(offsetReader, dialects, dialect, "dialect")) ||
failed(dialect->load(resourceReader, ctx)))
return failure();
Dialect *loadedDialect = dialect->getLoadedDialect();
if (!loadedDialect) {
return resourceReader.emitError()
<< "dialect '" << dialect->name << "' is unknown";
}
const auto *handler = dyn_cast<OpAsmDialectInterface>(loadedDialect);
if (!handler) {
return resourceReader.emitError()
<< "unexpected resources for dialect '" << dialect->name << "'";
}
// Ensure that each resource is declared before being processed.
auto processResourceKeyFn = [&](StringRef key) -> LogicalResult {
FailureOr<AsmDialectResourceHandle> handle =
handler->declareResource(key);
if (failed(handle)) {
return resourceReader.emitError()
<< "unknown 'resource' key '" << key << "' for dialect '"
<< dialect->name << "'";
}
dialectResources.push_back(*handle);
return success();
};
// Parse the resources for this dialect. We allow empty resources because we
// just treat these as declarations.
if (failed(parseGroup(handler, /*allowEmpty=*/true, processResourceKeyFn)))
return failure();
}
return success();
}
//===----------------------------------------------------------------------===//
// Attribute/Type Reader
//===----------------------------------------------------------------------===//
@@ -419,8 +720,10 @@ class AttrTypeReader {
using TypeEntry = Entry<Type>;
public:
AttrTypeReader(StringSectionReader &stringReader, Location fileLoc)
: stringReader(stringReader), fileLoc(fileLoc) {}
AttrTypeReader(StringSectionReader &stringReader,
ResourceSectionReader &resourceReader, Location fileLoc)
: stringReader(stringReader), resourceReader(resourceReader),
fileLoc(fileLoc) {}
/// Initialize the attribute and type information within the reader.
LogicalResult initialize(MutableArrayRef<BytecodeDialect> dialects,
@@ -483,6 +786,10 @@ private:
/// custom encoded attribute/type entries.
StringSectionReader &stringReader;
/// The resource section reader used to resolve resource references when
/// parsing custom encoded attribute/type entries.
ResourceSectionReader &resourceReader;
/// The set of attribute and type entries.
SmallVector<AttrEntry> attributes;
SmallVector<TypeEntry> types;
@@ -494,9 +801,10 @@ private:
class DialectReader : public DialectBytecodeReader {
public:
DialectReader(AttrTypeReader &attrTypeReader,
StringSectionReader &stringReader, EncodingReader &reader)
StringSectionReader &stringReader,
ResourceSectionReader &resourceReader, EncodingReader &reader)
: attrTypeReader(attrTypeReader), stringReader(stringReader),
reader(reader) {}
resourceReader(resourceReader), reader(reader) {}
InFlightDiagnostic emitError(const Twine &msg) override {
return reader.emitError(msg);
@@ -514,6 +822,13 @@ public:
return attrTypeReader.parseType(reader, result);
}
FailureOr<AsmDialectResourceHandle> readResourceHandle() override {
AsmDialectResourceHandle handle;
if (failed(resourceReader.parseResourceHandle(reader, handle)))
return failure();
return handle;
}
//===--------------------------------------------------------------------===//
// Primitives
//===--------------------------------------------------------------------===//
@@ -575,6 +890,7 @@ public:
private:
AttrTypeReader &attrTypeReader;
StringSectionReader &stringReader;
ResourceSectionReader &resourceReader;
EncodingReader &reader;
};
} // namespace
@@ -707,7 +1023,7 @@ LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry,
}
// Ask the dialect to parse the entry.
DialectReader dialectReader(*this, stringReader, reader);
DialectReader dialectReader(*this, stringReader, resourceReader, reader);
if constexpr (std::is_same_v<T, Type>)
entry.entry = entry.dialect->interface->readType(dialectReader);
else
@@ -724,7 +1040,8 @@ namespace {
class BytecodeReader {
public:
BytecodeReader(Location fileLoc, const ParserConfig &config)
: config(config), fileLoc(fileLoc), attrTypeReader(stringReader, fileLoc),
: config(config), fileLoc(fileLoc),
attrTypeReader(stringReader, resourceReader, fileLoc),
// Use the builtin unrealized conversion cast operation to represent
// forward references to values that aren't yet defined.
forwardRefOpState(UnknownLoc::get(config.getContext()),
@@ -761,6 +1078,13 @@ private:
return attrTypeReader.parseType(reader, result);
}
//===--------------------------------------------------------------------===//
// Resource Section
LogicalResult
parseResourceSection(Optional<ArrayRef<uint8_t>> resourceData,
Optional<ArrayRef<uint8_t>> resourceOffsetData);
//===--------------------------------------------------------------------===//
// IR Section
@@ -863,6 +1187,9 @@ private:
SmallVector<BytecodeDialect> dialects;
SmallVector<BytecodeOperationName> opNames;
/// The reader used to process resources within the bytecode.
ResourceSectionReader resourceReader;
/// The table of strings referenced within the bytecode file.
StringSectionReader stringReader;
@@ -914,11 +1241,12 @@ LogicalResult BytecodeReader::read(llvm::MemoryBufferRef buffer, Block *block) {
}
sectionDatas[sectionID] = sectionData;
}
// Check that all of the sections were found.
// Check that all of the required sections were found.
for (int i = 0; i < bytecode::Section::kNumSections; ++i) {
if (!sectionDatas[i]) {
bytecode::Section::ID sectionID = static_cast<bytecode::Section::ID>(i);
if (!sectionDatas[i] && !isSectionOptional(sectionID)) {
return reader.emitError("missing data for top-level section: ",
toString(bytecode::Section::ID(i)));
toString(sectionID));
}
}
@@ -931,6 +1259,12 @@ LogicalResult BytecodeReader::read(llvm::MemoryBufferRef buffer, Block *block) {
if (failed(parseDialectSection(*sectionDatas[bytecode::Section::kDialect])))
return failure();
// Process the resource section if present.
if (failed(parseResourceSection(
sectionDatas[bytecode::Section::kResource],
sectionDatas[bytecode::Section::kResourceOffset])))
return failure();
// Process the attribute and type section.
if (failed(attrTypeReader.initialize(
dialects, *sectionDatas[bytecode::Section::kAttrType],
@@ -1008,6 +1342,31 @@ FailureOr<OperationName> BytecodeReader::parseOpName(EncodingReader &reader) {
return *opName->opName;
}
//===----------------------------------------------------------------------===//
// Resource Section
LogicalResult BytecodeReader::parseResourceSection(
Optional<ArrayRef<uint8_t>> resourceData,
Optional<ArrayRef<uint8_t>> resourceOffsetData) {
// Ensure both sections are either present or not.
if (resourceData.has_value() != resourceOffsetData.has_value()) {
if (resourceOffsetData)
return emitError(fileLoc, "unexpected resource offset section when "
"resource section is not present");
return emitError(
fileLoc,
"expected resource offset section when resource section is present");
}
// If the resource sections are absent, there is nothing to do.
if (!resourceData)
return success();
// Initialize the resource reader with the resource sections.
return resourceReader.initialize(fileLoc, config, dialects, stringReader,
*resourceData, *resourceOffsetData);
}
//===----------------------------------------------------------------------===//
// IR Section