Introduce a conversion pass from SCF parallel loops to OpenMP dialect constructs - parallel region and workshare loop. Loops with reductions are not supported because the OpenMP dialect cannot model them yet. The conversion currently targets only one level of parallelism, i.e. only one top-level `omp.parallel` operation is produced even if there are nested `scf.parallel` operations that could be mapped to `omp.wsloop`. Nested parallelism support is left for future work. Reviewed By: kiranchandramohan Differential Revision: https://reviews.llvm.org/D91982
114 lines
4.4 KiB
C++
114 lines
4.4 KiB
C++
//===- SCFToOpenMP.cpp - Structured Control Flow to OpenMP conversion -----===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements a pass to convert scf.parallel operations into OpenMP
|
|
// parallel loops.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h"
|
|
#include "../PassDetail.h"
|
|
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
|
|
#include "mlir/Dialect/SCF/SCF.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace {
|
|
|
|
/// Converts SCF parallel operation into an OpenMP workshare loop construct.
|
|
struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
|
|
using OpRewritePattern<scf::ParallelOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(scf::ParallelOp parallelOp,
|
|
PatternRewriter &rewriter) const override {
|
|
// TODO: add support for reductions when OpenMP loops have them.
|
|
if (parallelOp.getNumResults() != 0)
|
|
return rewriter.notifyMatchFailure(
|
|
parallelOp,
|
|
"OpenMP dialect does not yet support loops with reductions");
|
|
|
|
// Replace SCF yield with OpenMP yield.
|
|
{
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
rewriter.setInsertionPointToEnd(parallelOp.getBody());
|
|
assert(llvm::hasSingleElement(parallelOp.region()) &&
|
|
"expected scf.parallel to have one block");
|
|
rewriter.replaceOpWithNewOp<omp::YieldOp>(
|
|
parallelOp.getBody()->getTerminator(), ValueRange());
|
|
}
|
|
|
|
// Replace the loop.
|
|
auto loop = rewriter.create<omp::WsLoopOp>(
|
|
parallelOp.getLoc(), parallelOp.lowerBound(), parallelOp.upperBound(),
|
|
parallelOp.step());
|
|
rewriter.inlineRegionBefore(parallelOp.region(), loop.region(),
|
|
loop.region().begin());
|
|
rewriter.eraseOp(parallelOp);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Inserts OpenMP "parallel" operations around top-level SCF "parallel"
|
|
/// operations in the given function. This is implemented as a direct IR
|
|
/// modification rather than as a conversion pattern because it does not
|
|
/// modify the top-level operation it matches, which is a requirement for
|
|
/// rewrite patterns.
|
|
//
|
|
// TODO: consider creating nested parallel operations when necessary.
|
|
static void insertOpenMPParallel(FuncOp func) {
|
|
// Collect top-level SCF "parallel" ops.
|
|
SmallVector<scf::ParallelOp, 4> topLevelParallelOps;
|
|
func.walk([&topLevelParallelOps](scf::ParallelOp parallelOp) {
|
|
// Ignore ops that are already within OpenMP parallel construct.
|
|
if (!parallelOp.getParentOfType<scf::ParallelOp>())
|
|
topLevelParallelOps.push_back(parallelOp);
|
|
});
|
|
|
|
// Wrap SCF ops into OpenMP "parallel" ops.
|
|
for (scf::ParallelOp parallelOp : topLevelParallelOps) {
|
|
OpBuilder builder(parallelOp);
|
|
auto omp = builder.create<omp::ParallelOp>(parallelOp.getLoc());
|
|
Block *block = builder.createBlock(&omp.getRegion());
|
|
builder.create<omp::TerminatorOp>(parallelOp.getLoc());
|
|
block->getOperations().splice(
|
|
block->begin(), parallelOp.getOperation()->getBlock()->getOperations(),
|
|
parallelOp.getOperation());
|
|
}
|
|
}
|
|
|
|
/// Applies the conversion patterns in the given function.
|
|
static LogicalResult applyPatterns(FuncOp func) {
|
|
ConversionTarget target(*func.getContext());
|
|
target.addIllegalOp<scf::ParallelOp>();
|
|
target.addDynamicallyLegalOp<scf::YieldOp>(
|
|
[](scf::YieldOp op) { return !isa<scf::ParallelOp>(op.getParentOp()); });
|
|
target.addLegalDialect<omp::OpenMPDialect>();
|
|
|
|
OwningRewritePatternList patterns;
|
|
patterns.insert<ParallelOpLowering>(func.getContext());
|
|
FrozenRewritePatternList frozen(std::move(patterns));
|
|
return applyPartialConversion(func, target, frozen);
|
|
}
|
|
|
|
/// A pass converting SCF operations to OpenMP operations.
|
|
struct SCFToOpenMPPass : public ConvertSCFToOpenMPBase<SCFToOpenMPPass> {
|
|
/// Pass entry point.
|
|
void runOnFunction() override {
|
|
insertOpenMPParallel(getFunction());
|
|
if (failed(applyPatterns(getFunction())))
|
|
signalPassFailure();
|
|
}
|
|
};
|
|
|
|
} // end namespace
|
|
|
|
std::unique_ptr<OperationPass<FuncOp>> mlir::createConvertSCFToOpenMPPass() {
|
|
return std::make_unique<SCFToOpenMPPass>();
|
|
}
|