[OpenMP][Flang][MLIR] Add lowering of TargetOp for host codegen to LLVM-IR

This patch adds lowering of TargetOps for the host. The lowering outlines the
target region function and uses the OpenMPIRBuilder support functions to emit
the function and call. Code generation for offloading will be done in later
patches.

Reviewed By: kiranchandramohan, jdoerfert, agozillon

Differential Revision: https://reviews.llvm.org/D147172
This commit is contained in:
Jan Sjodin
2023-04-03 10:46:21 -04:00
parent 17df2021a5
commit d3f9388ffb
8 changed files with 371 additions and 7 deletions

View File

@@ -1824,6 +1824,27 @@ public:
Value *IfCond, BodyGenCallbackTy ProcessMapOpCB,
BodyGenCallbackTy BodyGenCB = {});
using TargetBodyGenCallbackTy = function_ref<InsertPointTy(
InsertPointTy AllocaIP, InsertPointTy CodeGenIP)>;
/// Generator for '#omp target'
///
/// \param Loc where the target data construct was encountered.
/// \param CodeGenIP The insertion point where the call to the outlined
/// function should be emitted.
/// \param EntryInfo The entry information about the function.
/// \param NumTeams Number of teams specified in the num_teams clause.
/// \param NumThreads Number of teams specified in the thread_limit clause.
/// \param Inputs The input values to the region that will be passed.
/// as arguments to the outlined function.
/// \param BodyGenCB Callback that will generate the region code.
InsertPointTy createTarget(const LocationDescription &Loc,
OpenMPIRBuilder::InsertPointTy CodeGenIP,
TargetRegionEntryInfo &EntryInfo, int32_t NumTeams,
int32_t NumThreads,
SmallVectorImpl<Value *> &Inputs,
TargetBodyGenCallbackTy BodyGenCB);
/// Declarations for LLVM-IR types (simple, array, function and structure) are
/// generated below. Their names are defined and used in OpenMPKinds.def. Here
/// we provide the declarations, the initializeTypes function will provide the

View File

