[CudaSPIRV] Allow using integral non-type template parameters as attribute args (#131546)

Allow using integral non-type template parameters as attribute arguments
of
reqd_work_group_size and work_group_size_hint.

Test plan:
ninja check-all
This commit is contained in:
Alexander Shaposhnikov
2025-03-19 10:11:18 -07:00
committed by GitHub
parent d09ecb07c2
commit 297f0b3f4c
10 changed files with 221 additions and 41 deletions

View File

@@ -54,8 +54,12 @@ void SingleWorkItemBarrierCheck::check(const MatchFinder::MatchResult &Result) {
bool IsNDRange = false;
if (MatchedDecl->hasAttr<ReqdWorkGroupSizeAttr>()) {
const auto *Attribute = MatchedDecl->getAttr<ReqdWorkGroupSizeAttr>();
if (Attribute->getXDim() > 1 || Attribute->getYDim() > 1 ||
Attribute->getZDim() > 1)
auto Eval = [&](Expr *E) {
return E->EvaluateKnownConstInt(MatchedDecl->getASTContext())
.getExtValue();
};
if (Eval(Attribute->getXDim()) > 1 || Eval(Attribute->getYDim()) > 1 ||
Eval(Attribute->getZDim()) > 1)
IsNDRange = true;
}
if (IsNDRange) // No warning if kernel is treated as an NDRange.

View File

@@ -3044,8 +3044,7 @@ def NoDeref : TypeAttr {
def ReqdWorkGroupSize : InheritableAttr {
// Does not have a [[]] spelling because it is an OpenCL-related attribute.
let Spellings = [GNU<"reqd_work_group_size">];
let Args = [UnsignedArgument<"XDim">, UnsignedArgument<"YDim">,
UnsignedArgument<"ZDim">];
let Args = [ExprArgument<"XDim">, ExprArgument<"YDim">, ExprArgument<"ZDim">];
let Subjects = SubjectList<[Function], ErrorDiag>;
let Documentation = [Undocumented];
}
@@ -3053,9 +3052,7 @@ def ReqdWorkGroupSize : InheritableAttr {
def WorkGroupSizeHint : InheritableAttr {
// Does not have a [[]] spelling because it is an OpenCL-related attribute.
let Spellings = [GNU<"work_group_size_hint">];
let Args = [UnsignedArgument<"XDim">,
UnsignedArgument<"YDim">,
UnsignedArgument<"ZDim">];
let Args = [ExprArgument<"XDim">, ExprArgument<"YDim">, ExprArgument<"ZDim">];
let Subjects = SubjectList<[Function], ErrorDiag>;
let Documentation = [Undocumented];
}

View File

@@ -649,18 +649,24 @@ void CodeGenFunction::EmitKernelMetadata(const FunctionDecl *FD,
}
if (const WorkGroupSizeHintAttr *A = FD->getAttr<WorkGroupSizeHintAttr>()) {
auto Eval = [&](Expr *E) {
return E->EvaluateKnownConstInt(FD->getASTContext()).getExtValue();
};
llvm::Metadata *AttrMDArgs[] = {
llvm::ConstantAsMetadata::get(Builder.getInt32(A->getXDim())),
llvm::ConstantAsMetadata::get(Builder.getInt32(A->getYDim())),
llvm::ConstantAsMetadata::get(Builder.getInt32(A->getZDim()))};
llvm::ConstantAsMetadata::get(Builder.getInt32(Eval(A->getXDim()))),
llvm::ConstantAsMetadata::get(Builder.getInt32(Eval(A->getYDim()))),
llvm::ConstantAsMetadata::get(Builder.getInt32(Eval(A->getZDim())))};
Fn->setMetadata("work_group_size_hint", llvm::MDNode::get(Context, AttrMDArgs));
}
if (const ReqdWorkGroupSizeAttr *A = FD->getAttr<ReqdWorkGroupSizeAttr>()) {
auto Eval = [&](Expr *E) {
return E->EvaluateKnownConstInt(FD->getASTContext()).getExtValue();
};
llvm::Metadata *AttrMDArgs[] = {
llvm::ConstantAsMetadata::get(Builder.getInt32(A->getXDim())),
llvm::ConstantAsMetadata::get(Builder.getInt32(A->getYDim())),
llvm::ConstantAsMetadata::get(Builder.getInt32(A->getZDim()))};
llvm::ConstantAsMetadata::get(Builder.getInt32(Eval(A->getXDim()))),
llvm::ConstantAsMetadata::get(Builder.getInt32(Eval(A->getYDim()))),
llvm::ConstantAsMetadata::get(Builder.getInt32(Eval(A->getZDim())))};
Fn->setMetadata("reqd_work_group_size", llvm::MDNode::get(Context, AttrMDArgs));
}

View File

@@ -753,12 +753,16 @@ void CodeGenModule::handleAMDGPUFlatWorkGroupSizeAttr(
int32_t *MaxThreadsVal) {
unsigned Min = 0;
unsigned Max = 0;
auto Eval = [&](Expr *E) {
return E->EvaluateKnownConstInt(getContext()).getExtValue();
};
if (FlatWGS) {
Min = FlatWGS->getMin()->EvaluateKnownConstInt(getContext()).getExtValue();
Max = FlatWGS->getMax()->EvaluateKnownConstInt(getContext()).getExtValue();
Min = Eval(FlatWGS->getMin());
Max = Eval(FlatWGS->getMax());
}
if (ReqdWGS && Min == 0 && Max == 0)
Min = Max = ReqdWGS->getXDim() * ReqdWGS->getYDim() * ReqdWGS->getZDim();
Min = Max = Eval(ReqdWGS->getXDim()) * Eval(ReqdWGS->getYDim()) *
Eval(ReqdWGS->getZDim());
if (Min != 0) {
assert(Min <= Max && "Min must be less than or equal Max");

View File

@@ -50,24 +50,21 @@ void TCETargetCodeGenInfo::setTargetAttributes(
M.getModule().getOrInsertNamedMetadata(
"opencl.kernel_wg_size_info");
SmallVector<llvm::Metadata *, 5> Operands;
Operands.push_back(llvm::ConstantAsMetadata::get(F));
Operands.push_back(
auto Eval = [&](Expr *E) {
return E->EvaluateKnownConstInt(FD->getASTContext());
};
SmallVector<llvm::Metadata *, 5> Operands{
llvm::ConstantAsMetadata::get(F),
llvm::ConstantAsMetadata::get(llvm::Constant::getIntegerValue(
M.Int32Ty, llvm::APInt(32, Attr->getXDim()))));
Operands.push_back(
M.Int32Ty, Eval(Attr->getXDim()))),
llvm::ConstantAsMetadata::get(llvm::Constant::getIntegerValue(
M.Int32Ty, llvm::APInt(32, Attr->getYDim()))));
Operands.push_back(
M.Int32Ty, Eval(Attr->getYDim()))),
llvm::ConstantAsMetadata::get(llvm::Constant::getIntegerValue(
M.Int32Ty, llvm::APInt(32, Attr->getZDim()))));
// Add a boolean constant operand for "required" (true) or "hint"
// (false) for implementing the work_group_size_hint attr later.
// Currently always true as the hint is not yet implemented.
Operands.push_back(
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getTrue(Context)));
M.Int32Ty, Eval(Attr->getZDim()))),
// Add a boolean constant operand for "required" (true) or "hint"
// (false) for implementing the work_group_size_hint attr later.
// Currently always true as the hint is not yet implemented.
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getTrue(Context))};
OpenCLMetadata->addOperand(llvm::MDNode::get(Context, Operands));
}
}

