Files
clang-p2996/flang/lib/Optimizer/Dialect/FIRAttr.cpp
Valentin Clement (バレンタイン クレメン) 4e6745cc4d [flang][cuda] Lower simple host to device data transfer (#85960)
In CUDA Fortran data transfer can be done via assignment statements
between host and device variables.

This patch introduces a `fir.cuda_data_transfer` operation that
materialized the data transfer between two memory references.

Simple transfer not involving descriptors from host to device are also
lowered in this patch. When the rhs is an expression that required an
evaluation, a temporary is created. The evaluation is done on the host
and then the transfer is initiated.

Implicit transfer when device symbol are present on the rhs is not part
of this patch. Transfer from device to host is not part of this patch.
2024-03-25 11:53:39 -07:00

305 lines
11 KiB
C++

//===-- FIRAttr.cpp -------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
//
//===----------------------------------------------------------------------===//
#include "flang/Optimizer/Dialect/FIRAttr.h"
#include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/Dialect/Support/KindMapping.h"
#include "mlir/IR/AttributeSupport.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "flang/Optimizer/Dialect/FIREnumAttr.cpp.inc"
#define GET_ATTRDEF_CLASSES
#include "flang/Optimizer/Dialect/FIRAttr.cpp.inc"
using namespace fir;
namespace fir::detail {
struct RealAttributeStorage : public mlir::AttributeStorage {
using KeyTy = std::pair<int, llvm::APFloat>;
RealAttributeStorage(int kind, const llvm::APFloat &value)
: kind(kind), value(value) {}
RealAttributeStorage(const KeyTy &key)
: RealAttributeStorage(key.first, key.second) {}
static unsigned hashKey(const KeyTy &key) { return llvm::hash_value(key); }
bool operator==(const KeyTy &key) const {
return key.first == kind &&
key.second.compare(value) == llvm::APFloatBase::cmpEqual;
}
static RealAttributeStorage *
construct(mlir::AttributeStorageAllocator &allocator, const KeyTy &key) {
return new (allocator.allocate<RealAttributeStorage>())
RealAttributeStorage(key);
}
KindTy getFKind() const { return kind; }
llvm::APFloat getValue() const { return value; }
private:
int kind;
llvm::APFloat value;
};
/// An attribute representing a reference to a type.
struct TypeAttributeStorage : public mlir::AttributeStorage {
using KeyTy = mlir::Type;
TypeAttributeStorage(mlir::Type value) : value(value) {
assert(value && "must not be of Type null");
}
/// Key equality function.
bool operator==(const KeyTy &key) const { return key == value; }
/// Construct a new storage instance.
static TypeAttributeStorage *
construct(mlir::AttributeStorageAllocator &allocator, KeyTy key) {
return new (allocator.allocate<TypeAttributeStorage>())
TypeAttributeStorage(key);
}
mlir::Type getType() const { return value; }
private:
mlir::Type value;
};
} // namespace fir::detail
//===----------------------------------------------------------------------===//
// Attributes for SELECT TYPE
//===----------------------------------------------------------------------===//
ExactTypeAttr fir::ExactTypeAttr::get(mlir::Type value) {
return Base::get(value.getContext(), value);
}
mlir::Type fir::ExactTypeAttr::getType() const { return getImpl()->getType(); }
SubclassAttr fir::SubclassAttr::get(mlir::Type value) {
return Base::get(value.getContext(), value);
}
mlir::Type fir::SubclassAttr::getType() const { return getImpl()->getType(); }
//===----------------------------------------------------------------------===//
// Attributes for SELECT CASE
//===----------------------------------------------------------------------===//
using AttributeUniquer = mlir::detail::AttributeUniquer;
ClosedIntervalAttr fir::ClosedIntervalAttr::get(mlir::MLIRContext *ctxt) {
return AttributeUniquer::get<ClosedIntervalAttr>(ctxt);
}
UpperBoundAttr fir::UpperBoundAttr::get(mlir::MLIRContext *ctxt) {
return AttributeUniquer::get<UpperBoundAttr>(ctxt);
}
LowerBoundAttr fir::LowerBoundAttr::get(mlir::MLIRContext *ctxt) {
return AttributeUniquer::get<LowerBoundAttr>(ctxt);
}
PointIntervalAttr fir::PointIntervalAttr::get(mlir::MLIRContext *ctxt) {
return AttributeUniquer::get<PointIntervalAttr>(ctxt);
}
//===----------------------------------------------------------------------===//
// RealAttr
//===----------------------------------------------------------------------===//
RealAttr fir::RealAttr::get(mlir::MLIRContext *ctxt,
const RealAttr::ValueType &key) {
return Base::get(ctxt, key);
}
KindTy fir::RealAttr::getFKind() const { return getImpl()->getFKind(); }
llvm::APFloat fir::RealAttr::getValue() const { return getImpl()->getValue(); }
//===----------------------------------------------------------------------===//
// FIR attribute parsing
//===----------------------------------------------------------------------===//
static mlir::Attribute parseFirRealAttr(FIROpsDialect *dialect,
mlir::DialectAsmParser &parser,
mlir::Type type) {
int kind = 0;
if (parser.parseLess() || parser.parseInteger(kind) || parser.parseComma()) {
parser.emitError(parser.getNameLoc(), "expected '<' kind ','");
return {};
}
KindMapping kindMap(dialect->getContext());
llvm::APFloat value(0.);
if (parser.parseOptionalKeyword("i")) {
// `i` not present, so literal float must be present
double dontCare;
if (parser.parseFloat(dontCare) || parser.parseGreater()) {
parser.emitError(parser.getNameLoc(), "expected real constant '>'");
return {};
}
auto fltStr = parser.getFullSymbolSpec()
.drop_until([](char c) { return c == ','; })
.drop_front()
.drop_while([](char c) { return c == ' ' || c == '\t'; })
.take_until([](char c) {
return c == '>' || c == ' ' || c == '\t';
});
value = llvm::APFloat(kindMap.getFloatSemantics(kind), fltStr);
} else {
// `i` is present, so literal bitstring (hex) must be present
llvm::StringRef hex;
if (parser.parseKeyword(&hex) || parser.parseGreater()) {
parser.emitError(parser.getNameLoc(), "expected real constant '>'");
return {};
}
const llvm::fltSemantics &sem = kindMap.getFloatSemantics(kind);
unsigned int numBits = llvm::APFloat::semanticsSizeInBits(sem);
auto bits = llvm::APInt(numBits, hex.drop_front(), 16);
value = llvm::APFloat(sem, bits);
}
return RealAttr::get(dialect->getContext(), {kind, value});
}
mlir::Attribute fir::FortranVariableFlagsAttr::parse(mlir::AsmParser &parser,
mlir::Type type) {
if (mlir::failed(parser.parseLess()))
return {};
fir::FortranVariableFlagsEnum flags = {};
if (mlir::failed(parser.parseOptionalGreater())) {
auto parseFlags = [&]() -> mlir::ParseResult {
llvm::StringRef elemName;
if (mlir::failed(parser.parseKeyword(&elemName)))
return mlir::failure();
auto elem = fir::symbolizeFortranVariableFlagsEnum(elemName);
if (!elem)
return parser.emitError(parser.getNameLoc(),
"Unknown fortran variable attribute: ")
<< elemName;
flags = flags | *elem;
return mlir::success();
};
if (mlir::failed(parser.parseCommaSeparatedList(parseFlags)) ||
parser.parseGreater())
return {};
}
return FortranVariableFlagsAttr::get(parser.getContext(), flags);
}
mlir::Attribute fir::parseFirAttribute(FIROpsDialect *dialect,
mlir::DialectAsmParser &parser,
mlir::Type type) {
auto loc = parser.getNameLoc();
llvm::StringRef attrName;
mlir::Attribute attr;
mlir::OptionalParseResult result =
generatedAttributeParser(parser, &attrName, type, attr);
if (result.has_value())
return attr;
if (attrName.empty())
return {}; // error reported by generatedAttributeParser
if (attrName == ExactTypeAttr::getAttrName()) {
mlir::Type type;
if (parser.parseLess() || parser.parseType(type) || parser.parseGreater()) {
parser.emitError(loc, "expected a type");
return {};
}
return ExactTypeAttr::get(type);
}
if (attrName == SubclassAttr::getAttrName()) {
mlir::Type type;
if (parser.parseLess() || parser.parseType(type) || parser.parseGreater()) {
parser.emitError(loc, "expected a subtype");
return {};
}
return SubclassAttr::get(type);
}
if (attrName == PointIntervalAttr::getAttrName())
return PointIntervalAttr::get(dialect->getContext());
if (attrName == LowerBoundAttr::getAttrName())
return LowerBoundAttr::get(dialect->getContext());
if (attrName == UpperBoundAttr::getAttrName())
return UpperBoundAttr::get(dialect->getContext());
if (attrName == ClosedIntervalAttr::getAttrName())
return ClosedIntervalAttr::get(dialect->getContext());
if (attrName == RealAttr::getAttrName())
return parseFirRealAttr(dialect, parser, type);
parser.emitError(loc, "unknown FIR attribute: ") << attrName;
return {};
}
//===----------------------------------------------------------------------===//
// FIR attribute pretty printer
//===----------------------------------------------------------------------===//
void fir::FortranVariableFlagsAttr::print(mlir::AsmPrinter &printer) const {
printer << "<";
printer << fir::stringifyFortranVariableFlagsEnum(this->getFlags());
printer << ">";
}
void fir::printFirAttribute(FIROpsDialect *dialect, mlir::Attribute attr,
mlir::DialectAsmPrinter &p) {
auto &os = p.getStream();
if (auto exact = attr.dyn_cast<fir::ExactTypeAttr>()) {
os << fir::ExactTypeAttr::getAttrName() << '<';
p.printType(exact.getType());
os << '>';
} else if (auto sub = attr.dyn_cast<fir::SubclassAttr>()) {
os << fir::SubclassAttr::getAttrName() << '<';
p.printType(sub.getType());
os << '>';
} else if (attr.dyn_cast_or_null<fir::PointIntervalAttr>()) {
os << fir::PointIntervalAttr::getAttrName();
} else if (attr.dyn_cast_or_null<fir::ClosedIntervalAttr>()) {
os << fir::ClosedIntervalAttr::getAttrName();
} else if (attr.dyn_cast_or_null<fir::LowerBoundAttr>()) {
os << fir::LowerBoundAttr::getAttrName();
} else if (attr.dyn_cast_or_null<fir::UpperBoundAttr>()) {
os << fir::UpperBoundAttr::getAttrName();
} else if (auto a = attr.dyn_cast_or_null<fir::RealAttr>()) {
os << fir::RealAttr::getAttrName() << '<' << a.getFKind() << ", i x";
llvm::SmallString<40> ss;
a.getValue().bitcastToAPInt().toStringUnsigned(ss, 16);
os << ss << '>';
} else if (mlir::failed(generatedAttributePrinter(attr, p))) {
// don't know how to print the attribute, so use a default
os << "<(unknown attribute)>";
}
}
//===----------------------------------------------------------------------===//
// FIROpsDialect
//===----------------------------------------------------------------------===//
void FIROpsDialect::registerAttributes() {
addAttributes<ClosedIntervalAttr, ExactTypeAttr, FortranVariableFlagsAttr,
LowerBoundAttr, PointIntervalAttr, RealAttr, SubclassAttr,
UpperBoundAttr, CUDADataAttributeAttr, CUDAProcAttributeAttr,
CUDALaunchBoundsAttr, CUDAClusterDimsAttr,
CUDADataTransferKindAttr>();
}