//===-- IterationSpace.cpp ------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ // //===----------------------------------------------------------------------===// #include "flang/Lower/IterationSpace.h" #include "flang/Evaluate/expression.h" #include "flang/Lower/AbstractConverter.h" #include "flang/Lower/Support/Utils.h" #include "llvm/Support/Debug.h" #include #define DEBUG_TYPE "flang-lower-iteration-space" namespace { // Fortran::evaluate::Expr are functional values organized like an AST. A // Fortran::evaluate::Expr is meant to be moved and cloned. Using the front end // tools can often cause copies and extra wrapper classes to be added to any // Fortran::evalute::Expr. These values should not be assumed or relied upon to // have an *object* identity. They are deeply recursive, irregular structures // built from a large number of classes which do not use inheritance and // necessitate a large volume of boilerplate code as a result. // // Contrastingly, LLVM data structures make ubiquitous assumptions about an // object's identity via pointers to the object. An object's location in memory // is thus very often an identifying relation. // This class defines a hash computation of a Fortran::evaluate::Expr tree value // so it can be used with llvm::DenseMap. The Fortran::evaluate::Expr need not // have the same address. class HashEvaluateExpr { public: // A Se::Symbol is the only part of an Fortran::evaluate::Expr with an // identity property. static unsigned getHashValue(const Fortran::semantics::Symbol &x) { return static_cast(reinterpret_cast(&x)); } template static unsigned getHashValue(const Fortran::common::Indirection &x) { return getHashValue(x.value()); } template static unsigned getHashValue(const std::optional &x) { if (x.has_value()) return getHashValue(x.value()); return 0u; } static unsigned getHashValue(const Fortran::evaluate::Subscript &x) { return std::visit([&](const auto &v) { return getHashValue(v); }, x.u); } static unsigned getHashValue(const Fortran::evaluate::Triplet &x) { return getHashValue(x.lower()) - getHashValue(x.upper()) * 5u - getHashValue(x.stride()) * 11u; } static unsigned getHashValue(const Fortran::evaluate::Component &x) { return getHashValue(x.base()) * 83u - getHashValue(x.GetLastSymbol()); } static unsigned getHashValue(const Fortran::evaluate::ArrayRef &x) { unsigned subs = 1u; for (const Fortran::evaluate::Subscript &v : x.subscript()) subs -= getHashValue(v); return getHashValue(x.base()) * 89u - subs; } static unsigned getHashValue(const Fortran::evaluate::CoarrayRef &x) { unsigned subs = 1u; for (const Fortran::evaluate::Subscript &v : x.subscript()) subs -= getHashValue(v); unsigned cosubs = 3u; for (const Fortran::evaluate::Expr &v : x.cosubscript()) cosubs -= getHashValue(v); unsigned syms = 7u; for (const Fortran::evaluate::SymbolRef &v : x.base()) syms += getHashValue(v); return syms * 97u - subs - cosubs + getHashValue(x.stat()) + 257u + getHashValue(x.team()); } static unsigned getHashValue(const Fortran::evaluate::NamedEntity &x) { if (x.IsSymbol()) return getHashValue(x.GetFirstSymbol()) * 11u; return getHashValue(x.GetComponent()) * 13u; } static unsigned getHashValue(const Fortran::evaluate::DataRef &x) { return std::visit([&](const auto &v) { return getHashValue(v); }, x.u); } static unsigned getHashValue(const Fortran::evaluate::ComplexPart &x) { return getHashValue(x.complex()) - static_cast(x.part()); } template static unsigned getHashValue( const Fortran::evaluate::Convert, TC2> &x) { return getHashValue(x.left()) - (static_cast(TC1) + 2u) - (static_cast(KIND) + 5u); } template static unsigned getHashValue(const Fortran::evaluate::ComplexComponent &x) { return getHashValue(x.left()) - (static_cast(x.isImaginaryPart) + 1u) * 3u; } template static unsigned getHashValue(const Fortran::evaluate::Parentheses &x) { return getHashValue(x.left()) * 17u; } template static unsigned getHashValue( const Fortran::evaluate::Negate> &x) { return getHashValue(x.left()) - (static_cast(TC) + 5u) - (static_cast(KIND) + 7u); } template static unsigned getHashValue( const Fortran::evaluate::Add> &x) { return (getHashValue(x.left()) + getHashValue(x.right())) * 23u + static_cast(TC) + static_cast(KIND); } template static unsigned getHashValue( const Fortran::evaluate::Subtract> &x) { return (getHashValue(x.left()) - getHashValue(x.right())) * 19u + static_cast(TC) + static_cast(KIND); } template static unsigned getHashValue( const Fortran::evaluate::Multiply> &x) { return (getHashValue(x.left()) + getHashValue(x.right())) * 29u + static_cast(TC) + static_cast(KIND); } template static unsigned getHashValue( const Fortran::evaluate::Divide> &x) { return (getHashValue(x.left()) - getHashValue(x.right())) * 31u + static_cast(TC) + static_cast(KIND); } template static unsigned getHashValue( const Fortran::evaluate::Power> &x) { return (getHashValue(x.left()) - getHashValue(x.right())) * 37u + static_cast(TC) + static_cast(KIND); } template static unsigned getHashValue( const Fortran::evaluate::Extremum> &x) { return (getHashValue(x.left()) + getHashValue(x.right())) * 41u + static_cast(TC) + static_cast(KIND) + static_cast(x.ordering) * 7u; } template static unsigned getHashValue( const Fortran::evaluate::RealToIntPower> &x) { return (getHashValue(x.left()) - getHashValue(x.right())) * 43u + static_cast(TC) + static_cast(KIND); } template static unsigned getHashValue(const Fortran::evaluate::ComplexConstructor &x) { return (getHashValue(x.left()) - getHashValue(x.right())) * 47u + static_cast(KIND); } template static unsigned getHashValue(const Fortran::evaluate::Concat &x) { return (getHashValue(x.left()) - getHashValue(x.right())) * 53u + static_cast(KIND); } template static unsigned getHashValue(const Fortran::evaluate::SetLength &x) { return (getHashValue(x.left()) - getHashValue(x.right())) * 59u + static_cast(KIND); } static unsigned getHashValue(const Fortran::semantics::SymbolRef &sym) { return getHashValue(sym.get()); } static unsigned getHashValue(const Fortran::evaluate::Substring &x) { return 61u * std::visit([&](const auto &p) { return getHashValue(p); }, x.parent()) - getHashValue(x.lower()) - (getHashValue(x.lower()) + 1u); } static unsigned getHashValue(const Fortran::evaluate::StaticDataObject::Pointer &x) { return llvm::hash_value(x->name()); } static unsigned getHashValue(const Fortran::evaluate::SpecificIntrinsic &x) { return llvm::hash_value(x.name); } template static unsigned getHashValue(const Fortran::evaluate::Constant &x) { // FIXME: Should hash the content. return 103u; } static unsigned getHashValue(const Fortran::evaluate::ActualArgument &x) { if (const Fortran::evaluate::Symbol *sym = x.GetAssumedTypeDummy()) return getHashValue(*sym); return getHashValue(*x.UnwrapExpr()); } static unsigned getHashValue(const Fortran::evaluate::ProcedureDesignator &x) { return std::visit([&](const auto &v) { return getHashValue(v); }, x.u); } static unsigned getHashValue(const Fortran::evaluate::ProcedureRef &x) { unsigned args = 13u; for (const std::optional &v : x.arguments()) args -= getHashValue(v); return getHashValue(x.proc()) * 101u - args; } template static unsigned getHashValue(const Fortran::evaluate::ArrayConstructor &x) { // FIXME: hash the contents. return 127u; } static unsigned getHashValue(const Fortran::evaluate::ImpliedDoIndex &x) { return llvm::hash_value(toStringRef(x.name).str()) * 131u; } static unsigned getHashValue(const Fortran::evaluate::TypeParamInquiry &x) { return getHashValue(x.base()) * 137u - getHashValue(x.parameter()) * 3u; } static unsigned getHashValue(const Fortran::evaluate::DescriptorInquiry &x) { return getHashValue(x.base()) * 139u - static_cast(x.field()) * 13u + static_cast(x.dimension()); } static unsigned getHashValue(const Fortran::evaluate::StructureConstructor &x) { // FIXME: hash the contents. return 149u; } template static unsigned getHashValue(const Fortran::evaluate::Not &x) { return getHashValue(x.left()) * 61u + static_cast(KIND); } template static unsigned getHashValue(const Fortran::evaluate::LogicalOperation &x) { unsigned result = getHashValue(x.left()) + getHashValue(x.right()); return result * 67u + static_cast(x.logicalOperator) * 5u; } template static unsigned getHashValue( const Fortran::evaluate::Relational> &x) { return (getHashValue(x.left()) + getHashValue(x.right())) * 71u + static_cast(TC) + static_cast(KIND) + static_cast(x.opr) * 11u; } template static unsigned getHashValue(const Fortran::evaluate::Expr &x) { return std::visit([&](const auto &v) { return getHashValue(v); }, x.u); } static unsigned getHashValue( const Fortran::evaluate::Relational &x) { return std::visit([&](const auto &v) { return getHashValue(v); }, x.u); } template static unsigned getHashValue(const Fortran::evaluate::Designator &x) { return std::visit([&](const auto &v) { return getHashValue(v); }, x.u); } template static unsigned getHashValue(const Fortran::evaluate::value::Integer &x) { return static_cast(x.ToSInt()); } static unsigned getHashValue(const Fortran::evaluate::NullPointer &x) { return ~179u; } }; } // namespace unsigned Fortran::lower::getHashValue( const Fortran::lower::ExplicitIterSpace::ArrayBases &x) { return std::visit( [&](const auto *p) { return HashEvaluateExpr::getHashValue(*p); }, x); } unsigned Fortran::lower::getHashValue(Fortran::lower::FrontEndExpr x) { return HashEvaluateExpr::getHashValue(*x); } namespace { // Define the is equals test for using Fortran::evaluate::Expr values with // llvm::DenseMap. class IsEqualEvaluateExpr { public: // A Se::Symbol is the only part of an Fortran::evaluate::Expr with an // identity property. static bool isEqual(const Fortran::semantics::Symbol &x, const Fortran::semantics::Symbol &y) { return isEqual(&x, &y); } static bool isEqual(const Fortran::semantics::Symbol *x, const Fortran::semantics::Symbol *y) { return x == y; } template static bool isEqual(const Fortran::common::Indirection &x, const Fortran::common::Indirection &y) { return isEqual(x.value(), y.value()); } template static bool isEqual(const std::optional &x, const std::optional &y) { if (x.has_value() && y.has_value()) return isEqual(x.value(), y.value()); return !x.has_value() && !y.has_value(); } template static bool isEqual(const std::vector &x, const std::vector &y) { if (x.size() != y.size()) return false; const std::size_t size = x.size(); for (std::remove_const_t i = 0; i < size; ++i) if (!isEqual(x[i], y[i])) return false; return true; } static bool isEqual(const Fortran::evaluate::Subscript &x, const Fortran::evaluate::Subscript &y) { return std::visit( [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u); } static bool isEqual(const Fortran::evaluate::Triplet &x, const Fortran::evaluate::Triplet &y) { return isEqual(x.lower(), y.lower()) && isEqual(x.upper(), y.upper()) && isEqual(x.stride(), y.stride()); } static bool isEqual(const Fortran::evaluate::Component &x, const Fortran::evaluate::Component &y) { return isEqual(x.base(), y.base()) && isEqual(x.GetLastSymbol(), y.GetLastSymbol()); } static bool isEqual(const Fortran::evaluate::ArrayRef &x, const Fortran::evaluate::ArrayRef &y) { return isEqual(x.base(), y.base()) && isEqual(x.subscript(), y.subscript()); } static bool isEqual(const Fortran::evaluate::CoarrayRef &x, const Fortran::evaluate::CoarrayRef &y) { return isEqual(x.base(), y.base()) && isEqual(x.subscript(), y.subscript()) && isEqual(x.cosubscript(), y.cosubscript()) && isEqual(x.stat(), y.stat()) && isEqual(x.team(), y.team()); } static bool isEqual(const Fortran::evaluate::NamedEntity &x, const Fortran::evaluate::NamedEntity &y) { if (x.IsSymbol() && y.IsSymbol()) return isEqual(x.GetFirstSymbol(), y.GetFirstSymbol()); return !x.IsSymbol() && !y.IsSymbol() && isEqual(x.GetComponent(), y.GetComponent()); } static bool isEqual(const Fortran::evaluate::DataRef &x, const Fortran::evaluate::DataRef &y) { return std::visit( [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u); } static bool isEqual(const Fortran::evaluate::ComplexPart &x, const Fortran::evaluate::ComplexPart &y) { return isEqual(x.complex(), y.complex()) && x.part() == y.part(); } template static bool isEqual(const Fortran::evaluate::Convert &x, const Fortran::evaluate::Convert &y) { return isEqual(x.left(), y.left()); } template static bool isEqual(const Fortran::evaluate::ComplexComponent &x, const Fortran::evaluate::ComplexComponent &y) { return isEqual(x.left(), y.left()) && x.isImaginaryPart == y.isImaginaryPart; } template static bool isEqual(const Fortran::evaluate::Parentheses &x, const Fortran::evaluate::Parentheses &y) { return isEqual(x.left(), y.left()); } template static bool isEqual(const Fortran::evaluate::Negate &x, const Fortran::evaluate::Negate &y) { return isEqual(x.left(), y.left()); } template static bool isBinaryEqual(const A &x, const A &y) { return isEqual(x.left(), y.left()) && isEqual(x.right(), y.right()); } template static bool isEqual(const Fortran::evaluate::Add &x, const Fortran::evaluate::Add &y) { return isBinaryEqual(x, y); } template static bool isEqual(const Fortran::evaluate::Subtract &x, const Fortran::evaluate::Subtract &y) { return isBinaryEqual(x, y); } template static bool isEqual(const Fortran::evaluate::Multiply &x, const Fortran::evaluate::Multiply &y) { return isBinaryEqual(x, y); } template static bool isEqual(const Fortran::evaluate::Divide &x, const Fortran::evaluate::Divide &y) { return isBinaryEqual(x, y); } template static bool isEqual(const Fortran::evaluate::Power &x, const Fortran::evaluate::Power &y) { return isBinaryEqual(x, y); } template static bool isEqual(const Fortran::evaluate::Extremum &x, const Fortran::evaluate::Extremum &y) { return isBinaryEqual(x, y); } template static bool isEqual(const Fortran::evaluate::RealToIntPower &x, const Fortran::evaluate::RealToIntPower &y) { return isBinaryEqual(x, y); } template static bool isEqual(const Fortran::evaluate::ComplexConstructor &x, const Fortran::evaluate::ComplexConstructor &y) { return isBinaryEqual(x, y); } template static bool isEqual(const Fortran::evaluate::Concat &x, const Fortran::evaluate::Concat &y) { return isBinaryEqual(x, y); } template static bool isEqual(const Fortran::evaluate::SetLength &x, const Fortran::evaluate::SetLength &y) { return isBinaryEqual(x, y); } static bool isEqual(const Fortran::semantics::SymbolRef &x, const Fortran::semantics::SymbolRef &y) { return isEqual(x.get(), y.get()); } static bool isEqual(const Fortran::evaluate::Substring &x, const Fortran::evaluate::Substring &y) { return std::visit( [&](const auto &p, const auto &q) { return isEqual(p, q); }, x.parent(), y.parent()) && isEqual(x.lower(), y.lower()) && isEqual(x.lower(), y.lower()); } static bool isEqual(const Fortran::evaluate::StaticDataObject::Pointer &x, const Fortran::evaluate::StaticDataObject::Pointer &y) { return x->name() == y->name(); } static bool isEqual(const Fortran::evaluate::SpecificIntrinsic &x, const Fortran::evaluate::SpecificIntrinsic &y) { return x.name == y.name; } template static bool isEqual(const Fortran::evaluate::Constant &x, const Fortran::evaluate::Constant &y) { return x == y; } static bool isEqual(const Fortran::evaluate::ActualArgument &x, const Fortran::evaluate::ActualArgument &y) { if (const Fortran::evaluate::Symbol *xs = x.GetAssumedTypeDummy()) { if (const Fortran::evaluate::Symbol *ys = y.GetAssumedTypeDummy()) return isEqual(*xs, *ys); return false; } return !y.GetAssumedTypeDummy() && isEqual(*x.UnwrapExpr(), *y.UnwrapExpr()); } static bool isEqual(const Fortran::evaluate::ProcedureDesignator &x, const Fortran::evaluate::ProcedureDesignator &y) { return std::visit( [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u); } static bool isEqual(const Fortran::evaluate::ProcedureRef &x, const Fortran::evaluate::ProcedureRef &y) { return isEqual(x.proc(), y.proc()) && isEqual(x.arguments(), y.arguments()); } template static bool isEqual(const Fortran::evaluate::ArrayConstructor &x, const Fortran::evaluate::ArrayConstructor &y) { llvm::report_fatal_error("not implemented"); } static bool isEqual(const Fortran::evaluate::ImpliedDoIndex &x, const Fortran::evaluate::ImpliedDoIndex &y) { return toStringRef(x.name) == toStringRef(y.name); } static bool isEqual(const Fortran::evaluate::TypeParamInquiry &x, const Fortran::evaluate::TypeParamInquiry &y) { return isEqual(x.base(), y.base()) && isEqual(x.parameter(), y.parameter()); } static bool isEqual(const Fortran::evaluate::DescriptorInquiry &x, const Fortran::evaluate::DescriptorInquiry &y) { return isEqual(x.base(), y.base()) && x.field() == y.field() && x.dimension() == y.dimension(); } static bool isEqual(const Fortran::evaluate::StructureConstructor &x, const Fortran::evaluate::StructureConstructor &y) { llvm::report_fatal_error("not implemented"); } template static bool isEqual(const Fortran::evaluate::Not &x, const Fortran::evaluate::Not &y) { return isEqual(x.left(), y.left()); } template static bool isEqual(const Fortran::evaluate::LogicalOperation &x, const Fortran::evaluate::LogicalOperation &y) { return isEqual(x.left(), y.left()) && isEqual(x.right(), x.right()); } template static bool isEqual(const Fortran::evaluate::Relational &x, const Fortran::evaluate::Relational &y) { return isEqual(x.left(), y.left()) && isEqual(x.right(), y.right()); } template static bool isEqual(const Fortran::evaluate::Expr &x, const Fortran::evaluate::Expr &y) { return std::visit( [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u); } static bool isEqual(const Fortran::evaluate::Relational &x, const Fortran::evaluate::Relational &y) { return std::visit( [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u); } template static bool isEqual(const Fortran::evaluate::Designator &x, const Fortran::evaluate::Designator &y) { return std::visit( [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u); } template static bool isEqual(const Fortran::evaluate::value::Integer &x, const Fortran::evaluate::value::Integer &y) { return x == y; } static bool isEqual(const Fortran::evaluate::NullPointer &x, const Fortran::evaluate::NullPointer &y) { return true; } template , bool> = true> static bool isEqual(const A &, const B &) { return false; } }; } // namespace bool Fortran::lower::isEqual( const Fortran::lower::ExplicitIterSpace::ArrayBases &x, const Fortran::lower::ExplicitIterSpace::ArrayBases &y) { return std::visit( Fortran::common::visitors{ // Fortran::semantics::Symbol * are the exception here. These pointers // have identity; if two Symbol * values are the same (different) then // they are the same (different) logical symbol. [&](Fortran::lower::FrontEndSymbol p, Fortran::lower::FrontEndSymbol q) { return p == q; }, [&](const auto *p, const auto *q) { if constexpr (std::is_same_v) { LLVM_DEBUG(llvm::dbgs() << "is equal: " << p << ' ' << q << ' ' << IsEqualEvaluateExpr::isEqual(*p, *q) << '\n'); return IsEqualEvaluateExpr::isEqual(*p, *q); } else { // Different subtree types are never equal. return false; } }}, x, y); } bool Fortran::lower::isEqual(Fortran::lower::FrontEndExpr x, Fortran::lower::FrontEndExpr y) { auto empty = llvm::DenseMapInfo::getEmptyKey(); auto tombstone = llvm::DenseMapInfo::getTombstoneKey(); if (x == empty || y == empty || x == tombstone || y == tombstone) return x == y; return x == y || IsEqualEvaluateExpr::isEqual(*x, *y); } namespace { /// This class can recover the base array in an expression that contains /// explicit iteration space symbols. Most of the class can be ignored as it is /// boilerplate Fortran::evaluate::Expr traversal. class ArrayBaseFinder { public: using RT = bool; ArrayBaseFinder(llvm::ArrayRef syms) : controlVars(syms.begin(), syms.end()) {} template void operator()(const T &x) { (void)find(x); } /// Get the list of bases. llvm::ArrayRef getBases() const { LLVM_DEBUG(llvm::dbgs() << "number of array bases found: " << bases.size() << '\n'); return bases; } private: // First, the cases that are of interest. RT find(const Fortran::semantics::Symbol &symbol) { if (symbol.Rank() > 0) { bases.push_back(&symbol); return true; } return {}; } RT find(const Fortran::evaluate::Component &x) { auto found = find(x.base()); if (!found && x.base().Rank() == 0 && x.Rank() > 0) { bases.push_back(&x); return true; } return found; } RT find(const Fortran::evaluate::ArrayRef &x) { for (const auto &sub : x.subscript()) (void)find(sub); if (x.base().IsSymbol()) { if (x.Rank() > 0 || intersection(x.subscript())) { bases.push_back(&x); return true; } return {}; } auto found = find(x.base()); if (!found && ((x.base().Rank() == 0 && x.Rank() > 0) || intersection(x.subscript()))) { bases.push_back(&x); return true; } return found; } RT find(const Fortran::evaluate::Triplet &x) { if (const auto *lower = x.GetLower()) (void)find(*lower); if (const auto *upper = x.GetUpper()) (void)find(*upper); return find(x.GetStride()); } RT find(const Fortran::evaluate::IndirectSubscriptIntegerExpr &x) { return find(x.value()); } RT find(const Fortran::evaluate::Subscript &x) { return find(x.u); } RT find(const Fortran::evaluate::DataRef &x) { return find(x.u); } RT find(const Fortran::evaluate::CoarrayRef &x) { assert(false && "coarray reference"); return {}; } template bool intersection(const A &subscripts) { return Fortran::lower::symbolsIntersectSubscripts(controlVars, subscripts); } // The rest is traversal boilerplate and can be ignored. RT find(const Fortran::evaluate::Substring &x) { return find(x.parent()); } template RT find(const Fortran::semantics::SymbolRef x) { return find(*x); } RT find(const Fortran::evaluate::NamedEntity &x) { if (x.IsSymbol()) return find(x.GetFirstSymbol()); return find(x.GetComponent()); } template RT find(const Fortran::common::Indirection &x) { return find(x.value()); } template RT find(const std::unique_ptr &x) { return find(x.get()); } template RT find(const std::shared_ptr &x) { return find(x.get()); } template RT find(const A *x) { if (x) return find(*x); return {}; } template RT find(const std::optional &x) { if (x) return find(*x); return {}; } template RT find(const std::variant &u) { return std::visit([&](const auto &v) { return find(v); }, u); } template RT find(const std::vector &x) { for (auto &v : x) (void)find(v); return {}; } RT find(const Fortran::evaluate::BOZLiteralConstant &) { return {}; } RT find(const Fortran::evaluate::NullPointer &) { return {}; } template RT find(const Fortran::evaluate::Constant &x) { return {}; } RT find(const Fortran::evaluate::StaticDataObject &) { return {}; } RT find(const Fortran::evaluate::ImpliedDoIndex &) { return {}; } RT find(const Fortran::evaluate::BaseObject &x) { (void)find(x.u); return {}; } RT find(const Fortran::evaluate::TypeParamInquiry &) { return {}; } RT find(const Fortran::evaluate::ComplexPart &x) { return {}; } template RT find(const Fortran::evaluate::Designator &x) { return find(x.u); } template RT find(const Fortran::evaluate::Variable &x) { return find(x.u); } RT find(const Fortran::evaluate::DescriptorInquiry &) { return {}; } RT find(const Fortran::evaluate::SpecificIntrinsic &) { return {}; } RT find(const Fortran::evaluate::ProcedureDesignator &x) { return {}; } RT find(const Fortran::evaluate::ProcedureRef &x) { (void)find(x.proc()); if (x.IsElemental()) (void)find(x.arguments()); return {}; } RT find(const Fortran::evaluate::ActualArgument &x) { if (const auto *sym = x.GetAssumedTypeDummy()) (void)find(*sym); else (void)find(x.UnwrapExpr()); return {}; } template RT find(const Fortran::evaluate::FunctionRef &x) { (void)find(static_cast(x)); return {}; } template RT find(const Fortran::evaluate::ArrayConstructorValue &) { return {}; } template RT find(const Fortran::evaluate::ArrayConstructorValues &) { return {}; } template RT find(const Fortran::evaluate::ImpliedDo &) { return {}; } RT find(const Fortran::semantics::ParamValue &) { return {}; } RT find(const Fortran::semantics::DerivedTypeSpec &) { return {}; } RT find(const Fortran::evaluate::StructureConstructor &) { return {}; } template RT find(const Fortran::evaluate::Operation &op) { (void)find(op.left()); return false; } template RT find(const Fortran::evaluate::Operation &op) { (void)find(op.left()); (void)find(op.right()); return false; } RT find(const Fortran::evaluate::Relational &x) { (void)find(x.u); return {}; } template RT find(const Fortran::evaluate::Expr &x) { (void)find(x.u); return {}; } llvm::SmallVector bases; llvm::SmallVector controlVars; }; } // namespace void Fortran::lower::ExplicitIterSpace::leave() { ccLoopNest.pop_back(); --forallContextOpen; conditionalCleanup(); } void Fortran::lower::ExplicitIterSpace::addSymbol( Fortran::lower::FrontEndSymbol sym) { assert(!symbolStack.empty()); symbolStack.back().push_back(sym); } void Fortran::lower::ExplicitIterSpace::exprBase(Fortran::lower::FrontEndExpr x, bool lhs) { ArrayBaseFinder finder(collectAllSymbols()); finder(*x); llvm::ArrayRef bases = finder.getBases(); if (rhsBases.empty()) endAssign(); if (lhs) { if (bases.empty()) { lhsBases.push_back(std::nullopt); return; } assert(bases.size() >= 1 && "must detect an array reference on lhs"); if (bases.size() > 1) rhsBases.back().append(bases.begin(), bases.end() - 1); lhsBases.push_back(bases.back()); return; } rhsBases.back().append(bases.begin(), bases.end()); } void Fortran::lower::ExplicitIterSpace::endAssign() { rhsBases.emplace_back(); } void Fortran::lower::ExplicitIterSpace::pushLevel() { symbolStack.push_back(llvm::SmallVector{}); } void Fortran::lower::ExplicitIterSpace::popLevel() { symbolStack.pop_back(); } void Fortran::lower::ExplicitIterSpace::conditionalCleanup() { if (forallContextOpen == 0) { // Exiting the outermost FORALL context. // Cleanup any residual mask buffers. outermostContext().finalize(); // Clear and reset all the cached information. symbolStack.clear(); lhsBases.clear(); rhsBases.clear(); loadBindings.clear(); ccLoopNest.clear(); innerArgs.clear(); outerLoop = std::nullopt; clearLoops(); counter = 0; } } std::optional Fortran::lower::ExplicitIterSpace::findArgPosition(fir::ArrayLoadOp load) { if (lhsBases[counter]) { auto ld = loadBindings.find(*lhsBases[counter]); std::optional optPos; if (ld != loadBindings.end() && ld->second == load) optPos = static_cast(0u); assert(optPos.has_value() && "load does not correspond to lhs"); return optPos; } return std::nullopt; } llvm::SmallVector Fortran::lower::ExplicitIterSpace::collectAllSymbols() { llvm::SmallVector result; for (llvm::SmallVector vec : symbolStack) result.append(vec.begin(), vec.end()); return result; } llvm::raw_ostream & Fortran::lower::operator<<(llvm::raw_ostream &s, const Fortran::lower::ImplicitIterSpace &e) { for (const llvm::SmallVector< Fortran::lower::ImplicitIterSpace::FrontEndMaskExpr> &xs : e.getMasks()) { s << "{ "; for (const Fortran::lower::ImplicitIterSpace::FrontEndMaskExpr &x : xs) x->AsFortran(s << '(') << "), "; s << "}\n"; } return s; } llvm::raw_ostream & Fortran::lower::operator<<(llvm::raw_ostream &s, const Fortran::lower::ExplicitIterSpace &e) { auto dump = [&](const auto &u) { std::visit(Fortran::common::visitors{ [&](const Fortran::semantics::Symbol *y) { s << " " << *y << '\n'; }, [&](const Fortran::evaluate::ArrayRef *y) { s << " "; if (y->base().IsSymbol()) s << y->base().GetFirstSymbol(); else s << y->base().GetComponent().GetLastSymbol(); s << '\n'; }, [&](const Fortran::evaluate::Component *y) { s << " " << y->GetLastSymbol() << '\n'; }}, u); }; s << "LHS bases:\n"; for (const std::optional &u : e.lhsBases) if (u) dump(*u); s << "RHS bases:\n"; for (const llvm::SmallVector &bases : e.rhsBases) { for (const Fortran::lower::ExplicitIterSpace::ArrayBases &u : bases) dump(u); s << '\n'; } return s; } void Fortran::lower::ImplicitIterSpace::dump() const { llvm::errs() << *this << '\n'; } void Fortran::lower::ExplicitIterSpace::dump() const { llvm::errs() << *this << '\n'; }