[flang][cuda] Allocate descriptor in managed memory on rebox block argument (#123971)
Another case where the descriptor must be allocated with the CUF runtime and not a simple alloca instruction.
This commit is contained in:
committed by
GitHub
parent
afcbcae668
commit
9f83c4ed1c
@@ -2040,19 +2040,20 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
|
||||
getBaseAddrFromBox(loc, inputBoxTyPair, loweredBox, rewriter);
|
||||
|
||||
if (!rebox.getSlice().empty() || !rebox.getSubcomponent().empty())
|
||||
return sliceBox(rebox, boxTy, dest, baseAddr, inputExtents, inputStrides,
|
||||
operands, rewriter);
|
||||
return reshapeBox(rebox, boxTy, dest, baseAddr, inputExtents, inputStrides,
|
||||
operands, rewriter);
|
||||
return sliceBox(rebox, adaptor, boxTy, dest, baseAddr, inputExtents,
|
||||
inputStrides, operands, rewriter);
|
||||
return reshapeBox(rebox, adaptor, boxTy, dest, baseAddr, inputExtents,
|
||||
inputStrides, operands, rewriter);
|
||||
}
|
||||
|
||||
private:
|
||||
/// Write resulting shape and base address in descriptor, and replace rebox
|
||||
/// op.
|
||||
llvm::LogicalResult
|
||||
finalizeRebox(fir::cg::XReboxOp rebox, mlir::Type destBoxTy, mlir::Value dest,
|
||||
mlir::Value base, mlir::ValueRange lbounds,
|
||||
mlir::ValueRange extents, mlir::ValueRange strides,
|
||||
finalizeRebox(fir::cg::XReboxOp rebox, OpAdaptor adaptor,
|
||||
mlir::Type destBoxTy, mlir::Value dest, mlir::Value base,
|
||||
mlir::ValueRange lbounds, mlir::ValueRange extents,
|
||||
mlir::ValueRange strides,
|
||||
mlir::ConversionPatternRewriter &rewriter) const {
|
||||
mlir::Location loc = rebox.getLoc();
|
||||
mlir::Value zero =
|
||||
@@ -2075,15 +2076,15 @@ private:
|
||||
dest = insertBaseAddress(rewriter, loc, dest, base);
|
||||
mlir::Value result = placeInMemoryIfNotGlobalInit(
|
||||
rewriter, rebox.getLoc(), destBoxTy, dest,
|
||||
isDeviceAllocation(rebox.getBox(), rebox.getBox()));
|
||||
isDeviceAllocation(rebox.getBox(), adaptor.getBox()));
|
||||
rewriter.replaceOp(rebox, result);
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
// Apply slice given the base address, extents and strides of the input box.
|
||||
llvm::LogicalResult
|
||||
sliceBox(fir::cg::XReboxOp rebox, mlir::Type destBoxTy, mlir::Value dest,
|
||||
mlir::Value base, mlir::ValueRange inputExtents,
|
||||
sliceBox(fir::cg::XReboxOp rebox, OpAdaptor adaptor, mlir::Type destBoxTy,
|
||||
mlir::Value dest, mlir::Value base, mlir::ValueRange inputExtents,
|
||||
mlir::ValueRange inputStrides, mlir::ValueRange operands,
|
||||
mlir::ConversionPatternRewriter &rewriter) const {
|
||||
mlir::Location loc = rebox.getLoc();
|
||||
@@ -2109,7 +2110,7 @@ private:
|
||||
if (rebox.getSlice().empty())
|
||||
// The array section is of the form array[%component][substring], keep
|
||||
// the input array extents and strides.
|
||||
return finalizeRebox(rebox, destBoxTy, dest, base,
|
||||
return finalizeRebox(rebox, adaptor, destBoxTy, dest, base,
|
||||
/*lbounds*/ std::nullopt, inputExtents, inputStrides,
|
||||
rewriter);
|
||||
|
||||
@@ -2158,15 +2159,16 @@ private:
|
||||
slicedStrides.emplace_back(stride);
|
||||
}
|
||||
}
|
||||
return finalizeRebox(rebox, destBoxTy, dest, base, /*lbounds*/ std::nullopt,
|
||||
slicedExtents, slicedStrides, rewriter);
|
||||
return finalizeRebox(rebox, adaptor, destBoxTy, dest, base,
|
||||
/*lbounds*/ std::nullopt, slicedExtents, slicedStrides,
|
||||
rewriter);
|
||||
}
|
||||
|
||||
/// Apply a new shape to the data described by a box given the base address,
|
||||
/// extents and strides of the box.
|
||||
llvm::LogicalResult
|
||||
reshapeBox(fir::cg::XReboxOp rebox, mlir::Type destBoxTy, mlir::Value dest,
|
||||
mlir::Value base, mlir::ValueRange inputExtents,
|
||||
reshapeBox(fir::cg::XReboxOp rebox, OpAdaptor adaptor, mlir::Type destBoxTy,
|
||||
mlir::Value dest, mlir::Value base, mlir::ValueRange inputExtents,
|
||||
mlir::ValueRange inputStrides, mlir::ValueRange operands,
|
||||
mlir::ConversionPatternRewriter &rewriter) const {
|
||||
mlir::ValueRange reboxShifts{
|
||||
@@ -2175,7 +2177,7 @@ private:
|
||||
rebox.getShift().size()};
|
||||
if (rebox.getShape().empty()) {
|
||||
// Only setting new lower bounds.
|
||||
return finalizeRebox(rebox, destBoxTy, dest, base, reboxShifts,
|
||||
return finalizeRebox(rebox, adaptor, destBoxTy, dest, base, reboxShifts,
|
||||
inputExtents, inputStrides, rewriter);
|
||||
}
|
||||
|
||||
@@ -2199,8 +2201,8 @@ private:
|
||||
// nextStride = extent * stride;
|
||||
stride = rewriter.create<mlir::LLVM::MulOp>(loc, idxTy, extent, stride);
|
||||
}
|
||||
return finalizeRebox(rebox, destBoxTy, dest, base, reboxShifts, newExtents,
|
||||
newStrides, rewriter);
|
||||
return finalizeRebox(rebox, adaptor, destBoxTy, dest, base, reboxShifts,
|
||||
newExtents, newStrides, rewriter);
|
||||
}
|
||||
|
||||
/// Return scalar element type of the input box.
|
||||
|
||||
Reference in New Issue
Block a user