[flang][OpenMP] Add reduction clause support to loop directive (#128849)
Extends `loop` directive transformation by adding support for the `reduction` clause.
This commit is contained in:
@@ -15,6 +15,8 @@
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <type_traits>
|
||||
|
||||
namespace flangomp {
|
||||
#define GEN_PASS_DEF_GENERICLOOPCONVERSIONPASS
|
||||
@@ -58,7 +60,7 @@ public:
|
||||
if (teamsLoopCanBeParallelFor(loopOp))
|
||||
rewriteToDistributeParallelDo(loopOp, rewriter);
|
||||
else
|
||||
rewriteToDistrbute(loopOp, rewriter);
|
||||
rewriteToDistribute(loopOp, rewriter);
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -77,9 +79,6 @@ public:
|
||||
if (loopOp.getOrder())
|
||||
return todo("order");
|
||||
|
||||
if (!loopOp.getReductionVars().empty())
|
||||
return todo("reduction");
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
@@ -168,7 +167,7 @@ private:
|
||||
case ClauseBindKind::Parallel:
|
||||
return rewriteToWsloop(loopOp, rewriter);
|
||||
case ClauseBindKind::Teams:
|
||||
return rewriteToDistrbute(loopOp, rewriter);
|
||||
return rewriteToDistribute(loopOp, rewriter);
|
||||
case ClauseBindKind::Thread:
|
||||
return rewriteToSimdLoop(loopOp, rewriter);
|
||||
}
|
||||
@@ -211,8 +210,9 @@ private:
|
||||
loopOp, rewriter);
|
||||
}
|
||||
|
||||
void rewriteToDistrbute(mlir::omp::LoopOp loopOp,
|
||||
mlir::ConversionPatternRewriter &rewriter) const {
|
||||
void rewriteToDistribute(mlir::omp::LoopOp loopOp,
|
||||
mlir::ConversionPatternRewriter &rewriter) const {
|
||||
assert(loopOp.getReductionVars().empty());
|
||||
rewriteToSingleWrapperOp<mlir::omp::DistributeOp,
|
||||
mlir::omp::DistributeOperands>(loopOp, rewriter);
|
||||
}
|
||||
@@ -246,6 +246,12 @@ private:
|
||||
Fortran::common::openmp::EntryBlockArgs args;
|
||||
args.priv.vars = clauseOps.privateVars;
|
||||
|
||||
if constexpr (!std::is_same_v<OpOperandsTy,
|
||||
mlir::omp::DistributeOperands>) {
|
||||
populateReductionClauseOps(loopOp, clauseOps);
|
||||
args.reduction.vars = clauseOps.reductionVars;
|
||||
}
|
||||
|
||||
auto wrapperOp = rewriter.create<OpTy>(loopOp.getLoc(), clauseOps);
|
||||
mlir::Block *opBlock = genEntryBlock(rewriter, args, wrapperOp.getRegion());
|
||||
|
||||
@@ -275,8 +281,7 @@ private:
|
||||
|
||||
auto parallelOp = rewriter.create<mlir::omp::ParallelOp>(loopOp.getLoc(),
|
||||
parallelClauseOps);
|
||||
mlir::Block *parallelBlock =
|
||||
genEntryBlock(rewriter, parallelArgs, parallelOp.getRegion());
|
||||
genEntryBlock(rewriter, parallelArgs, parallelOp.getRegion());
|
||||
parallelOp.setComposite(true);
|
||||
rewriter.setInsertionPoint(
|
||||
rewriter.create<mlir::omp::TerminatorOp>(loopOp.getLoc()));
|
||||
@@ -288,20 +293,54 @@ private:
|
||||
rewriter.createBlock(&distributeOp.getRegion());
|
||||
|
||||
mlir::omp::WsloopOperands wsloopClauseOps;
|
||||
populateReductionClauseOps(loopOp, wsloopClauseOps);
|
||||
Fortran::common::openmp::EntryBlockArgs wsloopArgs;
|
||||
wsloopArgs.reduction.vars = wsloopClauseOps.reductionVars;
|
||||
|
||||
auto wsloopOp =
|
||||
rewriter.create<mlir::omp::WsloopOp>(loopOp.getLoc(), wsloopClauseOps);
|
||||
wsloopOp.setComposite(true);
|
||||
rewriter.createBlock(&wsloopOp.getRegion());
|
||||
genEntryBlock(rewriter, wsloopArgs, wsloopOp.getRegion());
|
||||
|
||||
mlir::IRMapping mapper;
|
||||
mlir::Block &loopBlock = *loopOp.getRegion().begin();
|
||||
|
||||
for (auto [loopOpArg, parallelOpArg] : llvm::zip_equal(
|
||||
loopBlock.getArguments(), parallelBlock->getArguments()))
|
||||
auto loopBlockInterface =
|
||||
llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(*loopOp);
|
||||
auto parallelBlockInterface =
|
||||
llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(*parallelOp);
|
||||
auto wsloopBlockInterface =
|
||||
llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(*wsloopOp);
|
||||
|
||||
for (auto [loopOpArg, parallelOpArg] :
|
||||
llvm::zip_equal(loopBlockInterface.getPrivateBlockArgs(),
|
||||
parallelBlockInterface.getPrivateBlockArgs()))
|
||||
mapper.map(loopOpArg, parallelOpArg);
|
||||
|
||||
for (auto [loopOpArg, wsloopOpArg] :
|
||||
llvm::zip_equal(loopBlockInterface.getReductionBlockArgs(),
|
||||
wsloopBlockInterface.getReductionBlockArgs()))
|
||||
mapper.map(loopOpArg, wsloopOpArg);
|
||||
|
||||
rewriter.clone(*loopOp.begin(), mapper);
|
||||
}
|
||||
|
||||
void
|
||||
populateReductionClauseOps(mlir::omp::LoopOp loopOp,
|
||||
mlir::omp::ReductionClauseOps &clauseOps) const {
|
||||
clauseOps.reductionMod = loopOp.getReductionModAttr();
|
||||
clauseOps.reductionVars = loopOp.getReductionVars();
|
||||
|
||||
std::optional<mlir::ArrayAttr> reductionSyms = loopOp.getReductionSyms();
|
||||
if (reductionSyms)
|
||||
clauseOps.reductionSyms.assign(reductionSyms->begin(),
|
||||
reductionSyms->end());
|
||||
|
||||
std::optional<llvm::ArrayRef<bool>> reductionByref =
|
||||
loopOp.getReductionByref();
|
||||
if (reductionByref)
|
||||
clauseOps.reductionByref.assign(reductionByref->begin(),
|
||||
reductionByref->end());
|
||||
}
|
||||
};
|
||||
|
||||
class GenericLoopConversionPass
|
||||
|
||||
@@ -75,7 +75,7 @@ end subroutine
|
||||
subroutine test_reduction()
|
||||
integer :: i, dummy = 1
|
||||
|
||||
! CHECK: omp.loop private(@{{.*}} %{{.*}}#0 -> %{{.*}} : !{{.*}}) reduction
|
||||
! CHECK: omp.simd private(@{{.*}} %{{.*}}#0 -> %{{.*}} : !{{.*}}) reduction
|
||||
! CHECK-SAME: (@[[RED]] %{{.*}}#0 -> %[[DUMMY_ARG:.*]] : !{{.*}}) {
|
||||
! CHECK-NEXT: omp.loop_nest (%{{.*}}) : i32 = (%{{.*}}) to (%{{.*}}) {{.*}} {
|
||||
! CHECK: %[[DUMMY_DECL:.*]]:2 = hlfir.declare %[[DUMMY_ARG]] {uniq_name = "_QFtest_reductionEdummy"}
|
||||
@@ -294,3 +294,46 @@ subroutine teams_loop_cannot_be_parallel_for_4
|
||||
!$omp end parallel
|
||||
END DO
|
||||
end subroutine
|
||||
|
||||
! CHECK-LABEL: func.func @_QPloop_parallel_bind_reduction
|
||||
subroutine loop_parallel_bind_reduction
|
||||
implicit none
|
||||
integer :: x, i
|
||||
|
||||
! CHECK: omp.wsloop
|
||||
! CHECK-SAME: private(@{{[^[:space:]]+}} %{{[^[:space:]]+}}#0 -> %[[PRIV_ARG:[^[:space:]]+]] : !fir.ref<i32>)
|
||||
! CHECK-SAME: reduction(@add_reduction_i32 %{{.*}}#0 -> %[[RED_ARG:.*]] : !fir.ref<i32>) {
|
||||
! CHECK-NEXT: omp.loop_nest {{.*}} {
|
||||
! CHECK-NEXT: hlfir.declare %[[PRIV_ARG]] {uniq_name = "_QF{{.*}}Ei"}
|
||||
! CHECK-NEXT: hlfir.declare %[[RED_ARG]] {uniq_name = "_QF{{.*}}Ex"}
|
||||
! CHECK: }
|
||||
! CHECK: }
|
||||
!$omp loop bind(parallel) reduction(+: x)
|
||||
do i = 0, 10
|
||||
x = x + i
|
||||
end do
|
||||
end subroutine
|
||||
|
||||
! CHECK-LABEL: func.func @_QPloop_teams_loop_reduction
|
||||
subroutine loop_teams_loop_reduction
|
||||
implicit none
|
||||
integer :: x, i
|
||||
! CHECK: omp.teams {
|
||||
! CHECK: omp.parallel
|
||||
! CHECK-SAME: private(@{{[^[:space:]]+}} %{{[^[:space:]]+}}#0 -> %[[PRIV_ARG:[^[:space:]]+]] : !fir.ref<i32>) {
|
||||
! CHECK: omp.distribute {
|
||||
! CHECK: omp.wsloop
|
||||
! CHECK-SAME: reduction(@add_reduction_i32 %{{.*}}#0 -> %[[RED_ARG:.*]] : !fir.ref<i32>) {
|
||||
! CHECK-NEXT: omp.loop_nest {{.*}} {
|
||||
! CHECK-NEXT: hlfir.declare %[[PRIV_ARG]] {uniq_name = "_QF{{.*}}Ei"}
|
||||
! CHECK-NEXT: hlfir.declare %[[RED_ARG]] {uniq_name = "_QF{{.*}}Ex"}
|
||||
! CHECK: }
|
||||
! CHECK: }
|
||||
! CHECK: }
|
||||
! CHECK: }
|
||||
! CHECK: }
|
||||
!$omp teams loop reduction(+: x)
|
||||
do i = 0, 10
|
||||
x = x + i
|
||||
end do
|
||||
end subroutine
|
||||
|
||||
@@ -1,24 +1,12 @@
|
||||
// RUN: fir-opt --omp-generic-loop-conversion -verify-diagnostics %s
|
||||
|
||||
omp.declare_reduction @add_reduction_i32 : i32 init {
|
||||
^bb0(%arg0: i32):
|
||||
%c0_i32 = arith.constant 0 : i32
|
||||
omp.yield(%c0_i32 : i32)
|
||||
} combiner {
|
||||
^bb0(%arg0: i32, %arg1: i32):
|
||||
%0 = arith.addi %arg0, %arg1 : i32
|
||||
omp.yield(%0 : i32)
|
||||
}
|
||||
|
||||
func.func @_QPloop_order() {
|
||||
omp.teams {
|
||||
%c0 = arith.constant 0 : i32
|
||||
%c10 = arith.constant 10 : i32
|
||||
%c1 = arith.constant 1 : i32
|
||||
%sum = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFtest_orderEi"}
|
||||
|
||||
// expected-error@below {{not yet implemented: Unhandled clause reduction in omp.loop operation}}
|
||||
omp.loop reduction(@add_reduction_i32 %sum -> %arg2 : !fir.ref<i32>) {
|
||||
// expected-error@below {{not yet implemented: Unhandled clause order in omp.loop operation}}
|
||||
omp.loop order(reproducible:concurrent) {
|
||||
omp.loop_nest (%arg3) : i32 = (%c0) to (%c10) inclusive step (%c1) {
|
||||
omp.yield
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user