306 lines
13 KiB
C++
306 lines
13 KiB
C++
//===-- Lower/OpenMP/ClauseProcessor.h --------------------------*- C++ -*-===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
#ifndef FORTRAN_LOWER_CLAUASEPROCESSOR_H
|
|
#define FORTRAN_LOWER_CLAUASEPROCESSOR_H
|
|
|
|
#include "Clauses.h"
|
|
#include "DirectivesCommon.h"
|
|
#include "ReductionProcessor.h"
|
|
#include "Utils.h"
|
|
#include "flang/Lower/AbstractConverter.h"
|
|
#include "flang/Lower/Bridge.h"
|
|
#include "flang/Optimizer/Builder/Todo.h"
|
|
#include "flang/Parser/dump-parse-tree.h"
|
|
#include "flang/Parser/parse-tree.h"
|
|
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
|
|
|
|
namespace fir {
|
|
class FirOpBuilder;
|
|
} // namespace fir
|
|
|
|
namespace Fortran {
|
|
namespace lower {
|
|
namespace omp {
|
|
|
|
/// Class that handles the processing of OpenMP clauses.
|
|
///
|
|
/// Its `process<ClauseName>()` methods perform MLIR code generation for their
|
|
/// corresponding clause if it is present in the clause list. Otherwise, they
|
|
/// will return `false` to signal that the clause was not found.
|
|
///
|
|
/// The intended use is of this class is to move clause processing outside of
|
|
/// construct processing, since the same clauses can appear attached to
|
|
/// different constructs and constructs can be combined, so that code
|
|
/// duplication is minimized.
|
|
///
|
|
/// Each construct-lowering function only calls the `process<ClauseName>()`
|
|
/// methods that relate to clauses that can impact the lowering of that
|
|
/// construct.
|
|
class ClauseProcessor {
|
|
public:
|
|
ClauseProcessor(Fortran::lower::AbstractConverter &converter,
|
|
Fortran::semantics::SemanticsContext &semaCtx,
|
|
const Fortran::parser::OmpClauseList &clauses)
|
|
: converter(converter), semaCtx(semaCtx),
|
|
clauses(makeClauses(clauses, semaCtx)) {}
|
|
|
|
// 'Unique' clauses: They can appear at most once in the clause list.
|
|
bool processCollapse(
|
|
mlir::Location currentLocation, Fortran::lower::pft::Evaluation &eval,
|
|
llvm::SmallVectorImpl<mlir::Value> &lowerBound,
|
|
llvm::SmallVectorImpl<mlir::Value> &upperBound,
|
|
llvm::SmallVectorImpl<mlir::Value> &step,
|
|
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &iv) const;
|
|
bool processDefault() const;
|
|
bool processDevice(Fortran::lower::StatementContext &stmtCtx,
|
|
mlir::Value &result) const;
|
|
bool processDeviceType(mlir::omp::DeclareTargetDeviceType &result) const;
|
|
bool processFinal(Fortran::lower::StatementContext &stmtCtx,
|
|
mlir::Value &result) const;
|
|
bool processHint(mlir::IntegerAttr &result) const;
|
|
bool processMergeable(mlir::UnitAttr &result) const;
|
|
bool processNowait(mlir::UnitAttr &result) const;
|
|
bool processNumTeams(Fortran::lower::StatementContext &stmtCtx,
|
|
mlir::Value &result) const;
|
|
bool processNumThreads(Fortran::lower::StatementContext &stmtCtx,
|
|
mlir::Value &result) const;
|
|
bool processOrdered(mlir::IntegerAttr &result) const;
|
|
bool processPriority(Fortran::lower::StatementContext &stmtCtx,
|
|
mlir::Value &result) const;
|
|
bool processProcBind(mlir::omp::ClauseProcBindKindAttr &result) const;
|
|
bool processSafelen(mlir::IntegerAttr &result) const;
|
|
bool processSchedule(mlir::omp::ClauseScheduleKindAttr &valAttr,
|
|
mlir::omp::ScheduleModifierAttr &modifierAttr,
|
|
mlir::UnitAttr &simdModifierAttr) const;
|
|
bool processScheduleChunk(Fortran::lower::StatementContext &stmtCtx,
|
|
mlir::Value &result) const;
|
|
bool processSimdlen(mlir::IntegerAttr &result) const;
|
|
bool processThreadLimit(Fortran::lower::StatementContext &stmtCtx,
|
|
mlir::Value &result) const;
|
|
bool processUntied(mlir::UnitAttr &result) const;
|
|
|
|
// 'Repeatable' clauses: They can appear multiple times in the clause list.
|
|
bool
|
|
processAllocate(llvm::SmallVectorImpl<mlir::Value> &allocatorOperands,
|
|
llvm::SmallVectorImpl<mlir::Value> &allocateOperands) const;
|
|
bool processCopyin() const;
|
|
bool processCopyPrivate(
|
|
mlir::Location currentLocation,
|
|
llvm::SmallVectorImpl<mlir::Value> ©PrivateVars,
|
|
llvm::SmallVectorImpl<mlir::Attribute> ©PrivateFuncs) const;
|
|
bool processDepend(llvm::SmallVectorImpl<mlir::Attribute> &dependTypeOperands,
|
|
llvm::SmallVectorImpl<mlir::Value> &dependOperands) const;
|
|
bool
|
|
processEnter(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
|
|
bool processIf(omp::clause::If::DirectiveNameModifier directiveName,
|
|
mlir::Value &result) const;
|
|
bool
|
|
processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
|
|
|
|
// This method is used to process a map clause.
|
|
// The optional parameters - mapSymTypes, mapSymLocs & mapSymbols are used to
|
|
// store the original type, location and Fortran symbol for the map operands.
|
|
// They may be used later on to create the block_arguments for some of the
|
|
// target directives that require it.
|
|
bool processMap(mlir::Location currentLocation,
|
|
const llvm::omp::Directive &directive,
|
|
Fortran::lower::StatementContext &stmtCtx,
|
|
llvm::SmallVectorImpl<mlir::Value> &mapOperands,
|
|
llvm::SmallVectorImpl<mlir::Type> *mapSymTypes = nullptr,
|
|
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs = nullptr,
|
|
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
|
|
*mapSymbols = nullptr) const;
|
|
bool
|
|
processReduction(mlir::Location currentLocation,
|
|
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
|
|
llvm::SmallVectorImpl<mlir::Type> &reductionTypes,
|
|
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
|
|
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
|
|
*reductionSymbols = nullptr) const;
|
|
bool processSectionsReduction(mlir::Location currentLocation) const;
|
|
bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
|
|
bool
|
|
processUseDeviceAddr(llvm::SmallVectorImpl<mlir::Value> &operands,
|
|
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
|
|
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
|
|
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
|
|
&useDeviceSymbols) const;
|
|
bool
|
|
processUseDevicePtr(llvm::SmallVectorImpl<mlir::Value> &operands,
|
|
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
|
|
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
|
|
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
|
|
&useDeviceSymbols) const;
|
|
|
|
template <typename T>
|
|
bool processMotionClauses(Fortran::lower::StatementContext &stmtCtx,
|
|
llvm::SmallVectorImpl<mlir::Value> &mapOperands);
|
|
|
|
// Call this method for these clauses that should be supported but are not
|
|
// implemented yet. It triggers a compilation error if any of the given
|
|
// clauses is found.
|
|
template <typename... Ts>
|
|
void processTODO(mlir::Location currentLocation,
|
|
llvm::omp::Directive directive) const;
|
|
|
|
private:
|
|
using ClauseIterator = List<Clause>::const_iterator;
|
|
|
|
/// Utility to find a clause within a range in the clause list.
|
|
template <typename T>
|
|
static ClauseIterator findClause(ClauseIterator begin, ClauseIterator end);
|
|
|
|
/// Return the first instance of the given clause found in the clause list or
|
|
/// `nullptr` if not present. If more than one instance is expected, use
|
|
/// `findRepeatableClause` instead.
|
|
template <typename T>
|
|
const T *
|
|
findUniqueClause(const Fortran::parser::CharBlock **source = nullptr) const;
|
|
|
|
/// Call `callbackFn` for each occurrence of the given clause. Return `true`
|
|
/// if at least one instance was found.
|
|
template <typename T>
|
|
bool findRepeatableClause(
|
|
std::function<void(const T &, const Fortran::parser::CharBlock &source)>
|
|
callbackFn) const;
|
|
|
|
/// Set the `result` to a new `mlir::UnitAttr` if the clause is present.
|
|
template <typename T>
|
|
bool markClauseOccurrence(mlir::UnitAttr &result) const;
|
|
|
|
Fortran::lower::AbstractConverter &converter;
|
|
Fortran::semantics::SemanticsContext &semaCtx;
|
|
List<Clause> clauses;
|
|
};
|
|
|
|
template <typename T>
|
|
bool ClauseProcessor::processMotionClauses(
|
|
Fortran::lower::StatementContext &stmtCtx,
|
|
llvm::SmallVectorImpl<mlir::Value> &mapOperands) {
|
|
return findRepeatableClause<T>(
|
|
[&](const T &clause, const Fortran::parser::CharBlock &source) {
|
|
mlir::Location clauseLocation = converter.genLocation(source);
|
|
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
|
|
|
|
static_assert(std::is_same_v<T, omp::clause::To> ||
|
|
std::is_same_v<T, omp::clause::From>);
|
|
|
|
// TODO Support motion modifiers: present, mapper, iterator.
|
|
constexpr llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
|
|
std::is_same_v<T, omp::clause::To>
|
|
? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO
|
|
: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
|
|
|
|
auto &objects = std::get<ObjectList>(clause.t);
|
|
for (const omp::Object &object : objects) {
|
|
llvm::SmallVector<mlir::Value> bounds;
|
|
std::stringstream asFortran;
|
|
Fortran::lower::AddrAndBoundsInfo info =
|
|
Fortran::lower::gatherDataOperandAddrAndBounds<
|
|
mlir::omp::MapBoundsOp, mlir::omp::MapBoundsType>(
|
|
converter, firOpBuilder, semaCtx, stmtCtx, *object.id(),
|
|
object.ref(), clauseLocation, asFortran, bounds,
|
|
treatIndexAsSection);
|
|
|
|
auto origSymbol = converter.getSymbolAddress(*object.id());
|
|
mlir::Value symAddr = info.addr;
|
|
if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType()))
|
|
symAddr = origSymbol;
|
|
|
|
// Explicit map captures are captured ByRef by default,
|
|
// optimisation passes may alter this to ByCopy or other capture
|
|
// types to optimise
|
|
mlir::Value mapOp = createMapInfoOp(
|
|
firOpBuilder, clauseLocation, symAddr, mlir::Value{},
|
|
asFortran.str(), bounds, {},
|
|
static_cast<
|
|
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
|
|
mapTypeBits),
|
|
mlir::omp::VariableCaptureKind::ByRef, symAddr.getType());
|
|
|
|
mapOperands.push_back(mapOp);
|
|
}
|
|
});
|
|
}
|
|
|
|
template <typename... Ts>
|
|
void ClauseProcessor::processTODO(mlir::Location currentLocation,
|
|
llvm::omp::Directive directive) const {
|
|
auto checkUnhandledClause = [&](llvm::omp::Clause id, const auto *x) {
|
|
if (!x)
|
|
return;
|
|
TODO(currentLocation,
|
|
"Unhandled clause " + llvm::omp::getOpenMPClauseName(id).upper() +
|
|
" in " + llvm::omp::getOpenMPDirectiveName(directive).upper() +
|
|
" construct");
|
|
};
|
|
|
|
for (ClauseIterator it = clauses.begin(); it != clauses.end(); ++it)
|
|
(checkUnhandledClause(it->id, std::get_if<Ts>(&it->u)), ...);
|
|
}
|
|
|
|
template <typename T>
|
|
ClauseProcessor::ClauseIterator
|
|
ClauseProcessor::findClause(ClauseIterator begin, ClauseIterator end) {
|
|
for (ClauseIterator it = begin; it != end; ++it) {
|
|
if (std::get_if<T>(&it->u))
|
|
return it;
|
|
}
|
|
|
|
return end;
|
|
}
|
|
|
|
template <typename T>
|
|
const T *ClauseProcessor::findUniqueClause(
|
|
const Fortran::parser::CharBlock **source) const {
|
|
ClauseIterator it = findClause<T>(clauses.begin(), clauses.end());
|
|
if (it != clauses.end()) {
|
|
if (source)
|
|
*source = &it->source;
|
|
return &std::get<T>(it->u);
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
template <typename T>
|
|
bool ClauseProcessor::findRepeatableClause(
|
|
std::function<void(const T &, const Fortran::parser::CharBlock &source)>
|
|
callbackFn) const {
|
|
bool found = false;
|
|
ClauseIterator nextIt, endIt = clauses.end();
|
|
for (ClauseIterator it = clauses.begin(); it != endIt; it = nextIt) {
|
|
nextIt = findClause<T>(it, endIt);
|
|
|
|
if (nextIt != endIt) {
|
|
callbackFn(std::get<T>(nextIt->u), nextIt->source);
|
|
found = true;
|
|
++nextIt;
|
|
}
|
|
}
|
|
return found;
|
|
}
|
|
|
|
template <typename T>
|
|
bool ClauseProcessor::markClauseOccurrence(mlir::UnitAttr &result) const {
|
|
if (findUniqueClause<T>()) {
|
|
result = converter.getFirOpBuilder().getUnitAttr();
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
} // namespace omp
|
|
} // namespace lower
|
|
} // namespace Fortran
|
|
|
|
#endif // FORTRAN_LOWER_CLAUASEPROCESSOR_H
|