[mlir:Bytecode] Add initial support for dialect defined attribute/type encodings

Dialects can opt-in to providing custom encodings by implementing the
`BytecodeDialectInterface`. This interface provides hooks, namely
`readAttribute`/`readType` and `writeAttribute`/`writeType`, that will be used
by the bytecode reader and writer. These hooks are provided a reader and writer
implementation that can be used to encode various constructs in the underlying
bytecode format. A unique feature of this interface is that dialects may choose
to only encode a subset of their attributes and types in a custom bytecode
format, which can simplify adding new or experimental components that aren't
fully baked.

Differential Revision: https://reviews.llvm.org/D132498
This commit is contained in:
River Riddle
2022-08-23 12:56:02 -07:00
parent b3449392f5
commit 02c2ecb9c6
14 changed files with 791 additions and 24 deletions

View File

@@ -12,6 +12,7 @@
#include "mlir/Bytecode/BytecodeReader.h"
#include "../Encoding.h"
#include "mlir/AsmParser/AsmParser.h"
#include "mlir/Bytecode/BytecodeImplementation.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/OpImplementation.h"
@@ -66,7 +67,7 @@ public:
/// Emit an error using the given arguments.
template <typename... Args>
LogicalResult emitError(Args &&...args) const {
InFlightDiagnostic emitError(Args &&...args) const {
return ::emitError(fileLoc).append(std::forward<Args>(args)...);
}
@@ -326,6 +327,11 @@ struct BytecodeDialect {
"-allow-unregistered-dialect with the MLIR tool used.");
}
dialect = loadedDialect;
// If the dialect was actually loaded, check to see if it has a bytecode
// interface.
if (loadedDialect)
interface = dyn_cast<BytecodeDialectInterface>(loadedDialect);
return success();
}
@@ -333,6 +339,11 @@ struct BytecodeDialect {
/// load, nullptr if we failed to load, otherwise the loaded dialect.
Optional<Dialect *> dialect;
/// The bytecode interface of the dialect, or nullptr if the dialect does not
/// implement the bytecode interface. This field should only be checked if the
/// `dialect` field is non-None.
const BytecodeDialectInterface *interface = nullptr;
/// The name of the dialect.
StringRef name;
};
@@ -397,7 +408,8 @@ class AttrTypeReader {
using TypeEntry = Entry<Type>;
public:
AttrTypeReader(Location fileLoc) : fileLoc(fileLoc) {}
AttrTypeReader(StringSectionReader &stringReader, Location fileLoc)
: stringReader(stringReader), fileLoc(fileLoc) {}
/// Initialize the attribute and type information within the reader.
LogicalResult initialize(MutableArrayRef<BytecodeDialect> dialects,
@@ -456,6 +468,10 @@ private:
LogicalResult parseCustomEntry(Entry<T> &entry, EncodingReader &reader,
StringRef entryType);
/// The string section reader used to resolve string references when parsing
/// custom encoded attribute/type entries.
StringSectionReader &stringReader;
/// The set of attribute and type entries.
SmallVector<AttrEntry> attributes;
SmallVector<TypeEntry> types;
@@ -463,6 +479,47 @@ private:
/// A location used for error emission.
Location fileLoc;
};
class DialectReader : public DialectBytecodeReader {
public:
DialectReader(AttrTypeReader &attrTypeReader,
StringSectionReader &stringReader, EncodingReader &reader)
: attrTypeReader(attrTypeReader), stringReader(stringReader),
reader(reader) {}
InFlightDiagnostic emitError(const Twine &msg) override {
return reader.emitError(msg);
}
//===--------------------------------------------------------------------===//
// IR
//===--------------------------------------------------------------------===//
LogicalResult readAttribute(Attribute &result) override {
return attrTypeReader.parseAttribute(reader, result);
}
LogicalResult readType(Type &result) override {
return attrTypeReader.parseType(reader, result);
}
//===--------------------------------------------------------------------===//
// Primitives
//===--------------------------------------------------------------------===//
LogicalResult readVarInt(uint64_t &result) override {
return reader.parseVarInt(result);
}
LogicalResult readString(StringRef &result) override {
return stringReader.parseString(reader, result);
}
private:
AttrTypeReader &attrTypeReader;
StringSectionReader &stringReader;
EncodingReader &reader;
};
} // namespace
LogicalResult
@@ -486,7 +543,7 @@ AttrTypeReader::initialize(MutableArrayRef<BytecodeDialect> dialects,
size_t currentIndex = 0, endIndex = range.size();
// Parse an individual entry.
auto parseEntryFn = [&](BytecodeDialect *dialect) {
auto parseEntryFn = [&](BytecodeDialect *dialect) -> LogicalResult {
auto &entry = range[currentIndex++];
uint64_t entrySize;
@@ -548,8 +605,7 @@ T AttrTypeReader::resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index,
}
if (!reader.empty()) {
(void)reader.emitError("unexpected trailing bytes after " + entryType +
" entry");
reader.emitError("unexpected trailing bytes after " + entryType + " entry");
return T();
}
return entry.entry;
@@ -584,8 +640,22 @@ template <typename T>
LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry,
EncodingReader &reader,
StringRef entryType) {
// FIXME: Add support for reading custom attribute/type encodings.
return reader.emitError("unexpected Attribute encoding");
if (failed(entry.dialect->load(reader, fileLoc.getContext())))
return failure();
// Ensure that the dialect implements the bytecode interface.
if (!entry.dialect->interface) {
return reader.emitError("dialect '", entry.dialect->name,
"' does not implement the bytecode interface");
}
// Ask the dialect to parse the entry.
DialectReader dialectReader(*this, stringReader, reader);
if constexpr (std::is_same_v<T, Type>)
entry.entry = entry.dialect->interface->readType(dialectReader);
else
entry.entry = entry.dialect->interface->readAttribute(dialectReader);
return success(!!entry.entry);
}
//===----------------------------------------------------------------------===//
@@ -597,7 +667,7 @@ namespace {
class BytecodeReader {
public:
BytecodeReader(Location fileLoc, const ParserConfig &config)
: config(config), fileLoc(fileLoc), attrTypeReader(fileLoc),
: config(config), fileLoc(fileLoc), attrTypeReader(stringReader, fileLoc),
// Use the builtin unrealized conversion cast operation to represent
// forward references to values that aren't yet defined.
forwardRefOpState(UnknownLoc::get(config.getContext()),