Files
clang-p2996/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
Luke Hutton 2e7aa7ead6 [mlir][tosa] Add custom operand getters for select op (#145921)
The select op has 3 inputs: input1, input2, input3 to according to the
tosa specification. However, use of getInput1(), getInput2() and
getInput3() in the codebase can be confusing and hinder readability.
This commit adds custom getters to help improve readability:
  - input1 -> getPred()
  - input2 -> getOnTrue()
  - input3 -> getOnFalse()

They should be preferred as they are more descriptive, however, the ODS
generated getters (getInputX()) may still be used.

Unfortunately the custom getters don't propagate to Adaptors such as
`FoldAdaptor`, so the ODS generated getters must be used.
2025-06-30 10:11:09 +01:00

619 lines
21 KiB
C++

//===--- TosaProfileCompliance.cpp - Tosa Profile Compliance Validation ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Tosa/IR/TosaProfileCompliance.h"
#include "llvm/ADT/StringExtras.h"
using namespace mlir;
using namespace mlir::tosa;
TosaProfileCompliance::TosaProfileCompliance() {
const TypeInfo boolT = {mlir::IntegerType::getTypeID(), 1};
const TypeInfo i4T = {mlir::IntegerType::getTypeID(), 4};
const TypeInfo i8T = {mlir::IntegerType::getTypeID(), 8};
const TypeInfo i16T = {mlir::IntegerType::getTypeID(), 16};
const TypeInfo i32T = {mlir::IntegerType::getTypeID(), 32};
const TypeInfo i48T = {mlir::IntegerType::getTypeID(), 48};
const TypeInfo bf16T = {mlir::BFloat16Type::getTypeID(), 16};
const TypeInfo fp16T = {mlir::Float16Type::getTypeID(), 16};
const TypeInfo fp32T = {mlir::Float32Type::getTypeID(), 32};
const TypeInfo fp8e4m3T = {mlir::Float8E4M3FNType::getTypeID(), 8};
const TypeInfo fp8e5m2T = {mlir::Float8E5M2Type::getTypeID(), 8};
// The profile-based compliance content below is auto-generated by a script
// in https://git.mlplatform.org/tosa/specification.git
#include "mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc"
// End of auto-generated metadata
}
template <>
OperationProfileComplianceMap TosaProfileCompliance::getProfileComplianceMap() {
return profileComplianceMap;
}
template <>
OperationExtensionComplianceMap
TosaProfileCompliance::getProfileComplianceMap() {
return extensionComplianceMap;
}
// Base populating function
LogicalResult ProfileInfoDepot::populateProfileInfo(ValueRange operands,
Value output) {
for (auto operand : operands)
addValue(operand);
addValue(output);
return success();
}
template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ConcatOp op) {
addValue(op.getInput1().front());
addValue(op.getOutput());
return success();
}
template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::AvgPool2dOp op) {
addValue(op.getInput());
addValue(op.getInputZp());
addValue(op.getOutputZp());
addType(op.getAccType());
addValue(op.getOutput());
return success();
}
template <typename T>
LogicalResult ProfileInfoDepot::populateProfileInfoConv(T op) {
addValue(op.getInput());
addValue(op.getWeight());
addValue(op.getBias());
addValue(op.getInputZp());
addValue(op.getWeightZp());
addType(op.getAccType());
addValue(op.getOutput());
return success();
}
template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::Conv2DOp op) {
return populateProfileInfoConv(op);
}
template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::Conv3DOp op) {
return populateProfileInfoConv(op);
}
template <>
LogicalResult
ProfileInfoDepot::populateProfileInfo(tosa::TransposeConv2DOp op) {
return populateProfileInfoConv(op);
}
template <>
LogicalResult
ProfileInfoDepot::populateProfileInfo(tosa::DepthwiseConv2DOp op) {
return populateProfileInfoConv(op);
}
template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::PadOp op) {
addValue(op.getInput1());
addValue(op.getPadConst());
addValue(op.getOutput());
return success();
}
template <typename T>
LogicalResult ProfileInfoDepot::populateProfileInfoDataLayout(T op) {
addValue(op.getInput1());
addValue(op.getOutput());
return success();
}
template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ReshapeOp op) {
return populateProfileInfoDataLayout(op);
}
template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SliceOp op) {
return populateProfileInfoDataLayout(op);
}
template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TileOp op) {
return populateProfileInfoDataLayout(op);
}
template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TransposeOp op) {
return populateProfileInfoDataLayout(op);
}
template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::GatherOp op) {
addValue(op.getValues());
addValue(op.getIndices());
addValue(op.getOutput());
return success();
}
template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ScatterOp op) {
addValue(op.getValuesIn());
addValue(op.getIndices());
addValue(op.getInput());
addValue(op.getValuesOut());
return success();
}
template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MulOp op) {
addValue(op.getInput1());
addValue(op.getInput2());
addValue(op.getOutput());
return success();
}
template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ResizeOp op) {
addValue(op.getInput());
addValue(op.getOutput());
return success();
}
template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::FFT2dOp op) {
addValue(op.getInputReal());
addValue(op.getInputImag());
addValue(op.getOutputReal());
addValue(op.getOutputImag());
return success();
}
template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RFFT2dOp op) {
addValue(op.getInputReal());
addValue(op.getOutputReal());
addValue(op.getOutputImag());
return success();
}
template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SelectOp op) {
addValue(op.getOnTrue());
addValue(op.getOnFalse());
addValue(op.getOutput());
return success();
}
template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RescaleOp op) {
addValue(op.getInput());
addValue(op.getInputZp());
addValue(op.getOutputZp());
addValue(op.getOutput());
return success();
}
template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MatMulOp op) {
addValue(op.getA());
addValue(op.getB());
addValue(op.getAZp());
addValue(op.getBZp());
addValue(op.getOutput());
return success();
}
template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableOp op) {
addType(op.getType());
return success();
}
template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableWriteOp op) {
addValue(op.getInput1());
return success();
}
template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::IfOp op) {
addValue(op.getCondition());
return success();
}
template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::WhileOp op) {
Block *block = &op.getCondGraph().front();
Operation *terminator = block->getTerminator();
addValue(terminator->getOperands().front());
return success();
}
LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
// This helper function only populates the info for the customised operands.
#define POPULATE_PROFILE_INFO_CUSTOM(tosaOp) \
if (isa<tosa::tosaOp##Op>(op)) { \
return populateProfileInfo(cast<tosa::tosaOp##Op>(op)); \
}
#define POPULATE_PROFILE_INFO_SKIP(tosaOp) \
if (isa<tosa::tosaOp##Op>(op)) \
return success();
// This helper function populates the info for all operands.
#define POPULATE_PROFILE_INFO_COMMON(tosaOp) \
if (isa<tosa::tosaOp##Op>(op)) { \
return populateProfileInfo(op->getOperands(), op->getResult(0)); \
}
// Skip irrelevant operands when they are independent and not tied to any
// specific profile/extension.
POPULATE_PROFILE_INFO_CUSTOM(AvgPool2d)
POPULATE_PROFILE_INFO_CUSTOM(TransposeConv2D)
POPULATE_PROFILE_INFO_CUSTOM(Conv2D)
POPULATE_PROFILE_INFO_CUSTOM(Conv3D)
POPULATE_PROFILE_INFO_CUSTOM(DepthwiseConv2D)
POPULATE_PROFILE_INFO_CUSTOM(Mul)
POPULATE_PROFILE_INFO_CUSTOM(FFT2d)
POPULATE_PROFILE_INFO_CUSTOM(RFFT2d)
POPULATE_PROFILE_INFO_CUSTOM(Concat)
POPULATE_PROFILE_INFO_CUSTOM(Pad)
POPULATE_PROFILE_INFO_CUSTOM(Reshape)
POPULATE_PROFILE_INFO_CUSTOM(Slice)
POPULATE_PROFILE_INFO_CUSTOM(Tile)
POPULATE_PROFILE_INFO_CUSTOM(Transpose)
POPULATE_PROFILE_INFO_CUSTOM(Gather)
POPULATE_PROFILE_INFO_CUSTOM(Scatter)
POPULATE_PROFILE_INFO_CUSTOM(Resize)
POPULATE_PROFILE_INFO_CUSTOM(Select)
POPULATE_PROFILE_INFO_CUSTOM(Rescale)
POPULATE_PROFILE_INFO_CUSTOM(MatMul)
POPULATE_PROFILE_INFO_CUSTOM(Variable)
POPULATE_PROFILE_INFO_CUSTOM(VariableWrite)
POPULATE_PROFILE_INFO_CUSTOM(If)
POPULATE_PROFILE_INFO_CUSTOM(While)
// For the most of tosa operators, all operands are profile/extension related
// and hence are all considered in this profile-based compilance check.
POPULATE_PROFILE_INFO_COMMON(Cast)
POPULATE_PROFILE_INFO_COMMON(Const)
POPULATE_PROFILE_INFO_COMMON(ArgMax)
POPULATE_PROFILE_INFO_COMMON(Sub)
POPULATE_PROFILE_INFO_COMMON(Maximum)
POPULATE_PROFILE_INFO_COMMON(Minimum)
POPULATE_PROFILE_INFO_COMMON(MaxPool2d)
POPULATE_PROFILE_INFO_COMMON(Clamp)
POPULATE_PROFILE_INFO_COMMON(Erf)
POPULATE_PROFILE_INFO_COMMON(Sigmoid)
POPULATE_PROFILE_INFO_COMMON(Tanh)
POPULATE_PROFILE_INFO_COMMON(Add)
POPULATE_PROFILE_INFO_COMMON(ArithmeticRightShift)
POPULATE_PROFILE_INFO_COMMON(BitwiseAnd)
POPULATE_PROFILE_INFO_COMMON(BitwiseNot)
POPULATE_PROFILE_INFO_COMMON(BitwiseOr)
POPULATE_PROFILE_INFO_COMMON(BitwiseXor)
POPULATE_PROFILE_INFO_COMMON(LogicalLeftShift)
POPULATE_PROFILE_INFO_COMMON(LogicalRightShift)
POPULATE_PROFILE_INFO_COMMON(LogicalAnd)
POPULATE_PROFILE_INFO_COMMON(LogicalNot)
POPULATE_PROFILE_INFO_COMMON(LogicalOr)
POPULATE_PROFILE_INFO_COMMON(LogicalXor)
POPULATE_PROFILE_INFO_COMMON(IntDiv)
POPULATE_PROFILE_INFO_COMMON(Pow)
POPULATE_PROFILE_INFO_COMMON(Table)
POPULATE_PROFILE_INFO_COMMON(Abs)
POPULATE_PROFILE_INFO_COMMON(Ceil)
POPULATE_PROFILE_INFO_COMMON(Clz)
POPULATE_PROFILE_INFO_COMMON(Sin)
POPULATE_PROFILE_INFO_COMMON(Cos)
POPULATE_PROFILE_INFO_COMMON(Exp)
POPULATE_PROFILE_INFO_COMMON(Floor)
POPULATE_PROFILE_INFO_COMMON(Log)
POPULATE_PROFILE_INFO_COMMON(Negate)
POPULATE_PROFILE_INFO_COMMON(Reciprocal)
POPULATE_PROFILE_INFO_COMMON(Rsqrt)
POPULATE_PROFILE_INFO_COMMON(ReduceAll)
POPULATE_PROFILE_INFO_COMMON(ReduceAny)
POPULATE_PROFILE_INFO_COMMON(ReduceMax)
POPULATE_PROFILE_INFO_COMMON(ReduceMin)
POPULATE_PROFILE_INFO_COMMON(ReduceProduct)
POPULATE_PROFILE_INFO_COMMON(ReduceSum)
POPULATE_PROFILE_INFO_COMMON(Equal)
POPULATE_PROFILE_INFO_COMMON(GreaterEqual)
POPULATE_PROFILE_INFO_COMMON(Greater)
POPULATE_PROFILE_INFO_COMMON(Reverse)
POPULATE_PROFILE_INFO_COMMON(Identity)
POPULATE_PROFILE_INFO_COMMON(VariableRead)
// Type Invariant Extension, a capability extension that is independent
// of the data type, meaning any compatible type can be used. No type
// constraint for those operations.
POPULATE_PROFILE_INFO_SKIP(ConstShape)
POPULATE_PROFILE_INFO_SKIP(Yield)
return failure();
}
//===----------------------------------------------------------------------===//
// Tosa Profile And Extension Compliance Checker
//===----------------------------------------------------------------------===//
template <typename T>
FailureOr<SmallVector<T>>
TosaProfileCompliance::getOperatorDefinition(Operation *op,
CheckCondition &condition) {
const std::string opName = op->getName().getStringRef().str();
const auto complianceMap = getProfileComplianceMap<T>();
const auto it = complianceMap.find(opName);
if (it == complianceMap.end())
return {};
return findMatchedProfile<T>(op, it->second, condition);
}
template <typename T>
LogicalResult TosaProfileCompliance::checkProfileOrExtension(
Operation *op, const tosa::TargetEnv &targetEnv,
const SmallVector<ArrayRef<T>> &specRequiredModeSet) {
// None of profile requirement is set in the specification.
if (specRequiredModeSet.size() == 0)
return success();
CheckCondition condition = CheckCondition::invalid;
const auto maybeOpRequiredMode = getOperatorDefinition<T>(op, condition);
if (failed(maybeOpRequiredMode)) {
// Operators such as control-flow and shape ops do not have an operand type
// restriction. When the profile compliance information of operation is not
// found, confirm if the target have enabled the profile required from the
// specification.
int mode_count = 0;
for (const auto &cands : specRequiredModeSet) {
if (targetEnv.allowsAnyOf(cands))
return success();
mode_count += cands.size();
}
op->emitOpError() << "illegal: requires"
<< (mode_count > 1 ? " any of " : " ") << "["
<< llvm::join(stringifyProfile<T>(specRequiredModeSet),
", ")
<< "] but not enabled in target\n";
return failure();
}
// Find the required profiles or extensions according to the operand type
// combination.
const auto opRequiredMode = maybeOpRequiredMode.value();
if (opRequiredMode.size() == 0) {
// No matched restriction found.
return success();
}
if (condition == CheckCondition::allOf &&
!targetEnv.allowsAllOf(opRequiredMode)) {
op->emitOpError() << "illegal: requires"
<< (opRequiredMode.size() > 1 ? " all of " : " ") << "["
<< llvm::join(stringifyProfile<T>(opRequiredMode), ", ")
<< "] but not enabled in target\n";
return failure();
}
if (condition == CheckCondition::anyOf &&
!targetEnv.allowsAnyOf(opRequiredMode)) {
op->emitOpError() << "illegal: requires"
<< (opRequiredMode.size() > 1 ? " any of " : " ") << "["
<< llvm::join(stringifyProfile<T>(opRequiredMode), ", ")
<< "] but not enabled in target\n";
return failure();
}
// Each extension can contain a list of profiles that it works with, usually
// have the same data type.
if constexpr (std::is_same_v<T, Extension>) {
for (const auto &mode : opRequiredMode) {
SmallVector<Profile> coProfs = getCooperativeProfiles(mode);
if (!targetEnv.allowsAnyOf(coProfs)) {
op->emitOpError() << "illegal: requires ["
<< llvm::join(stringifyProfile<Profile>(coProfs),
", ")
<< "] to work with but not enabled in target\n";
return failure();
}
}
}
// Ensure the profile inference match the profile knowledge of the
// specification.
for (const auto &cands : specRequiredModeSet) {
for (const auto &mode : opRequiredMode) {
if (!llvm::is_contained(cands, mode)) {
op->emitOpError() << "illegal: requires ["
<< llvm::join(stringifyProfile<T>(opRequiredMode),
", ")
<< "] but not included in the profile compliance ["
<< llvm::join(
stringifyProfile<T>(specRequiredModeSet), ", ")
<< "]\n";
return failure();
}
}
}
return success();
}
LogicalResult
TosaProfileCompliance::checkProfile(Operation *op,
const tosa::TargetEnv &targetEnv) {
if (auto interface = dyn_cast<tosa::QueryProfileInterface>(op))
return checkProfileOrExtension<Profile>(op, targetEnv,
interface.getProfiles());
return success();
}
LogicalResult
TosaProfileCompliance::checkExtension(Operation *op,
const tosa::TargetEnv &targetEnv) {
if (auto interface = dyn_cast<tosa::QueryExtensionInterface>(op))
return checkProfileOrExtension<Extension>(op, targetEnv,
interface.getExtensions());
return success();
}
LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) {
CheckCondition condition = CheckCondition::invalid;
const auto maybeProfDef = getOperatorDefinition<Profile>(op, condition);
const auto maybeExtDef = getOperatorDefinition<Extension>(op, condition);
if (!failed(maybeProfDef) && !failed(maybeExtDef) &&
!maybeProfDef.value().size() && !maybeExtDef.value().size()) {
std::string message;
llvm::raw_string_ostream os(message);
os << "illegal: operation operand/result data types did not align with any "
"profile or extension, got (";
ProfileInfoDepot depot(op);
SmallVector<TypeInfo> current = depot.getInfo();
for (const auto &typeInfo : llvm::drop_end(current))
os << stringifyTypeInfo(typeInfo) << ",";
os << stringifyTypeInfo(current.back()) << ")";
// avoid polluting the error message output by outputting only
// the best match
const std::string opName = op->getName().getStringRef().str();
int maxMatches = -1;
SmallVector<TypeInfo> bestTypeInfo;
const auto searchBestMatch = [&](auto map) {
for (const auto &complianceInfos : map[opName]) {
for (const auto &typeInfos : complianceInfos.operandTypeInfoSet) {
const int matches = llvm::count_if(
llvm::zip_equal(current, typeInfos), [&](const auto zipType) {
return isSameTypeInfo(std::get<0>(zipType),
std::get<1>(zipType));
});
if (matches > maxMatches) {
maxMatches = matches;
bestTypeInfo = typeInfos;
}
}
}
};
searchBestMatch(getProfileComplianceMap<Profile>());
searchBestMatch(getProfileComplianceMap<Extension>());
os << ", did you mean (";
for (const auto &typeInfo : llvm::drop_end(bestTypeInfo))
os << stringifyTypeInfo(typeInfo) << ",";
os << stringifyTypeInfo(bestTypeInfo.back()) << ")? ";
os << "Otherwise, please refer to the 'supported data types' for '"
<< opName << "' in the specification.";
op->emitOpError(message);
return failure();
}
return success();
}
// Find the profiles or extensions requirement according to the signature of
// type of the operand list.
template <typename T>
SmallVector<T> TosaProfileCompliance::findMatchedProfile(
Operation *op, SmallVector<OpComplianceInfo<T>> compInfo,
CheckCondition &condition) {
assert(compInfo.size() != 0 &&
"profile-based compliance information is empty");
// Populate the type of profile/extension relevant operands.
ProfileInfoDepot depot(op);
SmallVector<TypeInfo> present = depot.getInfo();
if (present.size() == 0)
return {};
for (size_t i = 0; i < compInfo.size(); i++) {
SmallVector<SmallVector<TypeInfo>> sets = compInfo[i].operandTypeInfoSet;
for (SmallVector<TypeInfo> expected : sets) {
assert(present.size() == expected.size() &&
"the entries for profile-based compliance do not match between "
"the generated metadata and the type definition retrieved from "
" the operation");
bool is_found = true;
// Compare the type signature between the given operation and the
// compliance metadata.
for (size_t j = 0; j < expected.size(); j++) {
if (!isSameTypeInfo(present[j], expected[j])) {
// Verify the next mode set from the list.
is_found = false;
break;
}
}
if (is_found == true) {
condition = compInfo[i].condition;
return compInfo[i].mode;
}
}
}
return {};
}
// Debug utilites.
template <typename T>
SmallVector<StringRef>
TosaProfileCompliance::stringifyProfile(ArrayRef<T> profiles) {
SmallVector<StringRef> debugStrings;
for (const auto &profile : profiles) {
if constexpr (std::is_same_v<T, Profile>)
debugStrings.push_back(tosa::stringifyProfile(profile));
else
debugStrings.push_back(tosa::stringifyExtension(profile));
}
return debugStrings;
}
template <typename T>
SmallVector<StringRef> TosaProfileCompliance::stringifyProfile(
const SmallVector<ArrayRef<T>> &profileSet) {
SmallVector<StringRef> debugStrings;
for (const auto &profiles : profileSet) {
auto tempStrings = stringifyProfile<T>(profiles);
llvm::append_range(debugStrings, tempStrings);
}
return debugStrings;
}
llvm::SmallString<7>
TosaProfileCompliance::stringifyTypeInfo(const TypeInfo &typeInfo) {
if (typeInfo.typeID == mlir::IntegerType::getTypeID()) {
return {"i" + llvm::utostr(typeInfo.bitWidth)};
} else if (typeInfo.typeID == mlir::Float16Type::getTypeID()) {
return {"f16"};
} else if (typeInfo.typeID == mlir::Float32Type::getTypeID()) {
return {"f32"};
} else if (typeInfo.typeID == mlir::BFloat16Type::getTypeID()) {
return {"bf16"};
} else if (typeInfo.typeID == mlir::Float8E4M3FNType::getTypeID()) {
return {"fp8e4m3"};
} else if (typeInfo.typeID == mlir::Float8E5M2Type::getTypeID()) {
return {"fp8e5m2"};
}
llvm_unreachable("unknown type");
}