diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp index 43c0e2686a8c..6ff2c20d7445 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -2040,19 +2040,20 @@ struct XReboxOpConversion : public EmboxCommonConversion { getBaseAddrFromBox(loc, inputBoxTyPair, loweredBox, rewriter); if (!rebox.getSlice().empty() || !rebox.getSubcomponent().empty()) - return sliceBox(rebox, boxTy, dest, baseAddr, inputExtents, inputStrides, - operands, rewriter); - return reshapeBox(rebox, boxTy, dest, baseAddr, inputExtents, inputStrides, - operands, rewriter); + return sliceBox(rebox, adaptor, boxTy, dest, baseAddr, inputExtents, + inputStrides, operands, rewriter); + return reshapeBox(rebox, adaptor, boxTy, dest, baseAddr, inputExtents, + inputStrides, operands, rewriter); } private: /// Write resulting shape and base address in descriptor, and replace rebox /// op. llvm::LogicalResult - finalizeRebox(fir::cg::XReboxOp rebox, mlir::Type destBoxTy, mlir::Value dest, - mlir::Value base, mlir::ValueRange lbounds, - mlir::ValueRange extents, mlir::ValueRange strides, + finalizeRebox(fir::cg::XReboxOp rebox, OpAdaptor adaptor, + mlir::Type destBoxTy, mlir::Value dest, mlir::Value base, + mlir::ValueRange lbounds, mlir::ValueRange extents, + mlir::ValueRange strides, mlir::ConversionPatternRewriter &rewriter) const { mlir::Location loc = rebox.getLoc(); mlir::Value zero = @@ -2075,15 +2076,15 @@ private: dest = insertBaseAddress(rewriter, loc, dest, base); mlir::Value result = placeInMemoryIfNotGlobalInit( rewriter, rebox.getLoc(), destBoxTy, dest, - isDeviceAllocation(rebox.getBox(), rebox.getBox())); + isDeviceAllocation(rebox.getBox(), adaptor.getBox())); rewriter.replaceOp(rebox, result); return mlir::success(); } // Apply slice given the base address, extents and strides of the input box. llvm::LogicalResult - sliceBox(fir::cg::XReboxOp rebox, mlir::Type destBoxTy, mlir::Value dest, - mlir::Value base, mlir::ValueRange inputExtents, + sliceBox(fir::cg::XReboxOp rebox, OpAdaptor adaptor, mlir::Type destBoxTy, + mlir::Value dest, mlir::Value base, mlir::ValueRange inputExtents, mlir::ValueRange inputStrides, mlir::ValueRange operands, mlir::ConversionPatternRewriter &rewriter) const { mlir::Location loc = rebox.getLoc(); @@ -2109,7 +2110,7 @@ private: if (rebox.getSlice().empty()) // The array section is of the form array[%component][substring], keep // the input array extents and strides. - return finalizeRebox(rebox, destBoxTy, dest, base, + return finalizeRebox(rebox, adaptor, destBoxTy, dest, base, /*lbounds*/ std::nullopt, inputExtents, inputStrides, rewriter); @@ -2158,15 +2159,16 @@ private: slicedStrides.emplace_back(stride); } } - return finalizeRebox(rebox, destBoxTy, dest, base, /*lbounds*/ std::nullopt, - slicedExtents, slicedStrides, rewriter); + return finalizeRebox(rebox, adaptor, destBoxTy, dest, base, + /*lbounds*/ std::nullopt, slicedExtents, slicedStrides, + rewriter); } /// Apply a new shape to the data described by a box given the base address, /// extents and strides of the box. llvm::LogicalResult - reshapeBox(fir::cg::XReboxOp rebox, mlir::Type destBoxTy, mlir::Value dest, - mlir::Value base, mlir::ValueRange inputExtents, + reshapeBox(fir::cg::XReboxOp rebox, OpAdaptor adaptor, mlir::Type destBoxTy, + mlir::Value dest, mlir::Value base, mlir::ValueRange inputExtents, mlir::ValueRange inputStrides, mlir::ValueRange operands, mlir::ConversionPatternRewriter &rewriter) const { mlir::ValueRange reboxShifts{ @@ -2175,7 +2177,7 @@ private: rebox.getShift().size()}; if (rebox.getShape().empty()) { // Only setting new lower bounds. - return finalizeRebox(rebox, destBoxTy, dest, base, reboxShifts, + return finalizeRebox(rebox, adaptor, destBoxTy, dest, base, reboxShifts, inputExtents, inputStrides, rewriter); } @@ -2199,8 +2201,8 @@ private: // nextStride = extent * stride; stride = rewriter.create(loc, idxTy, extent, stride); } - return finalizeRebox(rebox, destBoxTy, dest, base, reboxShifts, newExtents, - newStrides, rewriter); + return finalizeRebox(rebox, adaptor, destBoxTy, dest, base, reboxShifts, + newExtents, newStrides, rewriter); } /// Return scalar element type of the input box. diff --git a/flang/test/Fir/CUDA/cuda-code-gen.mlir b/flang/test/Fir/CUDA/cuda-code-gen.mlir index 7ac89836a3ff..063454799502 100644 --- a/flang/test/Fir/CUDA/cuda-code-gen.mlir +++ b/flang/test/Fir/CUDA/cuda-code-gen.mlir @@ -187,3 +187,14 @@ module attributes {dlti.dl_spec = #dlti.dl_spec = dense<32> : vec // CHECK-LABEL: llvm.func @_QPouter // CHECK: _FortranACUFAllocDescriptor + +// ----- + +func.func @_QMm1Psub1(%arg0: !fir.box> {cuf.data_attr = #cuf.cuda, fir.bindc_name = "da"}, %arg1: !fir.box> {cuf.data_attr = #cuf.cuda, fir.bindc_name = "db"}, %arg2: !fir.ref {fir.bindc_name = "n"}) { + %0 = fircg.ext_rebox %arg0 : (!fir.box>) -> !fir.box> + %1 = fircg.ext_rebox %arg1 : (!fir.box>) -> !fir.box> + return +} + +// CHECK-LABEL: llvm.func @_QMm1Psub1 +// CHECK-COUNT-2: _FortranACUFAllocDescriptor