View File

@@ -2914,21 +2914,70 @@ static void handleWeakImportAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
D->addAttr(::new (S.Context) WeakImportAttr(S.Context, AL));
}
// Checks whether an argument of launch_bounds-like attribute is
// acceptable, performs implicit conversion to Rvalue, and returns
// non-nullptr Expr result on success. Otherwise, it returns nullptr
// and may output an error.
template <class Attribute>
static Expr *makeAttributeArgExpr(Sema &S, Expr *E, const Attribute &Attr,
const unsigned Idx) {
if (S.DiagnoseUnexpandedParameterPack(E))
return nullptr;
// Accept template arguments for now as they depend on something else.
// We'll get to check them when they eventually get instantiated.
if (E->isValueDependent())
return E;
std::optional<llvm::APSInt> I = llvm::APSInt(64);
if (!(I = E->getIntegerConstantExpr(S.Context))) {
S.Diag(E->getExprLoc(), diag::err_attribute_argument_n_type)
<< &Attr << Idx << AANT_ArgumentIntegerConstant << E->getSourceRange();
return nullptr;
}
// Make sure we can fit it in 32 bits.
if (!I->isIntN(32)) {
S.Diag(E->getExprLoc(), diag::err_ice_too_large)
<< toString(*I, 10, false) << 32 << /* Unsigned */ 1;
return nullptr;
}
if (*I < 0)
S.Diag(E->getExprLoc(), diag::err_attribute_requires_positive_integer)
<< &Attr << /*non-negative*/ 1 << E->getSourceRange();
// We may need to perform implicit conversion of the argument.
InitializedEntity Entity = InitializedEntity::InitializeParameter(
S.Context, S.Context.getConstType(S.Context.IntTy), /*consume*/ false);
ExprResult ValArg = S.PerformCopyInitialization(Entity, SourceLocation(), E);
assert(!ValArg.isInvalid() &&
"Unexpected PerformCopyInitialization() failure.");
return ValArg.getAs<Expr>();
}
// Handles reqd_work_group_size and work_group_size_hint.
template <typename WorkGroupAttr>
static void handleWorkGroupSize(Sema &S, Decl *D, const ParsedAttr &AL) {
uint32_t WGSize[3];
Expr *WGSize[3];
for (unsigned i = 0; i < 3; ++i) {
const Expr *E = AL.getArgAsExpr(i);
if (!S.checkUInt32Argument(AL, E, WGSize[i], i,
/*StrictlyUnsigned=*/true))
if (Expr *E = makeAttributeArgExpr(S, AL.getArgAsExpr(i), AL, i))
WGSize[i] = E;
else
return;
}
if (!llvm::all_of(WGSize, [](uint32_t Size) { return Size == 0; })) {
auto IsZero = [&](Expr *E) {
if (E->isValueDependent())
return false;
std::optional<llvm::APSInt> I = E->getIntegerConstantExpr(S.Context);
assert(I && "Non-integer constant expr");
return I->isZero();
};
if (!llvm::all_of(WGSize, IsZero)) {
for (unsigned i = 0; i < 3; ++i) {
const Expr *E = AL.getArgAsExpr(i);
if (WGSize[i] == 0) {
if (IsZero(WGSize[i])) {
S.Diag(AL.getLoc(), diag::err_attribute_argument_is_zero)
<< AL << E->getSourceRange();
return;
@@ -2936,10 +2985,22 @@ static void handleWorkGroupSize(Sema &S, Decl *D, const ParsedAttr &AL) {
}
}
auto Equal = [&](Expr *LHS, Expr *RHS) {
if (LHS->isValueDependent() || RHS->isValueDependent())
return true;
std::optional<llvm::APSInt> L = LHS->getIntegerConstantExpr(S.Context);
assert(L && "Non-integer constant expr");
std::optional<llvm::APSInt> R = RHS->getIntegerConstantExpr(S.Context);
assert(L && "Non-integer constant expr");
return L == R;
};
WorkGroupAttr *Existing = D->getAttr<WorkGroupAttr>();
if (Existing && !(Existing->getXDim() == WGSize[0] &&
Existing->getYDim() == WGSize[1] &&
Existing->getZDim() == WGSize[2]))
if (Existing &&
!llvm::equal(std::initializer_list<Expr *>{Existing->getXDim(),
Existing->getYDim(),
Existing->getZDim()},
WGSize, Equal))
S.Diag(AL.getLoc(), diag::warn_duplicate_attribute) << AL;
D->addAttr(::new (S.Context)

View File

@@ -572,6 +572,32 @@ static void instantiateDependentAMDGPUFlatWorkGroupSizeAttr(
S.AMDGPU().addAMDGPUFlatWorkGroupSizeAttr(New, Attr, MinExpr, MaxExpr);
}
static void instantiateDependentReqdWorkGroupSizeAttr(
Sema &S, const MultiLevelTemplateArgumentList &TemplateArgs,
const ReqdWorkGroupSizeAttr &Attr, Decl *New) {
// Both min and max expression are constant expressions.
EnterExpressionEvaluationContext Unevaluated(
S, Sema::ExpressionEvaluationContext::ConstantEvaluated);
ExprResult Result = S.SubstExpr(Attr.getXDim(), TemplateArgs);
if (Result.isInvalid())
return;
Expr *X = Result.getAs<Expr>();
Result = S.SubstExpr(Attr.getYDim(), TemplateArgs);
if (Result.isInvalid())
return;
Expr *Y = Result.getAs<Expr>();
Result = S.SubstExpr(Attr.getZDim(), TemplateArgs);
if (Result.isInvalid())
return;
Expr *Z = Result.getAs<Expr>();
ASTContext &Context = S.getASTContext();
New->addAttr(::new (Context) ReqdWorkGroupSizeAttr(Context, Attr, X, Y, Z));
}
ExplicitSpecifier Sema::instantiateExplicitSpecifier(
const MultiLevelTemplateArgumentList &TemplateArgs, ExplicitSpecifier ES) {
if (!ES.getExpr())
@@ -812,6 +838,12 @@ void Sema::InstantiateAttrs(const MultiLevelTemplateArgumentList &TemplateArgs,
continue;
}
if (const auto *ReqdWorkGroupSize =
dyn_cast<ReqdWorkGroupSizeAttr>(TmplAttr)) {
instantiateDependentReqdWorkGroupSizeAttr(*this, TemplateArgs,
*ReqdWorkGroupSize, New);
}
if (const auto *AMDGPUFlatWorkGroupSize =
dyn_cast<AMDGPUFlatWorkGroupSizeAttr>(TmplAttr)) {
instantiateDependentAMDGPUFlatWorkGroupSizeAttr(

View File

@@ -18,12 +18,18 @@ __global__ void vec_type_hint_int() {}
__attribute__((intel_reqd_sub_group_size(64)))
__global__ void intel_reqd_sub_group_size_64() {}
template <unsigned a, unsigned b, unsigned c>
__attribute__((reqd_work_group_size(a, b, c)))
__global__ void reqd_work_group_size_a_b_c() {}
template __global__ void reqd_work_group_size_a_b_c<256,1,1>(void);
// CHECK: define spir_kernel void @_Z26reqd_work_group_size_0_0_0v() #[[ATTR:[0-9]+]] !reqd_work_group_size ![[WG_SIZE_ZEROS:[0-9]+]]
// CHECK: define spir_kernel void @_Z28reqd_work_group_size_128_1_1v() #[[ATTR:[0-9]+]] !reqd_work_group_size ![[WG_SIZE:[0-9]+]]
// CHECK: define spir_kernel void @_Z26work_group_size_hint_2_2_2v() #[[ATTR]] !work_group_size_hint ![[WG_HINT:[0-9]+]]
// CHECK: define spir_kernel void @_Z17vec_type_hint_intv() #[[ATTR]] !vec_type_hint ![[VEC_HINT:[0-9]+]]
// CHECK: define spir_kernel void @_Z28intel_reqd_sub_group_size_64v() #[[ATTR]] !intel_reqd_sub_group_size ![[SUB_GRP:[0-9]+]]
// CHECK: define spir_kernel void @_Z26reqd_work_group_size_a_b_cILj256ELj1ELj1EEvv() #[[ATTR]] comdat !reqd_work_group_size ![[WG_SIZE_TMPL:[0-9]+]]
// CHECK: attributes #[[ATTR]] = { {{.*}} }
@@ -32,3 +38,4 @@ __global__ void intel_reqd_sub_group_size_64() {}
// CHECK: ![[WG_HINT]] = !{i32 2, i32 2, i32 2}
// CHECK: ![[VEC_HINT]] = !{i32 poison, i32 1}
// CHECK: ![[SUB_GRP]] = !{i32 64}
// CHECK: ![[WG_SIZE_TMPL]] = !{i32 256, i32 1, i32 1}

View File

@@ -0,0 +1,58 @@
// RUN: %clang_cc1 -triple spirv64 -aux-triple x86_64-unknown-linux-gnu \
// RUN: -fcuda-is-device -verify -fsyntax-only %s
#define __global__ __attribute__((global))
__attribute__((reqd_work_group_size(0x100000000, 1, 1))) // expected-error {{integer constant expression evaluates to value 4294967296 that cannot be represented in a 32-bit unsigned integer type}}
__global__ void TestTooBigArg1(void);
__attribute__((work_group_size_hint(0x100000000, 1, 1))) // expected-error {{integer constant expression evaluates to value 4294967296 that cannot be represented in a 32-bit unsigned integer type}}
__global__ void TestTooBigArg2(void);
template <int... Args>
__attribute__((reqd_work_group_size(Args))) // expected-error {{expression contains unexpanded parameter pack 'Args'}}
__global__ void TestTemplateVariadicArgs1(void) {}
template <int... Args>
__attribute__((work_group_size_hint(Args))) // expected-error {{expression contains unexpanded parameter pack 'Args'}}
__global__ void TestTemplateVariadicArgs2(void) {}
template <class a> // expected-note {{declared here}}
__attribute__((reqd_work_group_size(a, 1, 1))) // expected-error {{'a' does not refer to a value}}
__global__ void TestTemplateArgClass1(void) {}
template <class a> // expected-note {{declared here}}
__attribute__((work_group_size_hint(a, 1, 1))) // expected-error {{'a' does not refer to a value}}
__global__ void TestTemplateArgClass2(void) {}
constexpr int A = 512;
__attribute__((reqd_work_group_size(A, A, A)))
__global__ void TestConstIntArg1(void) {}
__attribute__((work_group_size_hint(A, A, A)))
__global__ void TestConstIntArg2(void) {}
int B = 512;
__attribute__((reqd_work_group_size(B, 1, 1))) // expected-error {{attribute requires parameter 0 to be an integer constant}}
__global__ void TestNonConstIntArg1(void) {}
__attribute__((work_group_size_hint(B, 1, 1))) // expected-error {{attribute requires parameter 0 to be an integer constant}}
__global__ void TestNonConstIntArg2(void) {}
constexpr int C = -512;
__attribute__((reqd_work_group_size(C, 1, 1))) // expected-error {{attribute requires a non-negative integral compile time constant expression}}
__global__ void TestNegativeConstIntArg1(void) {}
__attribute__((work_group_size_hint(C, 1, 1))) // expected-error {{attribute requires a non-negative integral compile time constant expression}}
__global__ void TestNegativeConstIntArg2(void) {}
__attribute__((reqd_work_group_size(A, 0, 1))) // expected-error {{attribute must be greater than 0}}
__global__ void TestZeroArg1(void) {}
__attribute__((work_group_size_hint(A, 0, 1))) // expected-error {{attribute must be greater than 0}}
__global__ void TestZeroArg2(void) {}

View File

@@ -3,14 +3,28 @@
// RUN: %clang_cc1 -triple spirv64 -aux-triple x86_64-unknown-linux-gnu \
// RUN: -fcuda-is-device -verify -fsyntax-only %s
#include "Inputs/cuda.h"
#define __global__ __attribute__((global))
__attribute__((reqd_work_group_size(128, 1, 1)))
__global__ void reqd_work_group_size_128_1_1() {}
template <unsigned a, unsigned b, unsigned c>
__attribute__((reqd_work_group_size(a, b, c)))
__global__ void reqd_work_group_size_a_b_c() {}
template <>
__global__ void reqd_work_group_size_a_b_c<128,1,1>(void);
__attribute__((work_group_size_hint(2, 2, 2)))
__global__ void work_group_size_hint_2_2_2() {}
template <unsigned a, unsigned b, unsigned c>
__attribute__((work_group_size_hint(a, b, c)))
__global__ void work_group_size_hint_a_b_c() {}
template <>
__global__ void work_group_size_hint_a_b_c<128,1,1>(void);
__attribute__((vec_type_hint(int)))
__global__ void vec_type_hint_int() {}