[flang][cuda] Allocate descriptor in managed memory on rebox block argument (#123971)

Another case where the descriptor must be allocated with the CUF runtime
and not a simple alloca instruction.
This commit is contained in:
Valentin Clement (バレンタイン クレメン)
2025-01-22 10:04:39 -08:00
committed by GitHub
parent afcbcae668
commit 9f83c4ed1c
2 changed files with 31 additions and 18 deletions

View File

@@ -2040,19 +2040,20 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
getBaseAddrFromBox(loc, inputBoxTyPair, loweredBox, rewriter);
if (!rebox.getSlice().empty() || !rebox.getSubcomponent().empty())
return sliceBox(rebox, boxTy, dest, baseAddr, inputExtents, inputStrides,
operands, rewriter);
return reshapeBox(rebox, boxTy, dest, baseAddr, inputExtents, inputStrides,
operands, rewriter);
return sliceBox(rebox, adaptor, boxTy, dest, baseAddr, inputExtents,
inputStrides, operands, rewriter);
return reshapeBox(rebox, adaptor, boxTy, dest, baseAddr, inputExtents,
inputStrides, operands, rewriter);
}
private:
/// Write resulting shape and base address in descriptor, and replace rebox
/// op.
llvm::LogicalResult
finalizeRebox(fir::cg::XReboxOp rebox, mlir::Type destBoxTy, mlir::Value dest,
mlir::Value base, mlir::ValueRange lbounds,
mlir::ValueRange extents, mlir::ValueRange strides,
finalizeRebox(fir::cg::XReboxOp rebox, OpAdaptor adaptor,
mlir::Type destBoxTy, mlir::Value dest, mlir::Value base,
mlir::ValueRange lbounds, mlir::ValueRange extents,
mlir::ValueRange strides,
mlir::ConversionPatternRewriter &rewriter) const {
mlir::Location loc = rebox.getLoc();
mlir::Value zero =
@@ -2075,15 +2076,15 @@ private:
dest = insertBaseAddress(rewriter, loc, dest, base);
mlir::Value result = placeInMemoryIfNotGlobalInit(
rewriter, rebox.getLoc(), destBoxTy, dest,
isDeviceAllocation(rebox.getBox(), rebox.getBox()));
isDeviceAllocation(rebox.getBox(), adaptor.getBox()));
rewriter.replaceOp(rebox, result);
return mlir::success();
}
// Apply slice given the base address, extents and strides of the input box.
llvm::LogicalResult
sliceBox(fir::cg::XReboxOp rebox, mlir::Type destBoxTy, mlir::Value dest,
mlir::Value base, mlir::ValueRange inputExtents,
sliceBox(fir::cg::XReboxOp rebox, OpAdaptor adaptor, mlir::Type destBoxTy,
mlir::Value dest, mlir::Value base, mlir::ValueRange inputExtents,
mlir::ValueRange inputStrides, mlir::ValueRange operands,
mlir::ConversionPatternRewriter &rewriter) const {
mlir::Location loc = rebox.getLoc();
@@ -2109,7 +2110,7 @@ private:
if (rebox.getSlice().empty())
// The array section is of the form array[%component][substring], keep
// the input array extents and strides.
return finalizeRebox(rebox, destBoxTy, dest, base,
return finalizeRebox(rebox, adaptor, destBoxTy, dest, base,
/*lbounds*/ std::nullopt, inputExtents, inputStrides,
rewriter);
@@ -2158,15 +2159,16 @@ private:
slicedStrides.emplace_back(stride);
}
}
return finalizeRebox(rebox, destBoxTy, dest, base, /*lbounds*/ std::nullopt,
slicedExtents, slicedStrides, rewriter);
return finalizeRebox(rebox, adaptor, destBoxTy, dest, base,
/*lbounds*/ std::nullopt, slicedExtents, slicedStrides,
rewriter);
}
/// Apply a new shape to the data described by a box given the base address,
/// extents and strides of the box.
llvm::LogicalResult
reshapeBox(fir::cg::XReboxOp rebox, mlir::Type destBoxTy, mlir::Value dest,
mlir::Value base, mlir::ValueRange inputExtents,
reshapeBox(fir::cg::XReboxOp rebox, OpAdaptor adaptor, mlir::Type destBoxTy,
mlir::Value dest, mlir::Value base, mlir::ValueRange inputExtents,
mlir::ValueRange inputStrides, mlir::ValueRange operands,
mlir::ConversionPatternRewriter &rewriter) const {
mlir::ValueRange reboxShifts{
@@ -2175,7 +2177,7 @@ private:
rebox.getShift().size()};
if (rebox.getShape().empty()) {
// Only setting new lower bounds.
return finalizeRebox(rebox, destBoxTy, dest, base, reboxShifts,
return finalizeRebox(rebox, adaptor, destBoxTy, dest, base, reboxShifts,
inputExtents, inputStrides, rewriter);
}
@@ -2199,8 +2201,8 @@ private:
// nextStride = extent * stride;
stride = rewriter.create<mlir::LLVM::MulOp>(loc, idxTy, extent, stride);
}
return finalizeRebox(rebox, destBoxTy, dest, base, reboxShifts, newExtents,
newStrides, rewriter);
return finalizeRebox(rebox, adaptor, destBoxTy, dest, base, reboxShifts,
newExtents, newStrides, rewriter);
}
/// Return scalar element type of the input box.

View File

@@ -187,3 +187,14 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<!llvm.ptr<270> = dense<32> : vec
// CHECK-LABEL: llvm.func @_QPouter
// CHECK: _FortranACUFAllocDescriptor
// -----
func.func @_QMm1Psub1(%arg0: !fir.box<!fir.array<?xi32>> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "da"}, %arg1: !fir.box<!fir.array<?xi32>> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "db"}, %arg2: !fir.ref<i32> {fir.bindc_name = "n"}) {
%0 = fircg.ext_rebox %arg0 : (!fir.box<!fir.array<?xi32>>) -> !fir.box<!fir.array<?xi32>>
%1 = fircg.ext_rebox %arg1 : (!fir.box<!fir.array<?xi32>>) -> !fir.box<!fir.array<?xi32>>
return
}
// CHECK-LABEL: llvm.func @_QMm1Psub1
// CHECK-COUNT-2: _FortranACUFAllocDescriptor