[flang][cuda] Allocate descriptor in managed memory when emboxing device memory (#120485)
When emboxing memory that comes from CUFMemAlloc, we need to allocate the descriptor in manage memory as it might be passed to a kernel.
This commit is contained in:
committed by
GitHub
parent
ffff7bb582
commit
4530273d7c
@@ -24,6 +24,7 @@
|
||||
#include "flang/Optimizer/Support/TypeCode.h"
|
||||
#include "flang/Optimizer/Support/Utils.h"
|
||||
#include "flang/Runtime/CUDA/descriptor.h"
|
||||
#include "flang/Runtime/CUDA/memory.h"
|
||||
#include "flang/Runtime/allocator-registry-consts.h"
|
||||
#include "flang/Runtime/descriptor-consts.h"
|
||||
#include "flang/Semantics/runtime-type-info.h"
|
||||
@@ -1141,6 +1142,93 @@ convertSubcomponentIndices(mlir::Location loc, mlir::Type eleTy,
|
||||
return result;
|
||||
}
|
||||
|
||||
static mlir::Value genSourceFile(mlir::Location loc, mlir::ModuleOp mod,
|
||||
mlir::ConversionPatternRewriter &rewriter) {
|
||||
auto ptrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext());
|
||||
if (auto flc = mlir::dyn_cast<mlir::FileLineColLoc>(loc)) {
|
||||
auto fn = flc.getFilename().str() + '\0';
|
||||
std::string globalName = fir::factory::uniqueCGIdent("cl", fn);
|
||||
|
||||
if (auto g = mod.lookupSymbol<fir::GlobalOp>(globalName)) {
|
||||
return rewriter.create<mlir::LLVM::AddressOfOp>(loc, ptrTy, g.getName());
|
||||
} else if (auto g = mod.lookupSymbol<mlir::LLVM::GlobalOp>(globalName)) {
|
||||
return rewriter.create<mlir::LLVM::AddressOfOp>(loc, ptrTy, g.getName());
|
||||
}
|
||||
|
||||
auto crtInsPt = rewriter.saveInsertionPoint();
|
||||
rewriter.setInsertionPoint(mod.getBody(), mod.getBody()->end());
|
||||
auto arrayTy = mlir::LLVM::LLVMArrayType::get(
|
||||
mlir::IntegerType::get(rewriter.getContext(), 8), fn.size());
|
||||
mlir::LLVM::GlobalOp globalOp = rewriter.create<mlir::LLVM::GlobalOp>(
|
||||
loc, arrayTy, /*constant=*/true, mlir::LLVM::Linkage::Linkonce,
|
||||
globalName, mlir::Attribute());
|
||||
|
||||
mlir::Region ®ion = globalOp.getInitializerRegion();
|
||||
mlir::Block *block = rewriter.createBlock(®ion);
|
||||
rewriter.setInsertionPoint(block, block->begin());
|
||||
mlir::Value constValue = rewriter.create<mlir::LLVM::ConstantOp>(
|
||||
loc, arrayTy, rewriter.getStringAttr(fn));
|
||||
rewriter.create<mlir::LLVM::ReturnOp>(loc, constValue);
|
||||
rewriter.restoreInsertionPoint(crtInsPt);
|
||||
return rewriter.create<mlir::LLVM::AddressOfOp>(loc, ptrTy,
|
||||
globalOp.getName());
|
||||
}
|
||||
return rewriter.create<mlir::LLVM::ZeroOp>(loc, ptrTy);
|
||||
}
|
||||
|
||||
static mlir::Value genSourceLine(mlir::Location loc,
|
||||
mlir::ConversionPatternRewriter &rewriter) {
|
||||
if (auto flc = mlir::dyn_cast<mlir::FileLineColLoc>(loc))
|
||||
return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(),
|
||||
flc.getLine());
|
||||
return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(), 0);
|
||||
}
|
||||
|
||||
static mlir::Value
|
||||
genCUFAllocDescriptor(mlir::Location loc,
|
||||
mlir::ConversionPatternRewriter &rewriter,
|
||||
mlir::ModuleOp mod, fir::BaseBoxType boxTy,
|
||||
const fir::LLVMTypeConverter &typeConverter) {
|
||||
std::optional<mlir::DataLayout> dl =
|
||||
fir::support::getOrSetDataLayout(mod, /*allowDefaultLayout=*/true);
|
||||
if (!dl)
|
||||
mlir::emitError(mod.getLoc(),
|
||||
"module operation must carry a data layout attribute "
|
||||
"to generate llvm IR from FIR");
|
||||
|
||||
mlir::Value sourceFile = genSourceFile(loc, mod, rewriter);
|
||||
mlir::Value sourceLine = genSourceLine(loc, rewriter);
|
||||
|
||||
mlir::MLIRContext *ctx = mod.getContext();
|
||||
|
||||
mlir::LLVM::LLVMPointerType llvmPointerType =
|
||||
mlir::LLVM::LLVMPointerType::get(ctx);
|
||||
mlir::Type llvmInt32Type = mlir::IntegerType::get(ctx, 32);
|
||||
mlir::Type llvmIntPtrType =
|
||||
mlir::IntegerType::get(ctx, typeConverter.getPointerBitwidth(0));
|
||||
auto fctTy = mlir::LLVM::LLVMFunctionType::get(
|
||||
llvmPointerType, {llvmIntPtrType, llvmPointerType, llvmInt32Type});
|
||||
|
||||
auto llvmFunc = mod.lookupSymbol<mlir::LLVM::LLVMFuncOp>(
|
||||
RTNAME_STRING(CUFAllocDesciptor));
|
||||
auto funcFunc =
|
||||
mod.lookupSymbol<mlir::func::FuncOp>(RTNAME_STRING(CUFAllocDesciptor));
|
||||
if (!llvmFunc && !funcFunc)
|
||||
mlir::OpBuilder::atBlockEnd(mod.getBody())
|
||||
.create<mlir::LLVM::LLVMFuncOp>(loc, RTNAME_STRING(CUFAllocDesciptor),
|
||||
fctTy);
|
||||
|
||||
mlir::Type structTy = typeConverter.convertBoxTypeAsStruct(boxTy);
|
||||
std::size_t boxSize = dl->getTypeSizeInBits(structTy) / 8;
|
||||
mlir::Value sizeInBytes =
|
||||
genConstantIndex(loc, llvmIntPtrType, rewriter, boxSize);
|
||||
llvm::SmallVector args = {sizeInBytes, sourceFile, sourceLine};
|
||||
return rewriter
|
||||
.create<mlir::LLVM::CallOp>(loc, fctTy, RTNAME_STRING(CUFAllocDesciptor),
|
||||
args)
|
||||
.getResult();
|
||||
}
|
||||
|
||||
/// Common base class for embox to descriptor conversion.
|
||||
template <typename OP>
|
||||
struct EmboxCommonConversion : public fir::FIROpConversion<OP> {
|
||||
@@ -1554,15 +1642,24 @@ struct EmboxCommonConversion : public fir::FIROpConversion<OP> {
|
||||
mlir::Value
|
||||
placeInMemoryIfNotGlobalInit(mlir::ConversionPatternRewriter &rewriter,
|
||||
mlir::Location loc, mlir::Type boxTy,
|
||||
mlir::Value boxValue) const {
|
||||
mlir::Value boxValue,
|
||||
bool needDeviceAllocation = false) const {
|
||||
if (isInGlobalOp(rewriter))
|
||||
return boxValue;
|
||||
mlir::Type llvmBoxTy = boxValue.getType();
|
||||
auto alloca = this->genAllocaAndAddrCastWithType(loc, llvmBoxTy,
|
||||
defaultAlign, rewriter);
|
||||
auto storeOp = rewriter.create<mlir::LLVM::StoreOp>(loc, boxValue, alloca);
|
||||
mlir::Value storage;
|
||||
if (needDeviceAllocation) {
|
||||
auto mod = boxValue.getDefiningOp()->getParentOfType<mlir::ModuleOp>();
|
||||
auto baseBoxTy = mlir::dyn_cast<fir::BaseBoxType>(boxTy);
|
||||
storage =
|
||||
genCUFAllocDescriptor(loc, rewriter, mod, baseBoxTy, this->lowerTy());
|
||||
} else {
|
||||
storage = this->genAllocaAndAddrCastWithType(loc, llvmBoxTy, defaultAlign,
|
||||
rewriter);
|
||||
}
|
||||
auto storeOp = rewriter.create<mlir::LLVM::StoreOp>(loc, boxValue, storage);
|
||||
this->attachTBAATag(storeOp, boxTy, boxTy, nullptr);
|
||||
return alloca;
|
||||
return storage;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1614,6 +1711,18 @@ struct EmboxOpConversion : public EmboxCommonConversion<fir::EmboxOp> {
|
||||
}
|
||||
};
|
||||
|
||||
static bool isDeviceAllocation(mlir::Value val) {
|
||||
if (auto convertOp =
|
||||
mlir::dyn_cast_or_null<fir::ConvertOp>(val.getDefiningOp()))
|
||||
val = convertOp.getValue();
|
||||
if (auto callOp = mlir::dyn_cast_or_null<fir::CallOp>(val.getDefiningOp()))
|
||||
if (callOp.getCallee() &&
|
||||
callOp.getCallee().value().getRootReference().getValue().starts_with(
|
||||
RTNAME_STRING(CUFMemAlloc)))
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Create a generic box on a memory reference.
|
||||
struct XEmboxOpConversion : public EmboxCommonConversion<fir::cg::XEmboxOp> {
|
||||
using EmboxCommonConversion::EmboxCommonConversion;
|
||||
@@ -1797,9 +1906,8 @@ struct XEmboxOpConversion : public EmboxCommonConversion<fir::cg::XEmboxOp> {
|
||||
dest = insertBaseAddress(rewriter, loc, dest, base);
|
||||
if (fir::isDerivedTypeWithLenParams(boxTy))
|
||||
TODO(loc, "fir.embox codegen of derived with length parameters");
|
||||
|
||||
mlir::Value result =
|
||||
placeInMemoryIfNotGlobalInit(rewriter, loc, boxTy, dest);
|
||||
mlir::Value result = placeInMemoryIfNotGlobalInit(
|
||||
rewriter, loc, boxTy, dest, isDeviceAllocation(xbox.getMemref()));
|
||||
rewriter.replaceOp(xbox, result);
|
||||
return mlir::success();
|
||||
}
|
||||
@@ -2977,93 +3085,6 @@ private:
|
||||
}
|
||||
};
|
||||
|
||||
static mlir::Value genSourceFile(mlir::Location loc, mlir::ModuleOp mod,
|
||||
mlir::ConversionPatternRewriter &rewriter) {
|
||||
auto ptrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext());
|
||||
if (auto flc = mlir::dyn_cast<mlir::FileLineColLoc>(loc)) {
|
||||
auto fn = flc.getFilename().str() + '\0';
|
||||
std::string globalName = fir::factory::uniqueCGIdent("cl", fn);
|
||||
|
||||
if (auto g = mod.lookupSymbol<fir::GlobalOp>(globalName)) {
|
||||
return rewriter.create<mlir::LLVM::AddressOfOp>(loc, ptrTy, g.getName());
|
||||
} else if (auto g = mod.lookupSymbol<mlir::LLVM::GlobalOp>(globalName)) {
|
||||
return rewriter.create<mlir::LLVM::AddressOfOp>(loc, ptrTy, g.getName());
|
||||
}
|
||||
|
||||
auto crtInsPt = rewriter.saveInsertionPoint();
|
||||
rewriter.setInsertionPoint(mod.getBody(), mod.getBody()->end());
|
||||
auto arrayTy = mlir::LLVM::LLVMArrayType::get(
|
||||
mlir::IntegerType::get(rewriter.getContext(), 8), fn.size());
|
||||
mlir::LLVM::GlobalOp globalOp = rewriter.create<mlir::LLVM::GlobalOp>(
|
||||
loc, arrayTy, /*constant=*/true, mlir::LLVM::Linkage::Linkonce,
|
||||
globalName, mlir::Attribute());
|
||||
|
||||
mlir::Region ®ion = globalOp.getInitializerRegion();
|
||||
mlir::Block *block = rewriter.createBlock(®ion);
|
||||
rewriter.setInsertionPoint(block, block->begin());
|
||||
mlir::Value constValue = rewriter.create<mlir::LLVM::ConstantOp>(
|
||||
loc, arrayTy, rewriter.getStringAttr(fn));
|
||||
rewriter.create<mlir::LLVM::ReturnOp>(loc, constValue);
|
||||
rewriter.restoreInsertionPoint(crtInsPt);
|
||||
return rewriter.create<mlir::LLVM::AddressOfOp>(loc, ptrTy,
|
||||
globalOp.getName());
|
||||
}
|
||||
return rewriter.create<mlir::LLVM::ZeroOp>(loc, ptrTy);
|
||||
}
|
||||
|
||||
static mlir::Value genSourceLine(mlir::Location loc,
|
||||
mlir::ConversionPatternRewriter &rewriter) {
|
||||
if (auto flc = mlir::dyn_cast<mlir::FileLineColLoc>(loc))
|
||||
return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(),
|
||||
flc.getLine());
|
||||
return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(), 0);
|
||||
}
|
||||
|
||||
static mlir::Value
|
||||
genCUFAllocDescriptor(mlir::Location loc,
|
||||
mlir::ConversionPatternRewriter &rewriter,
|
||||
mlir::ModuleOp mod, fir::BaseBoxType boxTy,
|
||||
const fir::LLVMTypeConverter &typeConverter) {
|
||||
std::optional<mlir::DataLayout> dl =
|
||||
fir::support::getOrSetDataLayout(mod, /*allowDefaultLayout=*/true);
|
||||
if (!dl)
|
||||
mlir::emitError(mod.getLoc(),
|
||||
"module operation must carry a data layout attribute "
|
||||
"to generate llvm IR from FIR");
|
||||
|
||||
mlir::Value sourceFile = genSourceFile(loc, mod, rewriter);
|
||||
mlir::Value sourceLine = genSourceLine(loc, rewriter);
|
||||
|
||||
mlir::MLIRContext *ctx = mod.getContext();
|
||||
|
||||
mlir::LLVM::LLVMPointerType llvmPointerType =
|
||||
mlir::LLVM::LLVMPointerType::get(ctx);
|
||||
mlir::Type llvmInt32Type = mlir::IntegerType::get(ctx, 32);
|
||||
mlir::Type llvmIntPtrType =
|
||||
mlir::IntegerType::get(ctx, typeConverter.getPointerBitwidth(0));
|
||||
auto fctTy = mlir::LLVM::LLVMFunctionType::get(
|
||||
llvmPointerType, {llvmIntPtrType, llvmPointerType, llvmInt32Type});
|
||||
|
||||
auto llvmFunc = mod.lookupSymbol<mlir::LLVM::LLVMFuncOp>(
|
||||
RTNAME_STRING(CUFAllocDesciptor));
|
||||
auto funcFunc =
|
||||
mod.lookupSymbol<mlir::func::FuncOp>(RTNAME_STRING(CUFAllocDesciptor));
|
||||
if (!llvmFunc && !funcFunc)
|
||||
mlir::OpBuilder::atBlockEnd(mod.getBody())
|
||||
.create<mlir::LLVM::LLVMFuncOp>(loc, RTNAME_STRING(CUFAllocDesciptor),
|
||||
fctTy);
|
||||
|
||||
mlir::Type structTy = typeConverter.convertBoxTypeAsStruct(boxTy);
|
||||
std::size_t boxSize = dl->getTypeSizeInBits(structTy) / 8;
|
||||
mlir::Value sizeInBytes =
|
||||
genConstantIndex(loc, llvmIntPtrType, rewriter, boxSize);
|
||||
llvm::SmallVector args = {sizeInBytes, sourceFile, sourceLine};
|
||||
return rewriter
|
||||
.create<mlir::LLVM::CallOp>(loc, fctTy, RTNAME_STRING(CUFAllocDesciptor),
|
||||
args)
|
||||
.getResult();
|
||||
}
|
||||
|
||||
/// `fir.load` --> `llvm.load`
|
||||
struct LoadOpConversion : public fir::FIROpConversion<fir::LoadOp> {
|
||||
using FIROpConversion::FIROpConversion;
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
// RUN: fir-opt --split-input-file --fir-to-llvm-ir="target=x86_64-unknown-linux-gnu" %s | FileCheck %s
|
||||
|
||||
module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<f80, dense<128> : vector<2xi64>>, #dlti.dl_entry<i128, dense<128> : vector<2xi64>>, #dlti.dl_entry<i64, dense<64> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr<272>, dense<64> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<271>, dense<32> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<270>, dense<32> : vector<4xi64>>, #dlti.dl_entry<f128, dense<128> : vector<2xi64>>, #dlti.dl_entry<f64, dense<64> : vector<2xi64>>, #dlti.dl_entry<f16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i32, dense<32> : vector<2xi64>>, #dlti.dl_entry<i16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i8, dense<8> : vector<2xi64>>, #dlti.dl_entry<i1, dense<8> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr, dense<64> : vector<4xi64>>, #dlti.dl_entry<"dlti.endianness", "little">, #dlti.dl_entry<"dlti.stack_alignment", 128 : i64>>} {
|
||||
|
||||
func.func @_QQmain() attributes {fir.bindc_name = "cufkernel_global"} {
|
||||
%c0 = arith.constant 0 : index
|
||||
%0 = fir.address_of(@_QQclX3C737464696E3E00) : !fir.ref<!fir.char<1,8>>
|
||||
@@ -27,3 +26,33 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<f80, dense<128> :
|
||||
}
|
||||
func.func private @_FortranACUFAllocDesciptor(i64, !fir.ref<i8>, i32) -> !fir.ref<!fir.box<none>> attributes {fir.runtime}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {dlti.dl_spec = #dlti.dl_spec<f80 = dense<128> : vector<2xi64>, i128 = dense<128> : vector<2xi64>, i64 = dense<64> : vector<2xi64>, !llvm.ptr<272> = dense<64> : vector<4xi64>, !llvm.ptr<271> = dense<32> : vector<4xi64>, !llvm.ptr<270> = dense<32> : vector<4xi64>, f128 = dense<128> : vector<2xi64>, f64 = dense<64> : vector<2xi64>, f16 = dense<16> : vector<2xi64>, i32 = dense<32> : vector<2xi64>, i16 = dense<16> : vector<2xi64>, i8 = dense<8> : vector<2xi64>, i1 = dense<8> : vector<2xi64>, !llvm.ptr = dense<64> : vector<4xi64>, "dlti.endianness" = "little", "dlti.stack_alignment" = 128 : i64>} {
|
||||
func.func @_QQmain() attributes {fir.bindc_name = "test"} {
|
||||
%c10 = arith.constant 10 : index
|
||||
%c20 = arith.constant 20 : index
|
||||
%0 = fir.address_of(@_QQclX64756D6D792E6D6C697200) : !fir.ref<!fir.char<1,11>>
|
||||
%c4 = arith.constant 4 : index
|
||||
%c200 = arith.constant 200 : index
|
||||
%1 = arith.muli %c200, %c4 : index
|
||||
%c6_i32 = arith.constant 6 : i32
|
||||
%c0_i32 = arith.constant 0 : i32
|
||||
%2 = fir.convert %1 : (index) -> i64
|
||||
%3 = fir.convert %0 : (!fir.ref<!fir.char<1,11>>) -> !fir.ref<i8>
|
||||
%4 = fir.call @_FortranACUFMemAlloc(%2, %c0_i32, %3, %c6_i32) : (i64, i32, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
|
||||
%5 = fir.convert %4 : (!fir.llvm_ptr<i8>) -> !fir.ref<!fir.array<10x20xi32>>
|
||||
%6 = fircg.ext_embox %5(%c10, %c20) : (!fir.ref<!fir.array<10x20xi32>>, index, index) -> !fir.box<!fir.array<10x20xi32>>
|
||||
return
|
||||
}
|
||||
fir.global linkonce @_QQclX64756D6D792E6D6C697200 constant : !fir.char<1,11> {
|
||||
%0 = fir.string_lit "dummy.mlir\00"(11) : !fir.char<1,11>
|
||||
fir.has_value %0 : !fir.char<1,11>
|
||||
}
|
||||
func.func private @_FortranACUFMemAlloc(i64, i32, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8> attributes {fir.runtime}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: llvm.func @_QQmain()
|
||||
// CHECK: llvm.call @_FortranACUFMemAlloc
|
||||
// CHECK: llvm.call @_FortranACUFAllocDesciptor
|
||||
|
||||
Reference in New Issue
Block a user