@@ -4111,6 +4111,88 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetData(
return Builder.saveIP();
}
static Function *
createOutlinedFunction(IRBuilderBase &Builder, StringRef FuncName,
SmallVectorImpl<Value *> &Inputs,
OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc) {
SmallVector<Type *> ParameterTypes;
for (auto &Arg : Inputs)
ParameterTypes.push_back(Arg->getType());
auto FuncType = FunctionType::get(Builder.getVoidTy(), ParameterTypes,
/*isVarArg*/ false);
auto Func = Function::Create(FuncType, GlobalValue::InternalLinkage, FuncName,
Builder.GetInsertBlock()->getModule());
// Save insert point.
auto OldInsertPoint = Builder.saveIP();
// Generate the region into the function.
BasicBlock *EntryBB = BasicBlock::Create(Builder.getContext(), "entry", Func);
Builder.SetInsertPoint(EntryBB);
Builder.restoreIP(CBFunc(Builder.saveIP(), Builder.saveIP()));
// Insert return instruction.
Builder.CreateRetVoid();
// Rewrite uses of input valus to parameters.
for (auto InArg : zip(Inputs, Func->args())) {
Value *Input = std::get<0>(InArg);
Argument &Arg = std::get<1>(InArg);
// Collect all the instructions
for (User *User : make_early_inc_range(Input->users()))
if (auto Instr = dyn_cast<Instruction>(User))
if (Instr->getFunction() == Func)
Instr->replaceUsesOfWith(Input, &Arg);
}
// Restore insert point.
Builder.restoreIP(OldInsertPoint);
return Func;
}
static void
emitTargetOutlinedFunction(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
TargetRegionEntryInfo &EntryInfo,
Function *&OutlinedFn, int32_t NumTeams,
int32_t NumThreads, SmallVectorImpl<Value *> &Inputs,
OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc) {
OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction =
[&Builder, &Inputs, &CBFunc](StringRef EntryFnName) {
return createOutlinedFunction(Builder, EntryFnName, Inputs, CBFunc);
};
Constant *OutlinedFnID;
OMPBuilder.emitTargetRegionFunction(EntryInfo, GenerateOutlinedFunction,
NumTeams, NumThreads, true, OutlinedFn,
OutlinedFnID);
}
static void emitTargetCall(IRBuilderBase &Builder, Function *OutlinedFn,
SmallVectorImpl<Value *> &Args) {
// TODO: Add kernel launch call when device codegen is supported.
Builder.CreateCall(OutlinedFn, Args);
}
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTarget(
const LocationDescription &Loc, OpenMPIRBuilder::InsertPointTy CodeGenIP,
TargetRegionEntryInfo &EntryInfo, int32_t NumTeams, int32_t NumThreads,
SmallVectorImpl<Value *> &Args, TargetBodyGenCallbackTy CBFunc) {
if (!updateToLocation(Loc))
return InsertPointTy();
Builder.restoreIP(CodeGenIP);
Function *OutlinedFn;
emitTargetOutlinedFunction(*this, Builder, EntryInfo, OutlinedFn, NumTeams,
NumThreads, Args, CBFunc);
emitTargetCall(Builder, OutlinedFn, Args);
return Builder.saveIP();
}
std::string OpenMPIRBuilder::getNameWithSeparators(ArrayRef<StringRef> Parts,
StringRef FirstSeparator,
StringRef Separator) {

View File

@@ -5119,6 +5119,62 @@ TEST_F(OpenMPIRBuilderTest, TargetDataRegion) {
EXPECT_FALSE(verifyModule(*M, &errs()));
}
TEST_F(OpenMPIRBuilderTest, TargetRegion) {
using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
OpenMPIRBuilder OMPBuilder(*M);
OMPBuilder.initialize();
OpenMPIRBuilderConfig Config(false, false, false, false);
OMPBuilder.setConfig(Config);
F->setName("func");
IRBuilder<> Builder(BB);
auto Int32Ty = Builder.getInt32Ty();
AllocaInst *APtr = Builder.CreateAlloca(Int32Ty, nullptr, "a_ptr");
AllocaInst *BPtr = Builder.CreateAlloca(Int32Ty, nullptr, "b_ptr");
AllocaInst *CPtr = Builder.CreateAlloca(Int32Ty, nullptr, "c_ptr");
Builder.CreateStore(Builder.getInt32(10), APtr);
Builder.CreateStore(Builder.getInt32(20), BPtr);
auto BodyGenCB = [&](InsertPointTy AllocaIP,
InsertPointTy CodeGenIP) -> InsertPointTy {
Builder.restoreIP(CodeGenIP);
LoadInst *AVal = Builder.CreateLoad(Int32Ty, APtr);
LoadInst *BVal = Builder.CreateLoad(Int32Ty, BPtr);
Value *Sum = Builder.CreateAdd(AVal, BVal);
Builder.CreateStore(Sum, CPtr);
return Builder.saveIP();
};
llvm::SmallVector<llvm::Value *> Inputs;
Inputs.push_back(APtr);
Inputs.push_back(BPtr);
Inputs.push_back(CPtr);
TargetRegionEntryInfo EntryInfo("func", 42, 4711, 17);
OpenMPIRBuilder::LocationDescription OmpLoc({Builder.saveIP(), DL});
Builder.restoreIP(OMPBuilder.createTarget(OmpLoc, Builder.saveIP(), EntryInfo,
-1, -1, Inputs, BodyGenCB));
OMPBuilder.finalize();
Builder.CreateRetVoid();
// Check the outlined call
auto Iter = F->getEntryBlock().rbegin();
CallInst *Call = dyn_cast<CallInst>(&*(++Iter));
EXPECT_NE(Call, nullptr);
// Check that the correct aguments are passed in
for (auto ArgInput : zip(Call->args(), Inputs)) {
EXPECT_EQ(std::get<0>(ArgInput), std::get<1>(ArgInput));
}
// Check that the outlined function exists with the expected prefix
Function *OutlinedFunc = Call->getCalledFunction();
EXPECT_NE(OutlinedFunc, nullptr);
StringRef FunctionName = OutlinedFunc->getName();
EXPECT_TRUE(FunctionName.startswith("__omp_offloading"));
EXPECT_FALSE(verifyModule(*M, &errs()));
}
TEST_F(OpenMPIRBuilderTest, CreateTask) {
using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
OpenMPIRBuilder OMPBuilder(*M);

View File

@@ -162,13 +162,7 @@ public:
/// Returns the OpenMP IR builder associated with the LLVM IR module being
/// constructed.
llvm::OpenMPIRBuilder *getOpenMPBuilder() {
if (!ompBuilder) {
ompBuilder = std::make_unique<llvm::OpenMPIRBuilder>(*llvmModule);
ompBuilder->initialize();
}
return ompBuilder.get();
}
llvm::OpenMPIRBuilder *getOpenMPBuilder();
/// Translates the given location.
const llvm::DILocation *translateLoc(Location loc, llvm::DILocalScope *scope);

View File

@@ -39,6 +39,7 @@ add_mlir_translation_library(MLIRTargetLLVMIRExport
MLIRLLVMDialect
MLIRLLVMIRTransforms
MLIRTranslateLib
MLIROpenMPDialect
)
add_mlir_translation_library(MLIRToLLVMIRTranslationRegistration

View File

@@ -12,11 +12,13 @@
//===----------------------------------------------------------------------===//
#include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/Dialect/OpenMP/OpenMPInterfaces.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Operation.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Target/LLVMIR/Dialect/OpenMPCommon.h"
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/TypeSwitch.h"
@@ -24,6 +26,7 @@
#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
#include "llvm/IR/DebugInfoMetadata.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/Support/FileSystem.h"
using namespace mlir;
@@ -1573,6 +1576,102 @@ LogicalResult convertFlagsAttr(Operation *op, mlir::omp::FlagsAttr attribute,
return success();
}
static llvm::TargetRegionEntryInfo
getTargetEntryUniqueInfo(omp::TargetOp targetOp,
llvm::StringRef parentName = "") {
auto fileLoc = targetOp.getLoc()->findInstanceOf<FileLineColLoc>();
assert(fileLoc && "No file found from location");
StringRef fileName = fileLoc.getFilename().getValue();
llvm::sys::fs::UniqueID id;
if (auto ec = llvm::sys::fs::getUniqueID(fileName, id)) {
targetOp.emitError("Unable to get unique ID for file");
}
uint64_t line = fileLoc.getLine();
return llvm::TargetRegionEntryInfo(parentName, id.getDevice(), id.getFile(),
line);
}
static bool targetOpSupported(Operation &opInst) {
auto targetOp = cast<omp::TargetOp>(opInst);
if (targetOp.getIfExpr()) {
opInst.emitError("If clause not yet supported");
return false;
}
if (targetOp.getDevice()) {
opInst.emitError("Device clause not yet supported");
return false;
}
if (targetOp.getThreadLimit()) {
opInst.emitError("Thread limit clause not yet supported");
return false;
}
if (targetOp.getNowait()) {
opInst.emitError("Nowait clause not yet supported");
return false;
}
return true;
}
static LogicalResult
convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
if (!targetOpSupported(opInst))
return failure();
bool isDevice = false;
if (auto offloadMod = dyn_cast<mlir::omp::OffloadModuleInterface>(
opInst.getParentOfType<mlir::ModuleOp>().getOperation())) {
isDevice = offloadMod.getIsDevice();
}
if (isDevice) // TODO: Implement device codegen.
return success();
auto targetOp = cast<omp::TargetOp>(opInst);
auto &targetRegion = targetOp.getRegion();
llvm::SetVector<Value> operandSet;
getUsedValuesDefinedAbove(targetRegion, operandSet);
// Collect the input arguments.
llvm::SmallVector<llvm::Value *> inputs;
for (Value operand : operandSet)
inputs.push_back(moduleTranslation.lookupValue(operand));
LogicalResult bodyGenStatus = success();
using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
auto bodyCB = [&](InsertPointTy allocaIP,
InsertPointTy codeGenIP) -> InsertPointTy {
builder.restoreIP(codeGenIP);
llvm::BasicBlock *exitBlock = convertOmpOpRegions(
targetRegion, "omp.target", builder, moduleTranslation, bodyGenStatus);
builder.SetInsertPoint(exitBlock);
return builder.saveIP();
};
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
StringRef parentName = opInst.getParentOfType<LLVM::LLVMFuncOp>().getName();
llvm::TargetRegionEntryInfo entryInfo =
getTargetEntryUniqueInfo(targetOp, parentName);
int32_t defaultValTeams = -1;
int32_t defaultValThreads = -1;
builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createTarget(
ompLoc, builder.saveIP(), entryInfo, defaultValTeams, defaultValThreads,
inputs, bodyCB));
return bodyGenStatus;
}
namespace {
/// Implementation of the dialect interface that converts operations belonging
@@ -1713,6 +1812,9 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
.Case<omp::DataOp, omp::EnterDataOp, omp::ExitDataOp>([&](auto op) {
return convertOmpTargetData(op, builder, moduleTranslation);
})
.Case([&](omp::TargetOp) {
return convertOmpTarget(*op, builder, moduleTranslation);
})
.Default([&](Operation *inst) {
return inst->emitError("unsupported OpenMP operation: ")
<< inst->getName();

View File

@@ -21,6 +21,7 @@
#include "mlir/Dialect/LLVMIR/LLVMInterfaces.h"
#include "mlir/Dialect/LLVMIR/Transforms/LegalizeForExport.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/Dialect/OpenMP/OpenMPInterfaces.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -449,6 +450,7 @@ ModuleTranslation::ModuleTranslation(Operation *module,
assert(satisfiesLLVMModule(mlirModule) &&
"mlirModule should honor LLVM's module semantics.");
}
ModuleTranslation::~ModuleTranslation() {
if (ompBuilder)
ompBuilder->finalize();
@@ -1250,6 +1252,26 @@ SmallVector<llvm::Value *> ModuleTranslation::lookupValues(ValueRange values) {
return remapped;
}
llvm::OpenMPIRBuilder *ModuleTranslation::getOpenMPBuilder() {
if (!ompBuilder) {
ompBuilder = std::make_unique<llvm::OpenMPIRBuilder>(*llvmModule);
ompBuilder->initialize();
bool isDevice = false;
if (auto offloadMod =
dyn_cast<mlir::omp::OffloadModuleInterface>(mlirModule))
isDevice = offloadMod.getIsDevice();
// TODO: set the flags when available
llvm::OpenMPIRBuilderConfig Config(
isDevice, /* IsTargetCodegen */ false,
/* HasRequiresUnifiedSharedMemory */ false,
/* OpenMPOffloadMandatory */ false);
ompBuilder->setConfig(Config);
}
return ompBuilder.get();
}
const llvm::DILocation *
ModuleTranslation::translateLoc(Location loc, llvm::DILocalScope *scope) {
return debugTranslation->translateLoc(loc, scope);

View File

@@ -174,3 +174,89 @@ llvm.func @_QPomp_target_enter_exit(%1 : !llvm.ptr<array<1024 x i32>>, %3 : !llv
// CHECK: ret void
// -----
module attributes {omp.is_device = #omp.isdevice<is_device = false>} {
llvm.func @omp_target_region_() {
%0 = llvm.mlir.constant(20 : i32) : i32
%1 = llvm.mlir.constant(10 : i32) : i32
%2 = llvm.mlir.constant(1 : i64) : i64
%3 = llvm.alloca %2 x i32 {bindc_name = "a", in_type = i32, operand_segment_sizes = array<i32: 0, 0>, uniq_name = "_QFomp_target_regionEa"} : (i64) -> !llvm.ptr<i32>
%4 = llvm.mlir.constant(1 : i64) : i64
%5 = llvm.alloca %4 x i32 {bindc_name = "b", in_type = i32, operand_segment_sizes = array<i32: 0, 0>, uniq_name = "_QFomp_target_regionEb"} : (i64) -> !llvm.ptr<i32>
%6 = llvm.mlir.constant(1 : i64) : i64
%7 = llvm.alloca %6 x i32 {bindc_name = "c", in_type = i32, operand_segment_sizes = array<i32: 0, 0>, uniq_name = "_QFomp_target_regionEc"} : (i64) -> !llvm.ptr<i32>
llvm.store %1, %3 : !llvm.ptr<i32>
llvm.store %0, %5 : !llvm.ptr<i32>
omp.target {
%8 = llvm.load %3 : !llvm.ptr<i32>
%9 = llvm.load %5 : !llvm.ptr<i32>
%10 = llvm.add %8, %9 : i32
llvm.store %10, %7 : !llvm.ptr<i32>
omp.terminator
}
llvm.return
}
}
// CHECK: call void @__omp_offloading_[[DEV:.*]]_[[FIL:.*]]_omp_target_region__l[[LINE:.*]](ptr %{{.*}}, ptr %{{.*}}, ptr %{{.*}})
// CHECK: define internal void @__omp_offloading_[[DEV]]_[[FIL]]_omp_target_region__l[[LINE]](ptr %[[ADDR_A:.*]], ptr %[[ADDR_B:.*]], ptr %[[ADDR_C:.*]])
// CHECK: %[[VAL_A:.*]] = load i32, ptr %[[ADDR_A]], align 4
// CHECK: %[[VAL_B:.*]] = load i32, ptr %[[ADDR_B]], align 4
// CHECK: %[[SUM:.*]] = add i32 %[[VAL_A]], %[[VAL_B]]
// CHECK: store i32 %[[SUM]], ptr %[[ADDR_C]], align 4
// -----
module attributes {omp.is_device = #omp.isdevice<is_device = false>} {
llvm.func @omp_target_region_() {
%0 = llvm.mlir.constant(20 : i32) : i32
%1 = llvm.mlir.constant(10 : i32) : i32
%2 = llvm.mlir.constant(1 : i64) : i64
%3 = llvm.alloca %2 x i32 {bindc_name = "a", in_type = i32, operand_segment_sizes = array<i32: 0, 0>, uniq_name = "_QFomp_target_regionEa"} : (i64) -> !llvm.ptr<i32>
%4 = llvm.mlir.constant(1 : i64) : i64
%5 = llvm.alloca %4 x i32 {bindc_name = "b", in_type = i32, operand_segment_sizes = array<i32: 0, 0>, uniq_name = "_QFomp_target_regionEb"} : (i64) -> !llvm.ptr<i32>
%6 = llvm.mlir.constant(1 : i64) : i64
%7 = llvm.alloca %6 x i32 {bindc_name = "c", in_type = i32, operand_segment_sizes = array<i32: 0, 0>, uniq_name = "_QFomp_target_regionEc"} : (i64) -> !llvm.ptr<i32>
llvm.store %1, %3 : !llvm.ptr<i32>
llvm.store %0, %5 : !llvm.ptr<i32>
omp.target {
omp.parallel {
%8 = llvm.load %3 : !llvm.ptr<i32>
%9 = llvm.load %5 : !llvm.ptr<i32>
%10 = llvm.add %8, %9 : i32
llvm.store %10, %7 : !llvm.ptr<i32>
omp.terminator
}
omp.terminator
}
llvm.return
}
}
// CHECK: call void @__omp_offloading_[[DEV:.*]]_[[FIL:.*]]_omp_target_region__l[[LINE:.*]](ptr %{{.*}}, ptr %{{.*}}, ptr %{{.*}})
// CHECK: define internal void @__omp_offloading_[[DEV]]_[[FIL]]_omp_target_region__l[[LINE]](ptr %[[ADDR_A:.*]], ptr %[[ADDR_B:.*]], ptr %[[ADDR_C:.*]])
// CHECK: %[[STRUCTARG:.*]] = alloca { ptr, ptr, ptr }, align 8
// CHECK: %[[GEP1:.*]] = getelementptr { ptr, ptr, ptr }, ptr %[[STRUCTARG]], i32 0, i32 0
// CHECK: store ptr %[[ADDR_A]], ptr %[[GEP1]], align 8
// CHECK: %[[GEP2:.*]] = getelementptr { ptr, ptr, ptr }, ptr %[[STRUCTARG]], i32 0, i32 1
// CHECK: store ptr %[[ADDR_B]], ptr %[[GEP2]], align 8
// CHECK: %[[GEP3:.*]] = getelementptr { ptr, ptr, ptr }, ptr %[[STRUCTARG]], i32 0, i32 2
// CHECK: store ptr %[[ADDR_C]], ptr %[[GEP3]], align 8
// CHECK: call void (ptr, i32, ptr, ...) @__kmpc_fork_call(ptr @1, i32 1, ptr @__omp_offloading_[[DEV]]_[[FIL]]_omp_target_region__l[[LINE]]..omp_par, ptr %[[STRUCTARG]])
// CHECK: define internal void @__omp_offloading_[[DEV]]_[[FIL]]_omp_target_region__l[[LINE]]..omp_par(ptr noalias %tid.addr, ptr noalias %zero.addr, ptr %[[STRUCTARG2:.*]]) #0 {
// CHECK: %[[GEP4:.*]] = getelementptr { ptr, ptr, ptr }, ptr %[[STRUCTARG2]], i32 0, i32 0
// CHECK: %[[LOADGEP1:.*]] = load ptr, ptr %[[GEP4]], align 8
// CHECK: %[[GEP5:.*]] = getelementptr { ptr, ptr, ptr }, ptr %0, i32 0, i32 1
// CHECK: %[[LOADGEP2:.*]] = load ptr, ptr %[[GEP5]], align 8
// CHECK: %[[GEP6:.*]] = getelementptr { ptr, ptr, ptr }, ptr %0, i32 0, i32 2
// CHECK: %[[LOADGEP3:.*]] = load ptr, ptr %[[GEP6]], align 8
// CHECK: %[[VAL_A:.*]] = load i32, ptr %[[LOADGEP1]], align 4
// CHECK: %[[VAL_B:.*]] = load i32, ptr %[[LOADGEP2]], align 4
// CHECK: %[[SUM:.*]] = add i32 %[[VAL_A]], %[[VAL_B]]
// CHECK: store i32 %[[SUM]], ptr %[[LOADGEP3]], align 4
// -----