[flang][acc] Ensure fir.class is handled in type categorization (#146174)

fir.class is treated similarly as fir.box - but it has one key
distinction which is that it doesn't hold an element type. Thus the
categorization logic was mishandling this case for this reason (and also
the fact that it assumed that a base object is always a fir.ref).

This PR improves this handling and adds appropriate test exercising both
a class and a class field to ensure categorization works.
This commit is contained in:
Razvan Lupusoru
2025-06-30 15:04:14 -07:00
committed by GitHub
parent 6896d8a05d
commit f16983f7d0
3 changed files with 86 additions and 6 deletions

View File

@@ -306,6 +306,10 @@ static bool isArrayLike(mlir::Type type) {
}
static bool isCompositeLike(mlir::Type type) {
// class(*) is not a composite type since it does not have a determined type.
if (fir::isUnlimitedPolymorphicType(type))
return false;
return mlir::isa<fir::RecordType, fir::ClassType, mlir::TupleType>(type);
}
@@ -320,8 +324,18 @@ template <>
mlir::acc::VariableTypeCategory
OpenACCMappableModel<fir::BaseBoxType>::getTypeCategory(mlir::Type type,
mlir::Value var) const {
// Class-type does not behave like a normal box because it does not hold an
// element type. Thus special handle it here.
if (mlir::isa<fir::ClassType>(type)) {
// class(*) is not a composite type since it does not have a determined
// type.
if (fir::isUnlimitedPolymorphicType(type))
return mlir::acc::VariableTypeCategory::uncategorized;
return mlir::acc::VariableTypeCategory::composite;
}
mlir::Type eleTy = fir::dyn_cast_ptrOrBoxEleTy(type);
assert(eleTy && "expect to be able to unwrap the element type");
// If the type enclosed by the box is a mappable type, then have it
// provide the type category.
@@ -346,7 +360,7 @@ OpenACCMappableModel<fir::BaseBoxType>::getTypeCategory(mlir::Type type,
return mlir::acc::VariableTypeCategory::nonscalar;
}
static mlir::TypedValue<mlir::acc::PointerLikeType>
static mlir::Value
getBaseRef(mlir::TypedValue<mlir::acc::PointerLikeType> varPtr) {
// If there is no defining op - the unwrapped reference is the base one.
mlir::Operation *op = varPtr.getDefiningOp();
@@ -372,7 +386,7 @@ getBaseRef(mlir::TypedValue<mlir::acc::PointerLikeType> varPtr) {
})
.Default([&](mlir::Operation *) { return varPtr; });
return mlir::cast<mlir::TypedValue<mlir::acc::PointerLikeType>>(baseRef);
return baseRef;
}
static mlir::acc::VariableTypeCategory
@@ -384,10 +398,17 @@ categorizePointee(mlir::Type pointer,
// value would both be represented as !fir.ref<f32>. We do not want to treat
// such a reference as a scalar. Thus unwrap interior pointer calculations.
auto baseRef = getBaseRef(varPtr);
mlir::Type eleTy = baseRef.getType().getElementType();
if (auto mappableTy = mlir::dyn_cast<mlir::acc::MappableType>(eleTy))
return mappableTy.getTypeCategory(varPtr);
if (auto mappableTy =
mlir::dyn_cast<mlir::acc::MappableType>(baseRef.getType()))
return mappableTy.getTypeCategory(baseRef);
// It must be a pointer-like type since it is not a MappableType.
auto ptrLikeTy = mlir::cast<mlir::acc::PointerLikeType>(baseRef.getType());
mlir::Type eleTy = ptrLikeTy.getElementType();
if (auto mappableEleTy = mlir::dyn_cast<mlir::acc::MappableType>(eleTy))
return mappableEleTy.getTypeCategory(varPtr);
if (isScalarLike(eleTy))
return mlir::acc::VariableTypeCategory::scalar;
@@ -397,8 +418,12 @@ categorizePointee(mlir::Type pointer,
return mlir::acc::VariableTypeCategory::composite;
if (mlir::isa<fir::CharacterType, mlir::FunctionType>(eleTy))
return mlir::acc::VariableTypeCategory::nonscalar;
// Assumed-type (type(*))does not have a determined type that can be
// categorized.
if (mlir::isa<mlir::NoneType>(eleTy))
return mlir::acc::VariableTypeCategory::uncategorized;
// "pointers" - in the sense of raw address point-of-view, are considered
// scalars. However
// scalars.
if (mlir::isa<fir::LLVMPointerType>(eleTy))
return mlir::acc::VariableTypeCategory::scalar;

View File

@@ -0,0 +1,46 @@
! RUN: bbc -fopenacc -emit-hlfir %s -o - | fir-opt -pass-pipeline='builtin.module(test-fir-openacc-interfaces)' --mlir-disable-threading 2>&1 | FileCheck %s
module mm
type, public :: polyty
real :: field
end type
contains
subroutine init(this)
class(polyty), intent(inout) :: this
!$acc enter data copyin(this, this%field)
end subroutine
subroutine init_assumed_type(var)
type(*), intent(inout) :: var
!$acc enter data copyin(var)
end subroutine
subroutine init_unlimited(this)
class(*), intent(inout) :: this
!$acc enter data copyin(this)
select type(this)
type is(real)
!$acc enter data copyin(this)
class is(polyty)
!$acc enter data copyin(this, this%field)
end select
end subroutine
end module
! CHECK: Visiting: {{.*}} acc.copyin {{.*}} {name = "this", structured = false}
! CHECK: Mappable: !fir.class<!fir.type<_QMmmTpolyty{field:f32}>>
! CHECK: Type category: composite
! CHECK: Visiting: {{.*}} acc.copyin {{.*}} {name = "this%field", structured = false}
! CHECK: Pointer-like: !fir.ref<f32>
! CHECK: Type category: composite
! For unlimited polymorphic entities and assumed types - they effectively have
! no declared type. Thus the type categorizer cannot categorize it.
! CHECK: Visiting: {{.*}} = acc.copyin {{.*}} {name = "var", structured = false}
! CHECK: Pointer-like: !fir.ref<none>
! CHECK: Type category: uncategorized
! CHECK: Visiting: {{.*}} = acc.copyin {{.*}} {name = "this", structured = false}
! CHECK: Mappable: !fir.class<none>
! CHECK: Type category: uncategorized
! TODO: After using select type - the appropriate type category should be
! possible. Add the rest of the test once OpenACC lowering correctly handles
! unlimited polymorhic.

View File

@@ -6,11 +6,15 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
#include "flang/Optimizer/Support/DataLayout.h"
using namespace mlir;
@@ -25,6 +29,11 @@ struct TestFIROpenACCInterfaces
StringRef getDescription() const final {
return "Test FIR implementation of the OpenACC interfaces.";
}
void getDependentDialects(::mlir::DialectRegistry &registry) const override {
registry.insert<fir::FIROpsDialect, hlfir::hlfirDialect,
mlir::arith::ArithDialect, mlir::acc::OpenACCDialect,
mlir::DLTIDialect>();
}
void runOnOperation() override {
mlir::ModuleOp mod = getOperation();
auto datalayout =