[mlir:Bytecode] Add initial support for dialect defined attribute/type encodings
Dialects can opt-in to providing custom encodings by implementing the `BytecodeDialectInterface`. This interface provides hooks, namely `readAttribute`/`readType` and `writeAttribute`/`writeType`, that will be used by the bytecode reader and writer. These hooks are provided a reader and writer implementation that can be used to encode various constructs in the underlying bytecode format. A unique feature of this interface is that dialects may choose to only encode a subset of their attributes and types in a custom bytecode format, which can simplify adding new or experimental components that aren't fully baked. Differential Revision: https://reviews.llvm.org/D132498
This commit is contained in:
@@ -12,6 +12,7 @@
|
||||
#include "mlir/Bytecode/BytecodeReader.h"
|
||||
#include "../Encoding.h"
|
||||
#include "mlir/AsmParser/AsmParser.h"
|
||||
#include "mlir/Bytecode/BytecodeImplementation.h"
|
||||
#include "mlir/IR/BuiltinDialect.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
@@ -66,7 +67,7 @@ public:
|
||||
|
||||
/// Emit an error using the given arguments.
|
||||
template <typename... Args>
|
||||
LogicalResult emitError(Args &&...args) const {
|
||||
InFlightDiagnostic emitError(Args &&...args) const {
|
||||
return ::emitError(fileLoc).append(std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
@@ -326,6 +327,11 @@ struct BytecodeDialect {
|
||||
"-allow-unregistered-dialect with the MLIR tool used.");
|
||||
}
|
||||
dialect = loadedDialect;
|
||||
|
||||
// If the dialect was actually loaded, check to see if it has a bytecode
|
||||
// interface.
|
||||
if (loadedDialect)
|
||||
interface = dyn_cast<BytecodeDialectInterface>(loadedDialect);
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -333,6 +339,11 @@ struct BytecodeDialect {
|
||||
/// load, nullptr if we failed to load, otherwise the loaded dialect.
|
||||
Optional<Dialect *> dialect;
|
||||
|
||||
/// The bytecode interface of the dialect, or nullptr if the dialect does not
|
||||
/// implement the bytecode interface. This field should only be checked if the
|
||||
/// `dialect` field is non-None.
|
||||
const BytecodeDialectInterface *interface = nullptr;
|
||||
|
||||
/// The name of the dialect.
|
||||
StringRef name;
|
||||
};
|
||||
@@ -397,7 +408,8 @@ class AttrTypeReader {
|
||||
using TypeEntry = Entry<Type>;
|
||||
|
||||
public:
|
||||
AttrTypeReader(Location fileLoc) : fileLoc(fileLoc) {}
|
||||
AttrTypeReader(StringSectionReader &stringReader, Location fileLoc)
|
||||
: stringReader(stringReader), fileLoc(fileLoc) {}
|
||||
|
||||
/// Initialize the attribute and type information within the reader.
|
||||
LogicalResult initialize(MutableArrayRef<BytecodeDialect> dialects,
|
||||
@@ -456,6 +468,10 @@ private:
|
||||
LogicalResult parseCustomEntry(Entry<T> &entry, EncodingReader &reader,
|
||||
StringRef entryType);
|
||||
|
||||
/// The string section reader used to resolve string references when parsing
|
||||
/// custom encoded attribute/type entries.
|
||||
StringSectionReader &stringReader;
|
||||
|
||||
/// The set of attribute and type entries.
|
||||
SmallVector<AttrEntry> attributes;
|
||||
SmallVector<TypeEntry> types;
|
||||
@@ -463,6 +479,47 @@ private:
|
||||
/// A location used for error emission.
|
||||
Location fileLoc;
|
||||
};
|
||||
|
||||
class DialectReader : public DialectBytecodeReader {
|
||||
public:
|
||||
DialectReader(AttrTypeReader &attrTypeReader,
|
||||
StringSectionReader &stringReader, EncodingReader &reader)
|
||||
: attrTypeReader(attrTypeReader), stringReader(stringReader),
|
||||
reader(reader) {}
|
||||
|
||||
InFlightDiagnostic emitError(const Twine &msg) override {
|
||||
return reader.emitError(msg);
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// IR
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult readAttribute(Attribute &result) override {
|
||||
return attrTypeReader.parseAttribute(reader, result);
|
||||
}
|
||||
|
||||
LogicalResult readType(Type &result) override {
|
||||
return attrTypeReader.parseType(reader, result);
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Primitives
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult readVarInt(uint64_t &result) override {
|
||||
return reader.parseVarInt(result);
|
||||
}
|
||||
|
||||
LogicalResult readString(StringRef &result) override {
|
||||
return stringReader.parseString(reader, result);
|
||||
}
|
||||
|
||||
private:
|
||||
AttrTypeReader &attrTypeReader;
|
||||
StringSectionReader &stringReader;
|
||||
EncodingReader &reader;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
LogicalResult
|
||||
@@ -486,7 +543,7 @@ AttrTypeReader::initialize(MutableArrayRef<BytecodeDialect> dialects,
|
||||
size_t currentIndex = 0, endIndex = range.size();
|
||||
|
||||
// Parse an individual entry.
|
||||
auto parseEntryFn = [&](BytecodeDialect *dialect) {
|
||||
auto parseEntryFn = [&](BytecodeDialect *dialect) -> LogicalResult {
|
||||
auto &entry = range[currentIndex++];
|
||||
|
||||
uint64_t entrySize;
|
||||
@@ -548,8 +605,7 @@ T AttrTypeReader::resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index,
|
||||
}
|
||||
|
||||
if (!reader.empty()) {
|
||||
(void)reader.emitError("unexpected trailing bytes after " + entryType +
|
||||
" entry");
|
||||
reader.emitError("unexpected trailing bytes after " + entryType + " entry");
|
||||
return T();
|
||||
}
|
||||
return entry.entry;
|
||||
@@ -584,8 +640,22 @@ template <typename T>
|
||||
LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry,
|
||||
EncodingReader &reader,
|
||||
StringRef entryType) {
|
||||
// FIXME: Add support for reading custom attribute/type encodings.
|
||||
return reader.emitError("unexpected Attribute encoding");
|
||||
if (failed(entry.dialect->load(reader, fileLoc.getContext())))
|
||||
return failure();
|
||||
|
||||
// Ensure that the dialect implements the bytecode interface.
|
||||
if (!entry.dialect->interface) {
|
||||
return reader.emitError("dialect '", entry.dialect->name,
|
||||
"' does not implement the bytecode interface");
|
||||
}
|
||||
|
||||
// Ask the dialect to parse the entry.
|
||||
DialectReader dialectReader(*this, stringReader, reader);
|
||||
if constexpr (std::is_same_v<T, Type>)
|
||||
entry.entry = entry.dialect->interface->readType(dialectReader);
|
||||
else
|
||||
entry.entry = entry.dialect->interface->readAttribute(dialectReader);
|
||||
return success(!!entry.entry);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -597,7 +667,7 @@ namespace {
|
||||
class BytecodeReader {
|
||||
public:
|
||||
BytecodeReader(Location fileLoc, const ParserConfig &config)
|
||||
: config(config), fileLoc(fileLoc), attrTypeReader(fileLoc),
|
||||
: config(config), fileLoc(fileLoc), attrTypeReader(stringReader, fileLoc),
|
||||
// Use the builtin unrealized conversion cast operation to represent
|
||||
// forward references to values that aren't yet defined.
|
||||
forwardRefOpState(UnknownLoc::get(config.getContext()),
|
||||
|
||||
Reference in New Issue
Block a user