[flang][cuda] Add option to disable warp function in semantic (#143640)

These functions are not available in some lower compute capabilities.
Add option in the language feature to enforce the semantic check on
these.
This commit is contained in:
Valentin Clement (バレンタイン クレメン)
2025-06-10 22:10:26 -07:00
committed by GitHub
parent 3ece9b06a2
commit a3201ce9e1
4 changed files with 101 additions and 44 deletions

View File

@@ -55,7 +55,7 @@ ENUM_CLASS(LanguageFeature, BackslashEscapes, OldDebugLines,
SavedLocalInSpecExpr, PrintNamelist, AssumedRankPassedToNonAssumedRank,
IgnoreIrrelevantAttributes, Unsigned, AmbiguousStructureConstructor,
ContiguousOkForSeqAssociation, ForwardRefExplicitTypeDummy,
InaccessibleDeferredOverride)
InaccessibleDeferredOverride, CudaWarpMatchFunction)
// Portability and suspicious usage warnings
ENUM_CLASS(UsageWarning, Portability, PointerToUndefinable,

View File

@@ -17,6 +17,7 @@
#include "flang/Semantics/expression.h"
#include "flang/Semantics/symbol.h"
#include "flang/Semantics/tools.h"
#include "llvm/ADT/StringSet.h"
// Once labeled DO constructs have been canonicalized and their parse subtrees
// transformed into parser::DoConstructs, scan the parser::Blocks of the program
@@ -61,6 +62,11 @@ bool CanonicalizeCUDA(parser::Program &program) {
using MaybeMsg = std::optional<parser::MessageFormattedText>;
static const llvm::StringSet<> warpFunctions_ = {"match_all_syncjj",
"match_all_syncjx", "match_all_syncjf", "match_all_syncjd",
"match_any_syncjj", "match_any_syncjx", "match_any_syncjf",
"match_any_syncjd"};
// Traverses an evaluate::Expr<> in search of unsupported operations
// on the device.
@@ -68,7 +74,7 @@ struct DeviceExprChecker
: public evaluate::AnyTraverse<DeviceExprChecker, MaybeMsg> {
using Result = MaybeMsg;
using Base = evaluate::AnyTraverse<DeviceExprChecker, Result>;
DeviceExprChecker() : Base(*this) {}
explicit DeviceExprChecker(SemanticsContext &c) : Base(*this), context_{c} {}
using Base::operator();
Result operator()(const evaluate::ProcedureDesignator &x) const {
if (const Symbol * sym{x.GetInterfaceSymbol()}) {
@@ -78,10 +84,17 @@ struct DeviceExprChecker
if (auto attrs{subp->cudaSubprogramAttrs()}) {
if (*attrs == common::CUDASubprogramAttrs::HostDevice ||
*attrs == common::CUDASubprogramAttrs::Device) {
if (warpFunctions_.contains(sym->name().ToString()) &&
!context_.languageFeatures().IsEnabled(
Fortran::common::LanguageFeature::CudaWarpMatchFunction)) {
return parser::MessageFormattedText(
"warp match function disabled"_err_en_US);
}
return {};
}
}
}
const Symbol &ultimate{sym->GetUltimate()};
const Scope &scope{ultimate.owner()};
const Symbol *mod{scope.IsModule() ? scope.symbol() : nullptr};
@@ -94,9 +107,12 @@ struct DeviceExprChecker
// TODO(CUDA): Check for unsupported intrinsics here
return {};
}
return parser::MessageFormattedText(
"'%s' may not be called in device code"_err_en_US, x.GetName());
}
SemanticsContext &context_;
};
struct FindHostArray
@@ -133,9 +149,10 @@ struct FindHostArray
}
};
template <typename A> static MaybeMsg CheckUnwrappedExpr(const A &x) {
template <typename A>
static MaybeMsg CheckUnwrappedExpr(SemanticsContext &context, const A &x) {
if (const auto *expr{parser::Unwrap<parser::Expr>(x)}) {
return DeviceExprChecker{}(expr->typedExpr);
return DeviceExprChecker{context}(expr->typedExpr);
}
return {};
}
@@ -144,104 +161,124 @@ template <typename A>
static void CheckUnwrappedExpr(
SemanticsContext &context, SourceName at, const A &x) {
if (const auto *expr{parser::Unwrap<parser::Expr>(x)}) {
if (auto msg{DeviceExprChecker{}(expr->typedExpr)}) {
if (auto msg{DeviceExprChecker{context}(expr->typedExpr)}) {
context.Say(at, std::move(*msg));
}
}
}
template <bool CUF_KERNEL> struct ActionStmtChecker {
template <typename A> static MaybeMsg WhyNotOk(const A &x) {
template <typename A>
static MaybeMsg WhyNotOk(SemanticsContext &context, const A &x) {
if constexpr (ConstraintTrait<A>) {
return WhyNotOk(x.thing);
return WhyNotOk(context, x.thing);
} else if constexpr (WrapperTrait<A>) {
return WhyNotOk(x.v);
return WhyNotOk(context, x.v);
} else if constexpr (UnionTrait<A>) {
return WhyNotOk(x.u);
return WhyNotOk(context, x.u);
} else if constexpr (TupleTrait<A>) {
return WhyNotOk(x.t);
return WhyNotOk(context, x.t);
} else {
return parser::MessageFormattedText{
"Statement may not appear in device code"_err_en_US};
}
}
template <typename A>
static MaybeMsg WhyNotOk(const common::Indirection<A> &x) {
return WhyNotOk(x.value());
static MaybeMsg WhyNotOk(
SemanticsContext &context, const common::Indirection<A> &x) {
return WhyNotOk(context, x.value());
}
template <typename... As>
static MaybeMsg WhyNotOk(const std::variant<As...> &x) {
return common::visit([](const auto &x) { return WhyNotOk(x); }, x);
static MaybeMsg WhyNotOk(
SemanticsContext &context, const std::variant<As...> &x) {
return common::visit(
[&context](const auto &x) { return WhyNotOk(context, x); }, x);
}
template <std::size_t J = 0, typename... As>
static MaybeMsg WhyNotOk(const std::tuple<As...> &x) {
static MaybeMsg WhyNotOk(
SemanticsContext &context, const std::tuple<As...> &x) {
if constexpr (J == sizeof...(As)) {
return {};
} else if (auto msg{WhyNotOk(std::get<J>(x))}) {
} else if (auto msg{WhyNotOk(context, std::get<J>(x))}) {
return msg;
} else {
return WhyNotOk<(J + 1)>(x);
return WhyNotOk<(J + 1)>(context, x);
}
}
template <typename A> static MaybeMsg WhyNotOk(const std::list<A> &x) {
template <typename A>
static MaybeMsg WhyNotOk(SemanticsContext &context, const std::list<A> &x) {
for (const auto &y : x) {
if (MaybeMsg result{WhyNotOk(y)}) {
if (MaybeMsg result{WhyNotOk(context, y)}) {
return result;
}
}
return {};
}
template <typename A> static MaybeMsg WhyNotOk(const std::optional<A> &x) {
template <typename A>
static MaybeMsg WhyNotOk(
SemanticsContext &context, const std::optional<A> &x) {
if (x) {
return WhyNotOk(*x);
return WhyNotOk(context, *x);
} else {
return {};
}
}
template <typename A>
static MaybeMsg WhyNotOk(const parser::UnlabeledStatement<A> &x) {
return WhyNotOk(x.statement);
static MaybeMsg WhyNotOk(
SemanticsContext &context, const parser::UnlabeledStatement<A> &x) {
return WhyNotOk(context, x.statement);
}
template <typename A>
static MaybeMsg WhyNotOk(const parser::Statement<A> &x) {
return WhyNotOk(x.statement);
static MaybeMsg WhyNotOk(
SemanticsContext &context, const parser::Statement<A> &x) {
return WhyNotOk(context, x.statement);
}
static MaybeMsg WhyNotOk(const parser::AllocateStmt &) {
static MaybeMsg WhyNotOk(
SemanticsContext &context, const parser::AllocateStmt &) {
return {}; // AllocateObjects are checked elsewhere
}
static MaybeMsg WhyNotOk(const parser::AllocateCoarraySpec &) {
static MaybeMsg WhyNotOk(
SemanticsContext &context, const parser::AllocateCoarraySpec &) {
return parser::MessageFormattedText(
"A coarray may not be allocated on the device"_err_en_US);
}
static MaybeMsg WhyNotOk(const parser::DeallocateStmt &) {
static MaybeMsg WhyNotOk(
SemanticsContext &context, const parser::DeallocateStmt &) {
return {}; // AllocateObjects are checked elsewhere
}
static MaybeMsg WhyNotOk(const parser::AssignmentStmt &x) {
return DeviceExprChecker{}(x.typedAssignment);
static MaybeMsg WhyNotOk(
SemanticsContext &context, const parser::AssignmentStmt &x) {
return DeviceExprChecker{context}(x.typedAssignment);
}
static MaybeMsg WhyNotOk(const parser::CallStmt &x) {
return DeviceExprChecker{}(x.typedCall);
static MaybeMsg WhyNotOk(
SemanticsContext &context, const parser::CallStmt &x) {
return DeviceExprChecker{context}(x.typedCall);
}
static MaybeMsg WhyNotOk(const parser::ContinueStmt &) { return {}; }
static MaybeMsg WhyNotOk(const parser::IfStmt &x) {
if (auto result{
CheckUnwrappedExpr(std::get<parser::ScalarLogicalExpr>(x.t))}) {
static MaybeMsg WhyNotOk(
SemanticsContext &context, const parser::ContinueStmt &) {
return {};
}
static MaybeMsg WhyNotOk(SemanticsContext &context, const parser::IfStmt &x) {
if (auto result{CheckUnwrappedExpr(
context, std::get<parser::ScalarLogicalExpr>(x.t))}) {
return result;
}
return WhyNotOk(
return WhyNotOk(context,
std::get<parser::UnlabeledStatement<parser::ActionStmt>>(x.t)
.statement);
}
static MaybeMsg WhyNotOk(const parser::NullifyStmt &x) {
static MaybeMsg WhyNotOk(
SemanticsContext &context, const parser::NullifyStmt &x) {
for (const auto &y : x.v) {
if (MaybeMsg result{DeviceExprChecker{}(y.typedExpr)}) {
if (MaybeMsg result{DeviceExprChecker{context}(y.typedExpr)}) {
return result;
}
}
return {};
}
static MaybeMsg WhyNotOk(const parser::PointerAssignmentStmt &x) {
return DeviceExprChecker{}(x.typedAssignment);
static MaybeMsg WhyNotOk(
SemanticsContext &context, const parser::PointerAssignmentStmt &x) {
return DeviceExprChecker{context}(x.typedAssignment);
}
};
@@ -435,12 +472,14 @@ private:
ErrorIfHostSymbol(assign->lhs, source);
ErrorIfHostSymbol(assign->rhs, source);
}
if (auto msg{ActionStmtChecker<IsCUFKernelDo>::WhyNotOk(x)}) {
if (auto msg{ActionStmtChecker<IsCUFKernelDo>::WhyNotOk(
context_, x)}) {
context_.Say(source, std::move(*msg));
}
},
[&](const auto &x) {
if (auto msg{ActionStmtChecker<IsCUFKernelDo>::WhyNotOk(x)}) {
if (auto msg{ActionStmtChecker<IsCUFKernelDo>::WhyNotOk(
context_, x)}) {
context_.Say(source, std::move(*msg));
}
},
@@ -504,7 +543,7 @@ private:
Check(DEREF(parser::Unwrap<parser::Expr>(x)));
}
void Check(const parser::Expr &expr) {
if (MaybeMsg msg{DeviceExprChecker{}(expr.typedExpr)}) {
if (MaybeMsg msg{DeviceExprChecker{context_}(expr.typedExpr)}) {
context_.Say(expr.source, std::move(*msg));
}
}

View File

@@ -0,0 +1,8 @@
! RUN: not bbc -fcuda -fcuda-disable-warp-function %s -o - 2>&1 | FileCheck %s
attributes(device) subroutine testMatch()
integer :: a, ipred, mask, v32
a = match_all_sync(mask, v32, ipred)
end subroutine
! CHECK: warp match function disabled

View File

@@ -223,6 +223,11 @@ static llvm::cl::opt<bool> enableCUDA("fcuda",
llvm::cl::desc("enable CUDA Fortran"),
llvm::cl::init(false));
static llvm::cl::opt<bool>
disableCUDAWarpFunction("fcuda-disable-warp-function",
llvm::cl::desc("Disable CUDA Warp Function"),
llvm::cl::init(false));
static llvm::cl::opt<std::string>
enableGPUMode("gpu", llvm::cl::desc("Enable GPU Mode managed|unified"),
llvm::cl::init(""));
@@ -600,6 +605,11 @@ int main(int argc, char **argv) {
options.features.Enable(Fortran::common::LanguageFeature::CUDA);
}
if (disableCUDAWarpFunction) {
options.features.Enable(
Fortran::common::LanguageFeature::CudaWarpMatchFunction, false);
}
if (enableGPUMode == "managed") {
options.features.Enable(Fortran::common::LanguageFeature::CudaManaged);
} else if (enableGPUMode == "unified") {