Files
clang-p2996/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
Alexander Shaposhnikov 9e1a344155 [MLIR][TOSA] Switch Tosa to DenseArrayAttr
This diff completes switching Tosa to DenseArrayAttr.

Test plan: ninja check-mlir check-all

Differential revision: https://reviews.llvm.org/D141111
2023-01-06 22:57:14 +00:00

1023 lines
33 KiB
C++

//===- TosaCanonicalizations.cpp - Canonicalization patterns & folders ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// \file
// TOSA canonicalization patterns and folders.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Quant/QuantOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/InliningUtils.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/TypeSwitch.h"
#include <functional>
using namespace mlir;
using namespace mlir::tosa;
//===----------------------------------------------------------------------===//
// Operator Canonicalizers.
//===----------------------------------------------------------------------===//
struct ConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
using OpRewritePattern<tosa::ConcatOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::ConcatOp op,
PatternRewriter &rewriter) const override {
if (op.getInput1().size() != 1)
return failure();
if (op.getInput1().front().getType() != op.getType()) {
rewriter
.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
op.getInput1().front())
.getResult();
return success();
}
rewriter.replaceOp(op, op.getInput1().front());
return success();
}
};
void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ConcatOptimization>(context);
}
struct ReshapeReshapeOptimization : public OpRewritePattern<tosa::ReshapeOp> {
using OpRewritePattern<tosa::ReshapeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::ReshapeOp op,
PatternRewriter &rewriter) const override {
Value input = op.getInput1();
Operation *definingOp = input.getDefiningOp();
if (!definingOp)
return failure();
if (tosa::ReshapeOp reshapeOp = dyn_cast<tosa::ReshapeOp>(definingOp)) {
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
op, op.getType(), reshapeOp.getInput1(), op.getNewShape());
return success();
}
return failure();
}
};
struct ReshapeConstOptimization : public OpRewritePattern<tosa::ReshapeOp> {
using OpRewritePattern<tosa::ReshapeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::ReshapeOp op,
PatternRewriter &rewriter) const override {
Value input = op.getInput1();
ShapedType inputTy = input.getType().cast<ShapedType>();
ShapedType resultTy = op.getType().cast<ShapedType>();
if (inputTy.getElementType() != resultTy.getElementType())
return rewriter.notifyMatchFailure(op, "element type does not match.");
// Check if input is constant
DenseElementsAttr inputAttr;
if (!matchPattern(input, m_Constant(&inputAttr)))
return rewriter.notifyMatchFailure(op, "Non-constant input.");
// Check if has >1 consumer and is not splat
if (!input.hasOneUse() && !inputAttr.isSplat())
return rewriter.notifyMatchFailure(op,
"Used more than once or not-splat");
// Build new const op with correct output shape
ShapedType inputShape = input.getType().cast<ShapedType>();
DenseElementsAttr outputAttr =
inputAttr.reshape(inputShape.clone(op.getNewShape()));
rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, outputAttr.getType(),
outputAttr);
return success();
}
};
void ReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ReshapeReshapeOptimization>(context);
results.add<ReshapeConstOptimization>(context);
}
LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
auto notOp = op.getPred().getDefiningOp<tosa::LogicalNotOp>();
if (!notOp)
return failure();
rewriter.updateRootInPlace(op, [&]() {
op.getOperation()->setOperands(
{notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
});
return success();
}
struct TransposeNoOp : public OpRewritePattern<tosa::TransposeOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::TransposeOp op,
PatternRewriter &rewriter) const override {
auto perm = op.getPerms();
DenseIntElementsAttr permAttr;
if (!matchPattern(perm, m_Constant(&permAttr))) {
return failure();
}
SmallVector<int64_t> permValues = llvm::to_vector<6>(
llvm::map_range(permAttr.getValues<APInt>(),
[](const APInt &val) { return val.getSExtValue(); }));
for (int i = 0, s = permValues.size(); i < s; i++) {
if (i != permValues[i]) {
return failure();
}
}
rewriter.replaceOp(op, op.getInput1());
return success();
}
};
// Determines the case when tosa.transpose is a tosa.reshape operation.
struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::TransposeOp op,
PatternRewriter &rewriter) const override {
DenseIntElementsAttr permAttr;
if (!matchPattern(op.getPerms(), m_Constant(&permAttr)))
return rewriter.notifyMatchFailure(op, "Non-constant permutation");
auto input = op.getInput1();
auto inputTy = input.getType().cast<ShapedType>();
if (!inputTy.hasRank())
return rewriter.notifyMatchFailure(op, "Unranked input.");
int64_t numDynDims = 0;
for (int i = 0; i < inputTy.getRank(); ++i)
if (inputTy.isDynamicDim(i))
numDynDims++;
if (numDynDims > 1)
return rewriter.notifyMatchFailure(op, "Has more than one dynamic dim.");
SmallVector<int64_t> permValues = llvm::to_vector<6>(
llvm::map_range(permAttr.getValues<APInt>(),
[](const APInt &val) { return val.getSExtValue(); }));
SmallVector<int64_t> nonZeroPerms;
nonZeroPerms.reserve(permValues.size());
for (auto idx : permValues) {
auto sz = inputTy.getDimSize(idx);
if (sz != 1)
nonZeroPerms.push_back(idx);
}
for (int i = 1, s = nonZeroPerms.size(); i < s; ++i)
if (nonZeroPerms[i - 1] > nonZeroPerms[i])
return rewriter.notifyMatchFailure(op,
"Transpose changes memeory layout.");
SmallVector<int64_t> newShape;
newShape.reserve(inputTy.getRank());
for (int i = 0, s = inputTy.getRank(); i < s; ++i)
newShape.push_back(inputTy.getDimSize(permValues[i]));
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
op, op.getType(), op.getInput1(),
rewriter.getDenseI64ArrayAttr(newShape));
return success();
}
};
void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<TransposeNoOp, TransposeIsReshape>(context);
}
struct AddZeroOptimization : public OpRewritePattern<tosa::AddOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::AddOp op,
PatternRewriter &rewriter) const override {
auto input1 = op.getInput1();
auto input2 = op.getInput2();
DenseElementsAttr input1Attr;
if (matchPattern(input1, m_Constant(&input1Attr)) && input1Attr.isSplat() &&
input2.getType() == op.getType()) {
if (input1Attr.getType().getElementType().isa<IntegerType>() &&
input1Attr.getSplatValue<APInt>().isZero()) {
rewriter.replaceOp(op, op.getInput2());
return success();
}
}
DenseElementsAttr input2Attr;
if (matchPattern(input2, m_Constant(&input2Attr)) && input2Attr.isSplat() &&
input1.getType() == op.getType()) {
if (input2Attr.getType().getElementType().isa<IntegerType>() &&
input2Attr.getSplatValue<APInt>().isZero()) {
rewriter.replaceOp(op, op.getInput1());
return success();
}
}
return failure();
}
};
void AddOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<AddZeroOptimization>(context);
}
struct MulOneOptimization : public OpRewritePattern<tosa::MulOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::MulOp op,
PatternRewriter &rewriter) const override {
auto input1 = op.getInput1();
auto input2 = op.getInput2();
DenseElementsAttr input1Attr;
if (matchPattern(input1, m_Constant(&input1Attr)) && input1Attr.isSplat() &&
input2.getType() == op.getType()) {
if (input1Attr.getType().getElementType().isa<FloatType>() &&
input1Attr.getSplatValue<APFloat>().isExactlyValue(1)) {
rewriter.replaceOp(op, op.getInput2());
return success();
}
if (input1Attr.getType().getElementType().isa<IntegerType>() &&
matchPattern(input1, m_One())) {
rewriter.replaceOp(op, op.getInput2());
return success();
}
}
DenseElementsAttr input2Attr;
if (matchPattern(input2, m_Constant(&input2Attr)) && input2Attr.isSplat() &&
input1.getType() == op.getType()) {
if (input2Attr.getType().getElementType().isa<FloatType>() &&
input2Attr.getSplatValue<APFloat>().isExactlyValue(1)) {
rewriter.replaceOp(op, op.getInput1());
return success();
}
if (input2Attr.getType().getElementType().isa<IntegerType>() &&
matchPattern(input2, m_One())) {
rewriter.replaceOp(op, op.getInput1());
return success();
}
}
return failure();
}
};
void MulOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<MulOneOptimization>(context);
}
struct MaterializePadValue : public OpRewritePattern<tosa::PadOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::PadOp op,
PatternRewriter &rewriter) const override {
if (op.getPadConst())
return failure();
auto input = op.getInput1();
auto padding = op.getPadding();
ShapedType inputTy = input.getType().cast<ShapedType>();
Type elementTy = inputTy.getElementType();
Attribute constantAttr;
if (elementTy.isa<FloatType>()) {
constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
} else if (elementTy.isa<IntegerType>() && !op.getQuantizationInfo()) {
constantAttr = rewriter.getIntegerAttr(elementTy, 0);
} else if (elementTy.isa<IntegerType>() && op.getQuantizationInfo()) {
auto value = op.getQuantizationInfo()->getInputZp();
constantAttr = rewriter.getIntegerAttr(elementTy, value);
}
if (!constantAttr) {
return rewriter.notifyMatchFailure(
op,
"tosa.pad to linalg lowering encountered an unknown element type");
}
auto denseAttr = DenseElementsAttr::get(
RankedTensorType::get({}, elementTy), constantAttr);
auto constantVal = rewriter.create<tosa::ConstOp>(
op.getLoc(), denseAttr.getType(), denseAttr);
rewriter.replaceOpWithNewOp<tosa::PadOp>(
op, op.getType(), ValueRange{input, padding, constantVal},
op->getAttrs());
return success();
}
};
void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<MaterializePadValue>(context);
}
struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
PatternRewriter &rewriter) const override {
Value input = op.getInput();
Value output = op.getOutput();
ShapedType inputType = input.getType().cast<ShapedType>();
ShapedType outputType = output.getType().cast<ShapedType>();
if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
return failure();
}
// If the output and input shapes are 1x1, then this is a no op.
ArrayRef<int64_t> outputShape = outputType.getShape();
if (outputShape[1] != 1 || outputShape[2] != 1) {
return failure();
}
ArrayRef<int64_t> inputShape = inputType.getShape();
if (inputShape[1] != 1 || inputShape[2] != 1) {
return failure();
}
rewriter.replaceOp(op, input);
return success();
}
};
void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<MaxPool2dIsNoOp>(context);
}
struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::ClampOp op,
PatternRewriter &rewriter) const override {
Value input = op.getInput();
auto inputType =
op.getInput().getType().template dyn_cast<RankedTensorType>();
auto inputElementType = inputType.getElementType();
if (!inputType.hasStaticShape()) {
return failure();
}
if (inputElementType.isF32()) {
auto minClamp = op.getMinFp();
auto maxClamp = op.getMaxFp();
bool isMin = (minClamp.isLargest() || minClamp.isInfinity()) &&
minClamp.isNegative();
bool isMax = (maxClamp.isLargest() || maxClamp.isInfinity()) &&
!maxClamp.isNegative();
if (isMin && isMax) {
rewriter.replaceOp(op, input);
return success();
}
return failure();
}
if (inputElementType.isUnsignedInteger()) {
int64_t minClamp = op.getMinInt();
int64_t maxClamp = op.getMaxInt();
int64_t intMin =
APInt::getMinValue(inputElementType.getIntOrFloatBitWidth())
.getZExtValue();
int64_t intMax =
APInt::getMaxValue(inputElementType.getIntOrFloatBitWidth())
.getZExtValue();
if (minClamp <= intMin && maxClamp >= intMax) {
rewriter.replaceOp(op, input);
return success();
}
return failure();
}
if (inputElementType.isa<IntegerType>()) {
int64_t minClamp = op.getMinInt();
int64_t maxClamp = op.getMaxInt();
int64_t intMin =
APInt::getSignedMinValue(inputElementType.getIntOrFloatBitWidth())
.getSExtValue();
int64_t intMax =
APInt::getSignedMaxValue(inputElementType.getIntOrFloatBitWidth())
.getSExtValue();
if (minClamp <= intMin && maxClamp >= intMax) {
rewriter.replaceOp(op, input);
return success();
}
return failure();
}
return failure();
}
};
struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
using OpRewritePattern<tosa::ClampOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::ClampOp op,
PatternRewriter &rewriter) const override {
Value input = op.getInput();
Operation *definingOp = input.getDefiningOp();
if (!definingOp)
return failure();
if (tosa::ClampOp clampOp = dyn_cast<tosa::ClampOp>(definingOp)) {
auto minFp = std::max(op.getMinFp(), clampOp.getMinFp()).convertToFloat();
auto maxFp = std::min(op.getMaxFp(), clampOp.getMaxFp()).convertToFloat();
auto minInt = std::max(op.getMinInt(), clampOp.getMinInt());
auto maxInt = std::min(op.getMaxInt(), clampOp.getMaxInt());
rewriter.replaceOpWithNewOp<tosa::ClampOp>(
op, op.getType(), clampOp.getInput(),
rewriter.getI64IntegerAttr(minInt),
rewriter.getI64IntegerAttr(maxInt), rewriter.getF32FloatAttr(minFp),
rewriter.getF32FloatAttr(maxFp));
return success();
}
return failure();
}
};
void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ClampIsNoOp>(context);
results.add<ClampClampOptimization>(context);
}
//===----------------------------------------------------------------------===//
// Operator Folders.
//===----------------------------------------------------------------------===//
template <typename IntFolder, typename FloatFolder>
DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
RankedTensorType returnTy) {
if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
auto lETy = lhs.getType().cast<ShapedType>().getElementType();
auto rETy = rhs.getType().cast<ShapedType>().getElementType();
if (lETy != rETy)
return {};
if (lETy.isa<IntegerType>()) {
APInt l = lhs.getSplatValue<APInt>();
APInt r = rhs.getSplatValue<APInt>();
auto result = IntFolder()(l, r);
return DenseElementsAttr::get(returnTy, result);
}
if (lETy.isa<FloatType>()) {
APFloat l = lhs.getSplatValue<APFloat>();
APFloat r = rhs.getSplatValue<APFloat>();
auto result = FloatFolder()(l, r);
return DenseElementsAttr::get(returnTy, result);
}
}
return {};
}
OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
auto lhsTy = getInput1().getType().dyn_cast<RankedTensorType>();
auto rhsTy = getInput2().getType().dyn_cast<RankedTensorType>();
auto resultTy = getType().dyn_cast<RankedTensorType>();
if (!lhsTy || !rhsTy || !resultTy)
return {};
if (lhsTy != rhsTy)
return {};
auto resultETy = resultTy.getElementType();
auto lhsAttr = operands[0].dyn_cast_or_null<DenseElementsAttr>();
auto rhsAttr = operands[1].dyn_cast_or_null<DenseElementsAttr>();
if (lhsAttr && lhsAttr.isSplat() && resultETy.isa<FloatType>()) {
if (lhsAttr.getSplatValue<APFloat>().isZero())
return getInput2();
}
if (rhsAttr && rhsAttr.isSplat() && resultETy.isa<FloatType>()) {
if (rhsAttr.getSplatValue<APFloat>().isZero())
return getInput1();
}
if (lhsAttr && lhsAttr.isSplat() && resultETy.isa<IntegerType>()) {
if (lhsAttr.getSplatValue<APInt>().isZero())
return getInput2();
}
if (rhsAttr && rhsAttr.isSplat() && resultETy.isa<IntegerType>()) {
if (rhsAttr.getSplatValue<APInt>().isZero())
return getInput1();
}
if (!lhsAttr || !rhsAttr)
return {};
return binaryFolder<std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
lhsTy);
}
OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
auto lhsTy = getInput1().getType().dyn_cast<RankedTensorType>();
auto rhsTy = getInput2().getType().dyn_cast<RankedTensorType>();
auto resultTy = getType().dyn_cast<RankedTensorType>();
if (!lhsTy || !rhsTy || !resultTy)
return {};
if (lhsTy != rhsTy)
return {};
auto resultETy = resultTy.getElementType();
auto lhsAttr = operands[0].dyn_cast_or_null<DenseElementsAttr>();
auto rhsAttr = operands[1].dyn_cast_or_null<DenseElementsAttr>();
if (lhsAttr && lhsAttr.isSplat()) {
if (resultETy.isa<IntegerType>() && lhsAttr.getSplatValue<APInt>().isZero())
return lhsAttr;
}
if (rhsAttr && rhsAttr.isSplat()) {
if (resultETy.isa<IntegerType>() && rhsAttr.getSplatValue<APInt>().isOne())
return getInput1();
}
if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat()) {
if (resultETy.isa<IntegerType>()) {
APInt l = lhsAttr.getSplatValue<APInt>();
APInt r = rhsAttr.getSplatValue<APInt>();
APInt result = l.sdiv(r);
return DenseElementsAttr::get(resultTy, result);
}
}
return {};
}
namespace {
DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
RankedTensorType ty, int32_t shift) {
if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
if (ty.getElementType().isa<IntegerType>()) {
APInt l = lhs.getSplatValue<APInt>();
APInt r = rhs.getSplatValue<APInt>();
if (shift == 0) {
return DenseElementsAttr::get(ty, l * r);
}
auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
l = l.sext(bitwidth * 2);
r = r.sext(bitwidth * 2);
auto result = l * r;
result.lshrInPlace(shift);
result = result.trunc(bitwidth);
return DenseElementsAttr::get(ty, result);
}
if (ty.getElementType().isa<FloatType>()) {
APFloat l = lhs.getSplatValue<APFloat>();
APFloat r = rhs.getSplatValue<APFloat>();
APFloat result = l * r;
return DenseElementsAttr::get(ty, result);
}
}
return {};
}
} // namespace
OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
auto lhs = getInput1();
auto rhs = getInput2();
auto lhsTy = lhs.getType().dyn_cast<RankedTensorType>();
auto rhsTy = rhs.getType().dyn_cast<RankedTensorType>();
auto resultTy = getType().dyn_cast<RankedTensorType>();
if (!lhsTy || !rhsTy || !resultTy)
return {};
if (lhsTy != rhsTy)
return {};
auto resultETy = resultTy.getElementType();
auto lhsAttr = operands[0].dyn_cast_or_null<DenseElementsAttr>();
auto rhsAttr = operands[1].dyn_cast_or_null<DenseElementsAttr>();
if (lhsAttr && lhsAttr.isSplat() && resultETy.isa<FloatType>()) {
auto val = lhsAttr.getSplatValue<APFloat>();
if (val.isZero())
return lhsAttr;
if (val.isExactlyValue(1.0))
return rhs;
}
if (rhsAttr && rhsAttr.isSplat() && resultETy.isa<FloatType>()) {
auto val = rhsAttr.getSplatValue<APFloat>();
if (val.isZero())
return rhsAttr;
if (val.isExactlyValue(1.0))
return lhs;
}
if (lhsAttr && lhsAttr.isSplat() && resultETy.isa<IntegerType>()) {
auto val = lhsAttr.getSplatValue<APInt>();
if (val.isZero())
return lhsAttr;
const int64_t shift = getShift();
const int64_t shifted = 1LL << shift;
if (val.getSExtValue() == shifted)
return rhs;
}
if (rhsAttr && rhsAttr.isSplat() && resultETy.isa<IntegerType>()) {
auto val = rhsAttr.getSplatValue<APInt>();
const int64_t shift = getShift();
const int64_t shifted = 1LL << shift;
if (val.isZero())
return rhsAttr;
if (val.getSExtValue() == shifted)
return lhs;
}
return mulBinaryFolder(lhsAttr, rhsAttr, lhsTy, getShift());
}
OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
auto lhsTy = getInput1().getType().dyn_cast<RankedTensorType>();
auto rhsTy = getInput2().getType().dyn_cast<RankedTensorType>();
auto resultTy = getType().dyn_cast<RankedTensorType>();
if (!lhsTy || !rhsTy || !resultTy)
return {};
if (lhsTy != rhsTy)
return {};
auto resultETy = resultTy.getElementType();
auto lhsAttr = operands[0].dyn_cast_or_null<DenseElementsAttr>();
auto rhsAttr = operands[1].dyn_cast_or_null<DenseElementsAttr>();
if (rhsAttr && rhsAttr.isSplat() && resultETy.isa<FloatType>()) {
if (rhsAttr.getSplatValue<APFloat>().isZero())
return getInput1();
}
if (rhsAttr && rhsAttr.isSplat() && resultETy.isa<IntegerType>()) {
if (rhsAttr.getSplatValue<APInt>().isZero())
return getInput1();
}
if (!lhsAttr || !rhsAttr)
return {};
return binaryFolder<std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
lhsTy);
}
namespace {
template <typename Cmp> struct ComparisonFold {
ComparisonFold() = default;
APInt operator()(const APInt &l, const APInt &r) {
return APInt(1, Cmp()(l, r));
}
APInt operator()(const APFloat &l, const APFloat &r) {
return APInt(1, Cmp()(l, r));
}
};
struct APIntFoldGreater {
APIntFoldGreater() = default;
APInt operator()(const APInt &l, const APInt &r) {
return APInt(1, l.sgt(r));
}
};
struct APIntFoldGreaterEqual {
APIntFoldGreaterEqual() = default;
APInt operator()(const APInt &l, const APInt &r) {
return APInt(1, l.sge(r));
}
};
} // namespace
OpFoldResult GreaterOp::fold(ArrayRef<Attribute> operands) {
auto resultTy = getType().dyn_cast<RankedTensorType>();
auto lhsAttr = operands[0].dyn_cast_or_null<DenseElementsAttr>();
auto rhsAttr = operands[1].dyn_cast_or_null<DenseElementsAttr>();
if (!lhsAttr || !rhsAttr)
return {};
return binaryFolder<APIntFoldGreater, ComparisonFold<std::greater<APFloat>>>(
lhsAttr, rhsAttr, resultTy);
}
OpFoldResult GreaterEqualOp::fold(ArrayRef<Attribute> operands) {
auto resultTy = getType().dyn_cast<RankedTensorType>();
auto lhsAttr = operands[0].dyn_cast_or_null<DenseElementsAttr>();
auto rhsAttr = operands[1].dyn_cast_or_null<DenseElementsAttr>();
if (!lhsAttr || !rhsAttr)
return {};
return binaryFolder<APIntFoldGreaterEqual,
ComparisonFold<std::greater_equal<APFloat>>>(
lhsAttr, rhsAttr, resultTy);
}
OpFoldResult EqualOp::fold(ArrayRef<Attribute> operands) {
auto resultTy = getType().dyn_cast<RankedTensorType>();
auto lhsAttr = operands[0].dyn_cast_or_null<DenseElementsAttr>();
auto rhsAttr = operands[1].dyn_cast_or_null<DenseElementsAttr>();
Value lhs = getInput1();
Value rhs = getInput2();
auto lhsTy = lhs.getType().cast<ShapedType>();
// If we are comparing an integer value to itself it is always true. We can
// not do this with float due to float values.
if (lhsTy.getElementType().isa<IntegerType>() && resultTy.hasStaticShape() &&
lhs == rhs) {
return DenseElementsAttr::get(resultTy, true);
}
if (!lhsAttr || !rhsAttr)
return {};
return binaryFolder<ComparisonFold<std::equal_to<APInt>>,
ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
resultTy);
}
OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
if (getInput().getType() == getType())
return getInput();
auto operand = operands[0].dyn_cast_or_null<ElementsAttr>();
if (!operand)
return {};
auto inTy = getInput().getType().cast<ShapedType>();
auto outTy = getType().cast<ShapedType>();
auto inETy = inTy.getElementType();
auto outETy = outTy.getElementType();
if (operand.isSplat()) {
if (inETy.isa<FloatType>() && outETy.isa<FloatType>()) {
bool overflow;
auto splatVal = operand.getSplatValue<APFloat>();
auto &semantics = outETy.cast<FloatType>().getFloatSemantics();
splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven,
&overflow);
return SplatElementsAttr::get(outTy, splatVal);
}
if (inETy.isa<IntegerType>() && outETy.isa<FloatType>()) {
auto unsign = inETy.cast<IntegerType>().isUnsignedInteger();
APFloat splatVal(outETy.cast<FloatType>().getFloatSemantics());
splatVal.convertFromAPInt(operand.getSplatValue<APInt>(), !unsign,
llvm::RoundingMode::NearestTiesToEven);
return SplatElementsAttr::get(outTy, splatVal);
}
if (inETy.isa<FloatType>() && outETy.isa<IntegerType>()) {
auto unsign = outETy.cast<IntegerType>().isUnsignedInteger();
auto intVal =
APSInt(outETy.cast<IntegerType>().getIntOrFloatBitWidth(), unsign);
auto floatVal = operand.getSplatValue<APFloat>();
bool exact;
floatVal.convertToInteger(intVal, llvm::RoundingMode::TowardZero, &exact);
return SplatElementsAttr::get(outTy, intVal);
}
if (inETy.isa<IntegerType>() && outETy.isa<IntegerType>()) {
auto unsignIn = inETy.cast<IntegerType>().isUnsignedInteger();
bool trunc =
inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth();
auto intVal = operand.getSplatValue<APInt>();
auto bitwidth = outETy.getIntOrFloatBitWidth();
if (trunc) {
intVal = intVal.trunc(bitwidth);
} else if (unsignIn) {
intVal = intVal.zext(bitwidth);
} else {
intVal = intVal.sext(bitwidth);
}
return SplatElementsAttr::get(outTy, intVal);
}
}
return {};
}
OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
assert(operands.empty() && "constant has no operands");
return getValueAttr();
}
#define REDUCE_FOLDER(OP) \
OpFoldResult OP::fold(ArrayRef<Attribute> operands) { \
ShapedType inputTy = getInput().getType().cast<ShapedType>(); \
if (!inputTy.hasRank()) \
return {}; \
if (inputTy.getDimSize(getAxis()) == 1) \
return getInput(); \
return {}; \
}
REDUCE_FOLDER(ReduceAllOp)
REDUCE_FOLDER(ReduceAnyOp)
REDUCE_FOLDER(ReduceMaxOp)
REDUCE_FOLDER(ReduceMinOp)
REDUCE_FOLDER(ReduceProdOp)
REDUCE_FOLDER(ReduceSumOp)
#undef REDUCE_FOLDER
OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
auto inputTy = getInput1().getType().dyn_cast<RankedTensorType>();
auto outputTy = getType().dyn_cast<RankedTensorType>();
if (!inputTy || !outputTy)
return {};
if (inputTy == outputTy)
return getInput1();
auto operand = operands[0].dyn_cast_or_null<DenseElementsAttr>();
if (operand && outputTy.hasStaticShape() && operand.isSplat()) {
return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
}
return {};
}
OpFoldResult PadOp::fold(ArrayRef<Attribute> operands) {
// If the pad is all zeros we can fold this operation away.
if (operands[1]) {
auto densePad = operands[1].cast<DenseElementsAttr>();
if (densePad.isSplat() && densePad.getSplatValue<APInt>().isZero()) {
return getInput1();
}
}
return {};
}
// Fold away cases where a tosa.resize operation returns a copy
// of the input image.
OpFoldResult ResizeOp::fold(ArrayRef<Attribute> operands) {
ArrayRef<int64_t> offset = getOffset();
ArrayRef<int64_t> border = getBorder();
ArrayRef<int64_t> scale = getScale();
// Check unit scaling.
if (scale[0] != scale[1] || scale[2] != scale[3]) {
return {};
}
// There should be no offset.
if (offset[0] != 0 || offset[1] != 0) {
return {};
}
// There should be no border.
if (border[0] != 0 || border[1] != 0) {
return {};
}
auto input = getInput();
auto inputTy = input.getType().cast<RankedTensorType>();
auto resultTy = getType().cast<RankedTensorType>();
if (inputTy != resultTy)
return {};
return input;
}
OpFoldResult ReverseOp::fold(ArrayRef<Attribute> operands) {
auto operand = getInput();
auto operandTy = operand.getType().cast<ShapedType>();
auto axis = getAxis();
auto operandAttr = operands[0].dyn_cast_or_null<SplatElementsAttr>();
if (operandAttr)
return operandAttr;
// If the dim-length is 1, tosa.reverse is a no-op.
if (operandTy.hasRank() && operandTy.getDimSize(axis) == 1)
return operand;
return {};
}
OpFoldResult SliceOp::fold(ArrayRef<Attribute> operands) {
auto inputTy = getInput().getType().dyn_cast<RankedTensorType>();
auto outputTy = getType().dyn_cast<RankedTensorType>();
if (!inputTy || !outputTy)
return {};
if (inputTy == outputTy && inputTy.hasStaticShape())
return getInput();
if (!operands[0])
return {};
auto operand = operands[0].cast<ElementsAttr>();
if (operand.isSplat() && outputTy.hasStaticShape()) {
return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
}
if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
outputTy.getNumElements() == 1) {
llvm::SmallVector<uint64_t> indices(getStart());
auto value = operand.getValues<Attribute>()[indices];
return SplatElementsAttr::get(outputTy, value);
}
return {};
}
OpFoldResult tosa::SelectOp::fold(ArrayRef<Attribute> operands) {
if (getOnTrue() == getOnFalse())
return getOnTrue();
auto predicate = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
if (!predicate)
return {};
if (!predicate.isSplat())
return {};
return predicate.getSplatValue<APInt>().getBoolValue() ? getOnTrue()
: getOnFalse();
}
OpFoldResult TileOp::fold(ArrayRef<Attribute> operands) {
bool allOnes = llvm::all_of(getMultiples(), [](int64_t v) { return v == 1; });
if (allOnes && getInput1().getType() == getType())
return getInput1();
return {};
}
OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
if (!operands[1])
return {};
auto inputTy = getInput1().getType().cast<ShapedType>();
auto resultTy = getType().cast<ShapedType>();
if (inputTy.getElementType() != resultTy.getElementType())
return {};
// Transposing splat values just means reshaping.
if (auto input = operands[0].dyn_cast_or_null<DenseElementsAttr>()) {
if (input.isSplat())
return input.reshape(getType().cast<ShapedType>());
}
auto perms = llvm::to_vector<6>(llvm::map_range(
operands[1].cast<DenseIntElementsAttr>().getValues<APInt>(),
[](const APInt &val) { return val.getSExtValue(); }));
if (llvm::equal(llvm::seq<int64_t>(0, perms.size()), perms) &&
getInput1().getType() == getType())
return getInput1();
return {};
}