Files
clang-p2996/flang/lib/Lower/IterationSpace.cpp
Slava Zakharin be5747e516 [flang] Fixed global name creation for literal constants.
The global names were created using a hash based on the address
of std::vector::data address. Since the memory may be reused
by different std::vector's, this may cause non-equivalent
constant expressions to map to the same name. This is what is happening
in the modified flang/test/Lower/constant-literal-mangling.f90 test.

I changed the name creation to use a map between the constant expressions
and corresponding unique names. The uniquing is done using a name counter
in FirConverter. The effect of this change is that the equivalent
constant expressions are now mapped to the same global, and the naming
is "stable" (i.e. it does not change from compilation to compilation).

Though, the issue is not HLFIR specific it was affecting several tests
when using HLFIR lowering.

Differential Revision: https://reviews.llvm.org/D150380
2023-05-12 13:40:22 -07:00

403 lines
12 KiB
C++

//===-- IterationSpace.cpp ------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
//
//===----------------------------------------------------------------------===//
#include "flang/Lower/IterationSpace.h"
#include "flang/Evaluate/expression.h"
#include "flang/Lower/AbstractConverter.h"
#include "flang/Lower/Support/Utils.h"
#include "llvm/Support/Debug.h"
#include <optional>
#define DEBUG_TYPE "flang-lower-iteration-space"
unsigned Fortran::lower::getHashValue(
const Fortran::lower::ExplicitIterSpace::ArrayBases &x) {
return std::visit(
[&](const auto *p) { return HashEvaluateExpr::getHashValue(*p); }, x);
}
bool Fortran::lower::isEqual(
const Fortran::lower::ExplicitIterSpace::ArrayBases &x,
const Fortran::lower::ExplicitIterSpace::ArrayBases &y) {
return std::visit(
Fortran::common::visitors{
// Fortran::semantics::Symbol * are the exception here. These pointers
// have identity; if two Symbol * values are the same (different) then
// they are the same (different) logical symbol.
[&](Fortran::lower::FrontEndSymbol p,
Fortran::lower::FrontEndSymbol q) { return p == q; },
[&](const auto *p, const auto *q) {
if constexpr (std::is_same_v<decltype(p), decltype(q)>) {
LLVM_DEBUG(llvm::dbgs()
<< "is equal: " << p << ' ' << q << ' '
<< IsEqualEvaluateExpr::isEqual(*p, *q) << '\n');
return IsEqualEvaluateExpr::isEqual(*p, *q);
} else {
// Different subtree types are never equal.
return false;
}
}},
x, y);
}
namespace {
/// This class can recover the base array in an expression that contains
/// explicit iteration space symbols. Most of the class can be ignored as it is
/// boilerplate Fortran::evaluate::Expr traversal.
class ArrayBaseFinder {
public:
using RT = bool;
ArrayBaseFinder(llvm::ArrayRef<Fortran::lower::FrontEndSymbol> syms)
: controlVars(syms.begin(), syms.end()) {}
template <typename T>
void operator()(const T &x) {
(void)find(x);
}
/// Get the list of bases.
llvm::ArrayRef<Fortran::lower::ExplicitIterSpace::ArrayBases>
getBases() const {
LLVM_DEBUG(llvm::dbgs()
<< "number of array bases found: " << bases.size() << '\n');
return bases;
}
private:
// First, the cases that are of interest.
RT find(const Fortran::semantics::Symbol &symbol) {
if (symbol.Rank() > 0) {
bases.push_back(&symbol);
return true;
}
return {};
}
RT find(const Fortran::evaluate::Component &x) {
auto found = find(x.base());
if (!found && x.base().Rank() == 0 && x.Rank() > 0) {
bases.push_back(&x);
return true;
}
return found;
}
RT find(const Fortran::evaluate::ArrayRef &x) {
for (const auto &sub : x.subscript())
(void)find(sub);
if (x.base().IsSymbol()) {
if (x.Rank() > 0 || intersection(x.subscript())) {
bases.push_back(&x);
return true;
}
return {};
}
auto found = find(x.base());
if (!found && ((x.base().Rank() == 0 && x.Rank() > 0) ||
intersection(x.subscript()))) {
bases.push_back(&x);
return true;
}
return found;
}
RT find(const Fortran::evaluate::Triplet &x) {
if (const auto *lower = x.GetLower())
(void)find(*lower);
if (const auto *upper = x.GetUpper())
(void)find(*upper);
return find(x.GetStride());
}
RT find(const Fortran::evaluate::IndirectSubscriptIntegerExpr &x) {
return find(x.value());
}
RT find(const Fortran::evaluate::Subscript &x) { return find(x.u); }
RT find(const Fortran::evaluate::DataRef &x) { return find(x.u); }
RT find(const Fortran::evaluate::CoarrayRef &x) {
assert(false && "coarray reference");
return {};
}
template <typename A>
bool intersection(const A &subscripts) {
return Fortran::lower::symbolsIntersectSubscripts(controlVars, subscripts);
}
// The rest is traversal boilerplate and can be ignored.
RT find(const Fortran::evaluate::Substring &x) { return find(x.parent()); }
template <typename A>
RT find(const Fortran::semantics::SymbolRef x) {
return find(*x);
}
RT find(const Fortran::evaluate::NamedEntity &x) {
if (x.IsSymbol())
return find(x.GetFirstSymbol());
return find(x.GetComponent());
}
template <typename A, bool C>
RT find(const Fortran::common::Indirection<A, C> &x) {
return find(x.value());
}
template <typename A>
RT find(const std::unique_ptr<A> &x) {
return find(x.get());
}
template <typename A>
RT find(const std::shared_ptr<A> &x) {
return find(x.get());
}
template <typename A>
RT find(const A *x) {
if (x)
return find(*x);
return {};
}
template <typename A>
RT find(const std::optional<A> &x) {
if (x)
return find(*x);
return {};
}
template <typename... A>
RT find(const std::variant<A...> &u) {
return std::visit([&](const auto &v) { return find(v); }, u);
}
template <typename A>
RT find(const std::vector<A> &x) {
for (auto &v : x)
(void)find(v);
return {};
}
RT find(const Fortran::evaluate::BOZLiteralConstant &) { return {}; }
RT find(const Fortran::evaluate::NullPointer &) { return {}; }
template <typename T>
RT find(const Fortran::evaluate::Constant<T> &x) {
return {};
}
RT find(const Fortran::evaluate::StaticDataObject &) { return {}; }
RT find(const Fortran::evaluate::ImpliedDoIndex &) { return {}; }
RT find(const Fortran::evaluate::BaseObject &x) {
(void)find(x.u);
return {};
}
RT find(const Fortran::evaluate::TypeParamInquiry &) { return {}; }
RT find(const Fortran::evaluate::ComplexPart &x) { return {}; }
template <typename T>
RT find(const Fortran::evaluate::Designator<T> &x) {
return find(x.u);
}
template <typename T>
RT find(const Fortran::evaluate::Variable<T> &x) {
return find(x.u);
}
RT find(const Fortran::evaluate::DescriptorInquiry &) { return {}; }
RT find(const Fortran::evaluate::SpecificIntrinsic &) { return {}; }
RT find(const Fortran::evaluate::ProcedureDesignator &x) { return {}; }
RT find(const Fortran::evaluate::ProcedureRef &x) {
(void)find(x.proc());
if (x.IsElemental())
(void)find(x.arguments());
return {};
}
RT find(const Fortran::evaluate::ActualArgument &x) {
if (const auto *sym = x.GetAssumedTypeDummy())
(void)find(*sym);
else
(void)find(x.UnwrapExpr());
return {};
}
template <typename T>
RT find(const Fortran::evaluate::FunctionRef<T> &x) {
(void)find(static_cast<const Fortran::evaluate::ProcedureRef &>(x));
return {};
}
template <typename T>
RT find(const Fortran::evaluate::ArrayConstructorValue<T> &) {
return {};
}
template <typename T>
RT find(const Fortran::evaluate::ArrayConstructorValues<T> &) {
return {};
}
template <typename T>
RT find(const Fortran::evaluate::ImpliedDo<T> &) {
return {};
}
RT find(const Fortran::semantics::ParamValue &) { return {}; }
RT find(const Fortran::semantics::DerivedTypeSpec &) { return {}; }
RT find(const Fortran::evaluate::StructureConstructor &) { return {}; }
template <typename D, typename R, typename O>
RT find(const Fortran::evaluate::Operation<D, R, O> &op) {
(void)find(op.left());
return false;
}
template <typename D, typename R, typename LO, typename RO>
RT find(const Fortran::evaluate::Operation<D, R, LO, RO> &op) {
(void)find(op.left());
(void)find(op.right());
return false;
}
RT find(const Fortran::evaluate::Relational<Fortran::evaluate::SomeType> &x) {
(void)find(x.u);
return {};
}
template <typename T>
RT find(const Fortran::evaluate::Expr<T> &x) {
(void)find(x.u);
return {};
}
llvm::SmallVector<Fortran::lower::ExplicitIterSpace::ArrayBases> bases;
llvm::SmallVector<Fortran::lower::FrontEndSymbol> controlVars;
};
} // namespace
void Fortran::lower::ExplicitIterSpace::leave() {
ccLoopNest.pop_back();
--forallContextOpen;
conditionalCleanup();
}
void Fortran::lower::ExplicitIterSpace::addSymbol(
Fortran::lower::FrontEndSymbol sym) {
assert(!symbolStack.empty());
symbolStack.back().push_back(sym);
}
void Fortran::lower::ExplicitIterSpace::exprBase(Fortran::lower::FrontEndExpr x,
bool lhs) {
ArrayBaseFinder finder(collectAllSymbols());
finder(*x);
llvm::ArrayRef<Fortran::lower::ExplicitIterSpace::ArrayBases> bases =
finder.getBases();
if (rhsBases.empty())
endAssign();
if (lhs) {
if (bases.empty()) {
lhsBases.push_back(std::nullopt);
return;
}
assert(bases.size() >= 1 && "must detect an array reference on lhs");
if (bases.size() > 1)
rhsBases.back().append(bases.begin(), bases.end() - 1);
lhsBases.push_back(bases.back());
return;
}
rhsBases.back().append(bases.begin(), bases.end());
}
void Fortran::lower::ExplicitIterSpace::endAssign() { rhsBases.emplace_back(); }
void Fortran::lower::ExplicitIterSpace::pushLevel() {
symbolStack.push_back(llvm::SmallVector<Fortran::lower::FrontEndSymbol>{});
}
void Fortran::lower::ExplicitIterSpace::popLevel() { symbolStack.pop_back(); }
void Fortran::lower::ExplicitIterSpace::conditionalCleanup() {
if (forallContextOpen == 0) {
// Exiting the outermost FORALL context.
// Cleanup any residual mask buffers.
outermostContext().finalizeAndReset();
// Clear and reset all the cached information.
symbolStack.clear();
lhsBases.clear();
rhsBases.clear();
loadBindings.clear();
ccLoopNest.clear();
innerArgs.clear();
outerLoop = std::nullopt;
clearLoops();
counter = 0;
}
}
std::optional<size_t>
Fortran::lower::ExplicitIterSpace::findArgPosition(fir::ArrayLoadOp load) {
if (lhsBases[counter]) {
auto ld = loadBindings.find(*lhsBases[counter]);
std::optional<size_t> optPos;
if (ld != loadBindings.end() && ld->second == load)
optPos = static_cast<size_t>(0u);
assert(optPos.has_value() && "load does not correspond to lhs");
return optPos;
}
return std::nullopt;
}
llvm::SmallVector<Fortran::lower::FrontEndSymbol>
Fortran::lower::ExplicitIterSpace::collectAllSymbols() {
llvm::SmallVector<Fortran::lower::FrontEndSymbol> result;
for (llvm::SmallVector<FrontEndSymbol> vec : symbolStack)
result.append(vec.begin(), vec.end());
return result;
}
llvm::raw_ostream &
Fortran::lower::operator<<(llvm::raw_ostream &s,
const Fortran::lower::ImplicitIterSpace &e) {
for (const llvm::SmallVector<
Fortran::lower::ImplicitIterSpace::FrontEndMaskExpr> &xs :
e.getMasks()) {
s << "{ ";
for (const Fortran::lower::ImplicitIterSpace::FrontEndMaskExpr &x : xs)
x->AsFortran(s << '(') << "), ";
s << "}\n";
}
return s;
}
llvm::raw_ostream &
Fortran::lower::operator<<(llvm::raw_ostream &s,
const Fortran::lower::ExplicitIterSpace &e) {
auto dump = [&](const auto &u) {
std::visit(Fortran::common::visitors{
[&](const Fortran::semantics::Symbol *y) {
s << " " << *y << '\n';
},
[&](const Fortran::evaluate::ArrayRef *y) {
s << " ";
if (y->base().IsSymbol())
s << y->base().GetFirstSymbol();
else
s << y->base().GetComponent().GetLastSymbol();
s << '\n';
},
[&](const Fortran::evaluate::Component *y) {
s << " " << y->GetLastSymbol() << '\n';
}},
u);
};
s << "LHS bases:\n";
for (const std::optional<Fortran::lower::ExplicitIterSpace::ArrayBases> &u :
e.lhsBases)
if (u)
dump(*u);
s << "RHS bases:\n";
for (const llvm::SmallVector<Fortran::lower::ExplicitIterSpace::ArrayBases>
&bases : e.rhsBases) {
for (const Fortran::lower::ExplicitIterSpace::ArrayBases &u : bases)
dump(u);
s << '\n';
}
return s;
}
void Fortran::lower::ImplicitIterSpace::dump() const {
llvm::errs() << *this << '\n';
}
void Fortran::lower::ExplicitIterSpace::dump() const {
llvm::errs() << *this << '\n';
}