Files
clang-p2996/flang/lib/Optimizer/Transforms/OMPMarkDeclareTarget.cpp
Andrew Gozillon 062fce6f4d [Flang][OpenMP][MLIR] An mlir transformation pass for marking FuncOp's implicitly called from TargetOp's and declare target marked FuncOp's as implicitly declare target
This pass will mark functions called from TargetOp's
and declare target functions as implicitly declare
target by adding the MLIR declare target attribute
directly to the function.

This pass executes after the initial lowering of Fortran's PFT
to MLIR (FIR/OMP+Arith etc.) and is one of a series of passes
that aim to clean up the MLIR for offloading (seperate passes
in different patches, one for early outlining, another for declare
target function filtering).

Reviewers: jsjodin, skatrak, kiaranchandramohan

Differential Revision: https://reviews.llvm.org/D154247
2023-07-17 08:32:26 -05:00

98 lines
3.7 KiB
C++

#include "flang/Optimizer/Transforms/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/SmallPtrSet.h"
namespace fir {
#define GEN_PASS_DEF_OMPMARKDECLARETARGETPASS
#include "flang/Optimizer/Transforms/Passes.h.inc"
} // namespace fir
namespace {
class OMPMarkDeclareTargetPass
: public fir::impl::OMPMarkDeclareTargetPassBase<OMPMarkDeclareTargetPass> {
void markNestedFuncs(mlir::omp::DeclareTargetDeviceType parentDevTy,
mlir::omp::DeclareTargetCaptureClause parentCapClause,
mlir::Operation *currOp,
llvm::SmallPtrSet<mlir::Operation *, 16> visited) {
if (visited.contains(currOp))
return;
visited.insert(currOp);
currOp->walk([&, this](mlir::Operation *op) {
if (auto callOp = llvm::dyn_cast<mlir::CallOpInterface>(op)) {
if (auto symRef = llvm::dyn_cast_if_present<mlir::SymbolRefAttr>(
callOp.getCallableForCallee())) {
if (auto currFOp =
getOperation().lookupSymbol<mlir::func::FuncOp>(symRef)) {
auto current = llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
currFOp.getOperation());
if (current.isDeclareTarget()) {
auto currentDt = current.getDeclareTargetDeviceType();
// Found the same function twice, with different device_types,
// mark as Any as it belongs to both
if (currentDt != parentDevTy &&
currentDt != mlir::omp::DeclareTargetDeviceType::any) {
current.setDeclareTarget(
mlir::omp::DeclareTargetDeviceType::any,
current.getDeclareTargetCaptureClause());
}
} else {
current.setDeclareTarget(parentDevTy, parentCapClause);
}
markNestedFuncs(parentDevTy, parentCapClause, currFOp, visited);
}
}
}
});
}
// This pass executes on mlir::ModuleOp's marking functions contained within
// as implicitly declare target if they are called from within an explicitly
// marked declare target function or a target region (TargetOp)
void runOnOperation() override {
for (auto functionOp : getOperation().getOps<mlir::func::FuncOp>()) {
auto declareTargetOp = llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
functionOp.getOperation());
if (declareTargetOp.isDeclareTarget()) {
llvm::SmallPtrSet<mlir::Operation *, 16> visited;
markNestedFuncs(declareTargetOp.getDeclareTargetDeviceType(),
declareTargetOp.getDeclareTargetCaptureClause(),
functionOp, visited);
}
}
// TODO: Extend to work with reverse-offloading, this shouldn't
// require too much effort, just need to check the device clause
// when it's lowering has been implemented and change the
// DeclareTargetDeviceType argument from nohost to host depending on
// the contents of the device clause
getOperation()->walk([&](mlir::omp::TargetOp tarOp) {
llvm::SmallPtrSet<mlir::Operation *, 16> visited;
markNestedFuncs(mlir::omp::DeclareTargetDeviceType::nohost,
mlir::omp::DeclareTargetCaptureClause::to, tarOp,
visited);
});
}
};
} // namespace
namespace fir {
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
createOMPMarkDeclareTargetPass() {
return std::make_unique<OMPMarkDeclareTargetPass>();
}
} // namespace fir