Similar to 'enter data', except the data clauses have a 'getdeviceptr' operation before, so that they can properly use the 'exit' operation correctly. While this is a touch awkward, it fits perfectly into the existing infrastructure. Same as with 'enter data', we had to add some add-functions for async and wait.
4052 lines
149 KiB
C++
4052 lines
149 KiB
C++
//===- OpenACC.cpp - OpenACC MLIR Operations ------------------------------===//
|
|
//
|
|
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
// =============================================================================
|
|
|
|
#include "mlir/Dialect/OpenACC/OpenACC.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/BuiltinAttributes.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/IR/DialectImplementation.h"
|
|
#include "mlir/IR/Matchers.h"
|
|
#include "mlir/IR/OpImplementation.h"
|
|
#include "mlir/Support/LLVM.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
#include "llvm/ADT/SmallSet.h"
|
|
#include "llvm/ADT/TypeSwitch.h"
|
|
#include "llvm/Support/LogicalResult.h"
|
|
|
|
using namespace mlir;
|
|
using namespace acc;
|
|
|
|
#include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
|
|
#include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
|
|
#include "mlir/Dialect/OpenACC/OpenACCOpsInterfaces.cpp.inc"
|
|
#include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.cpp.inc"
|
|
#include "mlir/Dialect/OpenACCMPCommon/Interfaces/OpenACCMPOpsInterfaces.cpp.inc"
|
|
|
|
namespace {
|
|
|
|
static bool isScalarLikeType(Type type) {
|
|
return type.isIntOrIndexOrFloat() || isa<ComplexType>(type);
|
|
}
|
|
|
|
struct MemRefPointerLikeModel
|
|
: public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
|
|
MemRefType> {
|
|
Type getElementType(Type pointer) const {
|
|
return cast<MemRefType>(pointer).getElementType();
|
|
}
|
|
mlir::acc::VariableTypeCategory
|
|
getPointeeTypeCategory(Type pointer, TypedValue<PointerLikeType> varPtr,
|
|
Type varType) const {
|
|
if (auto mappableTy = dyn_cast<MappableType>(varType)) {
|
|
return mappableTy.getTypeCategory(varPtr);
|
|
}
|
|
auto memrefTy = cast<MemRefType>(pointer);
|
|
if (!memrefTy.hasRank()) {
|
|
// This memref is unranked - aka it could have any rank, including a
|
|
// rank of 0 which could mean scalar. For now, return uncategorized.
|
|
return mlir::acc::VariableTypeCategory::uncategorized;
|
|
}
|
|
|
|
if (memrefTy.getRank() == 0) {
|
|
if (isScalarLikeType(memrefTy.getElementType())) {
|
|
return mlir::acc::VariableTypeCategory::scalar;
|
|
}
|
|
// Zero-rank non-scalar - need further analysis to determine the type
|
|
// category. For now, return uncategorized.
|
|
return mlir::acc::VariableTypeCategory::uncategorized;
|
|
}
|
|
|
|
// It has a rank - must be an array.
|
|
assert(memrefTy.getRank() > 0 && "rank expected to be positive");
|
|
return mlir::acc::VariableTypeCategory::array;
|
|
}
|
|
};
|
|
|
|
struct LLVMPointerPointerLikeModel
|
|
: public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
|
|
LLVM::LLVMPointerType> {
|
|
Type getElementType(Type pointer) const { return Type(); }
|
|
};
|
|
|
|
/// Helper function for any of the times we need to modify an ArrayAttr based on
|
|
/// a device type list. Returns a new ArrayAttr with all of the
|
|
/// existingDeviceTypes, plus the effective new ones(or an added none if hte new
|
|
/// list is empty).
|
|
mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
|
|
MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
|
|
llvm::ArrayRef<acc::DeviceType> newDeviceTypes) {
|
|
llvm::SmallVector<mlir::Attribute> deviceTypes;
|
|
if (existingDeviceTypes)
|
|
llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
|
|
|
|
if (newDeviceTypes.empty())
|
|
deviceTypes.push_back(
|
|
acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
|
|
|
|
for (DeviceType DT : newDeviceTypes)
|
|
deviceTypes.push_back(acc::DeviceTypeAttr::get(context, DT));
|
|
|
|
return mlir::ArrayAttr::get(context, deviceTypes);
|
|
}
|
|
|
|
/// Helper function for any of the times we need to add operands that are
|
|
/// affected by a device type list. Returns a new ArrayAttr with all of the
|
|
/// existingDeviceTypes, plus the effective new ones (or an added none, if the
|
|
/// new list is empty). Additionally, adds the arguments to the argCollection
|
|
/// the correct number of times. This will also update a 'segments' array, even
|
|
/// if it won't be used.
|
|
mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
|
|
MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
|
|
llvm::ArrayRef<acc::DeviceType> newDeviceTypes, mlir::ValueRange arguments,
|
|
mlir::MutableOperandRange argCollection,
|
|
llvm::SmallVector<int32_t> &segments) {
|
|
llvm::SmallVector<mlir::Attribute> deviceTypes;
|
|
if (existingDeviceTypes)
|
|
llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
|
|
|
|
if (newDeviceTypes.empty()) {
|
|
argCollection.append(arguments);
|
|
segments.push_back(arguments.size());
|
|
deviceTypes.push_back(
|
|
acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
|
|
}
|
|
|
|
for (DeviceType DT : newDeviceTypes) {
|
|
argCollection.append(arguments);
|
|
segments.push_back(arguments.size());
|
|
deviceTypes.push_back(acc::DeviceTypeAttr::get(context, DT));
|
|
}
|
|
|
|
return mlir::ArrayAttr::get(context, deviceTypes);
|
|
}
|
|
|
|
/// Overload for when the 'segments' aren't needed.
|
|
mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
|
|
MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
|
|
llvm::ArrayRef<acc::DeviceType> newDeviceTypes, mlir::ValueRange arguments,
|
|
mlir::MutableOperandRange argCollection) {
|
|
llvm::SmallVector<int32_t> segments;
|
|
return addDeviceTypeAffectedOperandHelper(context, existingDeviceTypes,
|
|
newDeviceTypes, arguments,
|
|
argCollection, segments);
|
|
}
|
|
} // namespace
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// OpenACC operations
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void OpenACCDialect::initialize() {
|
|
addOperations<
|
|
#define GET_OP_LIST
|
|
#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
|
|
>();
|
|
addAttributes<
|
|
#define GET_ATTRDEF_LIST
|
|
#include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
|
|
>();
|
|
addTypes<
|
|
#define GET_TYPEDEF_LIST
|
|
#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
|
|
>();
|
|
|
|
// By attaching interfaces here, we make the OpenACC dialect dependent on
|
|
// the other dialects. This is probably better than having dialects like LLVM
|
|
// and memref be dependent on OpenACC.
|
|
MemRefType::attachInterface<MemRefPointerLikeModel>(*getContext());
|
|
LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
|
|
*getContext());
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// device_type support helpers
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
|
|
if (arrayAttr && *arrayAttr && arrayAttr->size() > 0)
|
|
return true;
|
|
return false;
|
|
}
|
|
|
|
static bool hasDeviceType(std::optional<mlir::ArrayAttr> arrayAttr,
|
|
mlir::acc::DeviceType deviceType) {
|
|
if (!hasDeviceTypeValues(arrayAttr))
|
|
return false;
|
|
|
|
for (auto attr : *arrayAttr) {
|
|
auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
|
|
if (deviceTypeAttr.getValue() == deviceType)
|
|
return true;
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
static void printDeviceTypes(mlir::OpAsmPrinter &p,
|
|
std::optional<mlir::ArrayAttr> deviceTypes) {
|
|
if (!hasDeviceTypeValues(deviceTypes))
|
|
return;
|
|
|
|
p << "[";
|
|
llvm::interleaveComma(*deviceTypes, p,
|
|
[&](mlir::Attribute attr) { p << attr; });
|
|
p << "]";
|
|
}
|
|
|
|
static std::optional<unsigned> findSegment(ArrayAttr segments,
|
|
mlir::acc::DeviceType deviceType) {
|
|
unsigned segmentIdx = 0;
|
|
for (auto attr : segments) {
|
|
auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
|
|
if (deviceTypeAttr.getValue() == deviceType)
|
|
return std::make_optional(segmentIdx);
|
|
++segmentIdx;
|
|
}
|
|
return std::nullopt;
|
|
}
|
|
|
|
static mlir::Operation::operand_range
|
|
getValuesFromSegments(std::optional<mlir::ArrayAttr> arrayAttr,
|
|
mlir::Operation::operand_range range,
|
|
std::optional<llvm::ArrayRef<int32_t>> segments,
|
|
mlir::acc::DeviceType deviceType) {
|
|
if (!arrayAttr)
|
|
return range.take_front(0);
|
|
if (auto pos = findSegment(*arrayAttr, deviceType)) {
|
|
int32_t nbOperandsBefore = 0;
|
|
for (unsigned i = 0; i < *pos; ++i)
|
|
nbOperandsBefore += (*segments)[i];
|
|
return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]);
|
|
}
|
|
return range.take_front(0);
|
|
}
|
|
|
|
static mlir::Value
|
|
getWaitDevnumValue(std::optional<mlir::ArrayAttr> deviceTypeAttr,
|
|
mlir::Operation::operand_range operands,
|
|
std::optional<llvm::ArrayRef<int32_t>> segments,
|
|
std::optional<mlir::ArrayAttr> hasWaitDevnum,
|
|
mlir::acc::DeviceType deviceType) {
|
|
if (!hasDeviceTypeValues(deviceTypeAttr))
|
|
return {};
|
|
if (auto pos = findSegment(*deviceTypeAttr, deviceType))
|
|
if (hasWaitDevnum->getValue()[*pos])
|
|
return getValuesFromSegments(deviceTypeAttr, operands, segments,
|
|
deviceType)
|
|
.front();
|
|
return {};
|
|
}
|
|
|
|
static mlir::Operation::operand_range
|
|
getWaitValuesWithoutDevnum(std::optional<mlir::ArrayAttr> deviceTypeAttr,
|
|
mlir::Operation::operand_range operands,
|
|
std::optional<llvm::ArrayRef<int32_t>> segments,
|
|
std::optional<mlir::ArrayAttr> hasWaitDevnum,
|
|
mlir::acc::DeviceType deviceType) {
|
|
auto range =
|
|
getValuesFromSegments(deviceTypeAttr, operands, segments, deviceType);
|
|
if (range.empty())
|
|
return range;
|
|
if (auto pos = findSegment(*deviceTypeAttr, deviceType)) {
|
|
if (hasWaitDevnum && *hasWaitDevnum) {
|
|
auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasWaitDevnum)[*pos]);
|
|
if (boolAttr.getValue())
|
|
return range.drop_front(1); // first value is devnum
|
|
}
|
|
}
|
|
return range;
|
|
}
|
|
|
|
template <typename Op>
|
|
static LogicalResult checkWaitAndAsyncConflict(Op op) {
|
|
for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
|
|
++dtypeInt) {
|
|
auto dtype = static_cast<acc::DeviceType>(dtypeInt);
|
|
|
|
// The asyncOnly attribute represent the async clause without value.
|
|
// Therefore the attribute and operand cannot appear at the same time.
|
|
if (hasDeviceType(op.getAsyncOperandsDeviceType(), dtype) &&
|
|
op.hasAsyncOnly(dtype))
|
|
return op.emitError(
|
|
"asyncOnly attribute cannot appear with asyncOperand");
|
|
|
|
// The wait attribute represent the wait clause without values. Therefore
|
|
// the attribute and operands cannot appear at the same time.
|
|
if (hasDeviceType(op.getWaitOperandsDeviceType(), dtype) &&
|
|
op.hasWaitOnly(dtype))
|
|
return op.emitError("wait attribute cannot appear with waitOperands");
|
|
}
|
|
return success();
|
|
}
|
|
|
|
template <typename Op>
|
|
static LogicalResult checkVarAndVarType(Op op) {
|
|
if (!op.getVar())
|
|
return op.emitError("must have var operand");
|
|
|
|
if (mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
|
|
mlir::isa<mlir::acc::MappableType>(op.getVar().getType())) {
|
|
// TODO: If a type implements both interfaces (mappable and pointer-like),
|
|
// it is unclear which semantics to apply without additional info which
|
|
// would need captured in the data operation. For now restrict this case
|
|
// unless a compelling reason to support disambiguating between the two.
|
|
return op.emitError("var must be mappable or pointer-like (not both)");
|
|
}
|
|
|
|
if (!mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
|
|
!mlir::isa<mlir::acc::MappableType>(op.getVar().getType()))
|
|
return op.emitError("var must be mappable or pointer-like");
|
|
|
|
if (mlir::isa<mlir::acc::MappableType>(op.getVar().getType()) &&
|
|
op.getVarType() != op.getVar().getType())
|
|
return op.emitError("varType must match when var is mappable");
|
|
|
|
return success();
|
|
}
|
|
|
|
template <typename Op>
|
|
static LogicalResult checkVarAndAccVar(Op op) {
|
|
if (op.getVar().getType() != op.getAccVar().getType())
|
|
return op.emitError("input and output types must match");
|
|
|
|
return success();
|
|
}
|
|
|
|
template <typename Op>
|
|
static LogicalResult checkNoModifier(Op op) {
|
|
if (op.getModifiers() != acc::DataClauseModifier::none)
|
|
return op.emitError("no data clause modifiers are allowed");
|
|
return success();
|
|
}
|
|
|
|
template <typename Op>
|
|
static LogicalResult
|
|
checkValidModifier(Op op, acc::DataClauseModifier validModifiers) {
|
|
if (acc::bitEnumContainsAny(op.getModifiers(), ~validModifiers))
|
|
return op.emitError(
|
|
"invalid data clause modifiers: " +
|
|
acc::stringifyDataClauseModifier(op.getModifiers() & ~validModifiers));
|
|
|
|
return success();
|
|
}
|
|
|
|
static ParseResult parseVar(mlir::OpAsmParser &parser,
|
|
OpAsmParser::UnresolvedOperand &var) {
|
|
// Either `var` or `varPtr` keyword is required.
|
|
if (failed(parser.parseOptionalKeyword("varPtr"))) {
|
|
if (failed(parser.parseKeyword("var")))
|
|
return failure();
|
|
}
|
|
if (failed(parser.parseLParen()))
|
|
return failure();
|
|
if (failed(parser.parseOperand(var)))
|
|
return failure();
|
|
|
|
return success();
|
|
}
|
|
|
|
static void printVar(mlir::OpAsmPrinter &p, mlir::Operation *op,
|
|
mlir::Value var) {
|
|
if (mlir::isa<mlir::acc::PointerLikeType>(var.getType()))
|
|
p << "varPtr(";
|
|
else
|
|
p << "var(";
|
|
p.printOperand(var);
|
|
}
|
|
|
|
static ParseResult parseAccVar(mlir::OpAsmParser &parser,
|
|
OpAsmParser::UnresolvedOperand &var,
|
|
mlir::Type &accVarType) {
|
|
// Either `accVar` or `accPtr` keyword is required.
|
|
if (failed(parser.parseOptionalKeyword("accPtr"))) {
|
|
if (failed(parser.parseKeyword("accVar")))
|
|
return failure();
|
|
}
|
|
if (failed(parser.parseLParen()))
|
|
return failure();
|
|
if (failed(parser.parseOperand(var)))
|
|
return failure();
|
|
if (failed(parser.parseColon()))
|
|
return failure();
|
|
if (failed(parser.parseType(accVarType)))
|
|
return failure();
|
|
if (failed(parser.parseRParen()))
|
|
return failure();
|
|
|
|
return success();
|
|
}
|
|
|
|
static void printAccVar(mlir::OpAsmPrinter &p, mlir::Operation *op,
|
|
mlir::Value accVar, mlir::Type accVarType) {
|
|
if (mlir::isa<mlir::acc::PointerLikeType>(accVar.getType()))
|
|
p << "accPtr(";
|
|
else
|
|
p << "accVar(";
|
|
p.printOperand(accVar);
|
|
p << " : ";
|
|
p.printType(accVarType);
|
|
p << ")";
|
|
}
|
|
|
|
static ParseResult parseVarPtrType(mlir::OpAsmParser &parser,
|
|
mlir::Type &varPtrType,
|
|
mlir::TypeAttr &varTypeAttr) {
|
|
if (failed(parser.parseType(varPtrType)))
|
|
return failure();
|
|
if (failed(parser.parseRParen()))
|
|
return failure();
|
|
|
|
if (succeeded(parser.parseOptionalKeyword("varType"))) {
|
|
if (failed(parser.parseLParen()))
|
|
return failure();
|
|
mlir::Type varType;
|
|
if (failed(parser.parseType(varType)))
|
|
return failure();
|
|
varTypeAttr = mlir::TypeAttr::get(varType);
|
|
if (failed(parser.parseRParen()))
|
|
return failure();
|
|
} else {
|
|
// Set `varType` from the element type of the type of `varPtr`.
|
|
if (mlir::isa<mlir::acc::PointerLikeType>(varPtrType))
|
|
varTypeAttr = mlir::TypeAttr::get(
|
|
mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType());
|
|
else
|
|
varTypeAttr = mlir::TypeAttr::get(varPtrType);
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
static void printVarPtrType(mlir::OpAsmPrinter &p, mlir::Operation *op,
|
|
mlir::Type varPtrType, mlir::TypeAttr varTypeAttr) {
|
|
p.printType(varPtrType);
|
|
p << ")";
|
|
|
|
// Print the `varType` only if it differs from the element type of
|
|
// `varPtr`'s type.
|
|
mlir::Type varType = varTypeAttr.getValue();
|
|
mlir::Type typeToCheckAgainst =
|
|
mlir::isa<mlir::acc::PointerLikeType>(varPtrType)
|
|
? mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType()
|
|
: varPtrType;
|
|
if (typeToCheckAgainst != varType) {
|
|
p << " varType(";
|
|
p.printType(varType);
|
|
p << ")";
|
|
}
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// DataBoundsOp
|
|
//===----------------------------------------------------------------------===//
|
|
LogicalResult acc::DataBoundsOp::verify() {
|
|
auto extent = getExtent();
|
|
auto upperbound = getUpperbound();
|
|
if (!extent && !upperbound)
|
|
return emitError("expected extent or upperbound.");
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// PrivateOp
|
|
//===----------------------------------------------------------------------===//
|
|
LogicalResult acc::PrivateOp::verify() {
|
|
if (getDataClause() != acc::DataClause::acc_private)
|
|
return emitError(
|
|
"data clause associated with private operation must match its intent");
|
|
if (failed(checkVarAndVarType(*this)))
|
|
return failure();
|
|
if (failed(checkNoModifier(*this)))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// FirstprivateOp
|
|
//===----------------------------------------------------------------------===//
|
|
LogicalResult acc::FirstprivateOp::verify() {
|
|
if (getDataClause() != acc::DataClause::acc_firstprivate)
|
|
return emitError("data clause associated with firstprivate operation must "
|
|
"match its intent");
|
|
if (failed(checkVarAndVarType(*this)))
|
|
return failure();
|
|
if (failed(checkNoModifier(*this)))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ReductionOp
|
|
//===----------------------------------------------------------------------===//
|
|
LogicalResult acc::ReductionOp::verify() {
|
|
if (getDataClause() != acc::DataClause::acc_reduction)
|
|
return emitError("data clause associated with reduction operation must "
|
|
"match its intent");
|
|
if (failed(checkVarAndVarType(*this)))
|
|
return failure();
|
|
if (failed(checkNoModifier(*this)))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// DevicePtrOp
|
|
//===----------------------------------------------------------------------===//
|
|
LogicalResult acc::DevicePtrOp::verify() {
|
|
if (getDataClause() != acc::DataClause::acc_deviceptr)
|
|
return emitError("data clause associated with deviceptr operation must "
|
|
"match its intent");
|
|
if (failed(checkVarAndVarType(*this)))
|
|
return failure();
|
|
if (failed(checkVarAndAccVar(*this)))
|
|
return failure();
|
|
if (failed(checkNoModifier(*this)))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// PresentOp
|
|
//===----------------------------------------------------------------------===//
|
|
LogicalResult acc::PresentOp::verify() {
|
|
if (getDataClause() != acc::DataClause::acc_present)
|
|
return emitError(
|
|
"data clause associated with present operation must match its intent");
|
|
if (failed(checkVarAndVarType(*this)))
|
|
return failure();
|
|
if (failed(checkVarAndAccVar(*this)))
|
|
return failure();
|
|
if (failed(checkNoModifier(*this)))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CopyinOp
|
|
//===----------------------------------------------------------------------===//
|
|
LogicalResult acc::CopyinOp::verify() {
|
|
// Test for all clauses this operation can be decomposed from:
|
|
if (!getImplicit() && getDataClause() != acc::DataClause::acc_copyin &&
|
|
getDataClause() != acc::DataClause::acc_copyin_readonly &&
|
|
getDataClause() != acc::DataClause::acc_copy &&
|
|
getDataClause() != acc::DataClause::acc_reduction)
|
|
return emitError(
|
|
"data clause associated with copyin operation must match its intent"
|
|
" or specify original clause this operation was decomposed from");
|
|
if (failed(checkVarAndVarType(*this)))
|
|
return failure();
|
|
if (failed(checkVarAndAccVar(*this)))
|
|
return failure();
|
|
if (failed(checkValidModifier(*this, acc::DataClauseModifier::readonly |
|
|
acc::DataClauseModifier::always |
|
|
acc::DataClauseModifier::capture)))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
bool acc::CopyinOp::isCopyinReadonly() {
|
|
return getDataClause() == acc::DataClause::acc_copyin_readonly ||
|
|
acc::bitEnumContainsAny(getModifiers(),
|
|
acc::DataClauseModifier::readonly);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CreateOp
|
|
//===----------------------------------------------------------------------===//
|
|
LogicalResult acc::CreateOp::verify() {
|
|
// Test for all clauses this operation can be decomposed from:
|
|
if (getDataClause() != acc::DataClause::acc_create &&
|
|
getDataClause() != acc::DataClause::acc_create_zero &&
|
|
getDataClause() != acc::DataClause::acc_copyout &&
|
|
getDataClause() != acc::DataClause::acc_copyout_zero)
|
|
return emitError(
|
|
"data clause associated with create operation must match its intent"
|
|
" or specify original clause this operation was decomposed from");
|
|
if (failed(checkVarAndVarType(*this)))
|
|
return failure();
|
|
if (failed(checkVarAndAccVar(*this)))
|
|
return failure();
|
|
// this op is the entry part of copyout, so it also needs to allow all
|
|
// modifiers allowed on copyout.
|
|
if (failed(checkValidModifier(*this, acc::DataClauseModifier::zero |
|
|
acc::DataClauseModifier::always |
|
|
acc::DataClauseModifier::capture)))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
bool acc::CreateOp::isCreateZero() {
|
|
// The zero modifier is encoded in the data clause.
|
|
return getDataClause() == acc::DataClause::acc_create_zero ||
|
|
getDataClause() == acc::DataClause::acc_copyout_zero ||
|
|
acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// NoCreateOp
|
|
//===----------------------------------------------------------------------===//
|
|
LogicalResult acc::NoCreateOp::verify() {
|
|
if (getDataClause() != acc::DataClause::acc_no_create)
|
|
return emitError("data clause associated with no_create operation must "
|
|
"match its intent");
|
|
if (failed(checkVarAndVarType(*this)))
|
|
return failure();
|
|
if (failed(checkVarAndAccVar(*this)))
|
|
return failure();
|
|
if (failed(checkNoModifier(*this)))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// AttachOp
|
|
//===----------------------------------------------------------------------===//
|
|
LogicalResult acc::AttachOp::verify() {
|
|
if (getDataClause() != acc::DataClause::acc_attach)
|
|
return emitError(
|
|
"data clause associated with attach operation must match its intent");
|
|
if (failed(checkVarAndVarType(*this)))
|
|
return failure();
|
|
if (failed(checkVarAndAccVar(*this)))
|
|
return failure();
|
|
if (failed(checkNoModifier(*this)))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// DeclareDeviceResidentOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult acc::DeclareDeviceResidentOp::verify() {
|
|
if (getDataClause() != acc::DataClause::acc_declare_device_resident)
|
|
return emitError("data clause associated with device_resident operation "
|
|
"must match its intent");
|
|
if (failed(checkVarAndVarType(*this)))
|
|
return failure();
|
|
if (failed(checkVarAndAccVar(*this)))
|
|
return failure();
|
|
if (failed(checkNoModifier(*this)))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// DeclareLinkOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult acc::DeclareLinkOp::verify() {
|
|
if (getDataClause() != acc::DataClause::acc_declare_link)
|
|
return emitError(
|
|
"data clause associated with link operation must match its intent");
|
|
if (failed(checkVarAndVarType(*this)))
|
|
return failure();
|
|
if (failed(checkVarAndAccVar(*this)))
|
|
return failure();
|
|
if (failed(checkNoModifier(*this)))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CopyoutOp
|
|
//===----------------------------------------------------------------------===//
|
|
LogicalResult acc::CopyoutOp::verify() {
|
|
// Test for all clauses this operation can be decomposed from:
|
|
if (getDataClause() != acc::DataClause::acc_copyout &&
|
|
getDataClause() != acc::DataClause::acc_copyout_zero &&
|
|
getDataClause() != acc::DataClause::acc_copy &&
|
|
getDataClause() != acc::DataClause::acc_reduction)
|
|
return emitError(
|
|
"data clause associated with copyout operation must match its intent"
|
|
" or specify original clause this operation was decomposed from");
|
|
if (!getVar() || !getAccVar())
|
|
return emitError("must have both host and device pointers");
|
|
if (failed(checkVarAndVarType(*this)))
|
|
return failure();
|
|
if (failed(checkVarAndAccVar(*this)))
|
|
return failure();
|
|
if (failed(checkValidModifier(*this, acc::DataClauseModifier::zero |
|
|
acc::DataClauseModifier::always |
|
|
acc::DataClauseModifier::capture)))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
bool acc::CopyoutOp::isCopyoutZero() {
|
|
return getDataClause() == acc::DataClause::acc_copyout_zero ||
|
|
acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// DeleteOp
|
|
//===----------------------------------------------------------------------===//
|
|
LogicalResult acc::DeleteOp::verify() {
|
|
// Test for all clauses this operation can be decomposed from:
|
|
if (getDataClause() != acc::DataClause::acc_delete &&
|
|
getDataClause() != acc::DataClause::acc_create &&
|
|
getDataClause() != acc::DataClause::acc_create_zero &&
|
|
getDataClause() != acc::DataClause::acc_copyin &&
|
|
getDataClause() != acc::DataClause::acc_copyin_readonly &&
|
|
getDataClause() != acc::DataClause::acc_present &&
|
|
getDataClause() != acc::DataClause::acc_no_create &&
|
|
getDataClause() != acc::DataClause::acc_declare_device_resident &&
|
|
getDataClause() != acc::DataClause::acc_declare_link)
|
|
return emitError(
|
|
"data clause associated with delete operation must match its intent"
|
|
" or specify original clause this operation was decomposed from");
|
|
if (!getAccVar())
|
|
return emitError("must have device pointer");
|
|
// This op is the exit part of copyin and create - thus allow all modifiers
|
|
// allowed on either case.
|
|
if (failed(checkValidModifier(*this, acc::DataClauseModifier::zero |
|
|
acc::DataClauseModifier::readonly |
|
|
acc::DataClauseModifier::always |
|
|
acc::DataClauseModifier::capture)))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// DetachOp
|
|
//===----------------------------------------------------------------------===//
|
|
LogicalResult acc::DetachOp::verify() {
|
|
// Test for all clauses this operation can be decomposed from:
|
|
if (getDataClause() != acc::DataClause::acc_detach &&
|
|
getDataClause() != acc::DataClause::acc_attach)
|
|
return emitError(
|
|
"data clause associated with detach operation must match its intent"
|
|
" or specify original clause this operation was decomposed from");
|
|
if (!getAccVar())
|
|
return emitError("must have device pointer");
|
|
if (failed(checkNoModifier(*this)))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// HostOp
|
|
//===----------------------------------------------------------------------===//
|
|
LogicalResult acc::UpdateHostOp::verify() {
|
|
// Test for all clauses this operation can be decomposed from:
|
|
if (getDataClause() != acc::DataClause::acc_update_host &&
|
|
getDataClause() != acc::DataClause::acc_update_self)
|
|
return emitError(
|
|
"data clause associated with host operation must match its intent"
|
|
" or specify original clause this operation was decomposed from");
|
|
if (!getVar() || !getAccVar())
|
|
return emitError("must have both host and device pointers");
|
|
if (failed(checkVarAndVarType(*this)))
|
|
return failure();
|
|
if (failed(checkVarAndAccVar(*this)))
|
|
return failure();
|
|
if (failed(checkNoModifier(*this)))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// DeviceOp
|
|
//===----------------------------------------------------------------------===//
|
|
LogicalResult acc::UpdateDeviceOp::verify() {
|
|
// Test for all clauses this operation can be decomposed from:
|
|
if (getDataClause() != acc::DataClause::acc_update_device)
|
|
return emitError(
|
|
"data clause associated with device operation must match its intent"
|
|
" or specify original clause this operation was decomposed from");
|
|
if (failed(checkVarAndVarType(*this)))
|
|
return failure();
|
|
if (failed(checkVarAndAccVar(*this)))
|
|
return failure();
|
|
if (failed(checkNoModifier(*this)))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// UseDeviceOp
|
|
//===----------------------------------------------------------------------===//
|
|
LogicalResult acc::UseDeviceOp::verify() {
|
|
// Test for all clauses this operation can be decomposed from:
|
|
if (getDataClause() != acc::DataClause::acc_use_device)
|
|
return emitError(
|
|
"data clause associated with use_device operation must match its intent"
|
|
" or specify original clause this operation was decomposed from");
|
|
if (failed(checkVarAndVarType(*this)))
|
|
return failure();
|
|
if (failed(checkVarAndAccVar(*this)))
|
|
return failure();
|
|
if (failed(checkNoModifier(*this)))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CacheOp
|
|
//===----------------------------------------------------------------------===//
|
|
LogicalResult acc::CacheOp::verify() {
|
|
// Test for all clauses this operation can be decomposed from:
|
|
if (getDataClause() != acc::DataClause::acc_cache &&
|
|
getDataClause() != acc::DataClause::acc_cache_readonly)
|
|
return emitError(
|
|
"data clause associated with cache operation must match its intent"
|
|
" or specify original clause this operation was decomposed from");
|
|
if (failed(checkVarAndVarType(*this)))
|
|
return failure();
|
|
if (failed(checkVarAndAccVar(*this)))
|
|
return failure();
|
|
if (failed(checkValidModifier(*this, acc::DataClauseModifier::readonly)))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
bool acc::CacheOp::isCacheReadonly() {
|
|
return getDataClause() == acc::DataClause::acc_cache_readonly ||
|
|
acc::bitEnumContainsAny(getModifiers(),
|
|
acc::DataClauseModifier::readonly);
|
|
}
|
|
|
|
template <typename StructureOp>
|
|
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state,
|
|
unsigned nRegions = 1) {
|
|
|
|
SmallVector<Region *, 2> regions;
|
|
for (unsigned i = 0; i < nRegions; ++i)
|
|
regions.push_back(state.addRegion());
|
|
|
|
for (Region *region : regions)
|
|
if (parser.parseRegion(*region, /*arguments=*/{}, /*argTypes=*/{}))
|
|
return failure();
|
|
|
|
return success();
|
|
}
|
|
|
|
static bool isComputeOperation(Operation *op) {
|
|
return isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(op);
|
|
}
|
|
|
|
namespace {
|
|
/// Pattern to remove operation without region that have constant false `ifCond`
|
|
/// and remove the condition from the operation if the `ifCond` is a true
|
|
/// constant.
|
|
template <typename OpTy>
|
|
struct RemoveConstantIfCondition : public OpRewritePattern<OpTy> {
|
|
using OpRewritePattern<OpTy>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(OpTy op,
|
|
PatternRewriter &rewriter) const override {
|
|
// Early return if there is no condition.
|
|
Value ifCond = op.getIfCond();
|
|
if (!ifCond)
|
|
return failure();
|
|
|
|
IntegerAttr constAttr;
|
|
if (!matchPattern(ifCond, m_Constant(&constAttr)))
|
|
return failure();
|
|
if (constAttr.getInt())
|
|
rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
|
|
else
|
|
rewriter.eraseOp(op);
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Replaces the given op with the contents of the given single-block region,
|
|
/// using the operands of the block terminator to replace operation results.
|
|
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
|
|
Region ®ion, ValueRange blockArgs = {}) {
|
|
assert(llvm::hasSingleElement(region) && "expected single-region block");
|
|
Block *block = ®ion.front();
|
|
Operation *terminator = block->getTerminator();
|
|
ValueRange results = terminator->getOperands();
|
|
rewriter.inlineBlockBefore(block, op, blockArgs);
|
|
rewriter.replaceOp(op, results);
|
|
rewriter.eraseOp(terminator);
|
|
}
|
|
|
|
/// Pattern to remove operation with region that have constant false `ifCond`
|
|
/// and remove the condition from the operation if the `ifCond` is constant
|
|
/// true.
|
|
template <typename OpTy>
|
|
struct RemoveConstantIfConditionWithRegion : public OpRewritePattern<OpTy> {
|
|
using OpRewritePattern<OpTy>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(OpTy op,
|
|
PatternRewriter &rewriter) const override {
|
|
// Early return if there is no condition.
|
|
Value ifCond = op.getIfCond();
|
|
if (!ifCond)
|
|
return failure();
|
|
|
|
IntegerAttr constAttr;
|
|
if (!matchPattern(ifCond, m_Constant(&constAttr)))
|
|
return failure();
|
|
if (constAttr.getInt())
|
|
rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
|
|
else
|
|
replaceOpWithRegion(rewriter, op, op.getRegion());
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// PrivateRecipeOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static LogicalResult verifyInitLikeSingleArgRegion(
|
|
Operation *op, Region ®ion, StringRef regionType, StringRef regionName,
|
|
Type type, bool verifyYield, bool optional = false) {
|
|
if (optional && region.empty())
|
|
return success();
|
|
|
|
if (region.empty())
|
|
return op->emitOpError() << "expects non-empty " << regionName << " region";
|
|
Block &firstBlock = region.front();
|
|
if (firstBlock.getNumArguments() < 1 ||
|
|
firstBlock.getArgument(0).getType() != type)
|
|
return op->emitOpError() << "expects " << regionName
|
|
<< " region first "
|
|
"argument of the "
|
|
<< regionType << " type";
|
|
|
|
if (verifyYield) {
|
|
for (YieldOp yieldOp : region.getOps<acc::YieldOp>()) {
|
|
if (yieldOp.getOperands().size() != 1 ||
|
|
yieldOp.getOperands().getTypes()[0] != type)
|
|
return op->emitOpError() << "expects " << regionName
|
|
<< " region to "
|
|
"yield a value of the "
|
|
<< regionType << " type";
|
|
}
|
|
}
|
|
return success();
|
|
}
|
|
|
|
LogicalResult acc::PrivateRecipeOp::verifyRegions() {
|
|
if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(),
|
|
"privatization", "init", getType(),
|
|
/*verifyYield=*/false)))
|
|
return failure();
|
|
if (failed(verifyInitLikeSingleArgRegion(
|
|
*this, getDestroyRegion(), "privatization", "destroy", getType(),
|
|
/*verifyYield=*/false, /*optional=*/true)))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// FirstprivateRecipeOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult acc::FirstprivateRecipeOp::verifyRegions() {
|
|
if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(),
|
|
"privatization", "init", getType(),
|
|
/*verifyYield=*/false)))
|
|
return failure();
|
|
|
|
if (getCopyRegion().empty())
|
|
return emitOpError() << "expects non-empty copy region";
|
|
|
|
Block &firstBlock = getCopyRegion().front();
|
|
if (firstBlock.getNumArguments() < 2 ||
|
|
firstBlock.getArgument(0).getType() != getType())
|
|
return emitOpError() << "expects copy region with two arguments of the "
|
|
"privatization type";
|
|
|
|
if (getDestroyRegion().empty())
|
|
return success();
|
|
|
|
if (failed(verifyInitLikeSingleArgRegion(*this, getDestroyRegion(),
|
|
"privatization", "destroy",
|
|
getType(), /*verifyYield=*/false)))
|
|
return failure();
|
|
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ReductionRecipeOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult acc::ReductionRecipeOp::verifyRegions() {
|
|
if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(), "reduction",
|
|
"init", getType(),
|
|
/*verifyYield=*/false)))
|
|
return failure();
|
|
|
|
if (getCombinerRegion().empty())
|
|
return emitOpError() << "expects non-empty combiner region";
|
|
|
|
Block &reductionBlock = getCombinerRegion().front();
|
|
if (reductionBlock.getNumArguments() < 2 ||
|
|
reductionBlock.getArgument(0).getType() != getType() ||
|
|
reductionBlock.getArgument(1).getType() != getType())
|
|
return emitOpError() << "expects combiner region with the first two "
|
|
<< "arguments of the reduction type";
|
|
|
|
for (YieldOp yieldOp : getCombinerRegion().getOps<YieldOp>()) {
|
|
if (yieldOp.getOperands().size() != 1 ||
|
|
yieldOp.getOperands().getTypes()[0] != getType())
|
|
return emitOpError() << "expects combiner region to yield a value "
|
|
"of the reduction type";
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Custom parser and printer verifier for private clause
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static ParseResult parseSymOperandList(
|
|
mlir::OpAsmParser &parser,
|
|
llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
|
|
llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &symbols) {
|
|
llvm::SmallVector<SymbolRefAttr> attributes;
|
|
if (failed(parser.parseCommaSeparatedList([&]() {
|
|
if (parser.parseAttribute(attributes.emplace_back()) ||
|
|
parser.parseArrow() ||
|
|
parser.parseOperand(operands.emplace_back()) ||
|
|
parser.parseColonType(types.emplace_back()))
|
|
return failure();
|
|
return success();
|
|
})))
|
|
return failure();
|
|
llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
|
|
attributes.end());
|
|
symbols = ArrayAttr::get(parser.getContext(), arrayAttr);
|
|
return success();
|
|
}
|
|
|
|
static void printSymOperandList(mlir::OpAsmPrinter &p, mlir::Operation *op,
|
|
mlir::OperandRange operands,
|
|
mlir::TypeRange types,
|
|
std::optional<mlir::ArrayAttr> attributes) {
|
|
llvm::interleaveComma(llvm::zip(*attributes, operands), p, [&](auto it) {
|
|
p << std::get<0>(it) << " -> " << std::get<1>(it) << " : "
|
|
<< std::get<1>(it).getType();
|
|
});
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ParallelOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Check dataOperands for acc.parallel, acc.serial and acc.kernels.
|
|
template <typename Op>
|
|
static LogicalResult checkDataOperands(Op op,
|
|
const mlir::ValueRange &operands) {
|
|
for (mlir::Value operand : operands)
|
|
if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
|
|
acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
|
|
acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
|
|
operand.getDefiningOp()))
|
|
return op.emitError(
|
|
"expect data entry/exit operation or acc.getdeviceptr "
|
|
"as defining op");
|
|
return success();
|
|
}
|
|
|
|
template <typename Op>
|
|
static LogicalResult
|
|
checkSymOperandList(Operation *op, std::optional<mlir::ArrayAttr> attributes,
|
|
mlir::OperandRange operands, llvm::StringRef operandName,
|
|
llvm::StringRef symbolName, bool checkOperandType = true) {
|
|
if (!operands.empty()) {
|
|
if (!attributes || attributes->size() != operands.size())
|
|
return op->emitOpError()
|
|
<< "expected as many " << symbolName << " symbol reference as "
|
|
<< operandName << " operands";
|
|
} else {
|
|
if (attributes)
|
|
return op->emitOpError()
|
|
<< "unexpected " << symbolName << " symbol reference";
|
|
return success();
|
|
}
|
|
|
|
llvm::DenseSet<Value> set;
|
|
for (auto args : llvm::zip(operands, *attributes)) {
|
|
mlir::Value operand = std::get<0>(args);
|
|
|
|
if (!set.insert(operand).second)
|
|
return op->emitOpError()
|
|
<< operandName << " operand appears more than once";
|
|
|
|
mlir::Type varType = operand.getType();
|
|
auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
|
|
auto decl = SymbolTable::lookupNearestSymbolFrom<Op>(op, symbolRef);
|
|
if (!decl)
|
|
return op->emitOpError()
|
|
<< "expected symbol reference " << symbolRef << " to point to a "
|
|
<< operandName << " declaration";
|
|
|
|
if (checkOperandType && decl.getType() && decl.getType() != varType)
|
|
return op->emitOpError() << "expected " << operandName << " (" << varType
|
|
<< ") to be the same type as " << operandName
|
|
<< " declaration (" << decl.getType() << ")";
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
unsigned ParallelOp::getNumDataOperands() {
|
|
return getReductionOperands().size() + getPrivateOperands().size() +
|
|
getFirstprivateOperands().size() + getDataClauseOperands().size();
|
|
}
|
|
|
|
Value ParallelOp::getDataOperand(unsigned i) {
|
|
unsigned numOptional = getAsyncOperands().size();
|
|
numOptional += getNumGangs().size();
|
|
numOptional += getNumWorkers().size();
|
|
numOptional += getVectorLength().size();
|
|
numOptional += getIfCond() ? 1 : 0;
|
|
numOptional += getSelfCond() ? 1 : 0;
|
|
return getOperand(getWaitOperands().size() + numOptional + i);
|
|
}
|
|
|
|
template <typename Op>
|
|
static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands,
|
|
ArrayAttr deviceTypes,
|
|
llvm::StringRef keyword) {
|
|
if (!operands.empty() && deviceTypes.getValue().size() != operands.size())
|
|
return op.emitOpError() << keyword << " operands count must match "
|
|
<< keyword << " device_type count";
|
|
return success();
|
|
}
|
|
|
|
template <typename Op>
|
|
static LogicalResult verifyDeviceTypeAndSegmentCountMatch(
|
|
Op op, OperandRange operands, DenseI32ArrayAttr segments,
|
|
ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) {
|
|
std::size_t numOperandsInSegments = 0;
|
|
std::size_t nbOfSegments = 0;
|
|
|
|
if (segments) {
|
|
for (auto segCount : segments.asArrayRef()) {
|
|
if (maxInSegment != 0 && segCount > maxInSegment)
|
|
return op.emitOpError() << keyword << " expects a maximum of "
|
|
<< maxInSegment << " values per segment";
|
|
numOperandsInSegments += segCount;
|
|
++nbOfSegments;
|
|
}
|
|
}
|
|
|
|
if ((numOperandsInSegments != operands.size()) ||
|
|
(!deviceTypes && !operands.empty()))
|
|
return op.emitOpError()
|
|
<< keyword << " operand count does not match count in segments";
|
|
if (deviceTypes && deviceTypes.getValue().size() != nbOfSegments)
|
|
return op.emitOpError()
|
|
<< keyword << " segment count does not match device_type count";
|
|
return success();
|
|
}
|
|
|
|
LogicalResult acc::ParallelOp::verify() {
|
|
if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
|
|
*this, getPrivatizationRecipes(), getPrivateOperands(), "private",
|
|
"privatizations", /*checkOperandType=*/false)))
|
|
return failure();
|
|
if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>(
|
|
*this, getFirstprivatizationRecipes(), getFirstprivateOperands(),
|
|
"firstprivate", "firstprivatizations", /*checkOperandType=*/false)))
|
|
return failure();
|
|
if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
|
|
*this, getReductionRecipes(), getReductionOperands(), "reduction",
|
|
"reductions", false)))
|
|
return failure();
|
|
|
|
if (failed(verifyDeviceTypeAndSegmentCountMatch(
|
|
*this, getNumGangs(), getNumGangsSegmentsAttr(),
|
|
getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
|
|
return failure();
|
|
|
|
if (failed(verifyDeviceTypeAndSegmentCountMatch(
|
|
*this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
|
|
getWaitOperandsDeviceTypeAttr(), "wait")))
|
|
return failure();
|
|
|
|
if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
|
|
getNumWorkersDeviceTypeAttr(),
|
|
"num_workers")))
|
|
return failure();
|
|
|
|
if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
|
|
getVectorLengthDeviceTypeAttr(),
|
|
"vector_length")))
|
|
return failure();
|
|
|
|
if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
|
|
getAsyncOperandsDeviceTypeAttr(),
|
|
"async")))
|
|
return failure();
|
|
|
|
if (failed(checkWaitAndAsyncConflict<acc::ParallelOp>(*this)))
|
|
return failure();
|
|
|
|
return checkDataOperands<acc::ParallelOp>(*this, getDataClauseOperands());
|
|
}
|
|
|
|
static mlir::Value
|
|
getValueInDeviceTypeSegment(std::optional<mlir::ArrayAttr> arrayAttr,
|
|
mlir::Operation::operand_range range,
|
|
mlir::acc::DeviceType deviceType) {
|
|
if (!arrayAttr)
|
|
return {};
|
|
if (auto pos = findSegment(*arrayAttr, deviceType))
|
|
return range[*pos];
|
|
return {};
|
|
}
|
|
|
|
bool acc::ParallelOp::hasAsyncOnly() {
|
|
return hasAsyncOnly(mlir::acc::DeviceType::None);
|
|
}
|
|
|
|
bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
|
|
return hasDeviceType(getAsyncOnly(), deviceType);
|
|
}
|
|
|
|
mlir::Value acc::ParallelOp::getAsyncValue() {
|
|
return getAsyncValue(mlir::acc::DeviceType::None);
|
|
}
|
|
|
|
mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
|
|
return getValueInDeviceTypeSegment(getAsyncOperandsDeviceType(),
|
|
getAsyncOperands(), deviceType);
|
|
}
|
|
|
|
mlir::Value acc::ParallelOp::getNumWorkersValue() {
|
|
return getNumWorkersValue(mlir::acc::DeviceType::None);
|
|
}
|
|
|
|
mlir::Value
|
|
acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
|
|
return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
|
|
deviceType);
|
|
}
|
|
|
|
mlir::Value acc::ParallelOp::getVectorLengthValue() {
|
|
return getVectorLengthValue(mlir::acc::DeviceType::None);
|
|
}
|
|
|
|
mlir::Value
|
|
acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
|
|
return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
|
|
getVectorLength(), deviceType);
|
|
}
|
|
|
|
mlir::Operation::operand_range ParallelOp::getNumGangsValues() {
|
|
return getNumGangsValues(mlir::acc::DeviceType::None);
|
|
}
|
|
|
|
mlir::Operation::operand_range
|
|
ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
|
|
return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
|
|
getNumGangsSegments(), deviceType);
|
|
}
|
|
|
|
bool acc::ParallelOp::hasWaitOnly() {
|
|
return hasWaitOnly(mlir::acc::DeviceType::None);
|
|
}
|
|
|
|
bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
|
|
return hasDeviceType(getWaitOnly(), deviceType);
|
|
}
|
|
|
|
mlir::Operation::operand_range ParallelOp::getWaitValues() {
|
|
return getWaitValues(mlir::acc::DeviceType::None);
|
|
}
|
|
|
|
mlir::Operation::operand_range
|
|
ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) {
|
|
return getWaitValuesWithoutDevnum(
|
|
getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
|
|
getHasWaitDevnum(), deviceType);
|
|
}
|
|
|
|
mlir::Value ParallelOp::getWaitDevnum() {
|
|
return getWaitDevnum(mlir::acc::DeviceType::None);
|
|
}
|
|
|
|
mlir::Value ParallelOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
|
|
return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
|
|
getWaitOperandsSegments(), getHasWaitDevnum(),
|
|
deviceType);
|
|
}
|
|
|
|
void ParallelOp::build(mlir::OpBuilder &odsBuilder,
|
|
mlir::OperationState &odsState,
|
|
mlir::ValueRange numGangs, mlir::ValueRange numWorkers,
|
|
mlir::ValueRange vectorLength,
|
|
mlir::ValueRange asyncOperands,
|
|
mlir::ValueRange waitOperands, mlir::Value ifCond,
|
|
mlir::Value selfCond, mlir::ValueRange reductionOperands,
|
|
mlir::ValueRange gangPrivateOperands,
|
|
mlir::ValueRange gangFirstPrivateOperands,
|
|
mlir::ValueRange dataClauseOperands) {
|
|
|
|
ParallelOp::build(
|
|
odsBuilder, odsState, asyncOperands, /*asyncOperandsDeviceType=*/nullptr,
|
|
/*asyncOnly=*/nullptr, waitOperands, /*waitOperandsSegments=*/nullptr,
|
|
/*waitOperandsDeviceType=*/nullptr, /*hasWaitDevnum=*/nullptr,
|
|
/*waitOnly=*/nullptr, numGangs, /*numGangsSegments=*/nullptr,
|
|
/*numGangsDeviceType=*/nullptr, numWorkers,
|
|
/*numWorkersDeviceType=*/nullptr, vectorLength,
|
|
/*vectorLengthDeviceType=*/nullptr, ifCond, selfCond,
|
|
/*selfAttr=*/nullptr, reductionOperands, /*reductionRecipes=*/nullptr,
|
|
gangPrivateOperands, /*privatizations=*/nullptr, gangFirstPrivateOperands,
|
|
/*firstprivatizations=*/nullptr, dataClauseOperands,
|
|
/*defaultAttr=*/nullptr, /*combined=*/nullptr);
|
|
}
|
|
|
|
void acc::ParallelOp::addNumWorkersOperand(
|
|
MLIRContext *context, mlir::Value newValue,
|
|
llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
|
|
setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
|
|
context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
|
|
getNumWorkersMutable()));
|
|
}
|
|
void acc::ParallelOp::addVectorLengthOperand(
|
|
MLIRContext *context, mlir::Value newValue,
|
|
llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
|
|
setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
|
|
context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
|
|
getVectorLengthMutable()));
|
|
}
|
|
|
|
void acc::ParallelOp::addAsyncOnly(
|
|
MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
|
|
setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
|
|
context, getAsyncOnlyAttr(), effectiveDeviceTypes));
|
|
}
|
|
|
|
void acc::ParallelOp::addAsyncOperand(
|
|
MLIRContext *context, mlir::Value newValue,
|
|
llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
|
|
setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
|
|
context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
|
|
getAsyncOperandsMutable()));
|
|
}
|
|
|
|
void acc::ParallelOp::addNumGangsOperands(
|
|
MLIRContext *context, mlir::ValueRange newValues,
|
|
llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
|
|
llvm::SmallVector<int32_t> segments;
|
|
if (getNumGangsSegments())
|
|
llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
|
|
|
|
setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
|
|
context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
|
|
getNumGangsMutable(), segments));
|
|
|
|
setNumGangsSegments(segments);
|
|
}
|
|
void acc::ParallelOp::addWaitOnly(
|
|
MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
|
|
setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
|
|
effectiveDeviceTypes));
|
|
}
|
|
void acc::ParallelOp::addWaitOperands(
|
|
MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
|
|
llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
|
|
|
|
llvm::SmallVector<int32_t> segments;
|
|
if (getWaitOperandsSegments())
|
|
llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
|
|
|
|
setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
|
|
context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
|
|
getWaitOperandsMutable(), segments));
|
|
setWaitOperandsSegments(segments);
|
|
|
|
llvm::SmallVector<mlir::Attribute> hasDevnums;
|
|
if (getHasWaitDevnumAttr())
|
|
llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
|
|
hasDevnums.insert(
|
|
hasDevnums.end(),
|
|
std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
|
|
mlir::BoolAttr::get(context, hasDevnum));
|
|
setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
|
|
}
|
|
|
|
static ParseResult parseNumGangs(
|
|
mlir::OpAsmParser &parser,
|
|
llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
|
|
llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
|
|
mlir::DenseI32ArrayAttr &segments) {
|
|
llvm::SmallVector<DeviceTypeAttr> attributes;
|
|
llvm::SmallVector<int32_t> seg;
|
|
|
|
do {
|
|
if (failed(parser.parseLBrace()))
|
|
return failure();
|
|
|
|
int32_t crtOperandsSize = operands.size();
|
|
if (failed(parser.parseCommaSeparatedList(
|
|
mlir::AsmParser::Delimiter::None, [&]() {
|
|
if (parser.parseOperand(operands.emplace_back()) ||
|
|
parser.parseColonType(types.emplace_back()))
|
|
return failure();
|
|
return success();
|
|
})))
|
|
return failure();
|
|
seg.push_back(operands.size() - crtOperandsSize);
|
|
|
|
if (failed(parser.parseRBrace()))
|
|
return failure();
|
|
|
|
if (succeeded(parser.parseOptionalLSquare())) {
|
|
if (parser.parseAttribute(attributes.emplace_back()) ||
|
|
parser.parseRSquare())
|
|
return failure();
|
|
} else {
|
|
attributes.push_back(mlir::acc::DeviceTypeAttr::get(
|
|
parser.getContext(), mlir::acc::DeviceType::None));
|
|
}
|
|
} while (succeeded(parser.parseOptionalComma()));
|
|
|
|
llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
|
|
attributes.end());
|
|
deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
|
|
segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
|
|
|
|
return success();
|
|
}
|
|
|
|
static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr) {
|
|
auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
|
|
if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
|
|
p << " [" << attr << "]";
|
|
}
|
|
|
|
static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op,
|
|
mlir::OperandRange operands, mlir::TypeRange types,
|
|
std::optional<mlir::ArrayAttr> deviceTypes,
|
|
std::optional<mlir::DenseI32ArrayAttr> segments) {
|
|
unsigned opIdx = 0;
|
|
llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
|
|
p << "{";
|
|
llvm::interleaveComma(
|
|
llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
|
|
p << operands[opIdx] << " : " << operands[opIdx].getType();
|
|
++opIdx;
|
|
});
|
|
p << "}";
|
|
printSingleDeviceType(p, it.value());
|
|
});
|
|
}
|
|
|
|
static ParseResult parseDeviceTypeOperandsWithSegment(
|
|
mlir::OpAsmParser &parser,
|
|
llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
|
|
llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
|
|
mlir::DenseI32ArrayAttr &segments) {
|
|
llvm::SmallVector<DeviceTypeAttr> attributes;
|
|
llvm::SmallVector<int32_t> seg;
|
|
|
|
do {
|
|
if (failed(parser.parseLBrace()))
|
|
return failure();
|
|
|
|
int32_t crtOperandsSize = operands.size();
|
|
|
|
if (failed(parser.parseCommaSeparatedList(
|
|
mlir::AsmParser::Delimiter::None, [&]() {
|
|
if (parser.parseOperand(operands.emplace_back()) ||
|
|
parser.parseColonType(types.emplace_back()))
|
|
return failure();
|
|
return success();
|
|
})))
|
|
return failure();
|
|
|
|
seg.push_back(operands.size() - crtOperandsSize);
|
|
|
|
if (failed(parser.parseRBrace()))
|
|
return failure();
|
|
|
|
if (succeeded(parser.parseOptionalLSquare())) {
|
|
if (parser.parseAttribute(attributes.emplace_back()) ||
|
|
parser.parseRSquare())
|
|
return failure();
|
|
} else {
|
|
attributes.push_back(mlir::acc::DeviceTypeAttr::get(
|
|
parser.getContext(), mlir::acc::DeviceType::None));
|
|
}
|
|
} while (succeeded(parser.parseOptionalComma()));
|
|
|
|
llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
|
|
attributes.end());
|
|
deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
|
|
segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
|
|
|
|
return success();
|
|
}
|
|
|
|
static void printDeviceTypeOperandsWithSegment(
|
|
mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands,
|
|
mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
|
|
std::optional<mlir::DenseI32ArrayAttr> segments) {
|
|
unsigned opIdx = 0;
|
|
llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
|
|
p << "{";
|
|
llvm::interleaveComma(
|
|
llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
|
|
p << operands[opIdx] << " : " << operands[opIdx].getType();
|
|
++opIdx;
|
|
});
|
|
p << "}";
|
|
printSingleDeviceType(p, it.value());
|
|
});
|
|
}
|
|
|
|
static ParseResult parseWaitClause(
|
|
mlir::OpAsmParser &parser,
|
|
llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
|
|
llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
|
|
mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum,
|
|
mlir::ArrayAttr &keywordOnly) {
|
|
llvm::SmallVector<mlir::Attribute> deviceTypeAttrs, keywordAttrs, devnum;
|
|
llvm::SmallVector<int32_t> seg;
|
|
|
|
bool needCommaBeforeOperands = false;
|
|
|
|
// Keyword only
|
|
if (failed(parser.parseOptionalLParen())) {
|
|
keywordAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
|
|
parser.getContext(), mlir::acc::DeviceType::None));
|
|
keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
|
|
return success();
|
|
}
|
|
|
|
// Parse keyword only attributes
|
|
if (succeeded(parser.parseOptionalLSquare())) {
|
|
if (failed(parser.parseCommaSeparatedList([&]() {
|
|
if (parser.parseAttribute(keywordAttrs.emplace_back()))
|
|
return failure();
|
|
return success();
|
|
})))
|
|
return failure();
|
|
if (parser.parseRSquare())
|
|
return failure();
|
|
needCommaBeforeOperands = true;
|
|
}
|
|
|
|
if (needCommaBeforeOperands && failed(parser.parseComma()))
|
|
return failure();
|
|
|
|
do {
|
|
if (failed(parser.parseLBrace()))
|
|
return failure();
|
|
|
|
int32_t crtOperandsSize = operands.size();
|
|
|
|
if (succeeded(parser.parseOptionalKeyword("devnum"))) {
|
|
if (failed(parser.parseColon()))
|
|
return failure();
|
|
devnum.push_back(BoolAttr::get(parser.getContext(), true));
|
|
} else {
|
|
devnum.push_back(BoolAttr::get(parser.getContext(), false));
|
|
}
|
|
|
|
if (failed(parser.parseCommaSeparatedList(
|
|
mlir::AsmParser::Delimiter::None, [&]() {
|
|
if (parser.parseOperand(operands.emplace_back()) ||
|
|
parser.parseColonType(types.emplace_back()))
|
|
return failure();
|
|
return success();
|
|
})))
|
|
return failure();
|
|
|
|
seg.push_back(operands.size() - crtOperandsSize);
|
|
|
|
if (failed(parser.parseRBrace()))
|
|
return failure();
|
|
|
|
if (succeeded(parser.parseOptionalLSquare())) {
|
|
if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
|
|
parser.parseRSquare())
|
|
return failure();
|
|
} else {
|
|
deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
|
|
parser.getContext(), mlir::acc::DeviceType::None));
|
|
}
|
|
} while (succeeded(parser.parseOptionalComma()));
|
|
|
|
if (failed(parser.parseRParen()))
|
|
return failure();
|
|
|
|
deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
|
|
keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
|
|
segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
|
|
hasDevNum = ArrayAttr::get(parser.getContext(), devnum);
|
|
|
|
return success();
|
|
}
|
|
|
|
static bool hasOnlyDeviceTypeNone(std::optional<mlir::ArrayAttr> attrs) {
|
|
if (!hasDeviceTypeValues(attrs))
|
|
return false;
|
|
if (attrs->size() != 1)
|
|
return false;
|
|
if (auto deviceTypeAttr =
|
|
mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0]))
|
|
return deviceTypeAttr.getValue() == mlir::acc::DeviceType::None;
|
|
return false;
|
|
}
|
|
|
|
static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op,
|
|
mlir::OperandRange operands, mlir::TypeRange types,
|
|
std::optional<mlir::ArrayAttr> deviceTypes,
|
|
std::optional<mlir::DenseI32ArrayAttr> segments,
|
|
std::optional<mlir::ArrayAttr> hasDevNum,
|
|
std::optional<mlir::ArrayAttr> keywordOnly) {
|
|
|
|
if (operands.begin() == operands.end() && hasOnlyDeviceTypeNone(keywordOnly))
|
|
return;
|
|
|
|
p << "(";
|
|
|
|
printDeviceTypes(p, keywordOnly);
|
|
if (hasDeviceTypeValues(keywordOnly) && hasDeviceTypeValues(deviceTypes))
|
|
p << ", ";
|
|
|
|
if (hasDeviceTypeValues(deviceTypes)) {
|
|
unsigned opIdx = 0;
|
|
llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
|
|
p << "{";
|
|
auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasDevNum)[it.index()]);
|
|
if (boolAttr && boolAttr.getValue())
|
|
p << "devnum: ";
|
|
llvm::interleaveComma(
|
|
llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
|
|
p << operands[opIdx] << " : " << operands[opIdx].getType();
|
|
++opIdx;
|
|
});
|
|
p << "}";
|
|
printSingleDeviceType(p, it.value());
|
|
});
|
|
}
|
|
|
|
p << ")";
|
|
}
|
|
|
|
static ParseResult parseDeviceTypeOperands(
|
|
mlir::OpAsmParser &parser,
|
|
llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
|
|
llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes) {
|
|
llvm::SmallVector<DeviceTypeAttr> attributes;
|
|
if (failed(parser.parseCommaSeparatedList([&]() {
|
|
if (parser.parseOperand(operands.emplace_back()) ||
|
|
parser.parseColonType(types.emplace_back()))
|
|
return failure();
|
|
if (succeeded(parser.parseOptionalLSquare())) {
|
|
if (parser.parseAttribute(attributes.emplace_back()) ||
|
|
parser.parseRSquare())
|
|
return failure();
|
|
} else {
|
|
attributes.push_back(mlir::acc::DeviceTypeAttr::get(
|
|
parser.getContext(), mlir::acc::DeviceType::None));
|
|
}
|
|
return success();
|
|
})))
|
|
return failure();
|
|
llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
|
|
attributes.end());
|
|
deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
|
|
return success();
|
|
}
|
|
|
|
static void
|
|
printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op,
|
|
mlir::OperandRange operands, mlir::TypeRange types,
|
|
std::optional<mlir::ArrayAttr> deviceTypes) {
|
|
if (!hasDeviceTypeValues(deviceTypes))
|
|
return;
|
|
llvm::interleaveComma(llvm::zip(*deviceTypes, operands), p, [&](auto it) {
|
|
p << std::get<1>(it) << " : " << std::get<1>(it).getType();
|
|
printSingleDeviceType(p, std::get<0>(it));
|
|
});
|
|
}
|
|
|
|
static ParseResult parseDeviceTypeOperandsWithKeywordOnly(
|
|
mlir::OpAsmParser &parser,
|
|
llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
|
|
llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
|
|
mlir::ArrayAttr &keywordOnlyDeviceType) {
|
|
|
|
llvm::SmallVector<mlir::Attribute> keywordOnlyDeviceTypeAttributes;
|
|
bool needCommaBeforeOperands = false;
|
|
|
|
if (failed(parser.parseOptionalLParen())) {
|
|
// Keyword only
|
|
keywordOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
|
|
parser.getContext(), mlir::acc::DeviceType::None));
|
|
keywordOnlyDeviceType =
|
|
ArrayAttr::get(parser.getContext(), keywordOnlyDeviceTypeAttributes);
|
|
return success();
|
|
}
|
|
|
|
// Parse keyword only attributes
|
|
if (succeeded(parser.parseOptionalLSquare())) {
|
|
// Parse keyword only attributes
|
|
if (failed(parser.parseCommaSeparatedList([&]() {
|
|
if (parser.parseAttribute(
|
|
keywordOnlyDeviceTypeAttributes.emplace_back()))
|
|
return failure();
|
|
return success();
|
|
})))
|
|
return failure();
|
|
if (parser.parseRSquare())
|
|
return failure();
|
|
needCommaBeforeOperands = true;
|
|
}
|
|
|
|
if (needCommaBeforeOperands && failed(parser.parseComma()))
|
|
return failure();
|
|
|
|
llvm::SmallVector<DeviceTypeAttr> attributes;
|
|
if (failed(parser.parseCommaSeparatedList([&]() {
|
|
if (parser.parseOperand(operands.emplace_back()) ||
|
|
parser.parseColonType(types.emplace_back()))
|
|
return failure();
|
|
if (succeeded(parser.parseOptionalLSquare())) {
|
|
if (parser.parseAttribute(attributes.emplace_back()) ||
|
|
parser.parseRSquare())
|
|
return failure();
|
|
} else {
|
|
attributes.push_back(mlir::acc::DeviceTypeAttr::get(
|
|
parser.getContext(), mlir::acc::DeviceType::None));
|
|
}
|
|
return success();
|
|
})))
|
|
return failure();
|
|
|
|
if (failed(parser.parseRParen()))
|
|
return failure();
|
|
|
|
llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
|
|
attributes.end());
|
|
deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
|
|
return success();
|
|
}
|
|
|
|
static void printDeviceTypeOperandsWithKeywordOnly(
|
|
mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands,
|
|
mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
|
|
std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {
|
|
|
|
if (operands.begin() == operands.end() &&
|
|
hasOnlyDeviceTypeNone(keywordOnlyDeviceTypes)) {
|
|
return;
|
|
}
|
|
|
|
p << "(";
|
|
printDeviceTypes(p, keywordOnlyDeviceTypes);
|
|
if (hasDeviceTypeValues(keywordOnlyDeviceTypes) &&
|
|
hasDeviceTypeValues(deviceTypes))
|
|
p << ", ";
|
|
printDeviceTypeOperands(p, op, operands, types, deviceTypes);
|
|
p << ")";
|
|
}
|
|
|
|
static ParseResult parseOperandWithKeywordOnly(
|
|
mlir::OpAsmParser &parser,
|
|
std::optional<OpAsmParser::UnresolvedOperand> &operand,
|
|
mlir::Type &operandType, mlir::UnitAttr &attr) {
|
|
// Keyword only
|
|
if (failed(parser.parseOptionalLParen())) {
|
|
attr = mlir::UnitAttr::get(parser.getContext());
|
|
return success();
|
|
}
|
|
|
|
OpAsmParser::UnresolvedOperand op;
|
|
if (failed(parser.parseOperand(op)))
|
|
return failure();
|
|
operand = op;
|
|
if (failed(parser.parseColon()))
|
|
return failure();
|
|
if (failed(parser.parseType(operandType)))
|
|
return failure();
|
|
if (failed(parser.parseRParen()))
|
|
return failure();
|
|
|
|
return success();
|
|
}
|
|
|
|
static void printOperandWithKeywordOnly(mlir::OpAsmPrinter &p,
|
|
mlir::Operation *op,
|
|
std::optional<mlir::Value> operand,
|
|
mlir::Type operandType,
|
|
mlir::UnitAttr attr) {
|
|
if (attr)
|
|
return;
|
|
|
|
p << "(";
|
|
p.printOperand(*operand);
|
|
p << " : ";
|
|
p.printType(operandType);
|
|
p << ")";
|
|
}
|
|
|
|
static ParseResult parseOperandsWithKeywordOnly(
|
|
mlir::OpAsmParser &parser,
|
|
llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
|
|
llvm::SmallVectorImpl<Type> &types, mlir::UnitAttr &attr) {
|
|
// Keyword only
|
|
if (failed(parser.parseOptionalLParen())) {
|
|
attr = mlir::UnitAttr::get(parser.getContext());
|
|
return success();
|
|
}
|
|
|
|
if (failed(parser.parseCommaSeparatedList([&]() {
|
|
if (parser.parseOperand(operands.emplace_back()))
|
|
return failure();
|
|
return success();
|
|
})))
|
|
return failure();
|
|
if (failed(parser.parseColon()))
|
|
return failure();
|
|
if (failed(parser.parseCommaSeparatedList([&]() {
|
|
if (parser.parseType(types.emplace_back()))
|
|
return failure();
|
|
return success();
|
|
})))
|
|
return failure();
|
|
if (failed(parser.parseRParen()))
|
|
return failure();
|
|
|
|
return success();
|
|
}
|
|
|
|
static void printOperandsWithKeywordOnly(mlir::OpAsmPrinter &p,
|
|
mlir::Operation *op,
|
|
mlir::OperandRange operands,
|
|
mlir::TypeRange types,
|
|
mlir::UnitAttr attr) {
|
|
if (attr)
|
|
return;
|
|
|
|
p << "(";
|
|
llvm::interleaveComma(operands, p, [&](auto it) { p << it; });
|
|
p << " : ";
|
|
llvm::interleaveComma(types, p, [&](auto it) { p << it; });
|
|
p << ")";
|
|
}
|
|
|
|
static ParseResult
|
|
parseCombinedConstructsLoop(mlir::OpAsmParser &parser,
|
|
mlir::acc::CombinedConstructsTypeAttr &attr) {
|
|
if (succeeded(parser.parseOptionalKeyword("kernels"))) {
|
|
attr = mlir::acc::CombinedConstructsTypeAttr::get(
|
|
parser.getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
|
|
} else if (succeeded(parser.parseOptionalKeyword("parallel"))) {
|
|
attr = mlir::acc::CombinedConstructsTypeAttr::get(
|
|
parser.getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
|
|
} else if (succeeded(parser.parseOptionalKeyword("serial"))) {
|
|
attr = mlir::acc::CombinedConstructsTypeAttr::get(
|
|
parser.getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
|
|
} else {
|
|
parser.emitError(parser.getCurrentLocation(),
|
|
"expected compute construct name");
|
|
return failure();
|
|
}
|
|
return success();
|
|
}
|
|
|
|
static void
|
|
printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op,
|
|
mlir::acc::CombinedConstructsTypeAttr attr) {
|
|
if (attr) {
|
|
switch (attr.getValue()) {
|
|
case mlir::acc::CombinedConstructsType::KernelsLoop:
|
|
p << "kernels";
|
|
break;
|
|
case mlir::acc::CombinedConstructsType::ParallelLoop:
|
|
p << "parallel";
|
|
break;
|
|
case mlir::acc::CombinedConstructsType::SerialLoop:
|
|
p << "serial";
|
|
break;
|
|
};
|
|
}
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// SerialOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
unsigned SerialOp::getNumDataOperands() {
|
|
return getReductionOperands().size() + getPrivateOperands().size() +
|
|
getFirstprivateOperands().size() + getDataClauseOperands().size();
|
|
}
|
|
|
|
Value SerialOp::getDataOperand(unsigned i) {
|
|
unsigned numOptional = getAsyncOperands().size();
|
|
numOptional += getIfCond() ? 1 : 0;
|
|
numOptional += getSelfCond() ? 1 : 0;
|
|
return getOperand(getWaitOperands().size() + numOptional + i);
|
|
}
|
|
|
|
bool acc::SerialOp::hasAsyncOnly() {
|
|
return hasAsyncOnly(mlir::acc::DeviceType::None);
|
|
}
|
|
|
|
bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
|
|
return hasDeviceType(getAsyncOnly(), deviceType);
|
|
}
|
|
|
|
mlir::Value acc::SerialOp::getAsyncValue() {
|
|
return getAsyncValue(mlir::acc::DeviceType::None);
|
|
}
|
|
|
|
mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
|
|
return getValueInDeviceTypeSegment(getAsyncOperandsDeviceType(),
|
|
getAsyncOperands(), deviceType);
|
|
}
|
|
|
|
bool acc::SerialOp::hasWaitOnly() {
|
|
return hasWaitOnly(mlir::acc::DeviceType::None);
|
|
}
|
|
|
|
bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
|
|
return hasDeviceType(getWaitOnly(), deviceType);
|
|
}
|
|
|
|
mlir::Operation::operand_range SerialOp::getWaitValues() {
|
|
return getWaitValues(mlir::acc::DeviceType::None);
|
|
}
|
|
|
|
mlir::Operation::operand_range
|
|
SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) {
|
|
return getWaitValuesWithoutDevnum(
|
|
getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
|
|
getHasWaitDevnum(), deviceType);
|
|
}
|
|
|
|
mlir::Value SerialOp::getWaitDevnum() {
|
|
return getWaitDevnum(mlir::acc::DeviceType::None);
|
|
}
|
|
|
|
mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
|
|
return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
|
|
getWaitOperandsSegments(), getHasWaitDevnum(),
|
|
deviceType);
|
|
}
|
|
|
|
LogicalResult acc::SerialOp::verify() {
|
|
if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
|
|
*this, getPrivatizationRecipes(), getPrivateOperands(), "private",
|
|
"privatizations", /*checkOperandType=*/false)))
|
|
return failure();
|
|
if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>(
|
|
*this, getFirstprivatizationRecipes(), getFirstprivateOperands(),
|
|
"firstprivate", "firstprivatizations", /*checkOperandType=*/false)))
|
|
return failure();
|
|
if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
|
|
*this, getReductionRecipes(), getReductionOperands(), "reduction",
|
|
"reductions", false)))
|
|
return failure();
|
|
|
|
if (failed(verifyDeviceTypeAndSegmentCountMatch(
|
|
*this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
|
|
getWaitOperandsDeviceTypeAttr(), "wait")))
|
|
return failure();
|
|
|
|
if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
|
|
getAsyncOperandsDeviceTypeAttr(),
|
|
"async")))
|
|
return failure();
|
|
|
|
if (failed(checkWaitAndAsyncConflict<acc::SerialOp>(*this)))
|
|
return failure();
|
|
|
|
return checkDataOperands<acc::SerialOp>(*this, getDataClauseOperands());
|
|
}
|
|
|
|
void acc::SerialOp::addAsyncOnly(
|
|
MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
|
|
setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
|
|
context, getAsyncOnlyAttr(), effectiveDeviceTypes));
|
|
}
|
|
|
|
void acc::SerialOp::addAsyncOperand(
|
|
MLIRContext *context, mlir::Value newValue,
|
|
llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
|
|
setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
|
|
context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
|
|
getAsyncOperandsMutable()));
|
|
}
|
|
|
|
void acc::SerialOp::addWaitOnly(
|
|
MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
|
|
setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
|
|
effectiveDeviceTypes));
|
|
}
|
|
void acc::SerialOp::addWaitOperands(
|
|
MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
|
|
llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
|
|
|
|
llvm::SmallVector<int32_t> segments;
|
|
if (getWaitOperandsSegments())
|
|
llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
|
|
|
|
setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
|
|
context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
|
|
getWaitOperandsMutable(), segments));
|
|
setWaitOperandsSegments(segments);
|
|
|
|
llvm::SmallVector<mlir::Attribute> hasDevnums;
|
|
if (getHasWaitDevnumAttr())
|
|
llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
|
|
hasDevnums.insert(
|
|
hasDevnums.end(),
|
|
std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
|
|
mlir::BoolAttr::get(context, hasDevnum));
|
|
setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// KernelsOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
unsigned KernelsOp::getNumDataOperands() {
|
|
return getDataClauseOperands().size();
|
|
}
|
|
|
|
Value KernelsOp::getDataOperand(unsigned i) {
|
|
unsigned numOptional = getAsyncOperands().size();
|
|
numOptional += getWaitOperands().size();
|
|
numOptional += getNumGangs().size();
|
|
numOptional += getNumWorkers().size();
|
|
numOptional += getVectorLength().size();
|
|
numOptional += getIfCond() ? 1 : 0;
|
|
numOptional += getSelfCond() ? 1 : 0;
|
|
return getOperand(numOptional + i);
|
|
}
|
|
|
|
bool acc::KernelsOp::hasAsyncOnly() {
|
|
return hasAsyncOnly(mlir::acc::DeviceType::None);
|
|
}
|
|
|
|
bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
|
|
return hasDeviceType(getAsyncOnly(), deviceType);
|
|
}
|
|
|
|
mlir::Value acc::KernelsOp::getAsyncValue() {
|
|
return getAsyncValue(mlir::acc::DeviceType::None);
|
|
}
|
|
|
|
mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
|
|
return getValueInDeviceTypeSegment(getAsyncOperandsDeviceType(),
|
|
getAsyncOperands(), deviceType);
|
|
}
|
|
|
|
mlir::Value acc::KernelsOp::getNumWorkersValue() {
|
|
return getNumWorkersValue(mlir::acc::DeviceType::None);
|
|
}
|
|
|
|
mlir::Value
|
|
acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
|
|
return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
|
|
deviceType);
|
|
}
|
|
|
|
mlir::Value acc::KernelsOp::getVectorLengthValue() {
|
|
return getVectorLengthValue(mlir::acc::DeviceType::None);
|
|
}
|
|
|
|
mlir::Value
|
|
acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
|
|
return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
|
|
getVectorLength(), deviceType);
|
|
}
|
|
|
|
mlir::Operation::operand_range KernelsOp::getNumGangsValues() {
|
|
return getNumGangsValues(mlir::acc::DeviceType::None);
|
|
}
|
|
|
|
mlir::Operation::operand_range
|
|
KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
|
|
return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
|
|
getNumGangsSegments(), deviceType);
|
|
}
|
|
|
|
bool acc::KernelsOp::hasWaitOnly() {
|
|
return hasWaitOnly(mlir::acc::DeviceType::None);
|
|
}
|
|
|
|
bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
|
|
return hasDeviceType(getWaitOnly(), deviceType);
|
|
}
|
|
|
|
mlir::Operation::operand_range KernelsOp::getWaitValues() {
|
|
return getWaitValues(mlir::acc::DeviceType::None);
|
|
}
|
|
|
|
mlir::Operation::operand_range
|
|
KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) {
|
|
return getWaitValuesWithoutDevnum(
|
|
getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
|
|
getHasWaitDevnum(), deviceType);
|
|
}
|
|
|
|
mlir::Value KernelsOp::getWaitDevnum() {
|
|
return getWaitDevnum(mlir::acc::DeviceType::None);
|
|
}
|
|
|
|
mlir::Value KernelsOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
|
|
return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
|
|
getWaitOperandsSegments(), getHasWaitDevnum(),
|
|
deviceType);
|
|
}
|
|
|
|
LogicalResult acc::KernelsOp::verify() {
|
|
if (failed(verifyDeviceTypeAndSegmentCountMatch(
|
|
*this, getNumGangs(), getNumGangsSegmentsAttr(),
|
|
getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
|
|
return failure();
|
|
|
|
if (failed(verifyDeviceTypeAndSegmentCountMatch(
|
|
*this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
|
|
getWaitOperandsDeviceTypeAttr(), "wait")))
|
|
return failure();
|
|
|
|
if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
|
|
getNumWorkersDeviceTypeAttr(),
|
|
"num_workers")))
|
|
return failure();
|
|
|
|
if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
|
|
getVectorLengthDeviceTypeAttr(),
|
|
"vector_length")))
|
|
return failure();
|
|
|
|
if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
|
|
getAsyncOperandsDeviceTypeAttr(),
|
|
"async")))
|
|
return failure();
|
|
|
|
if (failed(checkWaitAndAsyncConflict<acc::KernelsOp>(*this)))
|
|
return failure();
|
|
|
|
return checkDataOperands<acc::KernelsOp>(*this, getDataClauseOperands());
|
|
}
|
|
|
|
void acc::KernelsOp::addNumWorkersOperand(
|
|
MLIRContext *context, mlir::Value newValue,
|
|
llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
|
|
setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
|
|
context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
|
|
getNumWorkersMutable()));
|
|
}
|
|
|
|
void acc::KernelsOp::addVectorLengthOperand(
|
|
MLIRContext *context, mlir::Value newValue,
|
|
llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
|
|
setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
|
|
context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
|
|
getVectorLengthMutable()));
|
|
}
|
|
void acc::KernelsOp::addAsyncOnly(
|
|
MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
|
|
setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
|
|
context, getAsyncOnlyAttr(), effectiveDeviceTypes));
|
|
}
|
|
|
|
void acc::KernelsOp::addAsyncOperand(
|
|
MLIRContext *context, mlir::Value newValue,
|
|
llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
|
|
setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
|
|
context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
|
|
getAsyncOperandsMutable()));
|
|
}
|
|
|
|
void acc::KernelsOp::addNumGangsOperands(
|
|
MLIRContext *context, mlir::ValueRange newValues,
|
|
llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
|
|
llvm::SmallVector<int32_t> segments;
|
|
if (getNumGangsSegmentsAttr())
|
|
llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
|
|
|
|
setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
|
|
context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
|
|
getNumGangsMutable(), segments));
|
|
|
|
setNumGangsSegments(segments);
|
|
}
|
|
|
|
void acc::KernelsOp::addWaitOnly(
|
|
MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
|
|
setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
|
|
effectiveDeviceTypes));
|
|
}
|
|
void acc::KernelsOp::addWaitOperands(
|
|
MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
|
|
llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
|
|
|
|
llvm::SmallVector<int32_t> segments;
|
|
if (getWaitOperandsSegments())
|
|
llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
|
|
|
|
setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
|
|
context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
|
|
getWaitOperandsMutable(), segments));
|
|
setWaitOperandsSegments(segments);
|
|
|
|
llvm::SmallVector<mlir::Attribute> hasDevnums;
|
|
if (getHasWaitDevnumAttr())
|
|
llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
|
|
hasDevnums.insert(
|
|
hasDevnums.end(),
|
|
std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
|
|
mlir::BoolAttr::get(context, hasDevnum));
|
|
setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// HostDataOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult acc::HostDataOp::verify() {
|
|
if (getDataClauseOperands().empty())
|
|
return emitError("at least one operand must appear on the host_data "
|
|
"operation");
|
|
|
|
for (mlir::Value operand : getDataClauseOperands())
|
|
if (!mlir::isa<acc::UseDeviceOp>(operand.getDefiningOp()))
|
|
return emitError("expect data entry operation as defining op");
|
|
return success();
|
|
}
|
|
|
|
void acc::HostDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
MLIRContext *context) {
|
|
results.add<RemoveConstantIfConditionWithRegion<HostDataOp>>(context);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// LoopOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static ParseResult parseGangValue(
|
|
OpAsmParser &parser, llvm::StringRef keyword,
|
|
llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
|
|
llvm::SmallVectorImpl<Type> &types,
|
|
llvm::SmallVector<GangArgTypeAttr> &attributes, GangArgTypeAttr gangArgType,
|
|
bool &needCommaBetweenValues, bool &newValue) {
|
|
if (succeeded(parser.parseOptionalKeyword(keyword))) {
|
|
if (parser.parseEqual())
|
|
return failure();
|
|
if (parser.parseOperand(operands.emplace_back()) ||
|
|
parser.parseColonType(types.emplace_back()))
|
|
return failure();
|
|
attributes.push_back(gangArgType);
|
|
needCommaBetweenValues = true;
|
|
newValue = true;
|
|
}
|
|
return success();
|
|
}
|
|
|
|
static ParseResult parseGangClause(
|
|
OpAsmParser &parser,
|
|
llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &gangOperands,
|
|
llvm::SmallVectorImpl<Type> &gangOperandsType, mlir::ArrayAttr &gangArgType,
|
|
mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments,
|
|
mlir::ArrayAttr &gangOnlyDeviceType) {
|
|
llvm::SmallVector<GangArgTypeAttr> gangArgTypeAttributes;
|
|
llvm::SmallVector<mlir::Attribute> deviceTypeAttributes;
|
|
llvm::SmallVector<mlir::Attribute> gangOnlyDeviceTypeAttributes;
|
|
llvm::SmallVector<int32_t> seg;
|
|
bool needCommaBetweenValues = false;
|
|
bool needCommaBeforeOperands = false;
|
|
|
|
if (failed(parser.parseOptionalLParen())) {
|
|
// Gang only keyword
|
|
gangOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
|
|
parser.getContext(), mlir::acc::DeviceType::None));
|
|
gangOnlyDeviceType =
|
|
ArrayAttr::get(parser.getContext(), gangOnlyDeviceTypeAttributes);
|
|
return success();
|
|
}
|
|
|
|
// Parse gang only attributes
|
|
if (succeeded(parser.parseOptionalLSquare())) {
|
|
// Parse gang only attributes
|
|
if (failed(parser.parseCommaSeparatedList([&]() {
|
|
if (parser.parseAttribute(
|
|
gangOnlyDeviceTypeAttributes.emplace_back()))
|
|
return failure();
|
|
return success();
|
|
})))
|
|
return failure();
|
|
if (parser.parseRSquare())
|
|
return failure();
|
|
needCommaBeforeOperands = true;
|
|
}
|
|
|
|
auto argNum = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
|
|
mlir::acc::GangArgType::Num);
|
|
auto argDim = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
|
|
mlir::acc::GangArgType::Dim);
|
|
auto argStatic = mlir::acc::GangArgTypeAttr::get(
|
|
parser.getContext(), mlir::acc::GangArgType::Static);
|
|
|
|
do {
|
|
if (needCommaBeforeOperands) {
|
|
needCommaBeforeOperands = false;
|
|
continue;
|
|
}
|
|
|
|
if (failed(parser.parseLBrace()))
|
|
return failure();
|
|
|
|
int32_t crtOperandsSize = gangOperands.size();
|
|
while (true) {
|
|
bool newValue = false;
|
|
bool needValue = false;
|
|
if (needCommaBetweenValues) {
|
|
if (succeeded(parser.parseOptionalComma()))
|
|
needValue = true; // expect a new value after comma.
|
|
else
|
|
break;
|
|
}
|
|
|
|
if (failed(parseGangValue(parser, LoopOp::getGangNumKeyword(),
|
|
gangOperands, gangOperandsType,
|
|
gangArgTypeAttributes, argNum,
|
|
needCommaBetweenValues, newValue)))
|
|
return failure();
|
|
if (failed(parseGangValue(parser, LoopOp::getGangDimKeyword(),
|
|
gangOperands, gangOperandsType,
|
|
gangArgTypeAttributes, argDim,
|
|
needCommaBetweenValues, newValue)))
|
|
return failure();
|
|
if (failed(parseGangValue(parser, LoopOp::getGangStaticKeyword(),
|
|
gangOperands, gangOperandsType,
|
|
gangArgTypeAttributes, argStatic,
|
|
needCommaBetweenValues, newValue)))
|
|
return failure();
|
|
|
|
if (!newValue && needValue) {
|
|
parser.emitError(parser.getCurrentLocation(),
|
|
"new value expected after comma");
|
|
return failure();
|
|
}
|
|
|
|
if (!newValue)
|
|
break;
|
|
}
|
|
|
|
if (gangOperands.empty())
|
|
return parser.emitError(
|
|
parser.getCurrentLocation(),
|
|
"expect at least one of num, dim or static values");
|
|
|
|
if (failed(parser.parseRBrace()))
|
|
return failure();
|
|
|
|
if (succeeded(parser.parseOptionalLSquare())) {
|
|
if (parser.parseAttribute(deviceTypeAttributes.emplace_back()) ||
|
|
parser.parseRSquare())
|
|
return failure();
|
|
} else {
|
|
deviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
|
|
parser.getContext(), mlir::acc::DeviceType::None));
|
|
}
|
|
|
|
seg.push_back(gangOperands.size() - crtOperandsSize);
|
|
|
|
} while (succeeded(parser.parseOptionalComma()));
|
|
|
|
if (failed(parser.parseRParen()))
|
|
return failure();
|
|
|
|
llvm::SmallVector<mlir::Attribute> arrayAttr(gangArgTypeAttributes.begin(),
|
|
gangArgTypeAttributes.end());
|
|
gangArgType = ArrayAttr::get(parser.getContext(), arrayAttr);
|
|
deviceType = ArrayAttr::get(parser.getContext(), deviceTypeAttributes);
|
|
|
|
llvm::SmallVector<mlir::Attribute> gangOnlyAttr(
|
|
gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end());
|
|
gangOnlyDeviceType = ArrayAttr::get(parser.getContext(), gangOnlyAttr);
|
|
|
|
segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
|
|
return success();
|
|
}
|
|
|
|
void printGangClause(OpAsmPrinter &p, Operation *op,
|
|
mlir::OperandRange operands, mlir::TypeRange types,
|
|
std::optional<mlir::ArrayAttr> gangArgTypes,
|
|
std::optional<mlir::ArrayAttr> deviceTypes,
|
|
std::optional<mlir::DenseI32ArrayAttr> segments,
|
|
std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
|
|
|
|
if (operands.begin() == operands.end() &&
|
|
hasOnlyDeviceTypeNone(gangOnlyDeviceTypes)) {
|
|
return;
|
|
}
|
|
|
|
p << "(";
|
|
|
|
printDeviceTypes(p, gangOnlyDeviceTypes);
|
|
|
|
if (hasDeviceTypeValues(gangOnlyDeviceTypes) &&
|
|
hasDeviceTypeValues(deviceTypes))
|
|
p << ", ";
|
|
|
|
if (hasDeviceTypeValues(deviceTypes)) {
|
|
unsigned opIdx = 0;
|
|
llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
|
|
p << "{";
|
|
llvm::interleaveComma(
|
|
llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
|
|
auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
|
|
(*gangArgTypes)[opIdx]);
|
|
if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num)
|
|
p << LoopOp::getGangNumKeyword();
|
|
else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim)
|
|
p << LoopOp::getGangDimKeyword();
|
|
else if (gangArgTypeAttr.getValue() ==
|
|
mlir::acc::GangArgType::Static)
|
|
p << LoopOp::getGangStaticKeyword();
|
|
p << "=" << operands[opIdx] << " : " << operands[opIdx].getType();
|
|
++opIdx;
|
|
});
|
|
p << "}";
|
|
printSingleDeviceType(p, it.value());
|
|
});
|
|
}
|
|
p << ")";
|
|
}
|
|
|
|
bool hasDuplicateDeviceTypes(
|
|
std::optional<mlir::ArrayAttr> segments,
|
|
llvm::SmallSet<mlir::acc::DeviceType, 3> &deviceTypes) {
|
|
if (!segments)
|
|
return false;
|
|
for (auto attr : *segments) {
|
|
auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
|
|
if (!deviceTypes.insert(deviceTypeAttr.getValue()).second)
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
/// Check for duplicates in the DeviceType array attribute.
|
|
LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes) {
|
|
llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes;
|
|
if (!deviceTypes)
|
|
return success();
|
|
for (auto attr : deviceTypes) {
|
|
auto deviceTypeAttr =
|
|
mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr);
|
|
if (!deviceTypeAttr)
|
|
return failure();
|
|
if (!crtDeviceTypes.insert(deviceTypeAttr.getValue()).second)
|
|
return failure();
|
|
}
|
|
return success();
|
|
}
|
|
|
|
LogicalResult acc::LoopOp::verify() {
|
|
if (getUpperbound().size() != getStep().size())
|
|
return emitError() << "number of upperbounds expected to be the same as "
|
|
"number of steps";
|
|
|
|
if (getUpperbound().size() != getLowerbound().size())
|
|
return emitError() << "number of upperbounds expected to be the same as "
|
|
"number of lowerbounds";
|
|
|
|
if (!getUpperbound().empty() && getInclusiveUpperbound() &&
|
|
(getUpperbound().size() != getInclusiveUpperbound()->size()))
|
|
return emitError() << "inclusiveUpperbound size is expected to be the same"
|
|
<< " as upperbound size";
|
|
|
|
// Check collapse
|
|
if (getCollapseAttr() && !getCollapseDeviceTypeAttr())
|
|
return emitOpError() << "collapse device_type attr must be define when"
|
|
<< " collapse attr is present";
|
|
|
|
if (getCollapseAttr() && getCollapseDeviceTypeAttr() &&
|
|
getCollapseAttr().getValue().size() !=
|
|
getCollapseDeviceTypeAttr().getValue().size())
|
|
return emitOpError() << "collapse attribute count must match collapse"
|
|
<< " device_type count";
|
|
if (failed(checkDeviceTypes(getCollapseDeviceTypeAttr())))
|
|
return emitOpError()
|
|
<< "duplicate device_type found in collapseDeviceType attribute";
|
|
|
|
// Check gang
|
|
if (!getGangOperands().empty()) {
|
|
if (!getGangOperandsArgType())
|
|
return emitOpError() << "gangOperandsArgType attribute must be defined"
|
|
<< " when gang operands are present";
|
|
|
|
if (getGangOperands().size() !=
|
|
getGangOperandsArgTypeAttr().getValue().size())
|
|
return emitOpError() << "gangOperandsArgType attribute count must match"
|
|
<< " gangOperands count";
|
|
}
|
|
if (getGangAttr() && failed(checkDeviceTypes(getGangAttr())))
|
|
return emitOpError() << "duplicate device_type found in gang attribute";
|
|
|
|
if (failed(verifyDeviceTypeAndSegmentCountMatch(
|
|
*this, getGangOperands(), getGangOperandsSegmentsAttr(),
|
|
getGangOperandsDeviceTypeAttr(), "gang")))
|
|
return failure();
|
|
|
|
// Check worker
|
|
if (failed(checkDeviceTypes(getWorkerAttr())))
|
|
return emitOpError() << "duplicate device_type found in worker attribute";
|
|
if (failed(checkDeviceTypes(getWorkerNumOperandsDeviceTypeAttr())))
|
|
return emitOpError() << "duplicate device_type found in "
|
|
"workerNumOperandsDeviceType attribute";
|
|
if (failed(verifyDeviceTypeCountMatch(*this, getWorkerNumOperands(),
|
|
getWorkerNumOperandsDeviceTypeAttr(),
|
|
"worker")))
|
|
return failure();
|
|
|
|
// Check vector
|
|
if (failed(checkDeviceTypes(getVectorAttr())))
|
|
return emitOpError() << "duplicate device_type found in vector attribute";
|
|
if (failed(checkDeviceTypes(getVectorOperandsDeviceTypeAttr())))
|
|
return emitOpError() << "duplicate device_type found in "
|
|
"vectorOperandsDeviceType attribute";
|
|
if (failed(verifyDeviceTypeCountMatch(*this, getVectorOperands(),
|
|
getVectorOperandsDeviceTypeAttr(),
|
|
"vector")))
|
|
return failure();
|
|
|
|
if (failed(verifyDeviceTypeAndSegmentCountMatch(
|
|
*this, getTileOperands(), getTileOperandsSegmentsAttr(),
|
|
getTileOperandsDeviceTypeAttr(), "tile")))
|
|
return failure();
|
|
|
|
// auto, independent and seq attribute are mutually exclusive.
|
|
llvm::SmallSet<mlir::acc::DeviceType, 3> deviceTypes;
|
|
if (hasDuplicateDeviceTypes(getAuto_(), deviceTypes) ||
|
|
hasDuplicateDeviceTypes(getIndependent(), deviceTypes) ||
|
|
hasDuplicateDeviceTypes(getSeq(), deviceTypes)) {
|
|
return emitError() << "only one of auto, independent, seq can be present "
|
|
"at the same time";
|
|
}
|
|
|
|
// Check that at least one of auto, independent, or seq is present
|
|
// for the device-independent default clauses.
|
|
auto hasDeviceNone = [](mlir::acc::DeviceTypeAttr attr) -> bool {
|
|
return attr.getValue() == mlir::acc::DeviceType::None;
|
|
};
|
|
bool hasDefaultSeq =
|
|
getSeqAttr()
|
|
? llvm::any_of(getSeqAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
|
|
hasDeviceNone)
|
|
: false;
|
|
bool hasDefaultIndependent =
|
|
getIndependentAttr()
|
|
? llvm::any_of(
|
|
getIndependentAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
|
|
hasDeviceNone)
|
|
: false;
|
|
bool hasDefaultAuto =
|
|
getAuto_Attr()
|
|
? llvm::any_of(getAuto_Attr().getAsRange<mlir::acc::DeviceTypeAttr>(),
|
|
hasDeviceNone)
|
|
: false;
|
|
if (!hasDefaultSeq && !hasDefaultIndependent && !hasDefaultAuto) {
|
|
return emitError()
|
|
<< "at least one of auto, independent, seq must be present";
|
|
}
|
|
|
|
// Gang, worker and vector are incompatible with seq.
|
|
if (getSeqAttr()) {
|
|
for (auto attr : getSeqAttr()) {
|
|
auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
|
|
if (hasVector(deviceTypeAttr.getValue()) ||
|
|
getVectorValue(deviceTypeAttr.getValue()) ||
|
|
hasWorker(deviceTypeAttr.getValue()) ||
|
|
getWorkerValue(deviceTypeAttr.getValue()) ||
|
|
hasGang(deviceTypeAttr.getValue()) ||
|
|
getGangValue(mlir::acc::GangArgType::Num,
|
|
deviceTypeAttr.getValue()) ||
|
|
getGangValue(mlir::acc::GangArgType::Dim,
|
|
deviceTypeAttr.getValue()) ||
|
|
getGangValue(mlir::acc::GangArgType::Static,
|
|
deviceTypeAttr.getValue()))
|
|
return emitError() << "gang, worker or vector cannot appear with seq";
|
|
}
|
|
}
|
|
|
|
if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
|
|
*this, getPrivatizationRecipes(), getPrivateOperands(), "private",
|
|
"privatizations", false)))
|
|
return failure();
|
|
|
|
if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
|
|
*this, getReductionRecipes(), getReductionOperands(), "reduction",
|
|
"reductions", false)))
|
|
return failure();
|
|
|
|
if (getCombined().has_value() &&
|
|
(getCombined().value() != acc::CombinedConstructsType::ParallelLoop &&
|
|
getCombined().value() != acc::CombinedConstructsType::KernelsLoop &&
|
|
getCombined().value() != acc::CombinedConstructsType::SerialLoop)) {
|
|
return emitError("unexpected combined constructs attribute");
|
|
}
|
|
|
|
// Check non-empty body().
|
|
if (getRegion().empty())
|
|
return emitError("expected non-empty body.");
|
|
|
|
// When it is container-like - it is expected to hold a loop-like operation.
|
|
if (isContainerLike()) {
|
|
// Obtain the maximum collapse count - we use this to check that there
|
|
// are enough loops contained.
|
|
uint64_t collapseCount = getCollapseValue().value_or(1);
|
|
if (getCollapseAttr()) {
|
|
for (auto collapseEntry : getCollapseAttr()) {
|
|
auto intAttr = mlir::dyn_cast<IntegerAttr>(collapseEntry);
|
|
if (intAttr.getValue().getZExtValue() > collapseCount)
|
|
collapseCount = intAttr.getValue().getZExtValue();
|
|
}
|
|
}
|
|
|
|
// We want to check that we find enough loop-like operations inside.
|
|
// PreOrder walk allows us to walk in a breadth-first manner at each nesting
|
|
// level.
|
|
mlir::Operation *expectedParent = this->getOperation();
|
|
bool foundSibling = false;
|
|
getRegion().walk<WalkOrder::PreOrder>([&](mlir::Operation *op) {
|
|
if (mlir::isa<mlir::LoopLikeOpInterface>(op)) {
|
|
// This effectively checks that we are not looking at a sibling loop.
|
|
if (op->getParentOfType<mlir::LoopLikeOpInterface>() !=
|
|
expectedParent) {
|
|
foundSibling = true;
|
|
return mlir::WalkResult::interrupt();
|
|
}
|
|
|
|
collapseCount--;
|
|
expectedParent = op;
|
|
}
|
|
// We found enough contained loops.
|
|
if (collapseCount == 0)
|
|
return mlir::WalkResult::interrupt();
|
|
return mlir::WalkResult::advance();
|
|
});
|
|
|
|
if (foundSibling)
|
|
return emitError("found sibling loops inside container-like acc.loop");
|
|
if (collapseCount != 0)
|
|
return emitError("failed to find enough loop-like operations inside "
|
|
"container-like acc.loop");
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
unsigned LoopOp::getNumDataOperands() {
|
|
return getReductionOperands().size() + getPrivateOperands().size();
|
|
}
|
|
|
|
Value LoopOp::getDataOperand(unsigned i) {
|
|
unsigned numOptional =
|
|
getLowerbound().size() + getUpperbound().size() + getStep().size();
|
|
numOptional += getGangOperands().size();
|
|
numOptional += getVectorOperands().size();
|
|
numOptional += getWorkerNumOperands().size();
|
|
numOptional += getTileOperands().size();
|
|
numOptional += getCacheOperands().size();
|
|
return getOperand(numOptional + i);
|
|
}
|
|
|
|
bool LoopOp::hasAuto() { return hasAuto(mlir::acc::DeviceType::None); }
|
|
|
|
bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) {
|
|
return hasDeviceType(getAuto_(), deviceType);
|
|
}
|
|
|
|
bool LoopOp::hasIndependent() {
|
|
return hasIndependent(mlir::acc::DeviceType::None);
|
|
}
|
|
|
|
bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) {
|
|
return hasDeviceType(getIndependent(), deviceType);
|
|
}
|
|
|
|
bool LoopOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
|
|
|
|
bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) {
|
|
return hasDeviceType(getSeq(), deviceType);
|
|
}
|
|
|
|
mlir::Value LoopOp::getVectorValue() {
|
|
return getVectorValue(mlir::acc::DeviceType::None);
|
|
}
|
|
|
|
mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
|
|
return getValueInDeviceTypeSegment(getVectorOperandsDeviceType(),
|
|
getVectorOperands(), deviceType);
|
|
}
|
|
|
|
bool LoopOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
|
|
|
|
bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) {
|
|
return hasDeviceType(getVector(), deviceType);
|
|
}
|
|
|
|
mlir::Value LoopOp::getWorkerValue() {
|
|
return getWorkerValue(mlir::acc::DeviceType::None);
|
|
}
|
|
|
|
mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
|
|
return getValueInDeviceTypeSegment(getWorkerNumOperandsDeviceType(),
|
|
getWorkerNumOperands(), deviceType);
|
|
}
|
|
|
|
bool LoopOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
|
|
|
|
bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) {
|
|
return hasDeviceType(getWorker(), deviceType);
|
|
}
|
|
|
|
mlir::Operation::operand_range LoopOp::getTileValues() {
|
|
return getTileValues(mlir::acc::DeviceType::None);
|
|
}
|
|
|
|
mlir::Operation::operand_range
|
|
LoopOp::getTileValues(mlir::acc::DeviceType deviceType) {
|
|
return getValuesFromSegments(getTileOperandsDeviceType(), getTileOperands(),
|
|
getTileOperandsSegments(), deviceType);
|
|
}
|
|
|
|
std::optional<int64_t> LoopOp::getCollapseValue() {
|
|
return getCollapseValue(mlir::acc::DeviceType::None);
|
|
}
|
|
|
|
std::optional<int64_t>
|
|
LoopOp::getCollapseValue(mlir::acc::DeviceType deviceType) {
|
|
if (!getCollapseAttr())
|
|
return std::nullopt;
|
|
if (auto pos = findSegment(getCollapseDeviceTypeAttr(), deviceType)) {
|
|
auto intAttr =
|
|
mlir::dyn_cast<IntegerAttr>(getCollapseAttr().getValue()[*pos]);
|
|
return intAttr.getValue().getZExtValue();
|
|
}
|
|
return std::nullopt;
|
|
}
|
|
|
|
mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType) {
|
|
return getGangValue(gangArgType, mlir::acc::DeviceType::None);
|
|
}
|
|
|
|
mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
|
|
mlir::acc::DeviceType deviceType) {
|
|
if (getGangOperands().empty())
|
|
return {};
|
|
if (auto pos = findSegment(*getGangOperandsDeviceType(), deviceType)) {
|
|
int32_t nbOperandsBefore = 0;
|
|
for (unsigned i = 0; i < *pos; ++i)
|
|
nbOperandsBefore += (*getGangOperandsSegments())[i];
|
|
mlir::Operation::operand_range values =
|
|
getGangOperands()
|
|
.drop_front(nbOperandsBefore)
|
|
.take_front((*getGangOperandsSegments())[*pos]);
|
|
|
|
int32_t argTypeIdx = nbOperandsBefore;
|
|
for (auto value : values) {
|
|
auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
|
|
(*getGangOperandsArgType())[argTypeIdx]);
|
|
if (gangArgTypeAttr.getValue() == gangArgType)
|
|
return value;
|
|
++argTypeIdx;
|
|
}
|
|
}
|
|
return {};
|
|
}
|
|
|
|
bool LoopOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
|
|
|
|
bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
|
|
return hasDeviceType(getGang(), deviceType);
|
|
}
|
|
|
|
llvm::SmallVector<mlir::Region *> acc::LoopOp::getLoopRegions() {
|
|
return {&getRegion()};
|
|
}
|
|
|
|
/// loop-control ::= `control` `(` ssa-id-and-type-list `)` `=`
|
|
/// `(` ssa-id-and-type-list `)` `to` `(` ssa-id-and-type-list `)` `step`
|
|
/// `(` ssa-id-and-type-list `)`
|
|
/// region
|
|
ParseResult
|
|
parseLoopControl(OpAsmParser &parser, Region ®ion,
|
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &lowerbound,
|
|
SmallVectorImpl<Type> &lowerboundType,
|
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &upperbound,
|
|
SmallVectorImpl<Type> &upperboundType,
|
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &step,
|
|
SmallVectorImpl<Type> &stepType) {
|
|
|
|
SmallVector<OpAsmParser::Argument> inductionVars;
|
|
if (succeeded(
|
|
parser.parseOptionalKeyword(acc::LoopOp::getControlKeyword()))) {
|
|
if (parser.parseLParen() ||
|
|
parser.parseArgumentList(inductionVars, OpAsmParser::Delimiter::None,
|
|
/*allowType=*/true) ||
|
|
parser.parseRParen() || parser.parseEqual() || parser.parseLParen() ||
|
|
parser.parseOperandList(lowerbound, inductionVars.size(),
|
|
OpAsmParser::Delimiter::None) ||
|
|
parser.parseColonTypeList(lowerboundType) || parser.parseRParen() ||
|
|
parser.parseKeyword("to") || parser.parseLParen() ||
|
|
parser.parseOperandList(upperbound, inductionVars.size(),
|
|
OpAsmParser::Delimiter::None) ||
|
|
parser.parseColonTypeList(upperboundType) || parser.parseRParen() ||
|
|
parser.parseKeyword("step") || parser.parseLParen() ||
|
|
parser.parseOperandList(step, inductionVars.size(),
|
|
OpAsmParser::Delimiter::None) ||
|
|
parser.parseColonTypeList(stepType) || parser.parseRParen())
|
|
return failure();
|
|
}
|
|
return parser.parseRegion(region, inductionVars);
|
|
}
|
|
|
|
void printLoopControl(OpAsmPrinter &p, Operation *op, Region ®ion,
|
|
ValueRange lowerbound, TypeRange lowerboundType,
|
|
ValueRange upperbound, TypeRange upperboundType,
|
|
ValueRange steps, TypeRange stepType) {
|
|
ValueRange regionArgs = region.front().getArguments();
|
|
if (!regionArgs.empty()) {
|
|
p << acc::LoopOp::getControlKeyword() << "(";
|
|
llvm::interleaveComma(regionArgs, p,
|
|
[&p](Value v) { p << v << " : " << v.getType(); });
|
|
p << ") = (" << lowerbound << " : " << lowerboundType << ") to ("
|
|
<< upperbound << " : " << upperboundType << ") " << " step (" << steps
|
|
<< " : " << stepType << ") ";
|
|
}
|
|
p.printRegion(region, /*printEntryBlockArgs=*/false);
|
|
}
|
|
|
|
void acc::LoopOp::addSeq(MLIRContext *context,
|
|
llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
|
|
setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
|
|
effectiveDeviceTypes));
|
|
}
|
|
|
|
void acc::LoopOp::addIndependent(
|
|
MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
|
|
setIndependentAttr(addDeviceTypeAffectedOperandHelper(
|
|
context, getIndependentAttr(), effectiveDeviceTypes));
|
|
}
|
|
|
|
void acc::LoopOp::addAuto(MLIRContext *context,
|
|
llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
|
|
setAuto_Attr(addDeviceTypeAffectedOperandHelper(context, getAuto_Attr(),
|
|
effectiveDeviceTypes));
|
|
}
|
|
|
|
void acc::LoopOp::setCollapseForDeviceTypes(
|
|
MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
|
|
llvm::APInt value) {
|
|
llvm::SmallVector<mlir::Attribute> newValues;
|
|
llvm::SmallVector<mlir::Attribute> newDeviceTypes;
|
|
|
|
assert((getCollapseAttr() == nullptr) ==
|
|
(getCollapseDeviceTypeAttr() == nullptr));
|
|
assert(value.getBitWidth() == 64);
|
|
|
|
if (getCollapseAttr()) {
|
|
for (const auto &existing :
|
|
llvm::zip_equal(getCollapseAttr(), getCollapseDeviceTypeAttr())) {
|
|
newValues.push_back(std::get<0>(existing));
|
|
newDeviceTypes.push_back(std::get<1>(existing));
|
|
}
|
|
}
|
|
|
|
if (effectiveDeviceTypes.empty()) {
|
|
// If the effective device-types list is empty, this is before there are any
|
|
// being applied by device_type, so this should be added as a 'none'.
|
|
newValues.push_back(
|
|
mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
|
|
newDeviceTypes.push_back(
|
|
acc::DeviceTypeAttr::get(context, DeviceType::None));
|
|
} else {
|
|
for (DeviceType DT : effectiveDeviceTypes) {
|
|
newValues.push_back(
|
|
mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
|
|
newDeviceTypes.push_back(acc::DeviceTypeAttr::get(context, DT));
|
|
}
|
|
}
|
|
|
|
setCollapseAttr(ArrayAttr::get(context, newValues));
|
|
setCollapseDeviceTypeAttr(ArrayAttr::get(context, newDeviceTypes));
|
|
}
|
|
|
|
void acc::LoopOp::setTileForDeviceTypes(
|
|
MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
|
|
ValueRange values) {
|
|
llvm::SmallVector<int32_t> segments;
|
|
if (getTileOperandsSegments())
|
|
llvm::copy(*getTileOperandsSegments(), std::back_inserter(segments));
|
|
|
|
setTileOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
|
|
context, getTileOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
|
|
getTileOperandsMutable(), segments));
|
|
|
|
setTileOperandsSegments(segments);
|
|
}
|
|
|
|
void acc::LoopOp::addVectorOperand(
|
|
MLIRContext *context, mlir::Value newValue,
|
|
llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
|
|
setVectorOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
|
|
context, getVectorOperandsDeviceTypeAttr(), effectiveDeviceTypes,
|
|
newValue, getVectorOperandsMutable()));
|
|
}
|
|
|
|
void acc::LoopOp::addEmptyVector(
|
|
MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
|
|
setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
|
|
effectiveDeviceTypes));
|
|
}
|
|
|
|
void acc::LoopOp::addWorkerNumOperand(
|
|
MLIRContext *context, mlir::Value newValue,
|
|
llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
|
|
setWorkerNumOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
|
|
context, getWorkerNumOperandsDeviceTypeAttr(), effectiveDeviceTypes,
|
|
newValue, getWorkerNumOperandsMutable()));
|
|
}
|
|
|
|
void acc::LoopOp::addEmptyWorker(
|
|
MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
|
|
setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
|
|
effectiveDeviceTypes));
|
|
}
|
|
|
|
void acc::LoopOp::addEmptyGang(
|
|
MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
|
|
setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
|
|
effectiveDeviceTypes));
|
|
}
|
|
|
|
bool acc::LoopOp::hasParallelismFlag(DeviceType dt) {
|
|
auto hasDevice = [=](DeviceTypeAttr attr) -> bool {
|
|
return attr.getValue() == dt;
|
|
};
|
|
auto testFromArr = [=](ArrayAttr arr) -> bool {
|
|
return llvm::any_of(arr.getAsRange<DeviceTypeAttr>(), hasDevice);
|
|
};
|
|
|
|
if (ArrayAttr arr = getSeqAttr(); arr && testFromArr(arr))
|
|
return true;
|
|
if (ArrayAttr arr = getIndependentAttr(); arr && testFromArr(arr))
|
|
return true;
|
|
if (ArrayAttr arr = getAuto_Attr(); arr && testFromArr(arr))
|
|
return true;
|
|
|
|
return false;
|
|
}
|
|
|
|
bool acc::LoopOp::hasDefaultGangWorkerVector() {
|
|
return hasVector() || getVectorValue() || hasWorker() || getWorkerValue() ||
|
|
hasGang() || getGangValue(GangArgType::Num) ||
|
|
getGangValue(GangArgType::Dim) || getGangValue(GangArgType::Static);
|
|
}
|
|
|
|
void acc::LoopOp::addGangOperands(
|
|
MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
|
|
llvm::ArrayRef<GangArgType> argTypes, mlir::ValueRange values) {
|
|
llvm::SmallVector<int32_t> segments;
|
|
if (std::optional<ArrayRef<int32_t>> existingSegments =
|
|
getGangOperandsSegments())
|
|
llvm::copy(*existingSegments, std::back_inserter(segments));
|
|
|
|
unsigned beforeCount = segments.size();
|
|
|
|
setGangOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
|
|
context, getGangOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
|
|
getGangOperandsMutable(), segments));
|
|
|
|
setGangOperandsSegments(segments);
|
|
|
|
// This is a bit of extra work to make sure we update the 'types' correctly by
|
|
// adding to the types collection the correct number of times. We could
|
|
// potentially add something similar to the
|
|
// addDeviceTypeAffectedOperandHelper, but it seems that would be pretty
|
|
// excessive for a one-off case.
|
|
unsigned numAdded = segments.size() - beforeCount;
|
|
|
|
if (numAdded > 0) {
|
|
llvm::SmallVector<mlir::Attribute> gangTypes;
|
|
if (getGangOperandsArgTypeAttr())
|
|
llvm::copy(getGangOperandsArgTypeAttr(), std::back_inserter(gangTypes));
|
|
|
|
for (auto i : llvm::index_range(0u, numAdded)) {
|
|
llvm::transform(argTypes, std::back_inserter(gangTypes),
|
|
[=](mlir::acc::GangArgType gangTy) {
|
|
return mlir::acc::GangArgTypeAttr::get(context, gangTy);
|
|
});
|
|
(void)i;
|
|
}
|
|
|
|
setGangOperandsArgTypeAttr(mlir::ArrayAttr::get(context, gangTypes));
|
|
}
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// DataOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult acc::DataOp::verify() {
|
|
// 2.6.5. Data Construct restriction
|
|
// At least one copy, copyin, copyout, create, no_create, present, deviceptr,
|
|
// attach, or default clause must appear on a data construct.
|
|
if (getOperands().empty() && !getDefaultAttr())
|
|
return emitError("at least one operand or the default attribute "
|
|
"must appear on the data operation");
|
|
|
|
for (mlir::Value operand : getDataClauseOperands())
|
|
if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
|
|
acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
|
|
acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
|
|
operand.getDefiningOp()))
|
|
return emitError("expect data entry/exit operation or acc.getdeviceptr "
|
|
"as defining op");
|
|
|
|
if (failed(checkWaitAndAsyncConflict<acc::DataOp>(*this)))
|
|
return failure();
|
|
|
|
return success();
|
|
}
|
|
|
|
unsigned DataOp::getNumDataOperands() { return getDataClauseOperands().size(); }
|
|
|
|
Value DataOp::getDataOperand(unsigned i) {
|
|
unsigned numOptional = getIfCond() ? 1 : 0;
|
|
numOptional += getAsyncOperands().size() ? 1 : 0;
|
|
numOptional += getWaitOperands().size();
|
|
return getOperand(numOptional + i);
|
|
}
|
|
|
|
bool acc::DataOp::hasAsyncOnly() {
|
|
return hasAsyncOnly(mlir::acc::DeviceType::None);
|
|
}
|
|
|
|
bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
|
|
return hasDeviceType(getAsyncOnly(), deviceType);
|
|
}
|
|
|
|
mlir::Value DataOp::getAsyncValue() {
|
|
return getAsyncValue(mlir::acc::DeviceType::None);
|
|
}
|
|
|
|
mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
|
|
return getValueInDeviceTypeSegment(getAsyncOperandsDeviceType(),
|
|
getAsyncOperands(), deviceType);
|
|
}
|
|
|
|
bool DataOp::hasWaitOnly() { return hasWaitOnly(mlir::acc::DeviceType::None); }
|
|
|
|
bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
|
|
return hasDeviceType(getWaitOnly(), deviceType);
|
|
}
|
|
|
|
mlir::Operation::operand_range DataOp::getWaitValues() {
|
|
return getWaitValues(mlir::acc::DeviceType::None);
|
|
}
|
|
|
|
mlir::Operation::operand_range
|
|
DataOp::getWaitValues(mlir::acc::DeviceType deviceType) {
|
|
return getWaitValuesWithoutDevnum(
|
|
getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
|
|
getHasWaitDevnum(), deviceType);
|
|
}
|
|
|
|
mlir::Value DataOp::getWaitDevnum() {
|
|
return getWaitDevnum(mlir::acc::DeviceType::None);
|
|
}
|
|
|
|
mlir::Value DataOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
|
|
return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
|
|
getWaitOperandsSegments(), getHasWaitDevnum(),
|
|
deviceType);
|
|
}
|
|
|
|
void acc::DataOp::addAsyncOnly(
|
|
MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
|
|
setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
|
|
context, getAsyncOnlyAttr(), effectiveDeviceTypes));
|
|
}
|
|
|
|
void acc::DataOp::addAsyncOperand(
|
|
MLIRContext *context, mlir::Value newValue,
|
|
llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
|
|
setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
|
|
context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
|
|
getAsyncOperandsMutable()));
|
|
}
|
|
|
|
void acc::DataOp::addWaitOnly(MLIRContext *context,
|
|
llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
|
|
setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
|
|
effectiveDeviceTypes));
|
|
}
|
|
|
|
void acc::DataOp::addWaitOperands(
|
|
MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
|
|
llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
|
|
|
|
llvm::SmallVector<int32_t> segments;
|
|
if (getWaitOperandsSegments())
|
|
llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
|
|
|
|
setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
|
|
context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
|
|
getWaitOperandsMutable(), segments));
|
|
setWaitOperandsSegments(segments);
|
|
|
|
llvm::SmallVector<mlir::Attribute> hasDevnums;
|
|
if (getHasWaitDevnumAttr())
|
|
llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
|
|
hasDevnums.insert(
|
|
hasDevnums.end(),
|
|
std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
|
|
mlir::BoolAttr::get(context, hasDevnum));
|
|
setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ExitDataOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult acc::ExitDataOp::verify() {
|
|
// 2.6.6. Data Exit Directive restriction
|
|
// At least one copyout, delete, or detach clause must appear on an exit data
|
|
// directive.
|
|
if (getDataClauseOperands().empty())
|
|
return emitError("at least one operand must be present in dataOperands on "
|
|
"the exit data operation");
|
|
|
|
// The async attribute represent the async clause without value. Therefore the
|
|
// attribute and operand cannot appear at the same time.
|
|
if (getAsyncOperand() && getAsync())
|
|
return emitError("async attribute cannot appear with asyncOperand");
|
|
|
|
// The wait attribute represent the wait clause without values. Therefore the
|
|
// attribute and operands cannot appear at the same time.
|
|
if (!getWaitOperands().empty() && getWait())
|
|
return emitError("wait attribute cannot appear with waitOperands");
|
|
|
|
if (getWaitDevnum() && getWaitOperands().empty())
|
|
return emitError("wait_devnum cannot appear without waitOperands");
|
|
|
|
return success();
|
|
}
|
|
|
|
unsigned ExitDataOp::getNumDataOperands() {
|
|
return getDataClauseOperands().size();
|
|
}
|
|
|
|
Value ExitDataOp::getDataOperand(unsigned i) {
|
|
unsigned numOptional = getIfCond() ? 1 : 0;
|
|
numOptional += getAsyncOperand() ? 1 : 0;
|
|
numOptional += getWaitDevnum() ? 1 : 0;
|
|
return getOperand(getWaitOperands().size() + numOptional + i);
|
|
}
|
|
|
|
void ExitDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
MLIRContext *context) {
|
|
results.add<RemoveConstantIfCondition<ExitDataOp>>(context);
|
|
}
|
|
|
|
void ExitDataOp::addAsyncOnly(MLIRContext *context,
|
|
llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
|
|
assert(effectiveDeviceTypes.empty());
|
|
assert(!getAsyncAttr());
|
|
assert(!getAsyncOperand());
|
|
|
|
setAsyncAttr(mlir::UnitAttr::get(context));
|
|
}
|
|
|
|
void ExitDataOp::addAsyncOperand(
|
|
MLIRContext *context, mlir::Value newValue,
|
|
llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
|
|
assert(effectiveDeviceTypes.empty());
|
|
assert(!getAsyncAttr());
|
|
assert(!getAsyncOperand());
|
|
|
|
getAsyncOperandMutable().append(newValue);
|
|
}
|
|
|
|
void ExitDataOp::addWaitOnly(MLIRContext *context,
|
|
llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
|
|
assert(effectiveDeviceTypes.empty());
|
|
assert(!getWaitAttr());
|
|
assert(getWaitOperands().empty());
|
|
assert(!getWaitDevnum());
|
|
|
|
setWaitAttr(mlir::UnitAttr::get(context));
|
|
}
|
|
|
|
void ExitDataOp::addWaitOperands(
|
|
MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
|
|
llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
|
|
assert(effectiveDeviceTypes.empty());
|
|
assert(!getWaitAttr());
|
|
assert(getWaitOperands().empty());
|
|
assert(!getWaitDevnum());
|
|
|
|
// if hasDevnum, the first value is the devnum. The 'rest' go into the
|
|
// operands list.
|
|
if (hasDevnum) {
|
|
getWaitDevnumMutable().append(newValues.front());
|
|
newValues = newValues.drop_front();
|
|
}
|
|
|
|
getWaitOperandsMutable().append(newValues);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// EnterDataOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult acc::EnterDataOp::verify() {
|
|
// 2.6.6. Data Enter Directive restriction
|
|
// At least one copyin, create, or attach clause must appear on an enter data
|
|
// directive.
|
|
if (getDataClauseOperands().empty())
|
|
return emitError("at least one operand must be present in dataOperands on "
|
|
"the enter data operation");
|
|
|
|
// The async attribute represent the async clause without value. Therefore the
|
|
// attribute and operand cannot appear at the same time.
|
|
if (getAsyncOperand() && getAsync())
|
|
return emitError("async attribute cannot appear with asyncOperand");
|
|
|
|
// The wait attribute represent the wait clause without values. Therefore the
|
|
// attribute and operands cannot appear at the same time.
|
|
if (!getWaitOperands().empty() && getWait())
|
|
return emitError("wait attribute cannot appear with waitOperands");
|
|
|
|
if (getWaitDevnum() && getWaitOperands().empty())
|
|
return emitError("wait_devnum cannot appear without waitOperands");
|
|
|
|
for (mlir::Value operand : getDataClauseOperands())
|
|
if (!mlir::isa<acc::AttachOp, acc::CreateOp, acc::CopyinOp>(
|
|
operand.getDefiningOp()))
|
|
return emitError("expect data entry operation as defining op");
|
|
|
|
return success();
|
|
}
|
|
|
|
unsigned EnterDataOp::getNumDataOperands() {
|
|
return getDataClauseOperands().size();
|
|
}
|
|
|
|
Value EnterDataOp::getDataOperand(unsigned i) {
|
|
unsigned numOptional = getIfCond() ? 1 : 0;
|
|
numOptional += getAsyncOperand() ? 1 : 0;
|
|
numOptional += getWaitDevnum() ? 1 : 0;
|
|
return getOperand(getWaitOperands().size() + numOptional + i);
|
|
}
|
|
|
|
void EnterDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
MLIRContext *context) {
|
|
results.add<RemoveConstantIfCondition<EnterDataOp>>(context);
|
|
}
|
|
|
|
void EnterDataOp::addAsyncOnly(
|
|
MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
|
|
assert(effectiveDeviceTypes.empty());
|
|
assert(!getAsyncAttr());
|
|
assert(!getAsyncOperand());
|
|
|
|
setAsyncAttr(mlir::UnitAttr::get(context));
|
|
}
|
|
|
|
void EnterDataOp::addAsyncOperand(
|
|
MLIRContext *context, mlir::Value newValue,
|
|
llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
|
|
assert(effectiveDeviceTypes.empty());
|
|
assert(!getAsyncAttr());
|
|
assert(!getAsyncOperand());
|
|
|
|
getAsyncOperandMutable().append(newValue);
|
|
}
|
|
|
|
void EnterDataOp::addWaitOnly(MLIRContext *context,
|
|
llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
|
|
assert(effectiveDeviceTypes.empty());
|
|
assert(!getWaitAttr());
|
|
assert(getWaitOperands().empty());
|
|
assert(!getWaitDevnum());
|
|
|
|
setWaitAttr(mlir::UnitAttr::get(context));
|
|
}
|
|
|
|
void EnterDataOp::addWaitOperands(
|
|
MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
|
|
llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
|
|
assert(effectiveDeviceTypes.empty());
|
|
assert(!getWaitAttr());
|
|
assert(getWaitOperands().empty());
|
|
assert(!getWaitDevnum());
|
|
|
|
// if hasDevnum, the first value is the devnum. The 'rest' go into the
|
|
// operands list.
|
|
if (hasDevnum) {
|
|
getWaitDevnumMutable().append(newValues.front());
|
|
newValues = newValues.drop_front();
|
|
}
|
|
|
|
getWaitOperandsMutable().append(newValues);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// AtomicReadOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult AtomicReadOp::verify() { return verifyCommon(); }
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// AtomicWriteOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult AtomicWriteOp::verify() { return verifyCommon(); }
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// AtomicUpdateOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
|
|
PatternRewriter &rewriter) {
|
|
if (op.isNoOp()) {
|
|
rewriter.eraseOp(op);
|
|
return success();
|
|
}
|
|
|
|
if (Value writeVal = op.getWriteOpVal()) {
|
|
rewriter.replaceOpWithNewOp<AtomicWriteOp>(op, op.getX(), writeVal);
|
|
return success();
|
|
}
|
|
|
|
return failure();
|
|
}
|
|
|
|
LogicalResult AtomicUpdateOp::verify() { return verifyCommon(); }
|
|
|
|
LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// AtomicCaptureOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
|
|
if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
|
|
return op;
|
|
return dyn_cast<AtomicReadOp>(getSecondOp());
|
|
}
|
|
|
|
AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
|
|
if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
|
|
return op;
|
|
return dyn_cast<AtomicWriteOp>(getSecondOp());
|
|
}
|
|
|
|
AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
|
|
if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
|
|
return op;
|
|
return dyn_cast<AtomicUpdateOp>(getSecondOp());
|
|
}
|
|
|
|
LogicalResult AtomicCaptureOp::verifyRegions() { return verifyRegionsCommon(); }
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// DeclareEnterOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
template <typename Op>
|
|
static LogicalResult
|
|
checkDeclareOperands(Op &op, const mlir::ValueRange &operands,
|
|
bool requireAtLeastOneOperand = true) {
|
|
if (operands.empty() && requireAtLeastOneOperand)
|
|
return emitError(
|
|
op->getLoc(),
|
|
"at least one operand must appear on the declare operation");
|
|
|
|
for (mlir::Value operand : operands) {
|
|
if (!mlir::isa<acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
|
|
acc::DevicePtrOp, acc::GetDevicePtrOp, acc::PresentOp,
|
|
acc::DeclareDeviceResidentOp, acc::DeclareLinkOp>(
|
|
operand.getDefiningOp()))
|
|
return op.emitError(
|
|
"expect valid declare data entry operation or acc.getdeviceptr "
|
|
"as defining op");
|
|
|
|
mlir::Value var{getVar(operand.getDefiningOp())};
|
|
assert(var && "declare operands can only be data entry operations which "
|
|
"must have var");
|
|
(void)var;
|
|
std::optional<mlir::acc::DataClause> dataClauseOptional{
|
|
getDataClause(operand.getDefiningOp())};
|
|
assert(dataClauseOptional.has_value() &&
|
|
"declare operands can only be data entry operations which must have "
|
|
"dataClause");
|
|
(void)dataClauseOptional;
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
LogicalResult acc::DeclareEnterOp::verify() {
|
|
return checkDeclareOperands(*this, this->getDataClauseOperands());
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// DeclareExitOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult acc::DeclareExitOp::verify() {
|
|
if (getToken())
|
|
return checkDeclareOperands(*this, this->getDataClauseOperands(),
|
|
/*requireAtLeastOneOperand=*/false);
|
|
return checkDeclareOperands(*this, this->getDataClauseOperands());
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// DeclareOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult acc::DeclareOp::verify() {
|
|
return checkDeclareOperands(*this, this->getDataClauseOperands());
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// RoutineOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static unsigned getParallelismForDeviceType(acc::RoutineOp op,
|
|
acc::DeviceType dtype) {
|
|
unsigned parallelism = 0;
|
|
parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0;
|
|
parallelism += op.hasWorker(dtype) ? 1 : 0;
|
|
parallelism += op.hasVector(dtype) ? 1 : 0;
|
|
parallelism += op.hasSeq(dtype) ? 1 : 0;
|
|
return parallelism;
|
|
}
|
|
|
|
LogicalResult acc::RoutineOp::verify() {
|
|
unsigned baseParallelism =
|
|
getParallelismForDeviceType(*this, acc::DeviceType::None);
|
|
|
|
if (baseParallelism > 1)
|
|
return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
|
|
"be present at the same time";
|
|
|
|
for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
|
|
++dtypeInt) {
|
|
auto dtype = static_cast<acc::DeviceType>(dtypeInt);
|
|
if (dtype == acc::DeviceType::None)
|
|
continue;
|
|
unsigned parallelism = getParallelismForDeviceType(*this, dtype);
|
|
|
|
if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
|
|
return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
|
|
"be present at the same time";
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindName,
|
|
mlir::ArrayAttr &deviceTypes) {
|
|
llvm::SmallVector<mlir::Attribute> bindNameAttrs;
|
|
llvm::SmallVector<mlir::Attribute> deviceTypeAttrs;
|
|
|
|
if (failed(parser.parseCommaSeparatedList([&]() {
|
|
if (parser.parseAttribute(bindNameAttrs.emplace_back()))
|
|
return failure();
|
|
if (failed(parser.parseOptionalLSquare())) {
|
|
deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
|
|
parser.getContext(), mlir::acc::DeviceType::None));
|
|
} else {
|
|
if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
|
|
parser.parseRSquare())
|
|
return failure();
|
|
}
|
|
return success();
|
|
})))
|
|
return failure();
|
|
|
|
bindName = ArrayAttr::get(parser.getContext(), bindNameAttrs);
|
|
deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
|
|
|
|
return success();
|
|
}
|
|
|
|
static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op,
|
|
std::optional<mlir::ArrayAttr> bindName,
|
|
std::optional<mlir::ArrayAttr> deviceTypes) {
|
|
llvm::interleaveComma(llvm::zip(*bindName, *deviceTypes), p,
|
|
[&](const auto &pair) {
|
|
p << std::get<0>(pair);
|
|
printSingleDeviceType(p, std::get<1>(pair));
|
|
});
|
|
}
|
|
|
|
static ParseResult parseRoutineGangClause(OpAsmParser &parser,
|
|
mlir::ArrayAttr &gang,
|
|
mlir::ArrayAttr &gangDim,
|
|
mlir::ArrayAttr &gangDimDeviceTypes) {
|
|
|
|
llvm::SmallVector<mlir::Attribute> gangAttrs, gangDimAttrs,
|
|
gangDimDeviceTypeAttrs;
|
|
bool needCommaBeforeOperands = false;
|
|
|
|
// Gang keyword only
|
|
if (failed(parser.parseOptionalLParen())) {
|
|
gangAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
|
|
parser.getContext(), mlir::acc::DeviceType::None));
|
|
gang = ArrayAttr::get(parser.getContext(), gangAttrs);
|
|
return success();
|
|
}
|
|
|
|
// Parse keyword only attributes
|
|
if (succeeded(parser.parseOptionalLSquare())) {
|
|
if (failed(parser.parseCommaSeparatedList([&]() {
|
|
if (parser.parseAttribute(gangAttrs.emplace_back()))
|
|
return failure();
|
|
return success();
|
|
})))
|
|
return failure();
|
|
if (parser.parseRSquare())
|
|
return failure();
|
|
needCommaBeforeOperands = true;
|
|
}
|
|
|
|
if (needCommaBeforeOperands && failed(parser.parseComma()))
|
|
return failure();
|
|
|
|
if (failed(parser.parseCommaSeparatedList([&]() {
|
|
if (parser.parseKeyword(acc::RoutineOp::getGangDimKeyword()) ||
|
|
parser.parseColon() ||
|
|
parser.parseAttribute(gangDimAttrs.emplace_back()))
|
|
return failure();
|
|
if (succeeded(parser.parseOptionalLSquare())) {
|
|
if (parser.parseAttribute(gangDimDeviceTypeAttrs.emplace_back()) ||
|
|
parser.parseRSquare())
|
|
return failure();
|
|
} else {
|
|
gangDimDeviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
|
|
parser.getContext(), mlir::acc::DeviceType::None));
|
|
}
|
|
return success();
|
|
})))
|
|
return failure();
|
|
|
|
if (failed(parser.parseRParen()))
|
|
return failure();
|
|
|
|
gang = ArrayAttr::get(parser.getContext(), gangAttrs);
|
|
gangDim = ArrayAttr::get(parser.getContext(), gangDimAttrs);
|
|
gangDimDeviceTypes =
|
|
ArrayAttr::get(parser.getContext(), gangDimDeviceTypeAttrs);
|
|
|
|
return success();
|
|
}
|
|
|
|
void printRoutineGangClause(OpAsmPrinter &p, Operation *op,
|
|
std::optional<mlir::ArrayAttr> gang,
|
|
std::optional<mlir::ArrayAttr> gangDim,
|
|
std::optional<mlir::ArrayAttr> gangDimDeviceTypes) {
|
|
|
|
if (!hasDeviceTypeValues(gangDimDeviceTypes) && hasDeviceTypeValues(gang) &&
|
|
gang->size() == 1) {
|
|
auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gang)[0]);
|
|
if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
|
|
return;
|
|
}
|
|
|
|
p << "(";
|
|
|
|
printDeviceTypes(p, gang);
|
|
|
|
if (hasDeviceTypeValues(gang) && hasDeviceTypeValues(gangDimDeviceTypes))
|
|
p << ", ";
|
|
|
|
if (hasDeviceTypeValues(gangDimDeviceTypes))
|
|
llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p,
|
|
[&](const auto &pair) {
|
|
p << acc::RoutineOp::getGangDimKeyword() << ": ";
|
|
p << std::get<0>(pair);
|
|
printSingleDeviceType(p, std::get<1>(pair));
|
|
});
|
|
|
|
p << ")";
|
|
}
|
|
|
|
static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser,
|
|
mlir::ArrayAttr &deviceTypes) {
|
|
llvm::SmallVector<mlir::Attribute> attributes;
|
|
// Keyword only
|
|
if (failed(parser.parseOptionalLParen())) {
|
|
attributes.push_back(mlir::acc::DeviceTypeAttr::get(
|
|
parser.getContext(), mlir::acc::DeviceType::None));
|
|
deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
|
|
return success();
|
|
}
|
|
|
|
// Parse device type attributes
|
|
if (succeeded(parser.parseOptionalLSquare())) {
|
|
if (failed(parser.parseCommaSeparatedList([&]() {
|
|
if (parser.parseAttribute(attributes.emplace_back()))
|
|
return failure();
|
|
return success();
|
|
})))
|
|
return failure();
|
|
if (parser.parseRSquare() || parser.parseRParen())
|
|
return failure();
|
|
}
|
|
deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
|
|
return success();
|
|
}
|
|
|
|
static void
|
|
printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op,
|
|
std::optional<mlir::ArrayAttr> deviceTypes) {
|
|
|
|
if (hasDeviceTypeValues(deviceTypes) && deviceTypes->size() == 1) {
|
|
auto deviceTypeAttr =
|
|
mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[0]);
|
|
if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
|
|
return;
|
|
}
|
|
|
|
if (!hasDeviceTypeValues(deviceTypes))
|
|
return;
|
|
|
|
p << "([";
|
|
llvm::interleaveComma(*deviceTypes, p, [&](mlir::Attribute attr) {
|
|
auto dTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
|
|
p << dTypeAttr;
|
|
});
|
|
p << "])";
|
|
}
|
|
|
|
bool RoutineOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
|
|
|
|
bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) {
|
|
return hasDeviceType(getWorker(), deviceType);
|
|
}
|
|
|
|
bool RoutineOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
|
|
|
|
bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) {
|
|
return hasDeviceType(getVector(), deviceType);
|
|
}
|
|
|
|
bool RoutineOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
|
|
|
|
bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
|
|
return hasDeviceType(getSeq(), deviceType);
|
|
}
|
|
|
|
std::optional<llvm::StringRef> RoutineOp::getBindNameValue() {
|
|
return getBindNameValue(mlir::acc::DeviceType::None);
|
|
}
|
|
|
|
std::optional<llvm::StringRef>
|
|
RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
|
|
if (!hasDeviceTypeValues(getBindNameDeviceType()))
|
|
return std::nullopt;
|
|
if (auto pos = findSegment(*getBindNameDeviceType(), deviceType)) {
|
|
auto attr = (*getBindName())[*pos];
|
|
auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
|
|
return stringAttr.getValue();
|
|
}
|
|
return std::nullopt;
|
|
}
|
|
|
|
bool RoutineOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
|
|
|
|
bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) {
|
|
return hasDeviceType(getGang(), deviceType);
|
|
}
|
|
|
|
std::optional<int64_t> RoutineOp::getGangDimValue() {
|
|
return getGangDimValue(mlir::acc::DeviceType::None);
|
|
}
|
|
|
|
std::optional<int64_t>
|
|
RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
|
|
if (!hasDeviceTypeValues(getGangDimDeviceType()))
|
|
return std::nullopt;
|
|
if (auto pos = findSegment(*getGangDimDeviceType(), deviceType)) {
|
|
auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>((*getGangDim())[*pos]);
|
|
return intAttr.getInt();
|
|
}
|
|
return std::nullopt;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// InitOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult acc::InitOp::verify() {
|
|
Operation *currOp = *this;
|
|
while ((currOp = currOp->getParentOp()))
|
|
if (isComputeOperation(currOp))
|
|
return emitOpError("cannot be nested in a compute operation");
|
|
return success();
|
|
}
|
|
|
|
void acc::InitOp::addDeviceType(MLIRContext *context,
|
|
mlir::acc::DeviceType deviceType) {
|
|
llvm::SmallVector<mlir::Attribute> deviceTypes;
|
|
if (getDeviceTypesAttr())
|
|
llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
|
|
|
|
deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
|
|
setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ShutdownOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult acc::ShutdownOp::verify() {
|
|
Operation *currOp = *this;
|
|
while ((currOp = currOp->getParentOp()))
|
|
if (isComputeOperation(currOp))
|
|
return emitOpError("cannot be nested in a compute operation");
|
|
return success();
|
|
}
|
|
|
|
void acc::ShutdownOp::addDeviceType(MLIRContext *context,
|
|
mlir::acc::DeviceType deviceType) {
|
|
llvm::SmallVector<mlir::Attribute> deviceTypes;
|
|
if (getDeviceTypesAttr())
|
|
llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
|
|
|
|
deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
|
|
setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// SetOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult acc::SetOp::verify() {
|
|
Operation *currOp = *this;
|
|
while ((currOp = currOp->getParentOp()))
|
|
if (isComputeOperation(currOp))
|
|
return emitOpError("cannot be nested in a compute operation");
|
|
if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
|
|
return emitOpError("at least one default_async, device_num, or device_type "
|
|
"operand must appear");
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// UpdateOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult acc::UpdateOp::verify() {
|
|
// At least one of host or device should have a value.
|
|
if (getDataClauseOperands().empty())
|
|
return emitError("at least one value must be present in dataOperands");
|
|
|
|
if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
|
|
getAsyncOperandsDeviceTypeAttr(),
|
|
"async")))
|
|
return failure();
|
|
|
|
if (failed(verifyDeviceTypeAndSegmentCountMatch(
|
|
*this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
|
|
getWaitOperandsDeviceTypeAttr(), "wait")))
|
|
return failure();
|
|
|
|
if (failed(checkWaitAndAsyncConflict<acc::UpdateOp>(*this)))
|
|
return failure();
|
|
|
|
for (mlir::Value operand : getDataClauseOperands())
|
|
if (!mlir::isa<acc::UpdateDeviceOp, acc::UpdateHostOp, acc::GetDevicePtrOp>(
|
|
operand.getDefiningOp()))
|
|
return emitError("expect data entry/exit operation or acc.getdeviceptr "
|
|
"as defining op");
|
|
|
|
return success();
|
|
}
|
|
|
|
unsigned UpdateOp::getNumDataOperands() {
|
|
return getDataClauseOperands().size();
|
|
}
|
|
|
|
Value UpdateOp::getDataOperand(unsigned i) {
|
|
unsigned numOptional = getAsyncOperands().size();
|
|
numOptional += getIfCond() ? 1 : 0;
|
|
return getOperand(getWaitOperands().size() + numOptional + i);
|
|
}
|
|
|
|
void UpdateOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
MLIRContext *context) {
|
|
results.add<RemoveConstantIfCondition<UpdateOp>>(context);
|
|
}
|
|
|
|
bool UpdateOp::hasAsyncOnly() {
|
|
return hasAsyncOnly(mlir::acc::DeviceType::None);
|
|
}
|
|
|
|
bool UpdateOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
|
|
return hasDeviceType(getAsyncOnly(), deviceType);
|
|
}
|
|
|
|
mlir::Value UpdateOp::getAsyncValue() {
|
|
return getAsyncValue(mlir::acc::DeviceType::None);
|
|
}
|
|
|
|
mlir::Value UpdateOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
|
|
if (!hasDeviceTypeValues(getAsyncOperandsDeviceType()))
|
|
return {};
|
|
|
|
if (auto pos = findSegment(*getAsyncOperandsDeviceType(), deviceType))
|
|
return getAsyncOperands()[*pos];
|
|
|
|
return {};
|
|
}
|
|
|
|
bool UpdateOp::hasWaitOnly() {
|
|
return hasWaitOnly(mlir::acc::DeviceType::None);
|
|
}
|
|
|
|
bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
|
|
return hasDeviceType(getWaitOnly(), deviceType);
|
|
}
|
|
|
|
mlir::Operation::operand_range UpdateOp::getWaitValues() {
|
|
return getWaitValues(mlir::acc::DeviceType::None);
|
|
}
|
|
|
|
mlir::Operation::operand_range
|
|
UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) {
|
|
return getWaitValuesWithoutDevnum(
|
|
getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
|
|
getHasWaitDevnum(), deviceType);
|
|
}
|
|
|
|
mlir::Value UpdateOp::getWaitDevnum() {
|
|
return getWaitDevnum(mlir::acc::DeviceType::None);
|
|
}
|
|
|
|
mlir::Value UpdateOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
|
|
return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
|
|
getWaitOperandsSegments(), getHasWaitDevnum(),
|
|
deviceType);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// WaitOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult acc::WaitOp::verify() {
|
|
// The async attribute represent the async clause without value. Therefore the
|
|
// attribute and operand cannot appear at the same time.
|
|
if (getAsyncOperand() && getAsync())
|
|
return emitError("async attribute cannot appear with asyncOperand");
|
|
|
|
if (getWaitDevnum() && getWaitOperands().empty())
|
|
return emitError("wait_devnum cannot appear without waitOperands");
|
|
|
|
return success();
|
|
}
|
|
|
|
#define GET_OP_CLASSES
|
|
#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
|
|
|
|
#define GET_ATTRDEF_CLASSES
|
|
#include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
|
|
|
|
#define GET_TYPEDEF_CLASSES
|
|
#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// acc dialect utilities
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
mlir::TypedValue<mlir::acc::PointerLikeType>
|
|
mlir::acc::getVarPtr(mlir::Operation *accDataClauseOp) {
|
|
auto varPtr{llvm::TypeSwitch<mlir::Operation *,
|
|
mlir::TypedValue<mlir::acc::PointerLikeType>>(
|
|
accDataClauseOp)
|
|
.Case<ACC_DATA_ENTRY_OPS>(
|
|
[&](auto entry) { return entry.getVarPtr(); })
|
|
.Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
|
|
[&](auto exit) { return exit.getVarPtr(); })
|
|
.Default([&](mlir::Operation *) {
|
|
return mlir::TypedValue<mlir::acc::PointerLikeType>();
|
|
})};
|
|
return varPtr;
|
|
}
|
|
|
|
mlir::Value mlir::acc::getVar(mlir::Operation *accDataClauseOp) {
|
|
auto varPtr{
|
|
llvm::TypeSwitch<mlir::Operation *, mlir::Value>(accDataClauseOp)
|
|
.Case<ACC_DATA_ENTRY_OPS>([&](auto entry) { return entry.getVar(); })
|
|
.Default([&](mlir::Operation *) { return mlir::Value(); })};
|
|
return varPtr;
|
|
}
|
|
|
|
mlir::Type mlir::acc::getVarType(mlir::Operation *accDataClauseOp) {
|
|
auto varType{llvm::TypeSwitch<mlir::Operation *, mlir::Type>(accDataClauseOp)
|
|
.Case<ACC_DATA_ENTRY_OPS>(
|
|
[&](auto entry) { return entry.getVarType(); })
|
|
.Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
|
|
[&](auto exit) { return exit.getVarType(); })
|
|
.Default([&](mlir::Operation *) { return mlir::Type(); })};
|
|
return varType;
|
|
}
|
|
|
|
mlir::TypedValue<mlir::acc::PointerLikeType>
|
|
mlir::acc::getAccPtr(mlir::Operation *accDataClauseOp) {
|
|
auto accPtr{llvm::TypeSwitch<mlir::Operation *,
|
|
mlir::TypedValue<mlir::acc::PointerLikeType>>(
|
|
accDataClauseOp)
|
|
.Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
|
|
[&](auto dataClause) { return dataClause.getAccPtr(); })
|
|
.Default([&](mlir::Operation *) {
|
|
return mlir::TypedValue<mlir::acc::PointerLikeType>();
|
|
})};
|
|
return accPtr;
|
|
}
|
|
|
|
mlir::Value mlir::acc::getAccVar(mlir::Operation *accDataClauseOp) {
|
|
auto accPtr{llvm::TypeSwitch<mlir::Operation *, mlir::Value>(accDataClauseOp)
|
|
.Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
|
|
[&](auto dataClause) { return dataClause.getAccVar(); })
|
|
.Default([&](mlir::Operation *) { return mlir::Value(); })};
|
|
return accPtr;
|
|
}
|
|
|
|
mlir::Value mlir::acc::getVarPtrPtr(mlir::Operation *accDataClauseOp) {
|
|
auto varPtrPtr{
|
|
llvm::TypeSwitch<mlir::Operation *, mlir::Value>(accDataClauseOp)
|
|
.Case<ACC_DATA_ENTRY_OPS>(
|
|
[&](auto dataClause) { return dataClause.getVarPtrPtr(); })
|
|
.Default([&](mlir::Operation *) { return mlir::Value(); })};
|
|
return varPtrPtr;
|
|
}
|
|
|
|
mlir::SmallVector<mlir::Value>
|
|
mlir::acc::getBounds(mlir::Operation *accDataClauseOp) {
|
|
mlir::SmallVector<mlir::Value> bounds{
|
|
llvm::TypeSwitch<mlir::Operation *, mlir::SmallVector<mlir::Value>>(
|
|
accDataClauseOp)
|
|
.Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
|
|
return mlir::SmallVector<mlir::Value>(
|
|
dataClause.getBounds().begin(), dataClause.getBounds().end());
|
|
})
|
|
.Default([&](mlir::Operation *) {
|
|
return mlir::SmallVector<mlir::Value, 0>();
|
|
})};
|
|
return bounds;
|
|
}
|
|
|
|
mlir::SmallVector<mlir::Value>
|
|
mlir::acc::getAsyncOperands(mlir::Operation *accDataClauseOp) {
|
|
return llvm::TypeSwitch<mlir::Operation *, mlir::SmallVector<mlir::Value>>(
|
|
accDataClauseOp)
|
|
.Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
|
|
return mlir::SmallVector<mlir::Value>(
|
|
dataClause.getAsyncOperands().begin(),
|
|
dataClause.getAsyncOperands().end());
|
|
})
|
|
.Default([&](mlir::Operation *) {
|
|
return mlir::SmallVector<mlir::Value, 0>();
|
|
});
|
|
}
|
|
|
|
mlir::ArrayAttr
|
|
mlir::acc::getAsyncOperandsDeviceType(mlir::Operation *accDataClauseOp) {
|
|
return llvm::TypeSwitch<mlir::Operation *, mlir::ArrayAttr>(accDataClauseOp)
|
|
.Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
|
|
return dataClause.getAsyncOperandsDeviceTypeAttr();
|
|
})
|
|
.Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; });
|
|
}
|
|
|
|
mlir::ArrayAttr mlir::acc::getAsyncOnly(mlir::Operation *accDataClauseOp) {
|
|
return llvm::TypeSwitch<mlir::Operation *, mlir::ArrayAttr>(accDataClauseOp)
|
|
.Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
|
|
[&](auto dataClause) { return dataClause.getAsyncOnlyAttr(); })
|
|
.Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; });
|
|
}
|
|
|
|
std::optional<llvm::StringRef> mlir::acc::getVarName(mlir::Operation *accOp) {
|
|
auto name{
|
|
llvm::TypeSwitch<mlir::Operation *, std::optional<llvm::StringRef>>(accOp)
|
|
.Case<ACC_DATA_ENTRY_OPS>([&](auto entry) { return entry.getName(); })
|
|
.Default([&](mlir::Operation *) -> std::optional<llvm::StringRef> {
|
|
return {};
|
|
})};
|
|
return name;
|
|
}
|
|
|
|
std::optional<mlir::acc::DataClause>
|
|
mlir::acc::getDataClause(mlir::Operation *accDataEntryOp) {
|
|
auto dataClause{
|
|
llvm::TypeSwitch<mlir::Operation *, std::optional<mlir::acc::DataClause>>(
|
|
accDataEntryOp)
|
|
.Case<ACC_DATA_ENTRY_OPS>(
|
|
[&](auto entry) { return entry.getDataClause(); })
|
|
.Default([&](mlir::Operation *) { return std::nullopt; })};
|
|
return dataClause;
|
|
}
|
|
|
|
bool mlir::acc::getImplicitFlag(mlir::Operation *accDataEntryOp) {
|
|
auto implicit{llvm::TypeSwitch<mlir::Operation *, bool>(accDataEntryOp)
|
|
.Case<ACC_DATA_ENTRY_OPS>(
|
|
[&](auto entry) { return entry.getImplicit(); })
|
|
.Default([&](mlir::Operation *) { return false; })};
|
|
return implicit;
|
|
}
|
|
|
|
mlir::ValueRange mlir::acc::getDataOperands(mlir::Operation *accOp) {
|
|
auto dataOperands{
|
|
llvm::TypeSwitch<mlir::Operation *, mlir::ValueRange>(accOp)
|
|
.Case<ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS>(
|
|
[&](auto entry) { return entry.getDataClauseOperands(); })
|
|
.Default([&](mlir::Operation *) { return mlir::ValueRange(); })};
|
|
return dataOperands;
|
|
}
|
|
|
|
mlir::MutableOperandRange
|
|
mlir::acc::getMutableDataOperands(mlir::Operation *accOp) {
|
|
auto dataOperands{
|
|
llvm::TypeSwitch<mlir::Operation *, mlir::MutableOperandRange>(accOp)
|
|
.Case<ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS>(
|
|
[&](auto entry) { return entry.getDataClauseOperandsMutable(); })
|
|
.Default([&](mlir::Operation *) { return nullptr; })};
|
|
return dataOperands;
|
|
}
|
|
|
|
mlir::Operation *mlir::acc::getEnclosingComputeOp(mlir::Region ®ion) {
|
|
mlir::Operation *parentOp = region.getParentOp();
|
|
while (parentOp) {
|
|
if (mlir::isa<ACC_COMPUTE_CONSTRUCT_OPS>(parentOp)) {
|
|
return parentOp;
|
|
}
|
|
parentOp = parentOp->getParentOp();
|
|
}
|
|
return nullptr;
|
|
}
|