[flang][cuda] Allocate descriptor in managed memory when memref is a block argument (#123829)
This commit is contained in:
committed by
GitHub
parent
e45de3dba7
commit
c26e1a22df
@@ -1725,15 +1725,35 @@ struct EmboxOpConversion : public EmboxCommonConversion<fir::EmboxOp> {
|
||||
}
|
||||
};
|
||||
|
||||
static bool isDeviceAllocation(mlir::Value val) {
|
||||
static bool isDeviceAllocation(mlir::Value val, mlir::Value adaptorVal) {
|
||||
if (auto loadOp = mlir::dyn_cast_or_null<fir::LoadOp>(val.getDefiningOp()))
|
||||
return isDeviceAllocation(loadOp.getMemref());
|
||||
return isDeviceAllocation(loadOp.getMemref(), {});
|
||||
if (auto boxAddrOp =
|
||||
mlir::dyn_cast_or_null<fir::BoxAddrOp>(val.getDefiningOp()))
|
||||
return isDeviceAllocation(boxAddrOp.getVal());
|
||||
return isDeviceAllocation(boxAddrOp.getVal(), {});
|
||||
if (auto convertOp =
|
||||
mlir::dyn_cast_or_null<fir::ConvertOp>(val.getDefiningOp()))
|
||||
return isDeviceAllocation(convertOp.getValue());
|
||||
return isDeviceAllocation(convertOp.getValue(), {});
|
||||
if (!val.getDefiningOp() && adaptorVal) {
|
||||
if (auto blockArg = llvm::cast<mlir::BlockArgument>(adaptorVal)) {
|
||||
if (blockArg.getOwner() && blockArg.getOwner()->getParentOp() &&
|
||||
blockArg.getOwner()->isEntryBlock()) {
|
||||
if (auto func = mlir::dyn_cast_or_null<mlir::FunctionOpInterface>(
|
||||
*blockArg.getOwner()->getParentOp())) {
|
||||
auto argAttrs = func.getArgAttrs(blockArg.getArgNumber());
|
||||
for (auto attr : argAttrs) {
|
||||
if (attr.getName().getValue().ends_with(cuf::getDataAttrName())) {
|
||||
auto dataAttr =
|
||||
mlir::dyn_cast<cuf::DataAttributeAttr>(attr.getValue());
|
||||
if (dataAttr.getValue() != cuf::DataAttribute::Pinned &&
|
||||
dataAttr.getValue() != cuf::DataAttribute::Unified)
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (auto callOp = mlir::dyn_cast_or_null<fir::CallOp>(val.getDefiningOp()))
|
||||
if (callOp.getCallee() &&
|
||||
(callOp.getCallee().value().getRootReference().getValue().starts_with(
|
||||
@@ -1928,7 +1948,8 @@ struct XEmboxOpConversion : public EmboxCommonConversion<fir::cg::XEmboxOp> {
|
||||
if (fir::isDerivedTypeWithLenParams(boxTy))
|
||||
TODO(loc, "fir.embox codegen of derived with length parameters");
|
||||
mlir::Value result = placeInMemoryIfNotGlobalInit(
|
||||
rewriter, loc, boxTy, dest, isDeviceAllocation(xbox.getMemref()));
|
||||
rewriter, loc, boxTy, dest,
|
||||
isDeviceAllocation(xbox.getMemref(), adaptor.getMemref()));
|
||||
rewriter.replaceOp(xbox, result);
|
||||
return mlir::success();
|
||||
}
|
||||
@@ -2052,9 +2073,9 @@ private:
|
||||
dest = insertStride(rewriter, loc, dest, dim, std::get<1>(iter.value()));
|
||||
}
|
||||
dest = insertBaseAddress(rewriter, loc, dest, base);
|
||||
mlir::Value result =
|
||||
placeInMemoryIfNotGlobalInit(rewriter, rebox.getLoc(), destBoxTy, dest,
|
||||
isDeviceAllocation(rebox.getBox()));
|
||||
mlir::Value result = placeInMemoryIfNotGlobalInit(
|
||||
rewriter, rebox.getLoc(), destBoxTy, dest,
|
||||
isDeviceAllocation(rebox.getBox(), rebox.getBox()));
|
||||
rewriter.replaceOp(rebox, result);
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
@@ -170,3 +170,20 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<!llvm.ptr<270> = dense<32> : vec
|
||||
|
||||
// CHECK-LABEL: llvm.func @_QQmain()
|
||||
// CHECK-COUNT-3: llvm.call @_FortranACUFAllocDescriptor
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {dlti.dl_spec = #dlti.dl_spec<!llvm.ptr<270> = dense<32> : vector<4xi64>, f128 = dense<128> : vector<2xi64>, f64 = dense<64> : vector<2xi64>, f16 = dense<16> : vector<2xi64>, i32 = dense<32> : vector<2xi64>, i64 = dense<64> : vector<2xi64>, !llvm.ptr<272> = dense<64> : vector<4xi64>, !llvm.ptr<271> = dense<32> : vector<4xi64>, f80 = dense<128> : vector<2xi64>, i128 = dense<128> : vector<2xi64>, i16 = dense<16> : vector<2xi64>, i8 = dense<8> : vector<2xi64>, !llvm.ptr = dense<64> : vector<4xi64>, i1 = dense<8> : vector<2xi64>, "dlti.endianness" = "little", "dlti.stack_alignment" = 128 : i64>, fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", gpu.container_module, llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", llvm.ident = "flang version 20.0.0 (git@github.com:clementval/llvm-project.git efc2415bcce8e8a9e73e77aa122c8aba1c1fbbd2)", llvm.target_triple = "x86_64-unknown-linux-gnu"} {
|
||||
func.func @_QPouter(%arg0: !fir.ref<!fir.array<100x100xf64>> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "a"}) {
|
||||
%c0_i32 = arith.constant 0 : i32
|
||||
%c100 = arith.constant 100 : index
|
||||
%0 = fir.alloca tuple<!fir.box<!fir.array<100x100xf64>>>
|
||||
%1 = fir.coordinate_of %0, %c0_i32 : (!fir.ref<tuple<!fir.box<!fir.array<100x100xf64>>>>, i32) -> !fir.ref<!fir.box<!fir.array<100x100xf64>>>
|
||||
%2 = fircg.ext_embox %arg0(%c100, %c100) : (!fir.ref<!fir.array<100x100xf64>>, index, index) -> !fir.box<!fir.array<100x100xf64>>
|
||||
fir.store %2 to %1 : !fir.ref<!fir.box<!fir.array<100x100xf64>>>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: llvm.func @_QPouter
|
||||
// CHECK: _FortranACUFAllocDescriptor
|
||||
|
||||
Reference in New Issue
Block a user