Add support for Lazyloading to the MLIR bytecode

IsolatedRegions are emitted in sections in order for the reader to be
able to skip over them. A new class is exposed to manage the state and
allow the readers to load these IsolatedRegions on-demand.

Differential Revision: https://reviews.llvm.org/D149515
This commit is contained in:
Mehdi Amini
2023-04-29 02:36:45 -07:00
parent e8cc0d310c
commit 3128b3105d
10 changed files with 465 additions and 59 deletions

View File

@@ -17,6 +17,9 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Verifier.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SmallString.h"
@@ -24,6 +27,8 @@
#include "llvm/Support/MemoryBufferRef.h"
#include "llvm/Support/SaveAndRestore.h"
#include "llvm/Support/SourceMgr.h"
#include <list>
#include <memory>
#include <optional>
#define DEBUG_TYPE "mlir-bytecode-reader"
@@ -1092,25 +1097,93 @@ LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry,
// Bytecode Reader
//===----------------------------------------------------------------------===//
namespace {
/// This class is used to read a bytecode buffer and translate it into MLIR.
class BytecodeReader {
class mlir::BytecodeReader::Impl {
struct RegionReadState;
using LazyLoadableOpsInfo =
std::list<std::pair<Operation *, RegionReadState>>;
using LazyLoadableOpsMap =
DenseMap<Operation *, LazyLoadableOpsInfo::iterator>;
public:
BytecodeReader(Location fileLoc, const ParserConfig &config,
const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef)
: config(config), fileLoc(fileLoc),
Impl(Location fileLoc, const ParserConfig &config, bool lazyLoading,
llvm::MemoryBufferRef buffer,
const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef)
: config(config), fileLoc(fileLoc), lazyLoading(lazyLoading),
attrTypeReader(stringReader, resourceReader, fileLoc),
// Use the builtin unrealized conversion cast operation to represent
// forward references to values that aren't yet defined.
forwardRefOpState(UnknownLoc::get(config.getContext()),
"builtin.unrealized_conversion_cast", ValueRange(),
NoneType::get(config.getContext())),
bufferOwnerRef(bufferOwnerRef) {}
buffer(buffer), bufferOwnerRef(bufferOwnerRef) {}
/// Read the bytecode defined within `buffer` into the given block.
LogicalResult read(llvm::MemoryBufferRef buffer, Block *block);
LogicalResult read(Block *block,
llvm::function_ref<bool(Operation *)> lazyOps);
/// Return the number of ops that haven't been materialized yet.
int64_t getNumOpsToMaterialize() const { return lazyLoadableOpsMap.size(); }
bool isMaterializable(Operation *op) { return lazyLoadableOpsMap.count(op); }
/// Materialize the provided operation, invoke the lazyOpsCallback on every
/// newly found lazy operation.
LogicalResult
materialize(Operation *op,
llvm::function_ref<bool(Operation *)> lazyOpsCallback) {
this->lazyOpsCallback = lazyOpsCallback;
auto resetlazyOpsCallback =
llvm::make_scope_exit([&] { this->lazyOpsCallback = nullptr; });
auto it = lazyLoadableOpsMap.find(op);
assert(it != lazyLoadableOpsMap.end() &&
"materialize called on non-materializable op");
return materialize(it);
}
/// Materialize all operations.
LogicalResult materializeAll() {
while (!lazyLoadableOpsMap.empty()) {
if (failed(materialize(lazyLoadableOpsMap.begin())))
return failure();
}
return success();
}
/// Finalize the lazy-loading by calling back with every op that hasn't been
/// materialized to let the client decide if the op should be deleted or
/// materialized. The op is materialized if the callback returns true, deleted
/// otherwise.
LogicalResult finalize(function_ref<bool(Operation *)> shouldMaterialize) {
while (!lazyLoadableOps.empty()) {
Operation *op = lazyLoadableOps.begin()->first;
if (shouldMaterialize(op)) {
if (failed(materialize(lazyLoadableOpsMap.find(op))))
return failure();
continue;
}
op->dropAllReferences();
op->erase();
lazyLoadableOps.pop_front();
lazyLoadableOpsMap.erase(op);
}
return success();
}
private:
LogicalResult materialize(LazyLoadableOpsMap::iterator it) {
assert(it != lazyLoadableOpsMap.end() &&
"materialize called on non-materializable op");
valueScopes.emplace_back();
std::vector<RegionReadState> regionStack;
regionStack.push_back(std::move(it->getSecond()->second));
lazyLoadableOps.erase(it->getSecond());
lazyLoadableOpsMap.erase(it);
auto result = parseRegions(regionStack, regionStack.back());
assert(regionStack.empty());
return result;
}
/// Return the context for this config.
MLIRContext *getContext() const { return config.getContext(); }
@@ -1151,14 +1224,22 @@ private:
/// This struct represents the current read state of a range of regions. This
/// struct is used to enable iterative parsing of regions.
struct RegionReadState {
RegionReadState(Operation *op, bool isIsolatedFromAbove)
: RegionReadState(op->getRegions(), isIsolatedFromAbove) {}
RegionReadState(MutableArrayRef<Region> regions, bool isIsolatedFromAbove)
: curRegion(regions.begin()), endRegion(regions.end()),
RegionReadState(Operation *op, EncodingReader *reader,
bool isIsolatedFromAbove)
: RegionReadState(op->getRegions(), reader, isIsolatedFromAbove) {}
RegionReadState(MutableArrayRef<Region> regions, EncodingReader *reader,
bool isIsolatedFromAbove)
: curRegion(regions.begin()), endRegion(regions.end()), reader(reader),
isIsolatedFromAbove(isIsolatedFromAbove) {}
/// The current regions being read.
MutableArrayRef<Region>::iterator curRegion, endRegion;
/// This is the reader to use for this region, this pointer is pointing to
/// the parent region reader unless the current region is IsolatedFromAbove,
/// in which case the pointer is pointing to the `owningReader` which is a
/// section dedicated to the current region.
EncodingReader *reader;
std::unique_ptr<EncodingReader> owningReader;
/// The number of values defined immediately within this region.
unsigned numValues = 0;
@@ -1176,15 +1257,15 @@ private:
};
LogicalResult parseIRSection(ArrayRef<uint8_t> sectionData, Block *block);
LogicalResult parseRegions(EncodingReader &reader,
std::vector<RegionReadState> &regionStack,
LogicalResult parseRegions(std::vector<RegionReadState> &regionStack,
RegionReadState &readState);
FailureOr<Operation *> parseOpWithoutRegions(EncodingReader &reader,
RegionReadState &readState,
bool &isIsolatedFromAbove);
LogicalResult parseRegion(EncodingReader &reader, RegionReadState &readState);
LogicalResult parseBlock(EncodingReader &reader, RegionReadState &readState);
LogicalResult parseRegion(RegionReadState &readState);
LogicalResult parseBlockHeader(EncodingReader &reader,
RegionReadState &readState);
LogicalResult parseBlockArguments(EncodingReader &reader, Block *block);
//===--------------------------------------------------------------------===//
@@ -1234,6 +1315,16 @@ private:
/// A location to use when emitting errors.
Location fileLoc;
/// Flag that indicates if lazyloading is enabled.
bool lazyLoading;
/// Keep track of operations that have been lazy loaded (their regions haven't
/// been materialized), along with the `RegionReadState` that allows to
/// lazy-load the regions nested under the operation.
LazyLoadableOpsInfo lazyLoadableOps;
LazyLoadableOpsMap lazyLoadableOpsMap;
llvm::function_ref<bool(Operation *)> lazyOpsCallback;
/// The reader used to process attribute and types within the bytecode.
AttrTypeReader attrTypeReader;
@@ -1264,14 +1355,20 @@ private:
/// An operation state used when instantiating forward references.
OperationState forwardRefOpState;
/// Reference to the input buffer.
llvm::MemoryBufferRef buffer;
/// The optional owning source manager, which when present may be used to
/// extend the lifetime of the input buffer.
const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef;
};
} // namespace
LogicalResult BytecodeReader::read(llvm::MemoryBufferRef buffer, Block *block) {
LogicalResult BytecodeReader::Impl::read(
Block *block, llvm::function_ref<bool(Operation *)> lazyOpsCallback) {
EncodingReader reader(buffer.getBuffer(), fileLoc);
this->lazyOpsCallback = lazyOpsCallback;
auto resetlazyOpsCallback =
llvm::make_scope_exit([&] { this->lazyOpsCallback = nullptr; });
// Skip over the bytecode header, this should have already been checked.
if (failed(reader.skipBytes(StringRef("ML\xefR").size())))
@@ -1302,7 +1399,7 @@ LogicalResult BytecodeReader::read(llvm::MemoryBufferRef buffer, Block *block) {
// Check for duplicate sections, we only expect one instance of each.
if (sectionDatas[sectionID]) {
return reader.emitError("duplicate top-level section: ",
toString(sectionID));
::toString(sectionID));
}
sectionDatas[sectionID] = sectionData;
}
@@ -1311,7 +1408,7 @@ LogicalResult BytecodeReader::read(llvm::MemoryBufferRef buffer, Block *block) {
bytecode::Section::ID sectionID = static_cast<bytecode::Section::ID>(i);
if (!sectionDatas[i] && !isSectionOptional(sectionID)) {
return reader.emitError("missing data for top-level section: ",
toString(sectionID));
::toString(sectionID));
}
}
@@ -1340,7 +1437,7 @@ LogicalResult BytecodeReader::read(llvm::MemoryBufferRef buffer, Block *block) {
return parseIRSection(*sectionDatas[bytecode::Section::kIR], block);
}
LogicalResult BytecodeReader::parseVersion(EncodingReader &reader) {
LogicalResult BytecodeReader::Impl::parseVersion(EncodingReader &reader) {
if (failed(reader.parseVarInt(version)))
return failure();
@@ -1357,6 +1454,9 @@ LogicalResult BytecodeReader::parseVersion(EncodingReader &reader) {
" is newer than the current version ",
currentVersion);
}
// Override any request to lazy-load if the bytecode version is too old.
if (version < 2)
lazyLoading = false;
return success();
}
@@ -1396,7 +1496,7 @@ LogicalResult BytecodeDialect::load(DialectReader &reader, MLIRContext *ctx) {
}
LogicalResult
BytecodeReader::parseDialectSection(ArrayRef<uint8_t> sectionData) {
BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) {
EncodingReader sectionReader(sectionData, fileLoc);
// Parse the number of dialects in the section.
@@ -1449,7 +1549,8 @@ BytecodeReader::parseDialectSection(ArrayRef<uint8_t> sectionData) {
return success();
}
FailureOr<OperationName> BytecodeReader::parseOpName(EncodingReader &reader) {
FailureOr<OperationName>
BytecodeReader::Impl::parseOpName(EncodingReader &reader) {
BytecodeOperationName *opName = nullptr;
if (failed(parseEntry(reader, opNames, opName, "operation name")))
return failure();
@@ -1471,7 +1572,7 @@ FailureOr<OperationName> BytecodeReader::parseOpName(EncodingReader &reader) {
//===----------------------------------------------------------------------===//
// Resource Section
LogicalResult BytecodeReader::parseResourceSection(
LogicalResult BytecodeReader::Impl::parseResourceSection(
EncodingReader &reader, std::optional<ArrayRef<uint8_t>> resourceData,
std::optional<ArrayRef<uint8_t>> resourceOffsetData) {
// Ensure both sections are either present or not.
@@ -1499,8 +1600,9 @@ LogicalResult BytecodeReader::parseResourceSection(
//===----------------------------------------------------------------------===//
// IR Section
LogicalResult BytecodeReader::parseIRSection(ArrayRef<uint8_t> sectionData,
Block *block) {
LogicalResult
BytecodeReader::Impl::parseIRSection(ArrayRef<uint8_t> sectionData,
Block *block) {
EncodingReader reader(sectionData, fileLoc);
// A stack of operation regions currently being read from the bytecode.
@@ -1508,17 +1610,17 @@ LogicalResult BytecodeReader::parseIRSection(ArrayRef<uint8_t> sectionData,
// Parse the top-level block using a temporary module operation.
OwningOpRef<ModuleOp> moduleOp = ModuleOp::create(fileLoc);
regionStack.emplace_back(*moduleOp, /*isIsolatedFromAbove=*/true);
regionStack.emplace_back(*moduleOp, &reader, /*isIsolatedFromAbove=*/true);
regionStack.back().curBlocks.push_back(moduleOp->getBody());
regionStack.back().curBlock = regionStack.back().curRegion->begin();
if (failed(parseBlock(reader, regionStack.back())))
if (failed(parseBlockHeader(reader, regionStack.back())))
return failure();
valueScopes.emplace_back();
valueScopes.back().push(regionStack.back());
// Iteratively parse regions until everything has been resolved.
while (!regionStack.empty())
if (failed(parseRegions(reader, regionStack, regionStack.back())))
if (failed(parseRegions(regionStack, regionStack.back())))
return failure();
if (!forwardRefOps.empty()) {
return reader.emitError(
@@ -1549,15 +1651,18 @@ LogicalResult BytecodeReader::parseIRSection(ArrayRef<uint8_t> sectionData,
}
LogicalResult
BytecodeReader::parseRegions(EncodingReader &reader,
std::vector<RegionReadState> &regionStack,
RegionReadState &readState) {
// Read the regions of this operation.
BytecodeReader::Impl::parseRegions(std::vector<RegionReadState> &regionStack,
RegionReadState &readState) {
// Process regions, blocks, and operations until the end or if a nested
// region is encountered. In this case we push a new state in regionStack and
// return, the processing of the current region will resume afterward.
for (; readState.curRegion != readState.endRegion; ++readState.curRegion) {
// If the current block hasn't been setup yet, parse the header for this
// region.
// region. The current block is already setup when this function was
// interrupted to recurse down in a nested region and we resume the current
// block after processing the nested region.
if (readState.curBlock == Region::iterator()) {
if (failed(parseRegion(reader, readState)))
if (failed(parseRegion(readState)))
return failure();
// If the region is empty, there is nothing to more to do.
@@ -1566,6 +1671,7 @@ BytecodeReader::parseRegions(EncodingReader &reader,
}
// Parse the blocks within the region.
EncodingReader &reader = *readState.reader;
do {
while (readState.numOpsRemaining--) {
// Read in the next operation. We don't read its regions directly, we
@@ -1576,9 +1682,38 @@ BytecodeReader::parseRegions(EncodingReader &reader,
if (failed(op))
return failure();
// If the op has regions, add it to the stack for processing.
// If the op has regions, add it to the stack for processing and return:
// we stop the processing of the current region and resume it after the
// inner one is completed. Unless LazyLoading is activated in which case
// nested region parsing is delayed.
if ((*op)->getNumRegions()) {
regionStack.emplace_back(*op, isIsolatedFromAbove);
RegionReadState childState(*op, &reader, isIsolatedFromAbove);
// Isolated regions are encoded as a section in version 2 and above.
if (version >= 2 && isIsolatedFromAbove) {
bytecode::Section::ID sectionID;
ArrayRef<uint8_t> sectionData;
if (failed(reader.parseSection(sectionID, sectionData)))
return failure();
if (sectionID != bytecode::Section::kIR)
return emitError(fileLoc, "expected IR section for region");
childState.owningReader =
std::make_unique<EncodingReader>(sectionData, fileLoc);
childState.reader = childState.owningReader.get();
}
if (lazyLoading) {
// If the user has a callback set, they have the opportunity
// to control lazyloading as we go.
if (!lazyOpsCallback || !lazyOpsCallback(*op)) {
lazyLoadableOps.push_back(
std::make_pair(*op, std::move(childState)));
lazyLoadableOpsMap.try_emplace(*op,
std::prev(lazyLoadableOps.end()));
continue;
}
}
regionStack.push_back(std::move(childState));
// If the op is isolated from above, push a new value scope.
if (isIsolatedFromAbove)
@@ -1590,7 +1725,7 @@ BytecodeReader::parseRegions(EncodingReader &reader,
// Move to the next block of the region.
if (++readState.curBlock == readState.curRegion->end())
break;
if (failed(parseBlock(reader, readState)))
if (failed(parseBlockHeader(reader, readState)))
return failure();
} while (true);
@@ -1601,16 +1736,19 @@ BytecodeReader::parseRegions(EncodingReader &reader,
// When the regions have been fully parsed, pop them off of the read stack. If
// the regions were isolated from above, we also pop the last value scope.
if (readState.isIsolatedFromAbove)
if (readState.isIsolatedFromAbove) {
assert(!valueScopes.empty() && "Expect a valueScope after reading region");
valueScopes.pop_back();
}
assert(!regionStack.empty() && "Expect a regionStack after reading region");
regionStack.pop_back();
return success();
}
FailureOr<Operation *>
BytecodeReader::parseOpWithoutRegions(EncodingReader &reader,
RegionReadState &readState,
bool &isIsolatedFromAbove) {
BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader,
RegionReadState &readState,
bool &isIsolatedFromAbove) {
// Parse the name of the operation.
FailureOr<OperationName> opName = parseOpName(reader);
if (failed(opName))
@@ -1696,8 +1834,9 @@ BytecodeReader::parseOpWithoutRegions(EncodingReader &reader,
return op;
}
LogicalResult BytecodeReader::parseRegion(EncodingReader &reader,
RegionReadState &readState) {
LogicalResult BytecodeReader::Impl::parseRegion(RegionReadState &readState) {
EncodingReader &reader = *readState.reader;
// Parse the number of blocks in the region.
uint64_t numBlocks;
if (failed(reader.parseVarInt(numBlocks)))
@@ -1727,11 +1866,12 @@ LogicalResult BytecodeReader::parseRegion(EncodingReader &reader,
// Parse the entry block of the region.
readState.curBlock = readState.curRegion->begin();
return parseBlock(reader, readState);
return parseBlockHeader(reader, readState);
}
LogicalResult BytecodeReader::parseBlock(EncodingReader &reader,
RegionReadState &readState) {
LogicalResult
BytecodeReader::Impl::parseBlockHeader(EncodingReader &reader,
RegionReadState &readState) {
bool hasArgs;
if (failed(reader.parseVarIntWithFlag(readState.numOpsRemaining, hasArgs)))
return failure();
@@ -1744,8 +1884,8 @@ LogicalResult BytecodeReader::parseBlock(EncodingReader &reader,
return success();
}
LogicalResult BytecodeReader::parseBlockArguments(EncodingReader &reader,
Block *block) {
LogicalResult BytecodeReader::Impl::parseBlockArguments(EncodingReader &reader,
Block *block) {
// Parse the value ID for the first argument, and the number of arguments.
uint64_t numArgs;
if (failed(reader.parseVarInt(numArgs)))
@@ -1773,7 +1913,7 @@ LogicalResult BytecodeReader::parseBlockArguments(EncodingReader &reader,
//===----------------------------------------------------------------------===//
// Value Processing
Value BytecodeReader::parseOperand(EncodingReader &reader) {
Value BytecodeReader::Impl::parseOperand(EncodingReader &reader) {
std::vector<Value> &values = valueScopes.back().values;
Value *value = nullptr;
if (failed(parseEntry(reader, values, value, "value")))
@@ -1785,8 +1925,8 @@ Value BytecodeReader::parseOperand(EncodingReader &reader) {
return *value;
}
LogicalResult BytecodeReader::defineValues(EncodingReader &reader,
ValueRange newValues) {
LogicalResult BytecodeReader::Impl::defineValues(EncodingReader &reader,
ValueRange newValues) {
ValueScope &valueScope = valueScopes.back();
std::vector<Value> &values = valueScope.values;
@@ -1821,7 +1961,7 @@ LogicalResult BytecodeReader::defineValues(EncodingReader &reader,
return success();
}
Value BytecodeReader::createForwardRef() {
Value BytecodeReader::Impl::createForwardRef() {
// Check for an avaliable existing operation to use. Otherwise, create a new
// fake operation to use for the reference.
if (!openForwardRefOps.empty()) {
@@ -1837,6 +1977,41 @@ Value BytecodeReader::createForwardRef() {
// Entry Points
//===----------------------------------------------------------------------===//
BytecodeReader::~BytecodeReader() { assert(getNumOpsToMaterialize() == 0); }
BytecodeReader::BytecodeReader(
llvm::MemoryBufferRef buffer, const ParserConfig &config, bool lazyLoading,
const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) {
Location sourceFileLoc =
FileLineColLoc::get(config.getContext(), buffer.getBufferIdentifier(),
/*line=*/0, /*column=*/0);
impl = std::make_unique<Impl>(sourceFileLoc, config, lazyLoading, buffer,
bufferOwnerRef);
}
LogicalResult BytecodeReader::readTopLevel(
Block *block, llvm::function_ref<bool(Operation *)> lazyOpsCallback) {
return impl->read(block, lazyOpsCallback);
}
int64_t BytecodeReader::getNumOpsToMaterialize() const {
return impl->getNumOpsToMaterialize();
}
bool BytecodeReader::isMaterializable(Operation *op) {
return impl->isMaterializable(op);
}
LogicalResult BytecodeReader::materialize(
Operation *op, llvm::function_ref<bool(Operation *)> lazyOpsCallback) {
return impl->materialize(op, lazyOpsCallback);
}
LogicalResult
BytecodeReader::finalize(function_ref<bool(Operation *)> shouldMaterialize) {
return impl->finalize(shouldMaterialize);
}
bool mlir::isBytecode(llvm::MemoryBufferRef buffer) {
return buffer.getBuffer().startswith("ML\xefR");
}
@@ -1856,8 +2031,9 @@ readBytecodeFileImpl(llvm::MemoryBufferRef buffer, Block *block,
"input buffer is not an MLIR bytecode file");
}
BytecodeReader reader(sourceFileLoc, config, bufferOwnerRef);
return reader.read(buffer, block);
BytecodeReader::Impl reader(sourceFileLoc, config, /*lazyLoading=*/false,
buffer, bufferOwnerRef);
return reader.read(block, /*lazyOpsCallback=*/nullptr);
}
LogicalResult mlir::readBytecodeFile(llvm::MemoryBufferRef buffer, Block *block,