[flang] fix C_PTR function result lowering (#100082)

Functions returning C_PTR were lowered to function returning intptr (i64
on 64bit arch). This caused conflicts when these functions were defined
as returning !fir.ref<none>/llvm.ptr in other compiler generated
contexts (e.g., malloc).

Lower them to return !fir.ref<none>.

This should deal with https://github.com/llvm/llvm-project/issues/97325
and https://github.com/llvm/llvm-project/issues/98644.
This commit is contained in:
jeanPerier
2024-07-24 10:24:04 +02:00
committed by GitHub
parent a3de21cac1
commit 1ead51a86c
3 changed files with 110 additions and 88 deletions

View File

@@ -1541,21 +1541,44 @@ mlir::Value fir::factory::genMaxWithZero(fir::FirOpBuilder &builder,
zero);
}
static std::pair<mlir::Value, mlir::Type>
genCPtrOrCFunptrFieldIndex(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Type cptrTy) {
auto recTy = mlir::cast<fir::RecordType>(cptrTy);
assert(recTy.getTypeList().size() == 1);
auto addrFieldName = recTy.getTypeList()[0].first;
mlir::Type addrFieldTy = recTy.getTypeList()[0].second;
auto fieldIndexType = fir::FieldType::get(cptrTy.getContext());
mlir::Value addrFieldIndex = builder.create<fir::FieldIndexOp>(
loc, fieldIndexType, addrFieldName, recTy,
/*typeParams=*/mlir::ValueRange{});
return {addrFieldIndex, addrFieldTy};
}
mlir::Value fir::factory::genCPtrOrCFunptrAddr(fir::FirOpBuilder &builder,
mlir::Location loc,
mlir::Value cPtr,
mlir::Type ty) {
assert(mlir::isa<fir::RecordType>(ty));
auto recTy = mlir::dyn_cast<fir::RecordType>(ty);
assert(recTy.getTypeList().size() == 1);
auto fieldName = recTy.getTypeList()[0].first;
mlir::Type fieldTy = recTy.getTypeList()[0].second;
auto fieldIndexType = fir::FieldType::get(ty.getContext());
mlir::Value field =
builder.create<fir::FieldIndexOp>(loc, fieldIndexType, fieldName, recTy,
/*typeParams=*/mlir::ValueRange{});
return builder.create<fir::CoordinateOp>(loc, builder.getRefType(fieldTy),
cPtr, field);
auto [addrFieldIndex, addrFieldTy] =
genCPtrOrCFunptrFieldIndex(builder, loc, ty);
return builder.create<fir::CoordinateOp>(loc, builder.getRefType(addrFieldTy),
cPtr, addrFieldIndex);
}
mlir::Value fir::factory::genCPtrOrCFunptrValue(fir::FirOpBuilder &builder,
mlir::Location loc,
mlir::Value cPtr) {
mlir::Type cPtrTy = fir::unwrapRefType(cPtr.getType());
if (fir::isa_ref_type(cPtr.getType())) {
mlir::Value cPtrAddr =
fir::factory::genCPtrOrCFunptrAddr(builder, loc, cPtr, cPtrTy);
return builder.create<fir::LoadOp>(loc, cPtrAddr);
}
auto [addrFieldIndex, addrFieldTy] =
genCPtrOrCFunptrFieldIndex(builder, loc, cPtrTy);
auto arrayAttr =
builder.getArrayAttr({builder.getIntegerAttr(builder.getIndexType(), 0)});
return builder.create<fir::ExtractValueOp>(loc, addrFieldTy, cPtr, arrayAttr);
}
fir::BoxValue fir::factory::createBoxValue(fir::FirOpBuilder &builder,
@@ -1596,15 +1619,6 @@ fir::BoxValue fir::factory::createBoxValue(fir::FirOpBuilder &builder,
return fir::BoxValue(box, lbounds, explicitTypeParams);
}
mlir::Value fir::factory::genCPtrOrCFunptrValue(fir::FirOpBuilder &builder,
mlir::Location loc,
mlir::Value cPtr) {
mlir::Type cPtrTy = fir::unwrapRefType(cPtr.getType());
mlir::Value cPtrAddr =
fir::factory::genCPtrOrCFunptrAddr(builder, loc, cPtr, cPtrTy);
return builder.create<fir::LoadOp>(loc, cPtrAddr);
}
mlir::Value fir::factory::createNullBoxProc(fir::FirOpBuilder &builder,
mlir::Location loc,
mlir::Type boxType) {

View File

@@ -59,14 +59,16 @@ static mlir::FunctionType getNewFunctionType(mlir::FunctionType funcTy,
/*resultTypes=*/{});
}
static mlir::Type getVoidPtrType(mlir::MLIRContext *context) {
return fir::ReferenceType::get(mlir::NoneType::get(context));
}
/// This is for function result types that are of type C_PTR from ISO_C_BINDING.
/// Follow the ABI for interoperability with C.
static mlir::FunctionType getCPtrFunctionType(mlir::FunctionType funcTy) {
auto resultType = funcTy.getResult(0);
assert(fir::isa_builtin_cptr_type(resultType));
llvm::SmallVector<mlir::Type> outputTypes;
auto recTy = mlir::dyn_cast<fir::RecordType>(resultType);
outputTypes.emplace_back(recTy.getTypeList()[0].second);
assert(fir::isa_builtin_cptr_type(funcTy.getResult(0)));
llvm::SmallVector<mlir::Type> outputTypes{
getVoidPtrType(funcTy.getContext())};
return mlir::FunctionType::get(funcTy.getContext(), funcTy.getInputs(),
outputTypes);
}
@@ -109,15 +111,11 @@ public:
saveResult.getTypeparams());
llvm::SmallVector<mlir::Type> newResultTypes;
// TODO: This should be generalized for derived types, and it is
// architecture and OS dependent.
bool isResultBuiltinCPtr = fir::isa_builtin_cptr_type(result.getType());
Op newOp;
if (isResultBuiltinCPtr) {
auto recTy = mlir::dyn_cast<fir::RecordType>(result.getType());
newResultTypes.emplace_back(recTy.getTypeList()[0].second);
}
if (isResultBuiltinCPtr)
newResultTypes.emplace_back(getVoidPtrType(result.getContext()));
Op newOp;
// fir::CallOp specific handling.
if constexpr (std::is_same_v<Op, fir::CallOp>) {
if (op.getCallee()) {
@@ -175,7 +173,7 @@ public:
FirOpBuilder builder(rewriter, module);
mlir::Value saveAddr = fir::factory::genCPtrOrCFunptrAddr(
builder, loc, save, result.getType());
rewriter.create<fir::StoreOp>(loc, newOp->getResult(0), saveAddr);
builder.createStoreWithConvert(loc, newOp->getResult(0), saveAddr);
}
op->dropAllReferences();
rewriter.eraseOp(op);
@@ -210,42 +208,52 @@ public:
mlir::PatternRewriter &rewriter) const override {
auto loc = ret.getLoc();
rewriter.setInsertionPoint(ret);
auto returnedValue = ret.getOperand(0);
bool replacedStorage = false;
if (auto *op = returnedValue.getDefiningOp())
if (auto load = mlir::dyn_cast<fir::LoadOp>(op)) {
auto resultStorage = load.getMemref();
// The result alloca may be behind a fir.declare, if any.
if (auto declare = mlir::dyn_cast_or_null<fir::DeclareOp>(
resultStorage.getDefiningOp()))
resultStorage = declare.getMemref();
// TODO: This should be generalized for derived types, and it is
// architecture and OS dependent.
if (fir::isa_builtin_cptr_type(returnedValue.getType())) {
rewriter.eraseOp(load);
auto module = ret->getParentOfType<mlir::ModuleOp>();
FirOpBuilder builder(rewriter, module);
mlir::Value retAddr = fir::factory::genCPtrOrCFunptrAddr(
builder, loc, resultStorage, returnedValue.getType());
mlir::Value retValue = rewriter.create<fir::LoadOp>(
loc, fir::unwrapRefType(retAddr.getType()), retAddr);
rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(
ret, mlir::ValueRange{retValue});
return mlir::success();
}
resultStorage.replaceAllUsesWith(newArg);
replacedStorage = true;
if (auto *alloc = resultStorage.getDefiningOp())
if (alloc->use_empty())
rewriter.eraseOp(alloc);
mlir::Value resultValue = ret.getOperand(0);
fir::LoadOp resultLoad;
mlir::Value resultStorage;
// Identify result local storage.
if (auto load = resultValue.getDefiningOp<fir::LoadOp>()) {
resultLoad = load;
resultStorage = load.getMemref();
// The result alloca may be behind a fir.declare, if any.
if (auto declare = resultStorage.getDefiningOp<fir::DeclareOp>())
resultStorage = declare.getMemref();
}
// Replace old local storage with new storage argument, unless
// the derived type is C_PTR/C_FUN_PTR, in which case the return
// type is updated to return void* (no new argument is passed).
if (fir::isa_builtin_cptr_type(resultValue.getType())) {
auto module = ret->getParentOfType<mlir::ModuleOp>();
FirOpBuilder builder(rewriter, module);
mlir::Value cptr = resultValue;
if (resultLoad) {
// Replace whole derived type load by component load.
cptr = resultLoad.getMemref();
rewriter.setInsertionPoint(resultLoad);
}
// The result storage may have been optimized out by a memory to
// register pass, this is possible for fir.box results, or fir.record
// with no length parameters. Simply store the result in the result storage.
// at the return point.
if (!replacedStorage)
rewriter.create<fir::StoreOp>(loc, returnedValue, newArg);
rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret);
mlir::Value newResultValue =
fir::factory::genCPtrOrCFunptrValue(builder, loc, cptr);
newResultValue = builder.createConvert(
loc, getVoidPtrType(ret.getContext()), newResultValue);
rewriter.setInsertionPoint(ret);
rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(
ret, mlir::ValueRange{newResultValue});
} else if (resultStorage) {
resultStorage.replaceAllUsesWith(newArg);
rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret);
} else {
// The result storage may have been optimized out by a memory to
// register pass, this is possible for fir.box results, or fir.record
// with no length parameters. Simply store the result in the result
// storage. at the return point.
rewriter.create<fir::StoreOp>(loc, resultValue, newArg);
rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret);
}
// Delete result old local storage if unused.
if (resultStorage)
if (auto alloc = resultStorage.getDefiningOp<fir::AllocaOp>())
if (alloc->use_empty())
rewriter.eraseOp(alloc);
return mlir::success();
}
@@ -263,8 +271,6 @@ public:
mlir::PatternRewriter &rewriter) const override {
auto oldFuncTy = mlir::cast<mlir::FunctionType>(addrOf.getType());
mlir::FunctionType newFuncTy;
// TODO: This should be generalized for derived types, and it is
// architecture and OS dependent.
if (oldFuncTy.getNumResults() != 0 &&
fir::isa_builtin_cptr_type(oldFuncTy.getResult(0)))
newFuncTy = getCPtrFunctionType(oldFuncTy);
@@ -298,8 +304,6 @@ public:
// Convert function type itself if it has an abstract result.
auto funcTy = mlir::cast<mlir::FunctionType>(func.getFunctionType());
if (hasAbstractResult(funcTy)) {
// TODO: This should be generalized for derived types, and it is
// architecture and OS dependent.
if (fir::isa_builtin_cptr_type(funcTy.getResult(0))) {
func.setType(getCPtrFunctionType(funcTy));
patterns.insert<ReturnOpConversion>(context, mlir::Value{});

View File

@@ -87,8 +87,8 @@ func.func @boxfunc_callee() -> !fir.box<!fir.heap<f64>> {
// FUNC-BOX: return
}
// FUNC-REF-LABEL: func @retcptr() -> i64
// FUNC-BOX-LABEL: func @retcptr() -> i64
// FUNC-REF-LABEL: func @retcptr() -> !fir.ref<none>
// FUNC-BOX-LABEL: func @retcptr() -> !fir.ref<none>
func.func @retcptr() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> {
%0 = fir.alloca !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> {bindc_name = "rec", uniq_name = "_QFrecErec"}
%1 = fir.load %0 : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>
@@ -98,12 +98,14 @@ func.func @retcptr() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__addres
// FUNC-REF: %[[FIELD:.*]] = fir.field_index __address, !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>
// FUNC-REF: %[[ADDR:.*]] = fir.coordinate_of %[[ALLOC]], %[[FIELD]] : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.field) -> !fir.ref<i64>
// FUNC-REF: %[[VAL:.*]] = fir.load %[[ADDR]] : !fir.ref<i64>
// FUNC-REF: return %[[VAL]] : i64
// FUNC-REF: %[[CAST:.*]] = fir.convert %[[VAL]] : (i64) -> !fir.ref<none>
// FUNC-REF: return %[[CAST]] : !fir.ref<none>
// FUNC-BOX: %[[ALLOC:.*]] = fir.alloca !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> {bindc_name = "rec", uniq_name = "_QFrecErec"}
// FUNC-BOX: %[[FIELD:.*]] = fir.field_index __address, !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>
// FUNC-BOX: %[[ADDR:.*]] = fir.coordinate_of %[[ALLOC]], %[[FIELD]] : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.field) -> !fir.ref<i64>
// FUNC-BOX: %[[VAL:.*]] = fir.load %[[ADDR]] : !fir.ref<i64>
// FUNC-BOX: return %[[VAL]] : i64
// FUNC-BOX: %[[CAST:.*]] = fir.convert %[[VAL]] : (i64) -> !fir.ref<none>
// FUNC-BOX: return %[[CAST]] : !fir.ref<none>
}
// FUNC-REF-LABEL: func private @arrayfunc_callee_declare(
@@ -311,8 +313,8 @@ func.func @test_address_of() {
}
// FUNC-REF-LABEL: func.func private @returns_null() -> i64
// FUNC-BOX-LABEL: func.func private @returns_null() -> i64
// FUNC-REF-LABEL: func.func private @returns_null() -> !fir.ref<none>
// FUNC-BOX-LABEL: func.func private @returns_null() -> !fir.ref<none>
func.func private @returns_null() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>
// FUNC-REF-LABEL: func @test_address_of_cptr
@@ -323,12 +325,12 @@ func.func @test_address_of_cptr() {
fir.call @_QMtest_c_func_modPsubr(%1) : (() -> ()) -> ()
return
// FUNC-REF: %[[VAL_0:.*]] = fir.address_of(@returns_null) : () -> i64
// FUNC-REF: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : (() -> i64) -> (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>)
// FUNC-REF: %[[VAL_0:.*]] = fir.address_of(@returns_null) : () -> !fir.ref<none>
// FUNC-REF: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : (() -> !fir.ref<none>) -> (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>)
// FUNC-REF: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) -> (() -> ())
// FUNC-REF: fir.call @_QMtest_c_func_modPsubr(%[[VAL_2]]) : (() -> ()) -> ()
// FUNC-BOX: %[[VAL_0:.*]] = fir.address_of(@returns_null) : () -> i64
// FUNC-BOX: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : (() -> i64) -> (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>)
// FUNC-BOX: %[[VAL_0:.*]] = fir.address_of(@returns_null) : () -> !fir.ref<none>
// FUNC-BOX: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : (() -> !fir.ref<none>) -> (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>)
// FUNC-BOX: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) -> (() -> ())
// FUNC-BOX: fir.call @_QMtest_c_func_modPsubr(%[[VAL_2]]) : (() -> ()) -> ()
}
@@ -380,18 +382,20 @@ func.func @test_indirect_calls_return_cptr(%arg0: () -> ()) {
// FUNC-REF: %[[VAL_0:.*]] = fir.alloca !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> {bindc_name = ".result"}
// FUNC-REF: %[[VAL_1:.*]] = fir.convert %[[ARG0]] : (() -> ()) -> (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>)
// FUNC-REF: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) -> (() -> i64)
// FUNC-REF: %[[VAL_3:.*]] = fir.call %[[VAL_2]]() : () -> i64
// FUNC-REF: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) -> (() -> !fir.ref<none>)
// FUNC-REF: %[[VAL_3:.*]] = fir.call %[[VAL_2]]() : () -> !fir.ref<none>
// FUNC-REF: %[[VAL_4:.*]] = fir.field_index __address, !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>
// FUNC-REF: %[[VAL_5:.*]] = fir.coordinate_of %[[VAL_0]], %[[VAL_4]] : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.field) -> !fir.ref<i64>
// FUNC-REF: fir.store %[[VAL_3]] to %[[VAL_5]] : !fir.ref<i64>
// FUNC-REF: %[[CAST:.*]] = fir.convert %[[VAL_3]] : (!fir.ref<none>) -> i64
// FUNC-REF: fir.store %[[CAST]] to %[[VAL_5]] : !fir.ref<i64>
// FUNC-BOX: %[[VAL_0:.*]] = fir.alloca !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> {bindc_name = ".result"}
// FUNC-BOX: %[[VAL_1:.*]] = fir.convert %[[ARG0]] : (() -> ()) -> (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>)
// FUNC-BOX: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) -> (() -> i64)
// FUNC-BOX: %[[VAL_3:.*]] = fir.call %[[VAL_2]]() : () -> i64
// FUNC-BOX: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) -> (() -> !fir.ref<none>)
// FUNC-BOX: %[[VAL_3:.*]] = fir.call %[[VAL_2]]() : () -> !fir.ref<none>
// FUNC-BOX: %[[VAL_4:.*]] = fir.field_index __address, !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>
// FUNC-BOX: %[[VAL_5:.*]] = fir.coordinate_of %[[VAL_0]], %[[VAL_4]] : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.field) -> !fir.ref<i64>
// FUNC-BOX: fir.store %[[VAL_3]] to %[[VAL_5]] : !fir.ref<i64>
// FUNC-BOX: %[[CAST:.*]] = fir.convert %[[VAL_3]] : (!fir.ref<none>) -> i64
// FUNC-BOX: fir.store %[[CAST]] to %[[VAL_5]] : !fir.ref<i64>
}
// ----------------------- Test GlobalOp rewrite ------------------------