[flang][cuda] Add option to disable warp function in semantic (#143640)
These functions are not available in some lower compute capabilities. Add option in the language feature to enforce the semantic check on these.
This commit is contained in:
committed by
GitHub
parent
3ece9b06a2
commit
a3201ce9e1
@@ -55,7 +55,7 @@ ENUM_CLASS(LanguageFeature, BackslashEscapes, OldDebugLines,
|
||||
SavedLocalInSpecExpr, PrintNamelist, AssumedRankPassedToNonAssumedRank,
|
||||
IgnoreIrrelevantAttributes, Unsigned, AmbiguousStructureConstructor,
|
||||
ContiguousOkForSeqAssociation, ForwardRefExplicitTypeDummy,
|
||||
InaccessibleDeferredOverride)
|
||||
InaccessibleDeferredOverride, CudaWarpMatchFunction)
|
||||
|
||||
// Portability and suspicious usage warnings
|
||||
ENUM_CLASS(UsageWarning, Portability, PointerToUndefinable,
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
#include "flang/Semantics/expression.h"
|
||||
#include "flang/Semantics/symbol.h"
|
||||
#include "flang/Semantics/tools.h"
|
||||
#include "llvm/ADT/StringSet.h"
|
||||
|
||||
// Once labeled DO constructs have been canonicalized and their parse subtrees
|
||||
// transformed into parser::DoConstructs, scan the parser::Blocks of the program
|
||||
@@ -61,6 +62,11 @@ bool CanonicalizeCUDA(parser::Program &program) {
|
||||
|
||||
using MaybeMsg = std::optional<parser::MessageFormattedText>;
|
||||
|
||||
static const llvm::StringSet<> warpFunctions_ = {"match_all_syncjj",
|
||||
"match_all_syncjx", "match_all_syncjf", "match_all_syncjd",
|
||||
"match_any_syncjj", "match_any_syncjx", "match_any_syncjf",
|
||||
"match_any_syncjd"};
|
||||
|
||||
// Traverses an evaluate::Expr<> in search of unsupported operations
|
||||
// on the device.
|
||||
|
||||
@@ -68,7 +74,7 @@ struct DeviceExprChecker
|
||||
: public evaluate::AnyTraverse<DeviceExprChecker, MaybeMsg> {
|
||||
using Result = MaybeMsg;
|
||||
using Base = evaluate::AnyTraverse<DeviceExprChecker, Result>;
|
||||
DeviceExprChecker() : Base(*this) {}
|
||||
explicit DeviceExprChecker(SemanticsContext &c) : Base(*this), context_{c} {}
|
||||
using Base::operator();
|
||||
Result operator()(const evaluate::ProcedureDesignator &x) const {
|
||||
if (const Symbol * sym{x.GetInterfaceSymbol()}) {
|
||||
@@ -78,10 +84,17 @@ struct DeviceExprChecker
|
||||
if (auto attrs{subp->cudaSubprogramAttrs()}) {
|
||||
if (*attrs == common::CUDASubprogramAttrs::HostDevice ||
|
||||
*attrs == common::CUDASubprogramAttrs::Device) {
|
||||
if (warpFunctions_.contains(sym->name().ToString()) &&
|
||||
!context_.languageFeatures().IsEnabled(
|
||||
Fortran::common::LanguageFeature::CudaWarpMatchFunction)) {
|
||||
return parser::MessageFormattedText(
|
||||
"warp match function disabled"_err_en_US);
|
||||
}
|
||||
return {};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const Symbol &ultimate{sym->GetUltimate()};
|
||||
const Scope &scope{ultimate.owner()};
|
||||
const Symbol *mod{scope.IsModule() ? scope.symbol() : nullptr};
|
||||
@@ -94,9 +107,12 @@ struct DeviceExprChecker
|
||||
// TODO(CUDA): Check for unsupported intrinsics here
|
||||
return {};
|
||||
}
|
||||
|
||||
return parser::MessageFormattedText(
|
||||
"'%s' may not be called in device code"_err_en_US, x.GetName());
|
||||
}
|
||||
|
||||
SemanticsContext &context_;
|
||||
};
|
||||
|
||||
struct FindHostArray
|
||||
@@ -133,9 +149,10 @@ struct FindHostArray
|
||||
}
|
||||
};
|
||||
|
||||
template <typename A> static MaybeMsg CheckUnwrappedExpr(const A &x) {
|
||||
template <typename A>
|
||||
static MaybeMsg CheckUnwrappedExpr(SemanticsContext &context, const A &x) {
|
||||
if (const auto *expr{parser::Unwrap<parser::Expr>(x)}) {
|
||||
return DeviceExprChecker{}(expr->typedExpr);
|
||||
return DeviceExprChecker{context}(expr->typedExpr);
|
||||
}
|
||||
return {};
|
||||
}
|
||||
@@ -144,104 +161,124 @@ template <typename A>
|
||||
static void CheckUnwrappedExpr(
|
||||
SemanticsContext &context, SourceName at, const A &x) {
|
||||
if (const auto *expr{parser::Unwrap<parser::Expr>(x)}) {
|
||||
if (auto msg{DeviceExprChecker{}(expr->typedExpr)}) {
|
||||
if (auto msg{DeviceExprChecker{context}(expr->typedExpr)}) {
|
||||
context.Say(at, std::move(*msg));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <bool CUF_KERNEL> struct ActionStmtChecker {
|
||||
template <typename A> static MaybeMsg WhyNotOk(const A &x) {
|
||||
template <typename A>
|
||||
static MaybeMsg WhyNotOk(SemanticsContext &context, const A &x) {
|
||||
if constexpr (ConstraintTrait<A>) {
|
||||
return WhyNotOk(x.thing);
|
||||
return WhyNotOk(context, x.thing);
|
||||
} else if constexpr (WrapperTrait<A>) {
|
||||
return WhyNotOk(x.v);
|
||||
return WhyNotOk(context, x.v);
|
||||
} else if constexpr (UnionTrait<A>) {
|
||||
return WhyNotOk(x.u);
|
||||
return WhyNotOk(context, x.u);
|
||||
} else if constexpr (TupleTrait<A>) {
|
||||
return WhyNotOk(x.t);
|
||||
return WhyNotOk(context, x.t);
|
||||
} else {
|
||||
return parser::MessageFormattedText{
|
||||
"Statement may not appear in device code"_err_en_US};
|
||||
}
|
||||
}
|
||||
template <typename A>
|
||||
static MaybeMsg WhyNotOk(const common::Indirection<A> &x) {
|
||||
return WhyNotOk(x.value());
|
||||
static MaybeMsg WhyNotOk(
|
||||
SemanticsContext &context, const common::Indirection<A> &x) {
|
||||
return WhyNotOk(context, x.value());
|
||||
}
|
||||
template <typename... As>
|
||||
static MaybeMsg WhyNotOk(const std::variant<As...> &x) {
|
||||
return common::visit([](const auto &x) { return WhyNotOk(x); }, x);
|
||||
static MaybeMsg WhyNotOk(
|
||||
SemanticsContext &context, const std::variant<As...> &x) {
|
||||
return common::visit(
|
||||
[&context](const auto &x) { return WhyNotOk(context, x); }, x);
|
||||
}
|
||||
template <std::size_t J = 0, typename... As>
|
||||
static MaybeMsg WhyNotOk(const std::tuple<As...> &x) {
|
||||
static MaybeMsg WhyNotOk(
|
||||
SemanticsContext &context, const std::tuple<As...> &x) {
|
||||
if constexpr (J == sizeof...(As)) {
|
||||
return {};
|
||||
} else if (auto msg{WhyNotOk(std::get<J>(x))}) {
|
||||
} else if (auto msg{WhyNotOk(context, std::get<J>(x))}) {
|
||||
return msg;
|
||||
} else {
|
||||
return WhyNotOk<(J + 1)>(x);
|
||||
return WhyNotOk<(J + 1)>(context, x);
|
||||
}
|
||||
}
|
||||
template <typename A> static MaybeMsg WhyNotOk(const std::list<A> &x) {
|
||||
template <typename A>
|
||||
static MaybeMsg WhyNotOk(SemanticsContext &context, const std::list<A> &x) {
|
||||
for (const auto &y : x) {
|
||||
if (MaybeMsg result{WhyNotOk(y)}) {
|
||||
if (MaybeMsg result{WhyNotOk(context, y)}) {
|
||||
return result;
|
||||
}
|
||||
}
|
||||
return {};
|
||||
}
|
||||
template <typename A> static MaybeMsg WhyNotOk(const std::optional<A> &x) {
|
||||
template <typename A>
|
||||
static MaybeMsg WhyNotOk(
|
||||
SemanticsContext &context, const std::optional<A> &x) {
|
||||
if (x) {
|
||||
return WhyNotOk(*x);
|
||||
return WhyNotOk(context, *x);
|
||||
} else {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
template <typename A>
|
||||
static MaybeMsg WhyNotOk(const parser::UnlabeledStatement<A> &x) {
|
||||
return WhyNotOk(x.statement);
|
||||
static MaybeMsg WhyNotOk(
|
||||
SemanticsContext &context, const parser::UnlabeledStatement<A> &x) {
|
||||
return WhyNotOk(context, x.statement);
|
||||
}
|
||||
template <typename A>
|
||||
static MaybeMsg WhyNotOk(const parser::Statement<A> &x) {
|
||||
return WhyNotOk(x.statement);
|
||||
static MaybeMsg WhyNotOk(
|
||||
SemanticsContext &context, const parser::Statement<A> &x) {
|
||||
return WhyNotOk(context, x.statement);
|
||||
}
|
||||
static MaybeMsg WhyNotOk(const parser::AllocateStmt &) {
|
||||
static MaybeMsg WhyNotOk(
|
||||
SemanticsContext &context, const parser::AllocateStmt &) {
|
||||
return {}; // AllocateObjects are checked elsewhere
|
||||
}
|
||||
static MaybeMsg WhyNotOk(const parser::AllocateCoarraySpec &) {
|
||||
static MaybeMsg WhyNotOk(
|
||||
SemanticsContext &context, const parser::AllocateCoarraySpec &) {
|
||||
return parser::MessageFormattedText(
|
||||
"A coarray may not be allocated on the device"_err_en_US);
|
||||
}
|
||||
static MaybeMsg WhyNotOk(const parser::DeallocateStmt &) {
|
||||
static MaybeMsg WhyNotOk(
|
||||
SemanticsContext &context, const parser::DeallocateStmt &) {
|
||||
return {}; // AllocateObjects are checked elsewhere
|
||||
}
|
||||
static MaybeMsg WhyNotOk(const parser::AssignmentStmt &x) {
|
||||
return DeviceExprChecker{}(x.typedAssignment);
|
||||
static MaybeMsg WhyNotOk(
|
||||
SemanticsContext &context, const parser::AssignmentStmt &x) {
|
||||
return DeviceExprChecker{context}(x.typedAssignment);
|
||||
}
|
||||
static MaybeMsg WhyNotOk(const parser::CallStmt &x) {
|
||||
return DeviceExprChecker{}(x.typedCall);
|
||||
static MaybeMsg WhyNotOk(
|
||||
SemanticsContext &context, const parser::CallStmt &x) {
|
||||
return DeviceExprChecker{context}(x.typedCall);
|
||||
}
|
||||
static MaybeMsg WhyNotOk(const parser::ContinueStmt &) { return {}; }
|
||||
static MaybeMsg WhyNotOk(const parser::IfStmt &x) {
|
||||
if (auto result{
|
||||
CheckUnwrappedExpr(std::get<parser::ScalarLogicalExpr>(x.t))}) {
|
||||
static MaybeMsg WhyNotOk(
|
||||
SemanticsContext &context, const parser::ContinueStmt &) {
|
||||
return {};
|
||||
}
|
||||
static MaybeMsg WhyNotOk(SemanticsContext &context, const parser::IfStmt &x) {
|
||||
if (auto result{CheckUnwrappedExpr(
|
||||
context, std::get<parser::ScalarLogicalExpr>(x.t))}) {
|
||||
return result;
|
||||
}
|
||||
return WhyNotOk(
|
||||
return WhyNotOk(context,
|
||||
std::get<parser::UnlabeledStatement<parser::ActionStmt>>(x.t)
|
||||
.statement);
|
||||
}
|
||||
static MaybeMsg WhyNotOk(const parser::NullifyStmt &x) {
|
||||
static MaybeMsg WhyNotOk(
|
||||
SemanticsContext &context, const parser::NullifyStmt &x) {
|
||||
for (const auto &y : x.v) {
|
||||
if (MaybeMsg result{DeviceExprChecker{}(y.typedExpr)}) {
|
||||
if (MaybeMsg result{DeviceExprChecker{context}(y.typedExpr)}) {
|
||||
return result;
|
||||
}
|
||||
}
|
||||
return {};
|
||||
}
|
||||
static MaybeMsg WhyNotOk(const parser::PointerAssignmentStmt &x) {
|
||||
return DeviceExprChecker{}(x.typedAssignment);
|
||||
static MaybeMsg WhyNotOk(
|
||||
SemanticsContext &context, const parser::PointerAssignmentStmt &x) {
|
||||
return DeviceExprChecker{context}(x.typedAssignment);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -435,12 +472,14 @@ private:
|
||||
ErrorIfHostSymbol(assign->lhs, source);
|
||||
ErrorIfHostSymbol(assign->rhs, source);
|
||||
}
|
||||
if (auto msg{ActionStmtChecker<IsCUFKernelDo>::WhyNotOk(x)}) {
|
||||
if (auto msg{ActionStmtChecker<IsCUFKernelDo>::WhyNotOk(
|
||||
context_, x)}) {
|
||||
context_.Say(source, std::move(*msg));
|
||||
}
|
||||
},
|
||||
[&](const auto &x) {
|
||||
if (auto msg{ActionStmtChecker<IsCUFKernelDo>::WhyNotOk(x)}) {
|
||||
if (auto msg{ActionStmtChecker<IsCUFKernelDo>::WhyNotOk(
|
||||
context_, x)}) {
|
||||
context_.Say(source, std::move(*msg));
|
||||
}
|
||||
},
|
||||
@@ -504,7 +543,7 @@ private:
|
||||
Check(DEREF(parser::Unwrap<parser::Expr>(x)));
|
||||
}
|
||||
void Check(const parser::Expr &expr) {
|
||||
if (MaybeMsg msg{DeviceExprChecker{}(expr.typedExpr)}) {
|
||||
if (MaybeMsg msg{DeviceExprChecker{context_}(expr.typedExpr)}) {
|
||||
context_.Say(expr.source, std::move(*msg));
|
||||
}
|
||||
}
|
||||
|
||||
8
flang/test/Semantics/cuf22.cuf
Normal file
8
flang/test/Semantics/cuf22.cuf
Normal file
@@ -0,0 +1,8 @@
|
||||
! RUN: not bbc -fcuda -fcuda-disable-warp-function %s -o - 2>&1 | FileCheck %s
|
||||
|
||||
attributes(device) subroutine testMatch()
|
||||
integer :: a, ipred, mask, v32
|
||||
a = match_all_sync(mask, v32, ipred)
|
||||
end subroutine
|
||||
|
||||
! CHECK: warp match function disabled
|
||||
@@ -223,6 +223,11 @@ static llvm::cl::opt<bool> enableCUDA("fcuda",
|
||||
llvm::cl::desc("enable CUDA Fortran"),
|
||||
llvm::cl::init(false));
|
||||
|
||||
static llvm::cl::opt<bool>
|
||||
disableCUDAWarpFunction("fcuda-disable-warp-function",
|
||||
llvm::cl::desc("Disable CUDA Warp Function"),
|
||||
llvm::cl::init(false));
|
||||
|
||||
static llvm::cl::opt<std::string>
|
||||
enableGPUMode("gpu", llvm::cl::desc("Enable GPU Mode managed|unified"),
|
||||
llvm::cl::init(""));
|
||||
@@ -600,6 +605,11 @@ int main(int argc, char **argv) {
|
||||
options.features.Enable(Fortran::common::LanguageFeature::CUDA);
|
||||
}
|
||||
|
||||
if (disableCUDAWarpFunction) {
|
||||
options.features.Enable(
|
||||
Fortran::common::LanguageFeature::CudaWarpMatchFunction, false);
|
||||
}
|
||||
|
||||
if (enableGPUMode == "managed") {
|
||||
options.features.Enable(Fortran::common::LanguageFeature::CudaManaged);
|
||||
} else if (enableGPUMode == "unified") {
|
||||
|
||||
Reference in New Issue
Block a user