We currently only support lazy loading for regions that statically implement the IsolatedFromAbove trait, but that limits the amount of operations that can be lazily loaded. This review lifts that restriction by computing which operations have isolated regions when numbering, allowing any operation to be lazily loaded as long as it doesn't use values defined above. Differential Revision: https://reviews.llvm.org/D156199
283 lines
9.8 KiB
C++
283 lines
9.8 KiB
C++
//===- IRNumbering.h - MLIR bytecode IR numbering ---------------*- C++ -*-===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file contains various utilities that number IR structures in preparation
|
|
// for bytecode emission.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#ifndef LIB_MLIR_BYTECODE_WRITER_IRNUMBERING_H
|
|
#define LIB_MLIR_BYTECODE_WRITER_IRNUMBERING_H
|
|
|
|
#include "mlir/IR/OpImplementation.h"
|
|
#include "llvm/ADT/MapVector.h"
|
|
#include "llvm/ADT/SetVector.h"
|
|
#include "llvm/ADT/StringMap.h"
|
|
#include <cstdint>
|
|
|
|
namespace mlir {
|
|
class BytecodeDialectInterface;
|
|
class BytecodeWriterConfig;
|
|
|
|
namespace bytecode {
|
|
namespace detail {
|
|
struct DialectNumbering;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Attribute and Type Numbering
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// This class represents a numbering entry for an Attribute or Type.
|
|
struct AttrTypeNumbering {
|
|
AttrTypeNumbering(PointerUnion<Attribute, Type> value) : value(value) {}
|
|
|
|
/// The concrete value.
|
|
PointerUnion<Attribute, Type> value;
|
|
|
|
/// The number assigned to this value.
|
|
unsigned number = 0;
|
|
|
|
/// The number of references to this value.
|
|
unsigned refCount = 1;
|
|
|
|
/// The dialect of this value.
|
|
DialectNumbering *dialect = nullptr;
|
|
};
|
|
struct AttributeNumbering : public AttrTypeNumbering {
|
|
AttributeNumbering(Attribute value) : AttrTypeNumbering(value) {}
|
|
Attribute getValue() const { return value.get<Attribute>(); }
|
|
};
|
|
struct TypeNumbering : public AttrTypeNumbering {
|
|
TypeNumbering(Type value) : AttrTypeNumbering(value) {}
|
|
Type getValue() const { return value.get<Type>(); }
|
|
};
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// OpName Numbering
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// This class represents the numbering entry of an operation name.
|
|
struct OpNameNumbering {
|
|
OpNameNumbering(DialectNumbering *dialect, OperationName name)
|
|
: dialect(dialect), name(name) {}
|
|
|
|
/// The dialect of this value.
|
|
DialectNumbering *dialect;
|
|
|
|
/// The concrete name.
|
|
OperationName name;
|
|
|
|
/// The number assigned to this name.
|
|
unsigned number = 0;
|
|
|
|
/// The number of references to this name.
|
|
unsigned refCount = 1;
|
|
};
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Dialect Resource Numbering
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// This class represents a numbering entry for a dialect resource.
|
|
struct DialectResourceNumbering {
|
|
DialectResourceNumbering(std::string key) : key(std::move(key)) {}
|
|
|
|
/// The key used to reference this resource.
|
|
std::string key;
|
|
|
|
/// The number assigned to this resource.
|
|
unsigned number = 0;
|
|
|
|
/// A flag indicating if this resource is only a declaration, not a full
|
|
/// definition.
|
|
bool isDeclaration = true;
|
|
};
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Dialect Numbering
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// This class represents a numbering entry for an Dialect.
|
|
struct DialectNumbering {
|
|
DialectNumbering(StringRef name, unsigned number)
|
|
: name(name), number(number) {}
|
|
|
|
/// The namespace of the dialect.
|
|
StringRef name;
|
|
|
|
/// The number assigned to the dialect.
|
|
unsigned number;
|
|
|
|
/// The bytecode dialect interface of the dialect if defined.
|
|
const BytecodeDialectInterface *interface = nullptr;
|
|
|
|
/// The asm dialect interface of the dialect if defined.
|
|
const OpAsmDialectInterface *asmInterface = nullptr;
|
|
|
|
/// The referenced resources of this dialect.
|
|
SetVector<AsmDialectResourceHandle> resources;
|
|
|
|
/// A mapping from resource key to the corresponding resource numbering entry.
|
|
llvm::MapVector<StringRef, DialectResourceNumbering *> resourceMap;
|
|
};
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Operation Numbering
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// This class represents the numbering entry of an operation.
|
|
struct OperationNumbering {
|
|
OperationNumbering(unsigned number) : number(number) {}
|
|
|
|
/// The number assigned to this operation.
|
|
unsigned number;
|
|
|
|
/// A flag indicating if this operation's regions are isolated. If unset, the
|
|
/// operation isn't yet known to be isolated.
|
|
std::optional<bool> isIsolatedFromAbove;
|
|
};
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// IRNumberingState
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// This class manages numbering IR entities in preparation of bytecode
|
|
/// emission.
|
|
class IRNumberingState {
|
|
public:
|
|
IRNumberingState(Operation *op, const BytecodeWriterConfig &config);
|
|
|
|
/// Return the numbered dialects.
|
|
auto getDialects() {
|
|
return llvm::make_pointee_range(llvm::make_second_range(dialects));
|
|
}
|
|
auto getAttributes() { return llvm::make_pointee_range(orderedAttrs); }
|
|
auto getOpNames() { return llvm::make_pointee_range(orderedOpNames); }
|
|
auto getTypes() { return llvm::make_pointee_range(orderedTypes); }
|
|
|
|
/// Return the number for the given IR unit.
|
|
unsigned getNumber(Attribute attr) {
|
|
assert(attrs.count(attr) && "attribute not numbered");
|
|
return attrs[attr]->number;
|
|
}
|
|
unsigned getNumber(Block *block) {
|
|
assert(blockIDs.count(block) && "block not numbered");
|
|
return blockIDs[block];
|
|
}
|
|
unsigned getNumber(Operation *op) {
|
|
assert(operations.count(op) && "operation not numbered");
|
|
return operations[op]->number;
|
|
}
|
|
unsigned getNumber(OperationName opName) {
|
|
assert(opNames.count(opName) && "opName not numbered");
|
|
return opNames[opName]->number;
|
|
}
|
|
unsigned getNumber(Type type) {
|
|
assert(types.count(type) && "type not numbered");
|
|
return types[type]->number;
|
|
}
|
|
unsigned getNumber(Value value) {
|
|
assert(valueIDs.count(value) && "value not numbered");
|
|
return valueIDs[value];
|
|
}
|
|
unsigned getNumber(const AsmDialectResourceHandle &resource) {
|
|
assert(dialectResources.count(resource) && "resource not numbered");
|
|
return dialectResources[resource]->number;
|
|
}
|
|
|
|
/// Return the block and value counts of the given region.
|
|
std::pair<unsigned, unsigned> getBlockValueCount(Region *region) {
|
|
assert(regionBlockValueCounts.count(region) && "value not numbered");
|
|
return regionBlockValueCounts[region];
|
|
}
|
|
|
|
/// Return the number of operations in the given block.
|
|
unsigned getOperationCount(Block *block) {
|
|
assert(blockOperationCounts.count(block) && "block not numbered");
|
|
return blockOperationCounts[block];
|
|
}
|
|
|
|
/// Return if the given operation is isolated from above.
|
|
bool isIsolatedFromAbove(Operation *op) {
|
|
assert(operations.count(op) && "operation not numbered");
|
|
return operations[op]->isIsolatedFromAbove.value_or(false);
|
|
}
|
|
|
|
/// Get the set desired bytecode version to emit.
|
|
int64_t getDesiredBytecodeVersion() const;
|
|
|
|
private:
|
|
/// This class is used to provide a fake dialect writer for numbering nested
|
|
/// attributes and types.
|
|
struct NumberingDialectWriter;
|
|
|
|
/// Compute the global numbering state for the given root operation.
|
|
void computeGlobalNumberingState(Operation *rootOp);
|
|
|
|
/// Number the given IR unit for bytecode emission.
|
|
void number(Attribute attr);
|
|
void number(Block &block);
|
|
DialectNumbering &numberDialect(Dialect *dialect);
|
|
DialectNumbering &numberDialect(StringRef dialect);
|
|
void number(Operation &op);
|
|
void number(OperationName opName);
|
|
void number(Region ®ion);
|
|
void number(Type type);
|
|
|
|
/// Number the given dialect resources.
|
|
void number(Dialect *dialect, ArrayRef<AsmDialectResourceHandle> resources);
|
|
|
|
/// Finalize the numberings of any dialect resources.
|
|
void finalizeDialectResourceNumberings(Operation *rootOp);
|
|
|
|
/// Mapping from IR to the respective numbering entries.
|
|
DenseMap<Attribute, AttributeNumbering *> attrs;
|
|
DenseMap<Operation *, OperationNumbering *> operations;
|
|
DenseMap<OperationName, OpNameNumbering *> opNames;
|
|
DenseMap<Type, TypeNumbering *> types;
|
|
DenseMap<Dialect *, DialectNumbering *> registeredDialects;
|
|
llvm::MapVector<StringRef, DialectNumbering *> dialects;
|
|
std::vector<AttributeNumbering *> orderedAttrs;
|
|
std::vector<OpNameNumbering *> orderedOpNames;
|
|
std::vector<TypeNumbering *> orderedTypes;
|
|
|
|
/// A mapping from dialect resource handle to the numbering for the referenced
|
|
/// resource.
|
|
llvm::DenseMap<AsmDialectResourceHandle, DialectResourceNumbering *>
|
|
dialectResources;
|
|
|
|
/// Allocators used for the various numbering entries.
|
|
llvm::SpecificBumpPtrAllocator<AttributeNumbering> attrAllocator;
|
|
llvm::SpecificBumpPtrAllocator<DialectNumbering> dialectAllocator;
|
|
llvm::SpecificBumpPtrAllocator<OperationNumbering> opAllocator;
|
|
llvm::SpecificBumpPtrAllocator<OpNameNumbering> opNameAllocator;
|
|
llvm::SpecificBumpPtrAllocator<DialectResourceNumbering> resourceAllocator;
|
|
llvm::SpecificBumpPtrAllocator<TypeNumbering> typeAllocator;
|
|
|
|
/// The value ID for each Block and Value.
|
|
DenseMap<Block *, unsigned> blockIDs;
|
|
DenseMap<Value, unsigned> valueIDs;
|
|
|
|
/// The number of operations in each block.
|
|
DenseMap<Block *, unsigned> blockOperationCounts;
|
|
|
|
/// A map from region to the number of blocks and values within that region.
|
|
DenseMap<Region *, std::pair<unsigned, unsigned>> regionBlockValueCounts;
|
|
|
|
/// The next value ID to assign when numbering.
|
|
unsigned nextValueID = 0;
|
|
|
|
// Configuration: useful to query the required version to emit.
|
|
const BytecodeWriterConfig &config;
|
|
};
|
|
} // namespace detail
|
|
} // namespace bytecode
|
|
} // namespace mlir
|
|
|
|
#endif
|