//===- ConversionUtils.cpp ------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // Utility functions for TOSA lowering // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" using namespace mlir; using namespace mlir::tosa; SmallVector mlir::tosa::getNParallelLoopsAttrs(unsigned nParallelLoops) { return SmallVector(nParallelLoops, utils::IteratorType::parallel); } SmallVector mlir::tosa::condenseValues(const SmallVector &values) { SmallVector condensedValues; for (auto value : values) if (value) condensedValues.push_back(value); return condensedValues; } Value mlir::tosa::clampFloatHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter) { Value minValue = rewriter.create(loc, arg, max); return rewriter.create(loc, minValue, min); } Value mlir::tosa::clampIntHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter) { auto smallerThanMin = rewriter.create(loc, arith::CmpIPredicate::slt, arg, min); auto minOrArg = rewriter.create(loc, smallerThanMin, min, arg); auto largerThanMax = rewriter.create(loc, arith::CmpIPredicate::slt, max, arg); return rewriter.create(loc, largerThanMax, max, minOrArg); } bool mlir::tosa::validIntegerRange(IntegerType ty, int64_t value) { uint64_t bitwidth = ty.getIntOrFloatBitWidth(); if (ty.getSignedness() == IntegerType::Unsigned) { uint64_t uvalue = value; APInt intMin = APInt::getMinValue(bitwidth); APInt intMax = APInt::getMaxValue(bitwidth); return uvalue >= intMin.getZExtValue() && uvalue <= intMax.getZExtValue(); } APInt intMin = APInt::getSignedMinValue(bitwidth); APInt intMax = APInt::getSignedMaxValue(bitwidth); return value >= intMin.getSExtValue() && value <= intMax.getSExtValue(); }