[flang][OpenMP] Add reduction clause support to loop directive (#128849)

Extends `loop` directive transformation by adding support for the
`reduction` clause.
This commit is contained in:
Kareem Ergawy
2025-02-28 05:46:03 +01:00
committed by GitHub
parent 55f254726e
commit e0c690990d
3 changed files with 98 additions and 28 deletions

View File

@@ -15,6 +15,8 @@
#include "mlir/Transforms/DialectConversion.h"
#include <memory>
#include <optional>
#include <type_traits>
namespace flangomp {
#define GEN_PASS_DEF_GENERICLOOPCONVERSIONPASS
@@ -58,7 +60,7 @@ public:
if (teamsLoopCanBeParallelFor(loopOp))
rewriteToDistributeParallelDo(loopOp, rewriter);
else
rewriteToDistrbute(loopOp, rewriter);
rewriteToDistribute(loopOp, rewriter);
break;
}
@@ -77,9 +79,6 @@ public:
if (loopOp.getOrder())
return todo("order");
if (!loopOp.getReductionVars().empty())
return todo("reduction");
return mlir::success();
}
@@ -168,7 +167,7 @@ private:
case ClauseBindKind::Parallel:
return rewriteToWsloop(loopOp, rewriter);
case ClauseBindKind::Teams:
return rewriteToDistrbute(loopOp, rewriter);
return rewriteToDistribute(loopOp, rewriter);
case ClauseBindKind::Thread:
return rewriteToSimdLoop(loopOp, rewriter);
}
@@ -211,8 +210,9 @@ private:
loopOp, rewriter);
}
void rewriteToDistrbute(mlir::omp::LoopOp loopOp,
mlir::ConversionPatternRewriter &rewriter) const {
void rewriteToDistribute(mlir::omp::LoopOp loopOp,
mlir::ConversionPatternRewriter &rewriter) const {
assert(loopOp.getReductionVars().empty());
rewriteToSingleWrapperOp<mlir::omp::DistributeOp,
mlir::omp::DistributeOperands>(loopOp, rewriter);
}
@@ -246,6 +246,12 @@ private:
Fortran::common::openmp::EntryBlockArgs args;
args.priv.vars = clauseOps.privateVars;
if constexpr (!std::is_same_v<OpOperandsTy,
mlir::omp::DistributeOperands>) {
populateReductionClauseOps(loopOp, clauseOps);
args.reduction.vars = clauseOps.reductionVars;
}
auto wrapperOp = rewriter.create<OpTy>(loopOp.getLoc(), clauseOps);
mlir::Block *opBlock = genEntryBlock(rewriter, args, wrapperOp.getRegion());
@@ -275,8 +281,7 @@ private:
auto parallelOp = rewriter.create<mlir::omp::ParallelOp>(loopOp.getLoc(),
parallelClauseOps);
mlir::Block *parallelBlock =
genEntryBlock(rewriter, parallelArgs, parallelOp.getRegion());
genEntryBlock(rewriter, parallelArgs, parallelOp.getRegion());
parallelOp.setComposite(true);
rewriter.setInsertionPoint(
rewriter.create<mlir::omp::TerminatorOp>(loopOp.getLoc()));
@@ -288,20 +293,54 @@ private:
rewriter.createBlock(&distributeOp.getRegion());
mlir::omp::WsloopOperands wsloopClauseOps;
populateReductionClauseOps(loopOp, wsloopClauseOps);
Fortran::common::openmp::EntryBlockArgs wsloopArgs;
wsloopArgs.reduction.vars = wsloopClauseOps.reductionVars;
auto wsloopOp =
rewriter.create<mlir::omp::WsloopOp>(loopOp.getLoc(), wsloopClauseOps);
wsloopOp.setComposite(true);
rewriter.createBlock(&wsloopOp.getRegion());
genEntryBlock(rewriter, wsloopArgs, wsloopOp.getRegion());
mlir::IRMapping mapper;
mlir::Block &loopBlock = *loopOp.getRegion().begin();
for (auto [loopOpArg, parallelOpArg] : llvm::zip_equal(
loopBlock.getArguments(), parallelBlock->getArguments()))
auto loopBlockInterface =
llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(*loopOp);
auto parallelBlockInterface =
llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(*parallelOp);
auto wsloopBlockInterface =
llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(*wsloopOp);
for (auto [loopOpArg, parallelOpArg] :
llvm::zip_equal(loopBlockInterface.getPrivateBlockArgs(),
parallelBlockInterface.getPrivateBlockArgs()))
mapper.map(loopOpArg, parallelOpArg);
for (auto [loopOpArg, wsloopOpArg] :
llvm::zip_equal(loopBlockInterface.getReductionBlockArgs(),
wsloopBlockInterface.getReductionBlockArgs()))
mapper.map(loopOpArg, wsloopOpArg);
rewriter.clone(*loopOp.begin(), mapper);
}
void
populateReductionClauseOps(mlir::omp::LoopOp loopOp,
mlir::omp::ReductionClauseOps &clauseOps) const {
clauseOps.reductionMod = loopOp.getReductionModAttr();
clauseOps.reductionVars = loopOp.getReductionVars();
std::optional<mlir::ArrayAttr> reductionSyms = loopOp.getReductionSyms();
if (reductionSyms)
clauseOps.reductionSyms.assign(reductionSyms->begin(),
reductionSyms->end());
std::optional<llvm::ArrayRef<bool>> reductionByref =
loopOp.getReductionByref();
if (reductionByref)
clauseOps.reductionByref.assign(reductionByref->begin(),
reductionByref->end());
}
};
class GenericLoopConversionPass

