[mlir][OpenMP] convert wsloop cancellation to LLVMIR (#137194)
Taskloop support will follow in a later patch.
This commit is contained in:
@@ -161,8 +161,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
|
||||
auto checkCancelDirective = [&todo](auto op, LogicalResult &result) {
|
||||
omp::ClauseCancellationConstructType cancelledDirective =
|
||||
op.getCancelDirective();
|
||||
if (cancelledDirective != omp::ClauseCancellationConstructType::Parallel &&
|
||||
cancelledDirective != omp::ClauseCancellationConstructType::Sections)
|
||||
if (cancelledDirective == omp::ClauseCancellationConstructType::Taskgroup)
|
||||
result = todo("cancel directive construct type not yet supported");
|
||||
};
|
||||
auto checkDepend = [&todo](auto op, LogicalResult &result) {
|
||||
@@ -2345,6 +2344,30 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
|
||||
? llvm::omp::WorksharingLoopType::DistributeForStaticLoop
|
||||
: llvm::omp::WorksharingLoopType::ForStaticLoop;
|
||||
|
||||
SmallVector<llvm::BranchInst *> cancelTerminators;
|
||||
// This callback is invoked only if there is cancellation inside of the wsloop
|
||||
// body.
|
||||
auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
|
||||
llvm::IRBuilderBase &llvmBuilder = ompBuilder->Builder;
|
||||
llvm::IRBuilderBase::InsertPointGuard guard(llvmBuilder);
|
||||
|
||||
// ip is currently in the block branched to if cancellation occured.
|
||||
// We need to create a branch to terminate that block.
|
||||
llvmBuilder.restoreIP(ip);
|
||||
|
||||
// We must still clean up the wsloop after cancelling it, so we need to
|
||||
// branch to the block that finalizes the wsloop.
|
||||
// That block has not been created yet so use this block as a dummy for now
|
||||
// and fix this after creating the wsloop.
|
||||
cancelTerminators.push_back(llvmBuilder.CreateBr(ip.getBlock()));
|
||||
return llvm::Error::success();
|
||||
};
|
||||
// We have to add the cleanup to the OpenMPIRBuilder before the body gets
|
||||
// created in case the body contains omp.cancel (which will then expect to be
|
||||
// able to find this cleanup callback).
|
||||
ompBuilder->pushFinalizationCB({finiCB, llvm::omp::Directive::OMPD_for,
|
||||
constructIsCancellable(wsloopOp)});
|
||||
|
||||
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
|
||||
llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
|
||||
wsloopOp.getRegion(), "omp.wsloop.region", builder, moduleTranslation);
|
||||
@@ -2366,6 +2389,19 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
|
||||
if (failed(handleError(wsloopIP, opInst)))
|
||||
return failure();
|
||||
|
||||
ompBuilder->popFinalizationCB();
|
||||
if (!cancelTerminators.empty()) {
|
||||
// If we cancelled the loop, we should branch to the finalization block of
|
||||
// the wsloop (which is always immediately before the loop continuation
|
||||
// block). Now the finalization has been created, we can fix the branch.
|
||||
llvm::BasicBlock *wsloopFini = wsloopIP->getBlock()->getSinglePredecessor();
|
||||
for (llvm::BranchInst *cancelBranch : cancelTerminators) {
|
||||
assert(cancelBranch->getNumSuccessors() == 1 &&
|
||||
"cancel branch should have one target");
|
||||
cancelBranch->setSuccessor(0, wsloopFini);
|
||||
}
|
||||
}
|
||||
|
||||
// Process the reductions if required.
|
||||
if (failed(createReductionsAndCleanup(
|
||||
wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,
|
||||
|
||||
@@ -0,0 +1,57 @@
|
||||
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
|
||||
|
||||
llvm.func @cancel_distribute_parallel_do(%lb : i32, %ub : i32, %step : i32) {
|
||||
omp.teams {
|
||||
omp.parallel {
|
||||
omp.distribute {
|
||||
omp.wsloop {
|
||||
omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
|
||||
omp.cancel cancellation_construct_type(loop)
|
||||
omp.yield
|
||||
}
|
||||
} {omp.composite}
|
||||
} {omp.composite}
|
||||
omp.terminator
|
||||
} {omp.composite}
|
||||
omp.terminator
|
||||
}
|
||||
llvm.return
|
||||
}
|
||||
// CHECK-LABEL: define internal void @cancel_distribute_parallel_do..omp_par
|
||||
// [...]
|
||||
// CHECK: omp_loop.cond:
|
||||
// CHECK: %[[VAL_102:.*]] = icmp ult i32 %{{.*}}, %{{.*}}
|
||||
// CHECK: br i1 %[[VAL_102]], label %omp_loop.body, label %omp_loop.exit
|
||||
// CHECK: omp_loop.exit:
|
||||
// CHECK: call void @__kmpc_for_static_fini(
|
||||
// CHECK: %[[VAL_106:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
|
||||
// CHECK: call void @__kmpc_barrier(ptr @2, i32 %[[VAL_106]])
|
||||
// CHECK: br label %omp_loop.after
|
||||
// CHECK: omp_loop.after:
|
||||
// CHECK: br label %omp.region.cont6
|
||||
// CHECK: omp.region.cont6:
|
||||
// CHECK: br label %omp.region.cont4
|
||||
// CHECK: omp.region.cont4:
|
||||
// CHECK: br label %distribute.exit.exitStub
|
||||
// CHECK: omp_loop.body:
|
||||
// CHECK: %[[VAL_111:.*]] = add i32 %{{.*}}, %{{.*}}
|
||||
// CHECK: %[[VAL_112:.*]] = mul i32 %[[VAL_111]], %{{.*}}
|
||||
// CHECK: %[[VAL_113:.*]] = add i32 %[[VAL_112]], %{{.*}}
|
||||
// CHECK: br label %omp.loop_nest.region
|
||||
// CHECK: omp.loop_nest.region:
|
||||
// CHECK: %[[VAL_115:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
|
||||
// CHECK: %[[VAL_116:.*]] = call i32 @__kmpc_cancel(ptr @1, i32 %[[VAL_115]], i32 2)
|
||||
// CHECK: %[[VAL_117:.*]] = icmp eq i32 %[[VAL_116]], 0
|
||||
// CHECK: br i1 %[[VAL_117]], label %omp.loop_nest.region.split, label %omp.loop_nest.region.cncl
|
||||
// CHECK: omp.loop_nest.region.cncl:
|
||||
// CHECK: br label %omp_loop.exit
|
||||
// CHECK: omp.loop_nest.region.split:
|
||||
// CHECK: br label %omp.region.cont7
|
||||
// CHECK: omp.region.cont7:
|
||||
// CHECK: br label %omp_loop.inc
|
||||
// CHECK: omp_loop.inc:
|
||||
// CHECK: %[[VAL_100:.*]] = add nuw i32 %{{.*}}, 1
|
||||
// CHECK: br label %omp_loop.header
|
||||
// CHECK: distribute.exit.exitStub:
|
||||
// CHECK: ret void
|
||||
|
||||
@@ -156,3 +156,90 @@ llvm.func @cancel_sections_if(%cond : i1) {
|
||||
// CHECK: ret void
|
||||
// CHECK: .cncl: ; preds = %[[VAL_27]]
|
||||
// CHECK: br label %[[VAL_19]]
|
||||
|
||||
llvm.func @cancel_wsloop_if(%lb : i32, %ub : i32, %step : i32, %cond : i1) {
|
||||
omp.wsloop {
|
||||
omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
|
||||
omp.cancel cancellation_construct_type(loop) if(%cond)
|
||||
omp.yield
|
||||
}
|
||||
}
|
||||
llvm.return
|
||||
}
|
||||
// CHECK-LABEL: define void @cancel_wsloop_if
|
||||
// CHECK: %[[VAL_0:.*]] = alloca i32, align 4
|
||||
// CHECK: %[[VAL_1:.*]] = alloca i32, align 4
|
||||
// CHECK: %[[VAL_2:.*]] = alloca i32, align 4
|
||||
// CHECK: %[[VAL_3:.*]] = alloca i32, align 4
|
||||
// CHECK: br label %[[VAL_4:.*]]
|
||||
// CHECK: omp.region.after_alloca: ; preds = %[[VAL_5:.*]]
|
||||
// CHECK: br label %[[VAL_6:.*]]
|
||||
// CHECK: entry: ; preds = %[[VAL_4]]
|
||||
// CHECK: br label %[[VAL_7:.*]]
|
||||
// CHECK: omp.wsloop.region: ; preds = %[[VAL_6]]
|
||||
// CHECK: %[[VAL_8:.*]] = icmp slt i32 %[[VAL_9:.*]], 0
|
||||
// CHECK: %[[VAL_10:.*]] = sub i32 0, %[[VAL_9]]
|
||||
// CHECK: %[[VAL_11:.*]] = select i1 %[[VAL_8]], i32 %[[VAL_10]], i32 %[[VAL_9]]
|
||||
// CHECK: %[[VAL_12:.*]] = select i1 %[[VAL_8]], i32 %[[VAL_13:.*]], i32 %[[VAL_14:.*]]
|
||||
// CHECK: %[[VAL_15:.*]] = select i1 %[[VAL_8]], i32 %[[VAL_14]], i32 %[[VAL_13]]
|
||||
// CHECK: %[[VAL_16:.*]] = sub nsw i32 %[[VAL_15]], %[[VAL_12]]
|
||||
// CHECK: %[[VAL_17:.*]] = icmp sle i32 %[[VAL_15]], %[[VAL_12]]
|
||||
// CHECK: %[[VAL_18:.*]] = sub i32 %[[VAL_16]], 1
|
||||
// CHECK: %[[VAL_19:.*]] = udiv i32 %[[VAL_18]], %[[VAL_11]]
|
||||
// CHECK: %[[VAL_20:.*]] = add i32 %[[VAL_19]], 1
|
||||
// CHECK: %[[VAL_21:.*]] = icmp ule i32 %[[VAL_16]], %[[VAL_11]]
|
||||
// CHECK: %[[VAL_22:.*]] = select i1 %[[VAL_21]], i32 1, i32 %[[VAL_20]]
|
||||
// CHECK: %[[VAL_23:.*]] = select i1 %[[VAL_17]], i32 0, i32 %[[VAL_22]]
|
||||
// CHECK: br label %[[VAL_24:.*]]
|
||||
// CHECK: omp_loop.preheader: ; preds = %[[VAL_7]]
|
||||
// CHECK: store i32 0, ptr %[[VAL_1]], align 4
|
||||
// CHECK: %[[VAL_25:.*]] = sub i32 %[[VAL_23]], 1
|
||||
// CHECK: store i32 %[[VAL_25]], ptr %[[VAL_2]], align 4
|
||||
// CHECK: store i32 1, ptr %[[VAL_3]], align 4
|
||||
// CHECK: %[[VAL_26:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
|
||||
// CHECK: call void @__kmpc_for_static_init_4u(ptr @1, i32 %[[VAL_26]], i32 34, ptr %[[VAL_0]], ptr %[[VAL_1]], ptr %[[VAL_2]], ptr %[[VAL_3]], i32 1, i32 0)
|
||||
// CHECK: %[[VAL_27:.*]] = load i32, ptr %[[VAL_1]], align 4
|
||||
// CHECK: %[[VAL_28:.*]] = load i32, ptr %[[VAL_2]], align 4
|
||||
// CHECK: %[[VAL_29:.*]] = sub i32 %[[VAL_28]], %[[VAL_27]]
|
||||
// CHECK: %[[VAL_30:.*]] = add i32 %[[VAL_29]], 1
|
||||
// CHECK: br label %[[VAL_31:.*]]
|
||||
// CHECK: omp_loop.header: ; preds = %[[VAL_32:.*]], %[[VAL_24]]
|
||||
// CHECK: %[[VAL_33:.*]] = phi i32 [ 0, %[[VAL_24]] ], [ %[[VAL_34:.*]], %[[VAL_32]] ]
|
||||
// CHECK: br label %[[VAL_35:.*]]
|
||||
// CHECK: omp_loop.cond: ; preds = %[[VAL_31]]
|
||||
// CHECK: %[[VAL_36:.*]] = icmp ult i32 %[[VAL_33]], %[[VAL_30]]
|
||||
// CHECK: br i1 %[[VAL_36]], label %[[VAL_37:.*]], label %[[VAL_38:.*]]
|
||||
// CHECK: omp_loop.body: ; preds = %[[VAL_35]]
|
||||
// CHECK: %[[VAL_39:.*]] = add i32 %[[VAL_33]], %[[VAL_27]]
|
||||
// CHECK: %[[VAL_40:.*]] = mul i32 %[[VAL_39]], %[[VAL_9]]
|
||||
// CHECK: %[[VAL_41:.*]] = add i32 %[[VAL_40]], %[[VAL_14]]
|
||||
// CHECK: br label %[[VAL_42:.*]]
|
||||
// CHECK: omp.loop_nest.region: ; preds = %[[VAL_37]]
|
||||
// CHECK: br i1 %[[VAL_43:.*]], label %[[VAL_44:.*]], label %[[VAL_45:.*]]
|
||||
// CHECK: 25: ; preds = %[[VAL_42]]
|
||||
// CHECK: %[[VAL_46:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
|
||||
// CHECK: %[[VAL_47:.*]] = call i32 @__kmpc_cancel(ptr @1, i32 %[[VAL_46]], i32 2)
|
||||
// CHECK: %[[VAL_48:.*]] = icmp eq i32 %[[VAL_47]], 0
|
||||
// CHECK: br i1 %[[VAL_48]], label %[[VAL_49:.*]], label %[[VAL_50:.*]]
|
||||
// CHECK: .split: ; preds = %[[VAL_44]]
|
||||
// CHECK: br label %[[VAL_51:.*]]
|
||||
// CHECK: 28: ; preds = %[[VAL_42]]
|
||||
// CHECK: br label %[[VAL_51]]
|
||||
// CHECK: 29: ; preds = %[[VAL_45]], %[[VAL_49]]
|
||||
// CHECK: br label %[[VAL_52:.*]]
|
||||
// CHECK: omp.region.cont1: ; preds = %[[VAL_51]]
|
||||
// CHECK: br label %[[VAL_32]]
|
||||
// CHECK: omp_loop.inc: ; preds = %[[VAL_52]]
|
||||
// CHECK: %[[VAL_34]] = add nuw i32 %[[VAL_33]], 1
|
||||
// CHECK: br label %[[VAL_31]]
|
||||
// CHECK: omp_loop.exit: ; preds = %[[VAL_50]], %[[VAL_35]]
|
||||
// CHECK: call void @__kmpc_for_static_fini(ptr @1, i32 %[[VAL_26]])
|
||||
// CHECK: %[[VAL_53:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
|
||||
// CHECK: call void @__kmpc_barrier(ptr @2, i32 %[[VAL_53]])
|
||||
// CHECK: br label %[[VAL_54:.*]]
|
||||
// CHECK: omp_loop.after: ; preds = %[[VAL_38]]
|
||||
// CHECK: br label %[[VAL_55:.*]]
|
||||
// CHECK: omp.region.cont: ; preds = %[[VAL_54]]
|
||||
// CHECK: ret void
|
||||
// CHECK: .cncl: ; preds = %[[VAL_44]]
|
||||
// CHECK: br label %[[VAL_38]]
|
||||
|
||||
@@ -26,22 +26,6 @@ llvm.func @atomic_hint(%v : !llvm.ptr, %x : !llvm.ptr, %expr : i32) {
|
||||
|
||||
// -----
|
||||
|
||||
llvm.func @cancel_wsloop(%lb : i32, %ub : i32, %step: i32) {
|
||||
// expected-error@below {{LLVM Translation failed for operation: omp.wsloop}}
|
||||
omp.wsloop {
|
||||
// expected-error@below {{LLVM Translation failed for operation: omp.loop_nest}}
|
||||
omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
|
||||
// expected-error@below {{not yet implemented: Unhandled clause cancel directive construct type not yet supported in omp.cancel operation}}
|
||||
// expected-error@below {{LLVM Translation failed for operation: omp.cancel}}
|
||||
omp.cancel cancellation_construct_type(loop)
|
||||
omp.yield
|
||||
}
|
||||
}
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
llvm.func @cancel_taskgroup() {
|
||||
// expected-error@below {{LLVM Translation failed for operation: omp.taskgroup}}
|
||||
omp.taskgroup {
|
||||
|
||||
Reference in New Issue
Block a user