Introduces a new conversion pass that rewrites `omp.loop` ops to their semantically equivalent op nests bases on the surrounding/binding context of the `loop` op. Not all forms of `omp.loop` are supported yet. See `isLoopConversionSupported` for more info on which forms are supported.
182 lines
6.2 KiB
C++
182 lines
6.2 KiB
C++
//===- GenericLoopConversion.cpp ------------------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "flang/Common/OpenMP-utils.h"
|
|
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
|
|
#include "mlir/IR/IRMapping.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
|
|
#include <memory>
|
|
|
|
namespace flangomp {
|
|
#define GEN_PASS_DEF_GENERICLOOPCONVERSIONPASS
|
|
#include "flang/Optimizer/OpenMP/Passes.h.inc"
|
|
} // namespace flangomp
|
|
|
|
namespace {
|
|
|
|
/// A conversion pattern to handle various combined forms of `omp.loop`. For how
|
|
/// combined/composite directive are handled see:
|
|
/// https://discourse.llvm.org/t/rfc-representing-combined-composite-constructs-in-the-openmp-dialect/76986.
|
|
class GenericLoopConversionPattern
|
|
: public mlir::OpConversionPattern<mlir::omp::LoopOp> {
|
|
public:
|
|
enum class GenericLoopCombinedInfo {
|
|
None,
|
|
TargetTeamsLoop,
|
|
TargetParallelLoop
|
|
};
|
|
|
|
using mlir::OpConversionPattern<mlir::omp::LoopOp>::OpConversionPattern;
|
|
|
|
mlir::LogicalResult
|
|
matchAndRewrite(mlir::omp::LoopOp loopOp, OpAdaptor adaptor,
|
|
mlir::ConversionPatternRewriter &rewriter) const override {
|
|
assert(mlir::succeeded(checkLoopConversionSupportStatus(loopOp)));
|
|
|
|
rewriteToDistributeParallelDo(loopOp, rewriter);
|
|
rewriter.eraseOp(loopOp);
|
|
return mlir::success();
|
|
}
|
|
|
|
static mlir::LogicalResult
|
|
checkLoopConversionSupportStatus(mlir::omp::LoopOp loopOp) {
|
|
GenericLoopCombinedInfo combinedInfo = findGenericLoopCombineInfo(loopOp);
|
|
|
|
switch (combinedInfo) {
|
|
case GenericLoopCombinedInfo::None:
|
|
return loopOp.emitError(
|
|
"not yet implemented: Standalone `omp loop` directive");
|
|
case GenericLoopCombinedInfo::TargetParallelLoop:
|
|
return loopOp.emitError(
|
|
"not yet implemented: Combined `omp target parallel loop` directive");
|
|
case GenericLoopCombinedInfo::TargetTeamsLoop:
|
|
break;
|
|
}
|
|
|
|
auto todo = [&loopOp](mlir::StringRef clauseName) {
|
|
return loopOp.emitError()
|
|
<< "not yet implemented: Unhandled clause " << clauseName << " in "
|
|
<< loopOp->getName() << " operation";
|
|
};
|
|
|
|
if (loopOp.getBindKind())
|
|
return todo("bind");
|
|
|
|
if (loopOp.getOrder())
|
|
return todo("order");
|
|
|
|
if (!loopOp.getReductionVars().empty())
|
|
return todo("reduction");
|
|
|
|
// TODO For `target teams loop`, check similar constrains to what is checked
|
|
// by `TeamsLoopChecker` in SemaOpenMP.cpp.
|
|
return mlir::success();
|
|
}
|
|
|
|
private:
|
|
static GenericLoopCombinedInfo
|
|
findGenericLoopCombineInfo(mlir::omp::LoopOp loopOp) {
|
|
mlir::Operation *parentOp = loopOp->getParentOp();
|
|
GenericLoopCombinedInfo result = GenericLoopCombinedInfo::None;
|
|
|
|
if (auto teamsOp = mlir::dyn_cast_if_present<mlir::omp::TeamsOp>(parentOp))
|
|
if (mlir::isa_and_present<mlir::omp::TargetOp>(teamsOp->getParentOp()))
|
|
result = GenericLoopCombinedInfo::TargetTeamsLoop;
|
|
|
|
if (auto parallelOp =
|
|
mlir::dyn_cast_if_present<mlir::omp::ParallelOp>(parentOp))
|
|
if (mlir::isa_and_present<mlir::omp::TargetOp>(parallelOp->getParentOp()))
|
|
result = GenericLoopCombinedInfo::TargetParallelLoop;
|
|
|
|
return result;
|
|
}
|
|
|
|
void rewriteToDistributeParallelDo(
|
|
mlir::omp::LoopOp loopOp,
|
|
mlir::ConversionPatternRewriter &rewriter) const {
|
|
mlir::omp::ParallelOperands parallelClauseOps;
|
|
parallelClauseOps.privateVars = loopOp.getPrivateVars();
|
|
|
|
auto privateSyms = loopOp.getPrivateSyms();
|
|
if (privateSyms)
|
|
parallelClauseOps.privateSyms.assign(privateSyms->begin(),
|
|
privateSyms->end());
|
|
|
|
Fortran::common::openmp::EntryBlockArgs parallelArgs;
|
|
parallelArgs.priv.vars = parallelClauseOps.privateVars;
|
|
|
|
auto parallelOp = rewriter.create<mlir::omp::ParallelOp>(loopOp.getLoc(),
|
|
parallelClauseOps);
|
|
mlir::Block *parallelBlock =
|
|
genEntryBlock(rewriter, parallelArgs, parallelOp.getRegion());
|
|
parallelOp.setComposite(true);
|
|
rewriter.setInsertionPoint(
|
|
rewriter.create<mlir::omp::TerminatorOp>(loopOp.getLoc()));
|
|
|
|
mlir::omp::DistributeOperands distributeClauseOps;
|
|
auto distributeOp = rewriter.create<mlir::omp::DistributeOp>(
|
|
loopOp.getLoc(), distributeClauseOps);
|
|
distributeOp.setComposite(true);
|
|
rewriter.createBlock(&distributeOp.getRegion());
|
|
|
|
mlir::omp::WsloopOperands wsloopClauseOps;
|
|
auto wsloopOp =
|
|
rewriter.create<mlir::omp::WsloopOp>(loopOp.getLoc(), wsloopClauseOps);
|
|
wsloopOp.setComposite(true);
|
|
rewriter.createBlock(&wsloopOp.getRegion());
|
|
|
|
mlir::IRMapping mapper;
|
|
mlir::Block &loopBlock = *loopOp.getRegion().begin();
|
|
|
|
for (auto [loopOpArg, parallelOpArg] : llvm::zip_equal(
|
|
loopBlock.getArguments(), parallelBlock->getArguments()))
|
|
mapper.map(loopOpArg, parallelOpArg);
|
|
|
|
rewriter.clone(*loopOp.begin(), mapper);
|
|
}
|
|
};
|
|
|
|
class GenericLoopConversionPass
|
|
: public flangomp::impl::GenericLoopConversionPassBase<
|
|
GenericLoopConversionPass> {
|
|
public:
|
|
GenericLoopConversionPass() = default;
|
|
|
|
void runOnOperation() override {
|
|
mlir::func::FuncOp func = getOperation();
|
|
|
|
if (func.isDeclaration())
|
|
return;
|
|
|
|
mlir::MLIRContext *context = &getContext();
|
|
mlir::RewritePatternSet patterns(context);
|
|
patterns.insert<GenericLoopConversionPattern>(context);
|
|
mlir::ConversionTarget target(*context);
|
|
|
|
target.markUnknownOpDynamicallyLegal(
|
|
[](mlir::Operation *) { return true; });
|
|
target.addDynamicallyLegalOp<mlir::omp::LoopOp>(
|
|
[](mlir::omp::LoopOp loopOp) {
|
|
return mlir::failed(
|
|
GenericLoopConversionPattern::checkLoopConversionSupportStatus(
|
|
loopOp));
|
|
});
|
|
|
|
if (mlir::failed(mlir::applyFullConversion(getOperation(), target,
|
|
std::move(patterns)))) {
|
|
mlir::emitError(func.getLoc(), "error in converting `omp.loop` op");
|
|
signalPassFailure();
|
|
}
|
|
}
|
|
};
|
|
} // namespace
|