View File

@@ -75,7 +75,7 @@ end subroutine
subroutine test_reduction()
integer :: i, dummy = 1
! CHECK: omp.loop private(@{{.*}} %{{.*}}#0 -> %{{.*}} : !{{.*}}) reduction
! CHECK: omp.simd private(@{{.*}} %{{.*}}#0 -> %{{.*}} : !{{.*}}) reduction
! CHECK-SAME: (@[[RED]] %{{.*}}#0 -> %[[DUMMY_ARG:.*]] : !{{.*}}) {
! CHECK-NEXT: omp.loop_nest (%{{.*}}) : i32 = (%{{.*}}) to (%{{.*}}) {{.*}} {
! CHECK: %[[DUMMY_DECL:.*]]:2 = hlfir.declare %[[DUMMY_ARG]] {uniq_name = "_QFtest_reductionEdummy"}
@@ -294,3 +294,46 @@ subroutine teams_loop_cannot_be_parallel_for_4
!$omp end parallel
END DO
end subroutine
! CHECK-LABEL: func.func @_QPloop_parallel_bind_reduction
subroutine loop_parallel_bind_reduction
implicit none
integer :: x, i
! CHECK: omp.wsloop
! CHECK-SAME: private(@{{[^[:space:]]+}} %{{[^[:space:]]+}}#0 -> %[[PRIV_ARG:[^[:space:]]+]] : !fir.ref<i32>)
! CHECK-SAME: reduction(@add_reduction_i32 %{{.*}}#0 -> %[[RED_ARG:.*]] : !fir.ref<i32>) {
! CHECK-NEXT: omp.loop_nest {{.*}} {
! CHECK-NEXT: hlfir.declare %[[PRIV_ARG]] {uniq_name = "_QF{{.*}}Ei"}
! CHECK-NEXT: hlfir.declare %[[RED_ARG]] {uniq_name = "_QF{{.*}}Ex"}
! CHECK: }
! CHECK: }
!$omp loop bind(parallel) reduction(+: x)
do i = 0, 10
x = x + i
end do
end subroutine
! CHECK-LABEL: func.func @_QPloop_teams_loop_reduction
subroutine loop_teams_loop_reduction
implicit none
integer :: x, i
! CHECK: omp.teams {
! CHECK: omp.parallel
! CHECK-SAME: private(@{{[^[:space:]]+}} %{{[^[:space:]]+}}#0 -> %[[PRIV_ARG:[^[:space:]]+]] : !fir.ref<i32>) {
! CHECK: omp.distribute {
! CHECK: omp.wsloop
! CHECK-SAME: reduction(@add_reduction_i32 %{{.*}}#0 -> %[[RED_ARG:.*]] : !fir.ref<i32>) {
! CHECK-NEXT: omp.loop_nest {{.*}} {
! CHECK-NEXT: hlfir.declare %[[PRIV_ARG]] {uniq_name = "_QF{{.*}}Ei"}
! CHECK-NEXT: hlfir.declare %[[RED_ARG]] {uniq_name = "_QF{{.*}}Ex"}
! CHECK: }
! CHECK: }
! CHECK: }
! CHECK: }
! CHECK: }
!$omp teams loop reduction(+: x)
do i = 0, 10
x = x + i
end do
end subroutine

View File

@@ -1,24 +1,12 @@
// RUN: fir-opt --omp-generic-loop-conversion -verify-diagnostics %s
omp.declare_reduction @add_reduction_i32 : i32 init {
^bb0(%arg0: i32):
%c0_i32 = arith.constant 0 : i32
omp.yield(%c0_i32 : i32)
} combiner {
^bb0(%arg0: i32, %arg1: i32):
%0 = arith.addi %arg0, %arg1 : i32
omp.yield(%0 : i32)
}
func.func @_QPloop_order() {
omp.teams {
%c0 = arith.constant 0 : i32
%c10 = arith.constant 10 : i32
%c1 = arith.constant 1 : i32
%sum = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFtest_orderEi"}
// expected-error@below {{not yet implemented: Unhandled clause reduction in omp.loop operation}}
omp.loop reduction(@add_reduction_i32 %sum -> %arg2 : !fir.ref<i32>) {
// expected-error@below {{not yet implemented: Unhandled clause order in omp.loop operation}}
omp.loop order(reproducible:concurrent) {
omp.loop_nest (%arg3) : i32 = (%c0) to (%c10) inclusive step (%c1) {
omp.yield
}