[MLIR] Add native Bytecode support for properties

This is adding a new interface (`BytecodeOpInterface`) to allow operations to
opt-in skipping conversion to attribute and serializing properties to native
bytecode.

The scheme relies on a new section where properties are stored in sequence

  { size, serialize_properties }, ...

The operations are storing the index of a properties, a table of offset is
built when loading the properties section the first time.

This is a re-commit of 837d1ce0dc which conflicted with another patch upgrading
the bytecode and the collision wasn't properly resolved before.

Differential Revision: https://reviews.llvm.org/D151065
This commit is contained in:
Mehdi Amini
2023-05-25 21:04:35 -07:00
parent f354e971b0
commit 660f714e26
62 changed files with 792 additions and 48 deletions

View File

@@ -11,6 +11,7 @@
#include "mlir/Bytecode/BytecodeReader.h"
#include "mlir/AsmParser/AsmParser.h"
#include "mlir/Bytecode/BytecodeImplementation.h"
#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Bytecode/Encoding.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinOps.h"
@@ -20,6 +21,7 @@
#include "mlir/IR/Visitors.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SmallString.h"
@@ -28,6 +30,7 @@
#include "llvm/Support/MemoryBufferRef.h"
#include "llvm/Support/SaveAndRestore.h"
#include "llvm/Support/SourceMgr.h"
#include <cstddef>
#include <list>
#include <memory>
#include <numeric>
@@ -56,13 +59,15 @@ static std::string toString(bytecode::Section::ID sectionID) {
return "ResourceOffset (6)";
case bytecode::Section::kDialectVersions:
return "DialectVersions (7)";
case bytecode::Section::kProperties:
return "Properties (8)";
default:
return ("Unknown (" + Twine(static_cast<unsigned>(sectionID)) + ")").str();
}
}
/// Returns true if the given top-level section ID is optional.
static bool isSectionOptional(bytecode::Section::ID sectionID) {
static bool isSectionOptional(bytecode::Section::ID sectionID, int version) {
switch (sectionID) {
case bytecode::Section::kString:
case bytecode::Section::kDialect:
@@ -74,6 +79,8 @@ static bool isSectionOptional(bytecode::Section::ID sectionID) {
case bytecode::Section::kResourceOffset:
case bytecode::Section::kDialectVersions:
return true;
case bytecode::Section::kProperties:
return version < 5;
default:
llvm_unreachable("unknown section ID");
}
@@ -362,6 +369,17 @@ public:
return parseEntry(reader, strings, result, "string");
}
/// Parse a shared string from the string section. The shared string is
/// encoded using an index to a corresponding string in the string section.
/// This variant parses a flag compressed with the index.
LogicalResult parseStringWithFlag(EncodingReader &reader, StringRef &result,
bool &flag) {
uint64_t entryIdx;
if (failed(reader.parseVarIntWithFlag(entryIdx, flag)))
return failure();
return parseStringAtIndex(reader, entryIdx, result);
}
/// Parse a shared string from the string section. The shared string is
/// encoded using an index to a corresponding string in the string section.
LogicalResult parseStringAtIndex(EncodingReader &reader, uint64_t index,
@@ -459,8 +477,9 @@ struct BytecodeDialect {
/// This struct represents an operation name entry within the bytecode.
struct BytecodeOperationName {
BytecodeOperationName(BytecodeDialect *dialect, StringRef name)
: dialect(dialect), name(name) {}
BytecodeOperationName(BytecodeDialect *dialect, StringRef name,
std::optional<bool> wasRegistered)
: dialect(dialect), name(name), wasRegistered(wasRegistered) {}
/// The loaded operation name, or std::nullopt if it hasn't been processed
/// yet.
@@ -471,6 +490,10 @@ struct BytecodeOperationName {
/// The name of the operation, without the dialect prefix.
StringRef name;
/// Whether this operation was registered when the bytecode was produced.
/// This flag is populated when bytecode version >=5.
std::optional<bool> wasRegistered;
};
} // namespace
@@ -791,6 +814,18 @@ public:
result = resolveAttribute(attrIdx);
return success(!!result);
}
LogicalResult parseOptionalAttribute(EncodingReader &reader,
Attribute &result) {
uint64_t attrIdx;
bool flag;
if (failed(reader.parseVarIntWithFlag(attrIdx, flag)))
return failure();
if (!flag)
return success();
result = resolveAttribute(attrIdx);
return success(!!result);
}
LogicalResult parseType(EncodingReader &reader, Type &result) {
uint64_t typeIdx;
if (failed(reader.parseVarInt(typeIdx)))
@@ -870,7 +905,9 @@ public:
LogicalResult readAttribute(Attribute &result) override {
return attrTypeReader.parseAttribute(reader, result);
}
LogicalResult readOptionalAttribute(Attribute &result) override {
return attrTypeReader.parseOptionalAttribute(reader, result);
}
LogicalResult readType(Type &result) override {
return attrTypeReader.parseType(reader, result);
}
@@ -957,6 +994,87 @@ private:
ResourceSectionReader &resourceReader;
EncodingReader &reader;
};
/// Wraps the properties section and handles reading properties out of it.
class PropertiesSectionReader {
public:
/// Initialize the properties section reader with the given section data.
LogicalResult initialize(Location fileLoc, ArrayRef<uint8_t> sectionData) {
if (sectionData.empty())
return success();
EncodingReader propReader(sectionData, fileLoc);
size_t count;
if (failed(propReader.parseVarInt(count)))
return failure();
// Parse the raw properties buffer.
if (failed(propReader.parseBytes(propReader.size(), propertiesBuffers)))
return failure();
EncodingReader offsetsReader(propertiesBuffers, fileLoc);
offsetTable.reserve(count);
for (auto idx : llvm::seq<int64_t>(0, count)) {
(void)idx;
offsetTable.push_back(propertiesBuffers.size() - offsetsReader.size());
ArrayRef<uint8_t> rawProperties;
size_t dataSize;
if (failed(offsetsReader.parseVarInt(dataSize)) ||
failed(offsetsReader.parseBytes(dataSize, rawProperties)))
return failure();
}
if (!offsetsReader.empty())
return offsetsReader.emitError()
<< "Broken properties section: didn't exhaust the offsets table";
return success();
}
LogicalResult read(Location fileLoc, DialectReader &dialectReader,
OperationName *opName, OperationState &opState) {
uint64_t propertiesIdx;
if (failed(dialectReader.readVarInt(propertiesIdx)))
return failure();
if (propertiesIdx >= offsetTable.size())
return dialectReader.emitError("Properties idx out-of-bound for ")
<< opName->getStringRef();
size_t propertiesOffset = offsetTable[propertiesIdx];
if (propertiesIdx >= propertiesBuffers.size())
return dialectReader.emitError("Properties offset out-of-bound for ")
<< opName->getStringRef();
// Acquire the sub-buffer that represent the requested properties.
ArrayRef<char> rawProperties;
{
// "Seek" to the requested offset by getting a new reader with the right
// sub-buffer.
EncodingReader reader(propertiesBuffers.drop_front(propertiesOffset),
fileLoc);
// Properties are stored as a sequence of {size + raw_data}.
if (failed(
dialectReader.withEncodingReader(reader).readBlob(rawProperties)))
return failure();
}
// Setup a new reader to read from the `rawProperties` sub-buffer.
EncodingReader reader(
StringRef(rawProperties.begin(), rawProperties.size()), fileLoc);
DialectReader propReader = dialectReader.withEncodingReader(reader);
auto *iface = opName->getInterface<BytecodeOpInterface>();
if (iface)
return iface->readProperties(propReader, opState);
if (opName->isRegistered())
return propReader.emitError(
"has properties but missing BytecodeOpInterface for ")
<< opName->getStringRef();
// Unregistered op are storing properties as an attribute.
return propReader.readAttribute(opState.propertiesAttr);
}
private:
/// The properties buffer referenced within the bytecode file.
ArrayRef<uint8_t> propertiesBuffers;
/// Table of offset in the buffer above.
SmallVector<int64_t> offsetTable;
};
} // namespace
LogicalResult
@@ -1194,7 +1312,9 @@ private:
lazyLoadableOps.erase(it->getSecond());
lazyLoadableOpsMap.erase(it);
auto result = parseRegions(regionStack, regionStack.back());
assert(regionStack.empty());
assert((regionStack.empty() || failed(result)) &&
"broken invariant: regionStack should be empty when parseRegions "
"succeeds");
return result;
}
@@ -1209,8 +1329,11 @@ private:
LogicalResult parseDialectSection(ArrayRef<uint8_t> sectionData);
/// Parse an operation name reference using the given reader.
FailureOr<OperationName> parseOpName(EncodingReader &reader);
/// Parse an operation name reference using the given reader, and set the
/// `wasRegistered` flag that indicates if the bytecode was produced by a
/// context where opName was registered.
FailureOr<OperationName> parseOpName(EncodingReader &reader,
std::optional<bool> &wasRegistered);
//===--------------------------------------------------------------------===//
// Attribute/Type Section
@@ -1398,6 +1521,9 @@ private:
/// The table of strings referenced within the bytecode file.
StringSectionReader stringReader;
/// The table of properties referenced by the operation in the bytecode file.
PropertiesSectionReader propertiesReader;
/// The current set of available IR value scopes.
std::vector<ValueScope> valueScopes;
@@ -1466,7 +1592,7 @@ LogicalResult BytecodeReader::Impl::read(
// Check that all of the required sections were found.
for (int i = 0; i < bytecode::Section::kNumSections; ++i) {
bytecode::Section::ID sectionID = static_cast<bytecode::Section::ID>(i);
if (!sectionDatas[i] && !isSectionOptional(sectionID)) {
if (!sectionDatas[i] && !isSectionOptional(sectionID, version)) {
return reader.emitError("missing data for top-level section: ",
::toString(sectionID));
}
@@ -1477,6 +1603,12 @@ LogicalResult BytecodeReader::Impl::read(
fileLoc, *sectionDatas[bytecode::Section::kString])))
return failure();
// Process the properties section.
if (sectionDatas[bytecode::Section::kProperties] &&
failed(propertiesReader.initialize(
fileLoc, *sectionDatas[bytecode::Section::kProperties])))
return failure();
// Process the dialect section.
if (failed(parseDialectSection(*sectionDatas[bytecode::Section::kDialect])))
return failure();
@@ -1598,9 +1730,20 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) {
// Parse the operation names, which are grouped by dialect.
auto parseOpName = [&](BytecodeDialect *dialect) {
StringRef opName;
if (failed(stringReader.parseString(sectionReader, opName)))
return failure();
opNames.emplace_back(dialect, opName);
std::optional<bool> wasRegistered;
// Prior to version 5, the information about wheter an op was registered or
// not wasn't encoded.
if (version < 5) {
if (failed(stringReader.parseString(sectionReader, opName)))
return failure();
} else {
bool wasRegisteredFlag;
if (failed(stringReader.parseStringWithFlag(sectionReader, opName,
wasRegisteredFlag)))
return failure();
wasRegistered = wasRegisteredFlag;
}
opNames.emplace_back(dialect, opName, wasRegistered);
return success();
};
// Avoid re-allocation in bytecode version > 3 where the number of ops are
@@ -1618,11 +1761,12 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) {
}
FailureOr<OperationName>
BytecodeReader::Impl::parseOpName(EncodingReader &reader) {
BytecodeReader::Impl::parseOpName(EncodingReader &reader,
std::optional<bool> &wasRegistered) {
BytecodeOperationName *opName = nullptr;
if (failed(parseEntry(reader, opNames, opName, "operation name")))
return failure();
wasRegistered = opName->wasRegistered;
// Check to see if this operation name has already been resolved. If we
// haven't, load the dialect and build the operation name.
if (!opName->opName) {
@@ -1994,7 +2138,8 @@ BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader,
RegionReadState &readState,
bool &isIsolatedFromAbove) {
// Parse the name of the operation.
FailureOr<OperationName> opName = parseOpName(reader);
std::optional<bool> wasRegistered;
FailureOr<OperationName> opName = parseOpName(reader, wasRegistered);
if (failed(opName))
return failure();
@@ -2021,6 +2166,31 @@ BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader,
opState.attributes = dictAttr;
}
if (opMask & bytecode::OpEncodingMask::kHasProperties) {
// kHasProperties wasn't emitted in older bytecode, we should never get
// there without also having the `wasRegistered` flag available.
if (!wasRegistered)
return emitError(fileLoc,
"Unexpected missing `wasRegistered` opname flag at "
"bytecode version ")
<< version << " with properties.";
// When an operation is emitted without being registered, the properties are
// stored as an attribute. Otherwise the op must implement the bytecode
// interface and control the serialization.
if (wasRegistered) {
DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
reader);
if (failed(
propertiesReader.read(fileLoc, dialectReader, &*opName, opState)))
return failure();
} else {
// If the operation wasn't registered when it was emitted, the properties
// was serialized as an attribute.
if (failed(parseAttribute(reader, opState.propertiesAttr)))
return failure();
}
}
/// Parse the results of the operation.
if (opMask & bytecode::OpEncodingMask::kHasResults) {
uint64_t numResults;