[flang][cuda] Fix lowering of cuf kernel with unstructured nested construct (#107149)
Lowering was crashing when cuf kernels has an unstructured construct. Blocks created by PFT need to be re-created inside of the operation like it is done for OpenACC construct.
This commit is contained in:
committed by
GitHub
parent
fe454b2044
commit
c81b43074a
@@ -11,6 +11,7 @@
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "flang/Lower/Bridge.h"
|
||||
#include "DirectivesCommon.h"
|
||||
#include "flang/Common/Version.h"
|
||||
#include "flang/Lower/Allocatable.h"
|
||||
#include "flang/Lower/CallInterface.h"
|
||||
@@ -2999,6 +3000,12 @@ private:
|
||||
mlir::Block &b = op.getRegion().back();
|
||||
builder->setInsertionPointToStart(&b);
|
||||
|
||||
Fortran::lower::pft::Evaluation *crtEval = &getEval();
|
||||
if (crtEval->lowerAsUnstructured())
|
||||
Fortran::lower::createEmptyRegionBlocks<fir::FirEndOp>(
|
||||
*builder, crtEval->getNestedEvaluations());
|
||||
builder->setInsertionPointToStart(&b);
|
||||
|
||||
for (auto [arg, value] : llvm::zip(
|
||||
op.getLoopRegions().front()->front().getArguments(), ivValues)) {
|
||||
mlir::Value convArg =
|
||||
@@ -3006,7 +3013,6 @@ private:
|
||||
builder->create<fir::StoreOp>(loc, convArg, value);
|
||||
}
|
||||
|
||||
Fortran::lower::pft::Evaluation *crtEval = &getEval();
|
||||
if (crtEval->lowerAsStructured()) {
|
||||
crtEval = &crtEval->getFirstNestedEvaluation();
|
||||
for (int64_t i = 1; i < nestedLoops; i++)
|
||||
|
||||
@@ -78,3 +78,23 @@ end
|
||||
! CHECK: %[[STREAM_LOAD:.*]] = fir.load %[[STREAM]]#0 : !fir.ref<i64>
|
||||
! CHECK: %[[STREAM_I32:.*]] = fir.convert %[[STREAM_LOAD]] : (i64) -> i32
|
||||
! CHECK: cuf.kernel<<<*, *, stream = %[[STREAM_I32]]>>>
|
||||
|
||||
|
||||
! Test lowering with unstructured construct inside.
|
||||
subroutine sub2(m,a,b)
|
||||
integer :: m
|
||||
real, device :: a(m,m), b(m)
|
||||
integer :: i,j
|
||||
!$cuf kernel do<<<*,*>>>
|
||||
|
||||
do j = 1, m
|
||||
i = 1
|
||||
do while (a(i,j).eq.0)
|
||||
i = i + 1
|
||||
end do
|
||||
b(j) = i
|
||||
end do
|
||||
end subroutine
|
||||
|
||||
! CHECK-LABEL: func.func @_QPsub2
|
||||
! CHECK: cuf.kernel
|
||||
|
||||
Reference in New Issue
Block a user