[OpenMP 60] Initial parsing/sema for need_device_addr modifier on adjust_args clause (#143442)

Adds initial parsing and semantic analysis for `need_device_addr`
modifier on `adjust_args` clause.
This commit is contained in:
Fazlay Rabbi
2025-06-11 22:06:11 -07:00
committed by GitHub
parent 99638537cd
commit 02550da932
10 changed files with 80 additions and 29 deletions

View File

@@ -4630,6 +4630,7 @@ def OMPDeclareVariant : InheritableAttr {
OMPTraitInfoArgument<"TraitInfos">,
VariadicExprArgument<"AdjustArgsNothing">,
VariadicExprArgument<"AdjustArgsNeedDevicePtr">,
VariadicExprArgument<"AdjustArgsNeedDeviceAddr">,
VariadicOMPInteropInfoArgument<"AppendArgs">,
];
let AdditionalMembers = [{

View File

@@ -1581,8 +1581,10 @@ def err_omp_unexpected_append_op : Error<
"unexpected operation specified in 'append_args' clause, expected 'interop'">;
def err_omp_unexpected_execution_modifier : Error<
"unexpected 'execution' modifier in non-executable context">;
def err_omp_unknown_adjust_args_op : Error<
"incorrect adjust_args type, expected 'need_device_ptr' or 'nothing'">;
def err_omp_unknown_adjust_args_op
: Error<
"incorrect 'adjust_args' type, expected 'need_device_ptr'%select{|, "
"'need_device_addr',}0 or 'nothing'">;
def err_omp_declare_variant_wrong_clause : Error<
"expected %select{'match'|'match', 'adjust_args', or 'append_args'}0 clause "
"on 'omp declare variant' directive">;

View File

@@ -214,6 +214,7 @@ OPENMP_ORIGINAL_SHARING_MODIFIER(default)
// Adjust-op kinds for the 'adjust_args' clause.
OPENMP_ADJUST_ARGS_KIND(nothing)
OPENMP_ADJUST_ARGS_KIND(need_device_ptr)
OPENMP_ADJUST_ARGS_KIND(need_device_addr)
// Binding kinds for the 'bind' clause.
OPENMP_BIND_KIND(teams)

View File

@@ -849,6 +849,7 @@ public:
FunctionDecl *FD, Expr *VariantRef, OMPTraitInfo &TI,
ArrayRef<Expr *> AdjustArgsNothing,
ArrayRef<Expr *> AdjustArgsNeedDevicePtr,
ArrayRef<Expr *> AdjustArgsNeedDeviceAddr,
ArrayRef<OMPInteropInfo> AppendArgs, SourceLocation AdjustArgsLoc,
SourceLocation AppendArgsLoc, SourceRange SR);

View File

@@ -224,6 +224,12 @@ void OMPDeclareVariantAttr::printPrettyPragma(
PrintExprs(adjustArgsNeedDevicePtr_begin(), adjustArgsNeedDevicePtr_end());
OS << ")";
}
if (adjustArgsNeedDeviceAddr_size()) {
OS << " adjust_args(need_device_addr:";
PrintExprs(adjustArgsNeedDeviceAddr_begin(),
adjustArgsNeedDeviceAddr_end());
OS << ")";
}
auto PrintInteropInfo = [&OS](OMPInteropInfo *Begin, OMPInteropInfo *End) {
for (OMPInteropInfo *I = Begin; I != End; ++I) {

View File

@@ -1483,6 +1483,7 @@ void Parser::ParseOMPDeclareVariantClauses(Parser::DeclGroupPtrTy Ptr,
OMPTraitInfo &TI = ASTCtx.getNewOMPTraitInfo();
SmallVector<Expr *, 6> AdjustNothing;
SmallVector<Expr *, 6> AdjustNeedDevicePtr;
SmallVector<Expr *, 6> AdjustNeedDeviceAddr;
SmallVector<OMPInteropInfo, 3> AppendArgs;
SourceLocation AdjustArgsLoc, AppendArgsLoc;
@@ -1515,11 +1516,21 @@ void Parser::ParseOMPDeclareVariantClauses(Parser::DeclGroupPtrTy Ptr,
SmallVector<Expr *> Vars;
IsError = ParseOpenMPVarList(OMPD_declare_variant, OMPC_adjust_args,
Vars, Data);
if (!IsError)
llvm::append_range(Data.ExtraModifier == OMPC_ADJUST_ARGS_nothing
? AdjustNothing
: AdjustNeedDevicePtr,
Vars);
if (!IsError) {
switch (Data.ExtraModifier) {
case OMPC_ADJUST_ARGS_nothing:
llvm::append_range(AdjustNothing, Vars);
break;
case OMPC_ADJUST_ARGS_need_device_ptr:
llvm::append_range(AdjustNeedDevicePtr, Vars);
break;
case OMPC_ADJUST_ARGS_need_device_addr:
llvm::append_range(AdjustNeedDeviceAddr, Vars);
break;
default:
llvm_unreachable("Unexpected 'adjust_args' clause modifier.");
}
}
break;
}
case OMPC_append_args:
@@ -1559,8 +1570,8 @@ void Parser::ParseOMPDeclareVariantClauses(Parser::DeclGroupPtrTy Ptr,
if (DeclVarData && !TI.Sets.empty())
Actions.OpenMP().ActOnOpenMPDeclareVariantDirective(
DeclVarData->first, DeclVarData->second, TI, AdjustNothing,
AdjustNeedDevicePtr, AppendArgs, AdjustArgsLoc, AppendArgsLoc,
SourceRange(Loc, Tok.getLocation()));
AdjustNeedDevicePtr, AdjustNeedDeviceAddr, AppendArgs, AdjustArgsLoc,
AppendArgsLoc, SourceRange(Loc, Tok.getLocation()));
// Skip the last annot_pragma_openmp_end.
(void)ConsumeAnnotationToken();
@@ -4818,7 +4829,8 @@ bool Parser::ParseOpenMPVarList(OpenMPDirectiveKind DKind,
getLangOpts());
Data.ExtraModifierLoc = Tok.getLocation();
if (Data.ExtraModifier == OMPC_ADJUST_ARGS_unknown) {
Diag(Tok, diag::err_omp_unknown_adjust_args_op);
Diag(Tok, diag::err_omp_unknown_adjust_args_op)
<< (getLangOpts().OpenMP >= 60 ? 1 : 0);
SkipUntil(tok::r_paren, tok::annot_pragma_openmp_end, StopBeforeMatch);
} else {
ConsumeToken();

View File

@@ -7122,6 +7122,7 @@ void SemaOpenMP::ActOnFinishedFunctionDefinitionInOpenMPDeclareVariantScope(
getASTContext(), VariantFuncRef, DVScope.TI,
/*NothingArgs=*/nullptr, /*NothingArgsSize=*/0,
/*NeedDevicePtrArgs=*/nullptr, /*NeedDevicePtrArgsSize=*/0,
/*NeedDeviceAddrArgs=*/nullptr, /*NeedDeviceAddrArgsSize=*/0,
/*AppendArgs=*/nullptr, /*AppendArgsSize=*/0);
for (FunctionDecl *BaseFD : Bases)
BaseFD->addAttr(OMPDeclareVariantA);
@@ -7553,6 +7554,7 @@ void SemaOpenMP::ActOnOpenMPDeclareVariantDirective(
FunctionDecl *FD, Expr *VariantRef, OMPTraitInfo &TI,
ArrayRef<Expr *> AdjustArgsNothing,
ArrayRef<Expr *> AdjustArgsNeedDevicePtr,
ArrayRef<Expr *> AdjustArgsNeedDeviceAddr,
ArrayRef<OMPInteropInfo> AppendArgs, SourceLocation AdjustArgsLoc,
SourceLocation AppendArgsLoc, SourceRange SR) {
@@ -7564,6 +7566,7 @@ void SemaOpenMP::ActOnOpenMPDeclareVariantDirective(
SmallVector<Expr *, 8> AllAdjustArgs;
llvm::append_range(AllAdjustArgs, AdjustArgsNothing);
llvm::append_range(AllAdjustArgs, AdjustArgsNeedDevicePtr);
llvm::append_range(AllAdjustArgs, AdjustArgsNeedDeviceAddr);
if (!AllAdjustArgs.empty() || !AppendArgs.empty()) {
VariantMatchInfo VMI;
@@ -7614,6 +7617,8 @@ void SemaOpenMP::ActOnOpenMPDeclareVariantDirective(
const_cast<Expr **>(AdjustArgsNothing.data()), AdjustArgsNothing.size(),
const_cast<Expr **>(AdjustArgsNeedDevicePtr.data()),
AdjustArgsNeedDevicePtr.size(),
const_cast<Expr **>(AdjustArgsNeedDeviceAddr.data()),
AdjustArgsNeedDeviceAddr.size(),
const_cast<OMPInteropInfo *>(AppendArgs.data()), AppendArgs.size(), SR);
FD->addAttr(NewAttr);
}

View File

@@ -527,6 +527,7 @@ static void instantiateOMPDeclareVariantAttr(
SmallVector<Expr *, 8> NothingExprs;
SmallVector<Expr *, 8> NeedDevicePtrExprs;
SmallVector<Expr *, 8> NeedDeviceAddrExprs;
SmallVector<OMPInteropInfo, 4> AppendArgs;
for (Expr *E : Attr.adjustArgsNothing()) {
@@ -541,14 +542,20 @@ static void instantiateOMPDeclareVariantAttr(
continue;
NeedDevicePtrExprs.push_back(ER.get());
}
for (Expr *E : Attr.adjustArgsNeedDeviceAddr()) {
ExprResult ER = Subst(E);
if (ER.isInvalid())
continue;
NeedDeviceAddrExprs.push_back(ER.get());
}
for (OMPInteropInfo &II : Attr.appendArgs()) {
// When prefer_type is implemented for append_args handle them here too.
AppendArgs.emplace_back(II.IsTarget, II.IsTargetSync);
}
S.OpenMP().ActOnOpenMPDeclareVariantDirective(
FD, E, TI, NothingExprs, NeedDevicePtrExprs, AppendArgs, SourceLocation(),
SourceLocation(), Attr.getRange());
FD, E, TI, NothingExprs, NeedDevicePtrExprs, NeedDeviceAddrExprs,
AppendArgs, SourceLocation(), SourceLocation(), Attr.getRange());
}
static void instantiateDependentAMDGPUFlatWorkGroupSizeAttr(

View File

@@ -54,9 +54,9 @@ void foo_v3(float *AAA, float *BBB, int *I) {return;}
//DUMP: DeclRefExpr{{.*}}Function{{.*}}foo_v1
//DUMP: DeclRefExpr{{.*}}ParmVar{{.*}}'AAA'
//DUMP: DeclRefExpr{{.*}}ParmVar{{.*}}'BBB'
//PRINT: #pragma omp declare variant(foo_v3) match(construct={dispatch}, device={arch(x86, x86_64)}) adjust_args(nothing:I) adjust_args(need_device_ptr:BBB)
//PRINT: #pragma omp declare variant(foo_v3) match(construct={dispatch}, device={arch(x86, x86_64)}) adjust_args(nothing:I) adjust_args(need_device_ptr:BBB) adjust_args(need_device_addr:AAA)
//PRINT: #pragma omp declare variant(foo_v2) match(construct={dispatch}, device={arch(ppc)}) adjust_args(need_device_ptr:AAA)
//PRINT: #pragma omp declare variant(foo_v2) match(construct={dispatch}, device={arch(ppc)}) adjust_args(need_device_ptr:AAA) adjust_args(need_device_addr:BBB)
//PRINT: omp declare variant(foo_v1) match(construct={dispatch}, device={arch(arm)}) adjust_args(need_device_ptr:AAA,BBB)
@@ -66,42 +66,48 @@ void foo_v3(float *AAA, float *BBB, int *I) {return;}
#pragma omp declare variant(foo_v2) \
match(construct={dispatch}, device={arch(ppc)}), \
adjust_args(need_device_ptr:AAA)
adjust_args(need_device_ptr:AAA) \
adjust_args(need_device_addr:BBB)
#pragma omp declare variant(foo_v3) \
adjust_args(need_device_ptr:BBB) adjust_args(nothing:I) \
adjust_args(need_device_addr:AAA) \
match(construct={dispatch}, device={arch(x86,x86_64)})
void foo(float *AAA, float *BBB, int *I) {return;}
void Foo_Var(float *AAA, float *BBB) {return;}
void Foo_Var(float *AAA, float *BBB, float *CCC) {return;}
#pragma omp declare variant(Foo_Var) \
match(construct={dispatch}, device={arch(x86_64)}) \
adjust_args(need_device_ptr:AAA) adjust_args(nothing:BBB)
adjust_args(need_device_ptr:AAA) adjust_args(nothing:BBB) \
adjust_args(need_device_addr:CCC)
template<typename T>
void Foo(T *AAA, T *BBB) {return;}
void Foo(T *AAA, T *BBB, T *CCC) {return;}
//PRINT: #pragma omp declare variant(Foo_Var) match(construct={dispatch}, device={arch(x86_64)}) adjust_args(nothing:BBB) adjust_args(need_device_ptr:AAA)
//DUMP: FunctionDecl{{.*}} Foo 'void (T *, T *)'
//PRINT: #pragma omp declare variant(Foo_Var) match(construct={dispatch}, device={arch(x86_64)}) adjust_args(nothing:BBB) adjust_args(need_device_ptr:AAA) adjust_args(need_device_addr:CCC)
//DUMP: FunctionDecl{{.*}} Foo 'void (T *, T *, T *)'
//DUMP: OMPDeclareVariantAttr{{.*}}device={arch(x86_64)}
//DUMP: DeclRefExpr{{.*}}Function{{.*}}Foo_Var
//DUMP: DeclRefExpr{{.*}}ParmVar{{.*}}'BBB'
//DUMP: DeclRefExpr{{.*}}ParmVar{{.*}}'AAA'
//DUMP: DeclRefExpr{{.*}}ParmVar{{.*}}'CCC'
//
//DUMP: FunctionDecl{{.*}} Foo 'void (float *, float *)'
//DUMP: FunctionDecl{{.*}} Foo 'void (float *, float *, float *)'
//DUMP: OMPDeclareVariantAttr{{.*}}device={arch(x86_64)}
//DUMP: DeclRefExpr{{.*}}Function{{.*}}Foo_Var
//DUMP: DeclRefExpr{{.*}}ParmVar{{.*}}'BBB'
//DUMP: DeclRefExpr{{.*}}ParmVar{{.*}}'AAA'
//DUMP: DeclRefExpr{{.*}}ParmVar{{.*}}'CCC'
void func()
{
float *A;
float *B;
float *C;
//#pragma omp dispatch
Foo(A, B);
Foo(A, B, C);
}
typedef void *omp_interop_t;

View File

@@ -1,10 +1,10 @@
// RUN: %clang_cc1 -verify -triple x86_64-unknown-linux -fopenmp -std=c++11 -o - %s
// RUN: %clang_cc1 -verify -triple x86_64-unknown-linux -fopenmp -std=c++11 \
// RUN: %clang_cc1 -verify -triple x86_64-unknown-linux -fopenmp -fopenmp-version=60 -std=c++11 -o - %s
// RUN: %clang_cc1 -verify -triple x86_64-unknown-linux -fopenmp -fopenmp-version=60 -std=c++11 \
// RUN: -DNO_INTEROP_T_DEF -o - %s
// RUN: %clang_cc1 -verify -triple x86_64-unknown-linux -fopenmp -std=c++11 -o - %s
// RUN: %clang_cc1 -verify -triple x86_64-unknown-linux -fopenmp -Wno-strict-prototypes -DC -x c -o - %s
// RUN: %clang_cc1 -verify -triple x86_64-unknown-linux -fopenmp -fopenmp-version=60 -std=c++11 -o - %s
// RUN: %clang_cc1 -verify -triple x86_64-unknown-linux -fopenmp -fopenmp-version=60 -Wno-strict-prototypes -DC -x c -o - %s
// RUN: %clang_cc1 -verify -triple x86_64-pc-windows-msvc -fms-compatibility \
// RUN: -fopenmp -Wno-strict-prototypes -DC -DWIN -x c -o - %s
// RUN: -fopenmp -fopenmp-version=60 -Wno-strict-prototypes -DC -DWIN -x c -o - %s
#ifdef NO_INTEROP_T_DEF
void foo_v1(float *, void *);
@@ -114,6 +114,16 @@ void vararg_bar2(const char *fmt) { return; }
match(construct={dispatch}, device={arch(ppc)}), \
adjust_args(need_device_ptr:AAA) adjust_args(nothing:AAA)
// expected-error@+3 {{'adjust_arg' argument 'AAA' used in multiple clauses}}
#pragma omp declare variant(foo_v1) \
match(construct={dispatch}, device={arch(arm)}) \
adjust_args(need_device_ptr:AAA,BBB) adjust_args(need_device_addr:AAA)
// expected-error@+3 {{'adjust_arg' argument 'AAA' used in multiple clauses}}
#pragma omp declare variant(foo_v1) \
match(construct={dispatch}, device={arch(ppc)}), \
adjust_args(need_device_addr:AAA) adjust_args(nothing:AAA)
// expected-error@+2 {{use of undeclared identifier 'J'}}
#pragma omp declare variant(foo_v1) \
adjust_args(nothing:J) \
@@ -186,12 +196,12 @@ void vararg_bar2(const char *fmt) { return; }
// expected-error@+1 {{variant in '#pragma omp declare variant' with type 'void (float *, float *, int *, omp_interop_t)' (aka 'void (float *, float *, int *, void *)') is incompatible with type 'void (float *, float *, int *)'}}
#pragma omp declare variant(foo_v4) match(construct={dispatch})
// expected-error@+3 {{incorrect adjust_args type, expected 'need_device_ptr' or 'nothing'}}
// expected-error@+3 {{incorrect 'adjust_args' type, expected 'need_device_ptr', 'need_device_addr', or 'nothing'}}
#pragma omp declare variant(foo_v1) \
match(construct={dispatch}, device={arch(arm)}) \
adjust_args(badaaop:AAA,BBB)
// expected-error@+3 {{incorrect adjust_args type, expected 'need_device_ptr' or 'nothing'}}
// expected-error@+3 {{incorrect 'adjust_args' type, expected 'need_device_ptr', 'need_device_addr', or 'nothing'}}
#pragma omp declare variant(foo_v1) \
match(construct={dispatch}, device={arch(arm)}) \
adjust_args(badaaop AAA,BBB)