Files
clang-p2996/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
2023-07-07 14:35:56 -07:00

312 lines
11 KiB
C++

//===- Var.cpp ------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "Var.h"
#include "DimLvlMap.h"
using namespace mlir;
using namespace mlir::sparse_tensor;
using namespace mlir::sparse_tensor::ir_detail;
//===----------------------------------------------------------------------===//
// `Var` implementation.
//===----------------------------------------------------------------------===//
void Var::print(AsmPrinter &printer) const { print(printer.getStream()); }
void Var::print(llvm::raw_ostream &os) const {
os << toChar(getKind()) << getNum();
}
void Var::dump() const {
print(llvm::errs());
llvm::errs() << "\n";
}
//===----------------------------------------------------------------------===//
// `Ranks` implementation.
//===----------------------------------------------------------------------===//
bool Ranks::isValid(DimLvlExpr expr) const {
// FIXME(wrengr): we have cases without affine expr at an early point
if (!expr.getAffineExpr())
return true;
// Each `DimLvlExpr` only allows one kind of non-symbol variable.
int64_t maxSym = -1, maxVar = -1;
// TODO(wrengr): If we run into ASan issues, that may be due to the
// "`{{...}}`" syntax; so we may want to try using local-variables instead.
mlir::getMaxDimAndSymbol<ArrayRef<AffineExpr>>({{expr.getAffineExpr()}},
maxVar, maxSym);
// TODO(wrengr): We may want to add a call to `LLVM_DEBUG` like
// `willBeValidAffineMap` does. And/or should return `InFlightDiagnostic`
// instead of bool.
return maxSym < getSymRank() && maxVar < getRank(expr.getAllowedVarKind());
}
//===----------------------------------------------------------------------===//
// `VarSet` implementation.
//===----------------------------------------------------------------------===//
static constexpr const VarKind everyVarKind[] = {
VarKind::Dimension, VarKind::Symbol, VarKind::Level};
VarSet::VarSet(Ranks const &ranks) {
for (const auto vk : everyVarKind)
impl[vk].reserve(ranks.getRank(vk));
}
bool VarSet::contains(Var var) const {
// NOTE: We make sure to return false on OOB, for consistency with
// the `anyCommon` implementation of `VarSet::occursIn(VarSet)`.
// However beware that, as always with silencing OOB, this can hide
// bugs in client code.
const llvm::SmallBitVector &bits = impl[var.getKind()];
const auto num = var.getNum();
// FIXME(wrengr): If we `assert(num < bits.size())` then
// "roundtrip_encoding.mlir" will fail. So we need to figure out
// where exactly the OOB `var` is coming from, to determine whether
// that's a logic bug or not.
return num < bits.size() && bits[num];
}
bool VarSet::occursIn(VarSet const &other) const {
for (const auto vk : everyVarKind)
if (impl[vk].anyCommon(other.impl[vk]))
return true;
return false;
}
bool VarSet::occursIn(DimLvlExpr expr) const {
if (!expr)
return false;
switch (expr.getAffineKind()) {
case AffineExprKind::Constant:
return false;
case AffineExprKind::SymbolId:
return contains(expr.castSymVar());
case AffineExprKind::DimId:
return contains(expr.castDimLvlVar());
case AffineExprKind::Add:
case AffineExprKind::Mul:
case AffineExprKind::Mod:
case AffineExprKind::FloorDiv:
case AffineExprKind::CeilDiv: {
const auto [lhs, op, rhs] = expr.unpackBinop();
(void)op;
return occursIn(lhs) || occursIn(rhs);
}
}
llvm_unreachable("unknown AffineExprKind");
}
void VarSet::add(Var var) {
// NOTE: `SmallBitVactor::operator[]` will raise assertion errors for OOB.
impl[var.getKind()][var.getNum()] = true;
}
void VarSet::add(VarSet const &other) {
// NOTE: `SmallBitVector::operator&=` will implicitly resize
// the bitvector (unlike `BitVector::operator&=`), so we add an
// assertion against OOB for consistency with the implementation
// of `VarSet::add(Var)`.
for (const auto vk : everyVarKind) {
assert(impl[vk].size() >= other.impl[vk].size());
impl[vk] &= other.impl[vk];
}
}
void VarSet::add(DimLvlExpr expr) {
if (!expr)
return;
switch (expr.getAffineKind()) {
case AffineExprKind::Constant:
return;
case AffineExprKind::SymbolId:
add(expr.castSymVar());
return;
case AffineExprKind::DimId:
add(expr.castDimLvlVar());
return;
case AffineExprKind::Add:
case AffineExprKind::Mul:
case AffineExprKind::Mod:
case AffineExprKind::FloorDiv:
case AffineExprKind::CeilDiv: {
const auto [lhs, op, rhs] = expr.unpackBinop();
(void)op;
add(lhs);
add(rhs);
return;
}
}
llvm_unreachable("unknown AffineExprKind");
}
//===----------------------------------------------------------------------===//
// `VarInfo` implementation.
//===----------------------------------------------------------------------===//
void VarInfo::setNum(Var::Num n) {
assert(!hasNum() && "Var::Num is already set");
assert(Var::isWF_Num(n) && "Var::Num is too large");
num = n;
}
//===----------------------------------------------------------------------===//
// `VarEnv` implementation.
//===----------------------------------------------------------------------===//
/// Helper function for `assertUsageConsistency` to better handle SMLoc
/// mismatches.
// TODO(wrengr): If we switch to the `LocatedVar` design, then there's
// no need for anything like `minSMLoc` since `assertUsageConsistency`
// won't need to do anything about locations.
LLVM_ATTRIBUTE_UNUSED static llvm::SMLoc
minSMLoc(AsmParser &parser, llvm::SMLoc sm1, llvm::SMLoc sm2) {
const auto loc1 = parser.getEncodedSourceLoc(sm1).dyn_cast<FileLineColLoc>();
assert(loc1 && "Could not get `FileLineColLoc` for first `SMLoc`");
const auto loc2 = parser.getEncodedSourceLoc(sm2).dyn_cast<FileLineColLoc>();
assert(loc2 && "Could not get `FileLineColLoc` for second `SMLoc`");
if (loc1.getFilename() != loc2.getFilename())
return SMLoc();
const auto pair1 = std::make_pair(loc1.getLine(), loc1.getColumn());
const auto pair2 = std::make_pair(loc2.getLine(), loc2.getColumn());
return pair1 <= pair2 ? sm1 : sm2;
}
LLVM_ATTRIBUTE_UNUSED static void
assertInternalConsistency(VarEnv const &env, VarInfo::ID id, StringRef name) {
#ifndef NDEBUG
const auto &var = env.access(id);
assert(var.getName() == name && "found inconsistent name");
assert(var.getID() == id && "found inconsistent VarInfo::ID");
#endif // NDEBUG
}
// NOTE(wrengr): if we can actually obtain an `AsmParser` for `minSMLoc`
// (or find some other way to convert SMLoc to FileLineColLoc), then this
// would no longer be `const VarEnv` (and couldn't be a free-function either).
LLVM_ATTRIBUTE_UNUSED static void assertUsageConsistency(VarEnv const &env,
VarInfo::ID id,
llvm::SMLoc loc,
VarKind vk) {
#ifndef NDEBUG
const auto &var = env.access(id);
assert(var.getKind() == vk &&
"a variable of that name already exists with a different VarKind");
// Since the same variable can occur at several locations,
// it would not be appropriate to do `assert(var.getLoc() == loc)`.
/* TODO(wrengr):
const auto minLoc = minSMLoc(_, var.getLoc(), loc);
assert(minLoc && "Location mismatch/incompatibility");
var.loc = minLoc;
// */
#endif // NDEBUG
}
std::optional<VarInfo::ID> VarEnv::lookup(StringRef name) const {
// NOTE: `StringMap::lookup` will return a default-constructed value if
// the key isn't found; which for enums means zero, and therefore makes
// it impossible to distinguish between actual zero-VarInfo::ID vs not-found.
// Whereas `StringMap::at` asserts that the key is found, which we don't
// want either.
const auto iter = ids.find(name);
if (iter == ids.end())
return std::nullopt;
const auto id = iter->second;
#ifndef NDEBUG
assertInternalConsistency(*this, id, name);
#endif // NDEBUG
return id;
}
std::pair<VarInfo::ID, bool> VarEnv::create(StringRef name, llvm::SMLoc loc,
VarKind vk, bool verifyUsage) {
const auto &[iter, didInsert] = ids.try_emplace(name, nextID());
const auto id = iter->second;
if (didInsert) {
vars.emplace_back(id, name, loc, vk);
} else {
#ifndef NDEBUG
assertInternalConsistency(*this, id, name);
if (verifyUsage)
assertUsageConsistency(*this, id, loc, vk);
#endif // NDEBUG
}
return std::make_pair(id, didInsert);
}
std::optional<std::pair<VarInfo::ID, bool>>
VarEnv::lookupOrCreate(CreationPolicy policy, StringRef name, llvm::SMLoc loc,
VarKind vk) {
switch (policy) {
case CreationPolicy::MustNot: {
const auto oid = lookup(name);
if (!oid)
return std::nullopt; // Doesn't exist, but must not create.
#ifndef NDEBUG
assertUsageConsistency(*this, *oid, loc, vk);
#endif // NDEBUG
return std::make_pair(*oid, false);
}
case CreationPolicy::May:
return create(name, loc, vk, /*verifyUsage=*/true);
case CreationPolicy::Must: {
const auto res = create(name, loc, vk, /*verifyUsage=*/false);
// const auto id = res.first;
const auto didCreate = res.second;
if (!didCreate)
return std::nullopt; // Already exists, but must create.
return res;
}
}
llvm_unreachable("unknown CreationPolicy");
}
Var VarEnv::bindUnusedVar(VarKind vk) { return Var(vk, nextNum[vk]++); }
Var VarEnv::bindVar(VarInfo::ID id) {
auto &info = access(id);
const auto var = bindUnusedVar(info.getKind());
// NOTE: `setNum` already checks wellformedness of the `Var::Num`.
info.setNum(var.getNum());
return var;
}
// TODO(wrengr): Alternatively there's `mlir::emitError(Location, Twine const&)`
// which is what `Operation::emitError` uses; though I'm not sure if
// that's appropriate to use here... But if it is, then that means
// we can have `VarInfo` store `Location` rather than `SMLoc`, which
// means we can use `FusedLoc` to handle the combination issue in
// `VarEnv::lookupOrCreate`.
//
// TODO(wrengr): is there any way to combine multiple IFDs, so that
// we can report all unbound variables instead of just the first one
// encountered?
//
InFlightDiagnostic VarEnv::emitErrorIfAnyUnbound(AsmParser &parser) const {
for (const auto &var : vars)
if (!var.hasNum())
return parser.emitError(var.getLoc(),
"Unbound variable: " + var.getName());
return {};
}
void VarEnv::addVars(
SmallVectorImpl<std::pair<StringRef, AffineExpr>> &dimsAndSymbols,
VarKind vk, MLIRContext *context) const {
for (const auto &var : vars) {
if (var.getKind() == vk) {
assert(var.hasNum());
dimsAndSymbols.push_back(std::make_pair(
var.getName(), getAffineDimExpr(*var.getNum(), context)));
}
}
}
//===----------------------------------------------------------------------===//