404 lines
15 KiB
C++
404 lines
15 KiB
C++
//===- Var.h ----------------------------------------------------*- C++ -*-===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#ifndef MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_VAR_H
|
|
#define MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_VAR_H
|
|
|
|
#include "TemplateExtras.h"
|
|
|
|
#include "mlir/IR/OpImplementation.h"
|
|
#include "llvm/ADT/EnumeratedArray.h"
|
|
#include "llvm/ADT/STLForwardCompat.h"
|
|
#include "llvm/ADT/SmallBitVector.h"
|
|
#include "llvm/ADT/StringMap.h"
|
|
|
|
namespace mlir {
|
|
namespace sparse_tensor {
|
|
namespace ir_detail {
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
/// The three kinds of variables that `Var` can be.
|
|
///
|
|
/// NOTE: The numerical values used to represent this enum should be
|
|
/// treated as an implementation detail, not as part of the API. In the
|
|
/// API below we use the canonical ordering `{Symbol,Dimension,Level}` even
|
|
/// though that does not agree with the numerical ordering of the numerical
|
|
/// representation.
|
|
enum class VarKind { Symbol = 1, Dimension = 0, Level = 2 };
|
|
|
|
[[nodiscard]] constexpr bool isWF(VarKind vk) {
|
|
const auto vk_ = llvm::to_underlying(vk);
|
|
return 0 <= vk_ && vk_ <= 2;
|
|
}
|
|
|
|
/// Gets the ASCII character used as the prefix when printing `Var`.
|
|
constexpr char toChar(VarKind vk) {
|
|
// If `isWF(vk)` then this computation's intermediate results are always
|
|
// in the range [-44..126] (where that lower bound is under worst-case
|
|
// rearranging of the expression); and `int_fast8_t` is the fastest type
|
|
// which can support that range without over-/underflow.
|
|
const auto vk_ = static_cast<int_fast8_t>(llvm::to_underlying(vk));
|
|
return static_cast<char>(100 + vk_ * (26 - vk_ * 11));
|
|
}
|
|
static_assert(toChar(VarKind::Symbol) == 's' &&
|
|
toChar(VarKind::Dimension) == 'd' &&
|
|
toChar(VarKind::Level) == 'l');
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
/// The type of arrays indexed by `VarKind`.
|
|
template <typename T>
|
|
using VarKindArray = llvm::EnumeratedArray<T, VarKind, VarKind::Level>;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
/// A concrete variable, to be used in our variant of `AffineExpr`.
|
|
/// Client-facing class for `VarKind` + `Var::Num` pairs, with RTTI
|
|
/// support for subclasses with a fixed `VarKind`.
|
|
class Var {
|
|
public:
|
|
/// Typedef for the type of variable numbers.
|
|
using Num = unsigned;
|
|
|
|
private:
|
|
/// Typedef for the underlying storage of `Var::Impl`.
|
|
using Storage = unsigned;
|
|
|
|
/// The largest `Var::Num` supported by `Var`/`Var::Impl`/`Var::Storage`.
|
|
/// Two low-order bits are reserved for storing the `VarKind`,
|
|
/// and one high-order bit is reserved for future use (e.g., to support
|
|
/// `DenseMapInfo<Var>` while maintaining the usual numeric values for
|
|
/// "empty" and "tombstone").
|
|
static constexpr Num kMaxNum =
|
|
static_cast<Num>(std::numeric_limits<Storage>::max() >> 3);
|
|
|
|
public:
|
|
/// Checks whether the number would be accepted by `Var(VarKind,Var::Num)`.
|
|
//
|
|
// This must be public for `VarInfo` to use it (whereas we don't want
|
|
// to expose the `impl` field via friendship).
|
|
[[nodiscard]] static constexpr bool isWF_Num(Num n) { return n <= kMaxNum; }
|
|
|
|
protected:
|
|
/// The underlying implementation of `Var`. Note that this must be kept
|
|
/// distinct from `Var` itself, since we want to ensure that the RTTI
|
|
/// methods will select the `U(Var::Impl)` ctor rather than selecting
|
|
/// the `U(Var::Num)` ctor.
|
|
class Impl final {
|
|
Storage data;
|
|
|
|
public:
|
|
constexpr Impl(VarKind vk, Num n)
|
|
: data((static_cast<Storage>(n) << 2) |
|
|
static_cast<Storage>(llvm::to_underlying(vk))) {
|
|
assert(isWF(vk) && "unknown VarKind");
|
|
assert(isWF_Num(n) && "Var::Num is too large");
|
|
}
|
|
constexpr bool operator==(Impl other) const { return data == other.data; }
|
|
constexpr bool operator!=(Impl other) const { return !(*this == other); }
|
|
constexpr VarKind getKind() const { return static_cast<VarKind>(data & 3); }
|
|
constexpr Num getNum() const { return static_cast<Num>(data >> 2); }
|
|
};
|
|
static_assert(IsZeroCostAbstraction<Impl>);
|
|
|
|
private:
|
|
Impl impl;
|
|
|
|
protected:
|
|
/// Protected ctor for the RTTI methods to use.
|
|
constexpr explicit Var(Impl impl) : impl(impl) {}
|
|
|
|
public:
|
|
constexpr Var(VarKind vk, Num n) : impl(Impl(vk, n)) {}
|
|
Var(AffineSymbolExpr sym) : Var(VarKind::Symbol, sym.getPosition()) {}
|
|
Var(VarKind vk, AffineDimExpr var) : Var(vk, var.getPosition()) {
|
|
assert(vk != VarKind::Symbol);
|
|
}
|
|
|
|
constexpr bool operator==(Var other) const { return impl == other.impl; }
|
|
constexpr bool operator!=(Var other) const { return !(*this == other); }
|
|
|
|
constexpr VarKind getKind() const { return impl.getKind(); }
|
|
constexpr Num getNum() const { return impl.getNum(); }
|
|
|
|
template <typename U>
|
|
constexpr bool isa() const;
|
|
template <typename U>
|
|
constexpr U cast() const;
|
|
template <typename U>
|
|
constexpr std::optional<U> dyn_cast() const;
|
|
|
|
std::string str() const;
|
|
void print(llvm::raw_ostream &os) const;
|
|
void print(AsmPrinter &printer) const;
|
|
void dump() const;
|
|
};
|
|
static_assert(IsZeroCostAbstraction<Var>);
|
|
|
|
class SymVar final : public Var {
|
|
using Var::Var; // inherit `Var(Impl)` ctor for RTTI use.
|
|
public:
|
|
static constexpr VarKind Kind = VarKind::Symbol;
|
|
static constexpr bool classof(Var const *var) {
|
|
return var->getKind() == Kind;
|
|
}
|
|
constexpr SymVar(Num sym) : Var(Kind, sym) {}
|
|
SymVar(AffineSymbolExpr symExpr) : Var(symExpr) {}
|
|
};
|
|
static_assert(IsZeroCostAbstraction<SymVar>);
|
|
|
|
class DimVar final : public Var {
|
|
using Var::Var; // inherit `Var(Impl)` ctor for RTTI use.
|
|
public:
|
|
static constexpr VarKind Kind = VarKind::Dimension;
|
|
static constexpr bool classof(Var const *var) {
|
|
return var->getKind() == Kind;
|
|
}
|
|
constexpr DimVar(Num dim) : Var(Kind, dim) {}
|
|
DimVar(AffineDimExpr dimExpr) : Var(Kind, dimExpr) {}
|
|
};
|
|
static_assert(IsZeroCostAbstraction<DimVar>);
|
|
|
|
class LvlVar final : public Var {
|
|
using Var::Var; // inherit `Var(Impl)` ctor for RTTI use.
|
|
public:
|
|
static constexpr VarKind Kind = VarKind::Level;
|
|
static constexpr bool classof(Var const *var) {
|
|
return var->getKind() == Kind;
|
|
}
|
|
constexpr LvlVar(Num lvl) : Var(Kind, lvl) {}
|
|
LvlVar(AffineDimExpr lvlExpr) : Var(Kind, lvlExpr) {}
|
|
};
|
|
static_assert(IsZeroCostAbstraction<LvlVar>);
|
|
|
|
template <typename U>
|
|
constexpr bool Var::isa() const {
|
|
if constexpr (std::is_same_v<U, SymVar>)
|
|
return getKind() == VarKind::Symbol;
|
|
if constexpr (std::is_same_v<U, DimVar>)
|
|
return getKind() == VarKind::Dimension;
|
|
if constexpr (std::is_same_v<U, LvlVar>)
|
|
return getKind() == VarKind::Level;
|
|
}
|
|
|
|
template <typename U>
|
|
constexpr U Var::cast() const {
|
|
assert(isa<U>());
|
|
// NOTE: This should select the `U(Var::Impl)` ctor, *not* `U(Var::Num)`
|
|
return U(impl);
|
|
}
|
|
|
|
template <typename U>
|
|
constexpr std::optional<U> Var::dyn_cast() const {
|
|
// NOTE: This should select the `U(Var::Impl)` ctor, *not* `U(Var::Num)`
|
|
return isa<U>() ? std::make_optional(U(impl)) : std::nullopt;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Forward-decl so that we can declare methods of `Ranks` and `VarSet`.
|
|
class DimLvlExpr;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
class Ranks final {
|
|
// Not using `VarKindArray` since `EnumeratedArray` doesn't support constexpr.
|
|
unsigned impl[3];
|
|
|
|
static constexpr unsigned to_index(VarKind vk) {
|
|
assert(isWF(vk) && "unknown VarKind");
|
|
return static_cast<unsigned>(llvm::to_underlying(vk));
|
|
}
|
|
|
|
public:
|
|
constexpr Ranks(unsigned symRank, unsigned dimRank, unsigned lvlRank)
|
|
: impl() {
|
|
impl[to_index(VarKind::Symbol)] = symRank;
|
|
impl[to_index(VarKind::Dimension)] = dimRank;
|
|
impl[to_index(VarKind::Level)] = lvlRank;
|
|
}
|
|
Ranks(VarKindArray<unsigned> const &ranks)
|
|
: Ranks(ranks[VarKind::Symbol], ranks[VarKind::Dimension],
|
|
ranks[VarKind::Level]) {}
|
|
|
|
bool operator==(Ranks const &other) const;
|
|
bool operator!=(Ranks const &other) const { return !(*this == other); }
|
|
|
|
constexpr unsigned getRank(VarKind vk) const { return impl[to_index(vk)]; }
|
|
constexpr unsigned getSymRank() const { return getRank(VarKind::Symbol); }
|
|
constexpr unsigned getDimRank() const { return getRank(VarKind::Dimension); }
|
|
constexpr unsigned getLvlRank() const { return getRank(VarKind::Level); }
|
|
|
|
[[nodiscard]] constexpr bool isValid(Var var) const {
|
|
return var.getNum() < getRank(var.getKind());
|
|
}
|
|
[[nodiscard]] bool isValid(DimLvlExpr expr) const;
|
|
};
|
|
static_assert(IsZeroCostAbstraction<Ranks>);
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
/// Efficient representation of a set of `Var`.
|
|
class VarSet final {
|
|
VarKindArray<llvm::SmallBitVector> impl;
|
|
|
|
public:
|
|
explicit VarSet(Ranks const &ranks);
|
|
|
|
unsigned getRank(VarKind vk) const { return impl[vk].size(); }
|
|
unsigned getSymRank() const { return getRank(VarKind::Symbol); }
|
|
unsigned getDimRank() const { return getRank(VarKind::Dimension); }
|
|
unsigned getLvlRank() const { return getRank(VarKind::Level); }
|
|
Ranks getRanks() const {
|
|
return Ranks(getSymRank(), getDimRank(), getLvlRank());
|
|
}
|
|
/// For the `contains` method: if variables occurring in
|
|
/// the method parameter are OOB for the `VarSet`, then these methods will
|
|
/// always return false.
|
|
bool contains(Var var) const;
|
|
|
|
/// For the `add` methods: OOB parameters cause undefined behavior.
|
|
/// Currently the `add` methods will raise an assertion error.
|
|
void add(Var var);
|
|
void add(VarSet const &vars);
|
|
void add(DimLvlExpr expr);
|
|
};
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
/// A record of metadata for/about a variable, used by `VarEnv`.
|
|
/// The principal goal of this record is to enable `VarEnv` to be used for
|
|
/// incremental parsing; in particular, `VarInfo` allows the `Var::Num` to
|
|
/// remain unknown, since each record is instead identified by `VarInfo::ID`.
|
|
/// Therefore the `VarEnv` can freely allocate `VarInfo::ID` in whatever
|
|
/// order it likes, irrespective of the binding order (`Var::Num`) of the
|
|
/// associated variable.
|
|
class VarInfo final {
|
|
public:
|
|
/// Newtype for unique identifiers of `VarInfo` records, to ensure
|
|
/// they aren't confused with `Var::Num`.
|
|
enum class ID : unsigned {};
|
|
|
|
private:
|
|
StringRef name; // The bare-id used in the MLIR source.
|
|
llvm::SMLoc loc; // The location of the first occurence.
|
|
ID id; // The unique `VarInfo`-identifier.
|
|
std::optional<Var::Num> num; // The unique `Var`-identifier (if resolved).
|
|
VarKind kind; // The kind of variable.
|
|
|
|
public:
|
|
constexpr VarInfo(ID id, StringRef name, llvm::SMLoc loc, VarKind vk,
|
|
std::optional<Var::Num> n = {})
|
|
: name(name), loc(loc), id(id), num(n), kind(vk) {
|
|
assert(!name.empty() && "null StringRef");
|
|
assert(loc.isValid() && "null SMLoc");
|
|
assert(isWF(vk) && "unknown VarKind");
|
|
assert((!n || Var::isWF_Num(*n)) && "Var::Num is too large");
|
|
}
|
|
|
|
constexpr StringRef getName() const { return name; }
|
|
constexpr llvm::SMLoc getLoc() const { return loc; }
|
|
Location getLocation(AsmParser &parser) const {
|
|
return parser.getEncodedSourceLoc(loc);
|
|
}
|
|
constexpr ID getID() const { return id; }
|
|
constexpr VarKind getKind() const { return kind; }
|
|
constexpr std::optional<Var::Num> getNum() const { return num; }
|
|
constexpr bool hasNum() const { return num.has_value(); }
|
|
void setNum(Var::Num n);
|
|
constexpr Var getVar() const {
|
|
assert(hasNum());
|
|
return Var(kind, *num);
|
|
}
|
|
};
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
enum class Policy { MustNot, May, Must };
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
class VarEnv final {
|
|
/// Map from `VarKind` to the next free `Var::Num`; used by `bindVar`.
|
|
VarKindArray<Var::Num> nextNum;
|
|
/// Map from `VarInfo::ID` to shared storage for the actual `VarInfo` objects.
|
|
SmallVector<VarInfo> vars;
|
|
/// Map from variable names to their `VarInfo::ID`.
|
|
llvm::StringMap<VarInfo::ID> ids;
|
|
|
|
VarInfo::ID nextID() const { return static_cast<VarInfo::ID>(vars.size()); }
|
|
|
|
public:
|
|
VarEnv() : nextNum(0) {}
|
|
|
|
/// Gets the underlying storage for the `VarInfo` identified by
|
|
/// the `VarInfo::ID`.
|
|
///
|
|
/// NOTE: The returned reference can become dangling if the `VarEnv`
|
|
/// object is mutated during the lifetime of the pointer. Therefore,
|
|
/// client code should not store the reference nor otherwise allow it
|
|
/// to live too long.
|
|
VarInfo const &access(VarInfo::ID id) const {
|
|
// `SmallVector::operator[]` already asserts the index is in-bounds.
|
|
return vars[llvm::to_underlying(id)];
|
|
}
|
|
VarInfo const *access(std::optional<VarInfo::ID> oid) const {
|
|
return oid ? &access(*oid) : nullptr;
|
|
}
|
|
|
|
private:
|
|
VarInfo &access(VarInfo::ID id) {
|
|
return const_cast<VarInfo &>(std::as_const(*this).access(id));
|
|
}
|
|
VarInfo *access(std::optional<VarInfo::ID> oid) {
|
|
return const_cast<VarInfo *>(std::as_const(*this).access(oid));
|
|
}
|
|
|
|
public:
|
|
/// Looks up the variable with the given name.
|
|
std::optional<VarInfo::ID> lookup(StringRef name) const;
|
|
|
|
/// Creates a new currently-unbound variable. When a variable
|
|
/// of that name already exists: if `verifyUsage` is true, then will assert
|
|
/// that the variable has the same kind and a consistent location; otherwise,
|
|
/// when `verifyUsage` is false, this is a noop. Returns the identifier
|
|
/// for the variable with the given name, and a bool indicating whether
|
|
/// a new variable was created.
|
|
std::optional<std::pair<VarInfo::ID, bool>>
|
|
create(StringRef name, llvm::SMLoc loc, VarKind vk, bool verifyUsage = false);
|
|
|
|
/// Looks up or creates a variable according to the given
|
|
/// `Policy`. Returns nullopt in one of two circumstances:
|
|
/// (1) the policy says we `Must` create, yet the variable already exists;
|
|
/// (2) the policy says we `MustNot` create, yet no such variable exists.
|
|
/// Otherwise, if the variable already exists then it is validated against
|
|
/// the given kind and location to ensure consistency.
|
|
std::optional<std::pair<VarInfo::ID, bool>>
|
|
lookupOrCreate(Policy creationPolicy, StringRef name, llvm::SMLoc loc,
|
|
VarKind vk);
|
|
|
|
/// Binds the given variable to the next free `Var::Num` for its `VarKind`.
|
|
Var bindVar(VarInfo::ID id);
|
|
|
|
/// Creates a new variable of the given kind and immediately binds it.
|
|
/// This should only be used whenever the variable is known to be unused
|
|
/// and therefore does not have a name.
|
|
Var bindUnusedVar(VarKind vk);
|
|
|
|
InFlightDiagnostic emitErrorIfAnyUnbound(AsmParser &parser) const;
|
|
|
|
/// Returns the current ranks of bound variables. This method should
|
|
/// only be used after the environment is "finished", since binding new
|
|
/// variables will (semantically) invalidate any previously returned `Ranks`.
|
|
Ranks getRanks() const { return Ranks(nextNum); }
|
|
|
|
/// Gets the `Var` identified by the `VarInfo::ID`, raising an assertion
|
|
/// failure if the variable is not bound.
|
|
Var getVar(VarInfo::ID id) const { return access(id).getVar(); }
|
|
};
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
} // namespace ir_detail
|
|
} // namespace sparse_tensor
|
|
} // namespace mlir
|
|
|
|
#endif // MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_VAR_H
|