Add `Idempotent` trait to `arith.{andi,ori}`.
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D114574
1153 lines
41 KiB
C++
1153 lines
41 KiB
C++
//===- ArithmeticOps.cpp - MLIR Arithmetic dialect ops implementation -----===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
|
#include "mlir/Dialect/CommonFolders.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/Matchers.h"
|
|
#include "mlir/IR/OpImplementation.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/IR/TypeUtilities.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::arith;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Pattern helpers
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
|
|
Attribute lhs, Attribute rhs) {
|
|
return builder.getIntegerAttr(res.getType(),
|
|
lhs.cast<IntegerAttr>().getInt() +
|
|
rhs.cast<IntegerAttr>().getInt());
|
|
}
|
|
|
|
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
|
|
Attribute lhs, Attribute rhs) {
|
|
return builder.getIntegerAttr(res.getType(),
|
|
lhs.cast<IntegerAttr>().getInt() -
|
|
rhs.cast<IntegerAttr>().getInt());
|
|
}
|
|
|
|
/// Invert an integer comparison predicate.
|
|
static arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred) {
|
|
switch (pred) {
|
|
case arith::CmpIPredicate::eq:
|
|
return arith::CmpIPredicate::ne;
|
|
case arith::CmpIPredicate::ne:
|
|
return arith::CmpIPredicate::eq;
|
|
case arith::CmpIPredicate::slt:
|
|
return arith::CmpIPredicate::sge;
|
|
case arith::CmpIPredicate::sle:
|
|
return arith::CmpIPredicate::sgt;
|
|
case arith::CmpIPredicate::sgt:
|
|
return arith::CmpIPredicate::sle;
|
|
case arith::CmpIPredicate::sge:
|
|
return arith::CmpIPredicate::slt;
|
|
case arith::CmpIPredicate::ult:
|
|
return arith::CmpIPredicate::uge;
|
|
case arith::CmpIPredicate::ule:
|
|
return arith::CmpIPredicate::ugt;
|
|
case arith::CmpIPredicate::ugt:
|
|
return arith::CmpIPredicate::ule;
|
|
case arith::CmpIPredicate::uge:
|
|
return arith::CmpIPredicate::ult;
|
|
}
|
|
llvm_unreachable("unknown cmpi predicate kind");
|
|
}
|
|
|
|
static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
|
|
return arith::CmpIPredicateAttr::get(pred.getContext(),
|
|
invertPredicate(pred.getValue()));
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TableGen'd canonicalization patterns
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
#include "ArithmeticCanonicalization.inc"
|
|
} // end anonymous namespace
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ConstantOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void arith::ConstantOp::getAsmResultNames(
|
|
function_ref<void(Value, StringRef)> setNameFn) {
|
|
auto type = getType();
|
|
if (auto intCst = getValue().dyn_cast<IntegerAttr>()) {
|
|
auto intType = type.dyn_cast<IntegerType>();
|
|
|
|
// Sugar i1 constants with 'true' and 'false'.
|
|
if (intType && intType.getWidth() == 1)
|
|
return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
|
|
|
|
// Otherwise, build a compex name with the value and type.
|
|
SmallString<32> specialNameBuffer;
|
|
llvm::raw_svector_ostream specialName(specialNameBuffer);
|
|
specialName << 'c' << intCst.getInt();
|
|
if (intType)
|
|
specialName << '_' << type;
|
|
setNameFn(getResult(), specialName.str());
|
|
} else {
|
|
setNameFn(getResult(), "cst");
|
|
}
|
|
}
|
|
|
|
/// TODO: disallow arith.constant to return anything other than signless integer
|
|
/// or float like.
|
|
static LogicalResult verify(arith::ConstantOp op) {
|
|
auto type = op.getType();
|
|
// The value's type must match the return type.
|
|
if (op.getValue().getType() != type) {
|
|
return op.emitOpError() << "value type " << op.getValue().getType()
|
|
<< " must match return type: " << type;
|
|
}
|
|
// Integer values must be signless.
|
|
if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
|
|
return op.emitOpError("integer return type must be signless");
|
|
// Any float or elements attribute are acceptable.
|
|
if (!op.getValue().isa<IntegerAttr, FloatAttr, ElementsAttr>()) {
|
|
return op.emitOpError(
|
|
"value must be an integer, float, or elements attribute");
|
|
}
|
|
return success();
|
|
}
|
|
|
|
bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
|
|
// The value's type must be the same as the provided type.
|
|
if (value.getType() != type)
|
|
return false;
|
|
// Integer values must be signless.
|
|
if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
|
|
return false;
|
|
// Integer, float, and element attributes are buildable.
|
|
return value.isa<IntegerAttr, FloatAttr, ElementsAttr>();
|
|
}
|
|
|
|
OpFoldResult arith::ConstantOp::fold(ArrayRef<Attribute> operands) {
|
|
return getValue();
|
|
}
|
|
|
|
void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
|
|
int64_t value, unsigned width) {
|
|
auto type = builder.getIntegerType(width);
|
|
arith::ConstantOp::build(builder, result, type,
|
|
builder.getIntegerAttr(type, value));
|
|
}
|
|
|
|
void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
|
|
int64_t value, Type type) {
|
|
assert(type.isSignlessInteger() &&
|
|
"ConstantIntOp can only have signless integer type values");
|
|
arith::ConstantOp::build(builder, result, type,
|
|
builder.getIntegerAttr(type, value));
|
|
}
|
|
|
|
bool arith::ConstantIntOp::classof(Operation *op) {
|
|
if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
|
|
return constOp.getType().isSignlessInteger();
|
|
return false;
|
|
}
|
|
|
|
void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result,
|
|
const APFloat &value, FloatType type) {
|
|
arith::ConstantOp::build(builder, result, type,
|
|
builder.getFloatAttr(type, value));
|
|
}
|
|
|
|
bool arith::ConstantFloatOp::classof(Operation *op) {
|
|
if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
|
|
return constOp.getType().isa<FloatType>();
|
|
return false;
|
|
}
|
|
|
|
void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result,
|
|
int64_t value) {
|
|
arith::ConstantOp::build(builder, result, builder.getIndexType(),
|
|
builder.getIndexAttr(value));
|
|
}
|
|
|
|
bool arith::ConstantIndexOp::classof(Operation *op) {
|
|
if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
|
|
return constOp.getType().isIndex();
|
|
return false;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// AddIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::AddIOp::fold(ArrayRef<Attribute> operands) {
|
|
// addi(x, 0) -> x
|
|
if (matchPattern(getRhs(), m_Zero()))
|
|
return getLhs();
|
|
|
|
return constFoldBinaryOp<IntegerAttr>(operands,
|
|
[](APInt a, APInt b) { return a + b; });
|
|
}
|
|
|
|
void arith::AddIOp::getCanonicalizationPatterns(
|
|
OwningRewritePatternList &patterns, MLIRContext *context) {
|
|
patterns.insert<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS>(
|
|
context);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// SubIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::SubIOp::fold(ArrayRef<Attribute> operands) {
|
|
// subi(x,x) -> 0
|
|
if (getOperand(0) == getOperand(1))
|
|
return Builder(getContext()).getZeroAttr(getType());
|
|
// subi(x,0) -> x
|
|
if (matchPattern(getRhs(), m_Zero()))
|
|
return getLhs();
|
|
|
|
return constFoldBinaryOp<IntegerAttr>(operands,
|
|
[](APInt a, APInt b) { return a - b; });
|
|
}
|
|
|
|
void arith::SubIOp::getCanonicalizationPatterns(
|
|
OwningRewritePatternList &patterns, MLIRContext *context) {
|
|
patterns.insert<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
|
|
SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
|
|
SubILHSSubConstantLHS>(context);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MulIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) {
|
|
// muli(x, 0) -> 0
|
|
if (matchPattern(getRhs(), m_Zero()))
|
|
return getRhs();
|
|
// muli(x, 1) -> x
|
|
if (matchPattern(getRhs(), m_One()))
|
|
return getOperand(0);
|
|
// TODO: Handle the overflow case.
|
|
|
|
// default folder
|
|
return constFoldBinaryOp<IntegerAttr>(operands,
|
|
[](APInt a, APInt b) { return a * b; });
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// DivUIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) {
|
|
// Don't fold if it would require a division by zero.
|
|
bool div0 = false;
|
|
auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
|
|
if (div0 || !b) {
|
|
div0 = true;
|
|
return a;
|
|
}
|
|
return a.udiv(b);
|
|
});
|
|
|
|
// Fold out division by one. Assumes all tensors of all ones are splats.
|
|
if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
|
|
if (rhs.getValue() == 1)
|
|
return getLhs();
|
|
} else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
|
|
if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
|
|
return getLhs();
|
|
}
|
|
|
|
return div0 ? Attribute() : result;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// DivSIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::DivSIOp::fold(ArrayRef<Attribute> operands) {
|
|
// Don't fold if it would overflow or if it requires a division by zero.
|
|
bool overflowOrDiv0 = false;
|
|
auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
|
|
if (overflowOrDiv0 || !b) {
|
|
overflowOrDiv0 = true;
|
|
return a;
|
|
}
|
|
return a.sdiv_ov(b, overflowOrDiv0);
|
|
});
|
|
|
|
// Fold out division by one. Assumes all tensors of all ones are splats.
|
|
if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
|
|
if (rhs.getValue() == 1)
|
|
return getLhs();
|
|
} else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
|
|
if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
|
|
return getLhs();
|
|
}
|
|
|
|
return overflowOrDiv0 ? Attribute() : result;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Ceil and floor division folding helpers
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static APInt signedCeilNonnegInputs(APInt a, APInt b, bool &overflow) {
|
|
// Returns (a-1)/b + 1
|
|
APInt one(a.getBitWidth(), 1, true); // Signed value 1.
|
|
APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
|
|
return val.sadd_ov(one, overflow);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CeilDivUIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) {
|
|
bool overflowOrDiv0 = false;
|
|
auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
|
|
if (overflowOrDiv0 || !b) {
|
|
overflowOrDiv0 = true;
|
|
return a;
|
|
}
|
|
APInt quotient = a.udiv(b);
|
|
if (!a.urem(b))
|
|
return quotient;
|
|
APInt one(a.getBitWidth(), 1, true);
|
|
return quotient.uadd_ov(one, overflowOrDiv0);
|
|
});
|
|
// Fold out ceil division by one. Assumes all tensors of all ones are
|
|
// splats.
|
|
if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
|
|
if (rhs.getValue() == 1)
|
|
return getLhs();
|
|
} else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
|
|
if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
|
|
return getLhs();
|
|
}
|
|
|
|
return overflowOrDiv0 ? Attribute() : result;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CeilDivSIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) {
|
|
// Don't fold if it would overflow or if it requires a division by zero.
|
|
bool overflowOrDiv0 = false;
|
|
auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
|
|
if (overflowOrDiv0 || !b) {
|
|
overflowOrDiv0 = true;
|
|
return a;
|
|
}
|
|
unsigned bits = a.getBitWidth();
|
|
APInt zero = APInt::getZero(bits);
|
|
if (a.sgt(zero) && b.sgt(zero)) {
|
|
// Both positive, return ceil(a, b).
|
|
return signedCeilNonnegInputs(a, b, overflowOrDiv0);
|
|
}
|
|
if (a.slt(zero) && b.slt(zero)) {
|
|
// Both negative, return ceil(-a, -b).
|
|
APInt posA = zero.ssub_ov(a, overflowOrDiv0);
|
|
APInt posB = zero.ssub_ov(b, overflowOrDiv0);
|
|
return signedCeilNonnegInputs(posA, posB, overflowOrDiv0);
|
|
}
|
|
if (a.slt(zero) && b.sgt(zero)) {
|
|
// A is negative, b is positive, return - ( -a / b).
|
|
APInt posA = zero.ssub_ov(a, overflowOrDiv0);
|
|
APInt div = posA.sdiv_ov(b, overflowOrDiv0);
|
|
return zero.ssub_ov(div, overflowOrDiv0);
|
|
}
|
|
// A is positive (or zero), b is negative, return - (a / -b).
|
|
APInt posB = zero.ssub_ov(b, overflowOrDiv0);
|
|
APInt div = a.sdiv_ov(posB, overflowOrDiv0);
|
|
return zero.ssub_ov(div, overflowOrDiv0);
|
|
});
|
|
|
|
// Fold out ceil division by one. Assumes all tensors of all ones are
|
|
// splats.
|
|
if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
|
|
if (rhs.getValue() == 1)
|
|
return getLhs();
|
|
} else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
|
|
if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
|
|
return getLhs();
|
|
}
|
|
|
|
return overflowOrDiv0 ? Attribute() : result;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// FloorDivSIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) {
|
|
// Don't fold if it would overflow or if it requires a division by zero.
|
|
bool overflowOrDiv0 = false;
|
|
auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
|
|
if (overflowOrDiv0 || !b) {
|
|
overflowOrDiv0 = true;
|
|
return a;
|
|
}
|
|
unsigned bits = a.getBitWidth();
|
|
APInt zero = APInt::getZero(bits);
|
|
if (a.sge(zero) && b.sgt(zero)) {
|
|
// Both positive (or a is zero), return a / b.
|
|
return a.sdiv_ov(b, overflowOrDiv0);
|
|
}
|
|
if (a.sle(zero) && b.slt(zero)) {
|
|
// Both negative (or a is zero), return -a / -b.
|
|
APInt posA = zero.ssub_ov(a, overflowOrDiv0);
|
|
APInt posB = zero.ssub_ov(b, overflowOrDiv0);
|
|
return posA.sdiv_ov(posB, overflowOrDiv0);
|
|
}
|
|
if (a.slt(zero) && b.sgt(zero)) {
|
|
// A is negative, b is positive, return - ceil(-a, b).
|
|
APInt posA = zero.ssub_ov(a, overflowOrDiv0);
|
|
APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0);
|
|
return zero.ssub_ov(ceil, overflowOrDiv0);
|
|
}
|
|
// A is positive, b is negative, return - ceil(a, -b).
|
|
APInt posB = zero.ssub_ov(b, overflowOrDiv0);
|
|
APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0);
|
|
return zero.ssub_ov(ceil, overflowOrDiv0);
|
|
});
|
|
|
|
// Fold out floor division by one. Assumes all tensors of all ones are
|
|
// splats.
|
|
if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
|
|
if (rhs.getValue() == 1)
|
|
return getLhs();
|
|
} else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
|
|
if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
|
|
return getLhs();
|
|
}
|
|
|
|
return overflowOrDiv0 ? Attribute() : result;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// RemUIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::RemUIOp::fold(ArrayRef<Attribute> operands) {
|
|
auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
|
|
if (!rhs)
|
|
return {};
|
|
auto rhsValue = rhs.getValue();
|
|
|
|
// x % 1 = 0
|
|
if (rhsValue.isOneValue())
|
|
return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
|
|
|
|
// Don't fold if it requires division by zero.
|
|
if (rhsValue.isNullValue())
|
|
return {};
|
|
|
|
auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
|
|
if (!lhs)
|
|
return {};
|
|
return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue));
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// RemSIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::RemSIOp::fold(ArrayRef<Attribute> operands) {
|
|
auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
|
|
if (!rhs)
|
|
return {};
|
|
auto rhsValue = rhs.getValue();
|
|
|
|
// x % 1 = 0
|
|
if (rhsValue.isOneValue())
|
|
return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
|
|
|
|
// Don't fold if it requires division by zero.
|
|
if (rhsValue.isNullValue())
|
|
return {};
|
|
|
|
auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
|
|
if (!lhs)
|
|
return {};
|
|
return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue));
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// AndIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::AndIOp::fold(ArrayRef<Attribute> operands) {
|
|
/// and(x, 0) -> 0
|
|
if (matchPattern(getRhs(), m_Zero()))
|
|
return getRhs();
|
|
/// and(x, allOnes) -> x
|
|
APInt intValue;
|
|
if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes())
|
|
return getLhs();
|
|
|
|
return constFoldBinaryOp<IntegerAttr>(operands,
|
|
[](APInt a, APInt b) { return a & b; });
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// OrIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::OrIOp::fold(ArrayRef<Attribute> operands) {
|
|
/// or(x, 0) -> x
|
|
if (matchPattern(getRhs(), m_Zero()))
|
|
return getLhs();
|
|
/// or(x, <all ones>) -> <all ones>
|
|
if (auto rhsAttr = operands[1].dyn_cast_or_null<IntegerAttr>())
|
|
if (rhsAttr.getValue().isAllOnes())
|
|
return rhsAttr;
|
|
|
|
return constFoldBinaryOp<IntegerAttr>(operands,
|
|
[](APInt a, APInt b) { return a | b; });
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// XOrIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::XOrIOp::fold(ArrayRef<Attribute> operands) {
|
|
/// xor(x, 0) -> x
|
|
if (matchPattern(getRhs(), m_Zero()))
|
|
return getLhs();
|
|
/// xor(x, x) -> 0
|
|
if (getLhs() == getRhs())
|
|
return Builder(getContext()).getZeroAttr(getType());
|
|
|
|
return constFoldBinaryOp<IntegerAttr>(operands,
|
|
[](APInt a, APInt b) { return a ^ b; });
|
|
}
|
|
|
|
void arith::XOrIOp::getCanonicalizationPatterns(
|
|
OwningRewritePatternList &patterns, MLIRContext *context) {
|
|
patterns.insert<XOrINotCmpI>(context);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// AddFOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) {
|
|
return constFoldBinaryOp<FloatAttr>(
|
|
operands, [](APFloat a, APFloat b) { return a + b; });
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// SubFOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::SubFOp::fold(ArrayRef<Attribute> operands) {
|
|
return constFoldBinaryOp<FloatAttr>(
|
|
operands, [](APFloat a, APFloat b) { return a - b; });
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MaxSIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult MaxSIOp::fold(ArrayRef<Attribute> operands) {
|
|
assert(operands.size() == 2 && "binary operation takes two operands");
|
|
|
|
// maxsi(x,x) -> x
|
|
if (getLhs() == getRhs())
|
|
return getRhs();
|
|
|
|
APInt intValue;
|
|
// maxsi(x,MAX_INT) -> MAX_INT
|
|
if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
|
|
intValue.isMaxSignedValue())
|
|
return getRhs();
|
|
|
|
// maxsi(x, MIN_INT) -> x
|
|
if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
|
|
intValue.isMinSignedValue())
|
|
return getLhs();
|
|
|
|
return constFoldBinaryOp<IntegerAttr>(
|
|
operands, [](APInt a, APInt b) { return llvm::APIntOps::smax(a, b); });
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MaxUIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult MaxUIOp::fold(ArrayRef<Attribute> operands) {
|
|
assert(operands.size() == 2 && "binary operation takes two operands");
|
|
|
|
// maxui(x,x) -> x
|
|
if (getLhs() == getRhs())
|
|
return getRhs();
|
|
|
|
APInt intValue;
|
|
// maxui(x,MAX_INT) -> MAX_INT
|
|
if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
|
|
return getRhs();
|
|
|
|
// maxui(x, MIN_INT) -> x
|
|
if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
|
|
return getLhs();
|
|
|
|
return constFoldBinaryOp<IntegerAttr>(
|
|
operands, [](APInt a, APInt b) { return llvm::APIntOps::umax(a, b); });
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MinSIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult MinSIOp::fold(ArrayRef<Attribute> operands) {
|
|
assert(operands.size() == 2 && "binary operation takes two operands");
|
|
|
|
// minsi(x,x) -> x
|
|
if (getLhs() == getRhs())
|
|
return getRhs();
|
|
|
|
APInt intValue;
|
|
// minsi(x,MIN_INT) -> MIN_INT
|
|
if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
|
|
intValue.isMinSignedValue())
|
|
return getRhs();
|
|
|
|
// minsi(x, MAX_INT) -> x
|
|
if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
|
|
intValue.isMaxSignedValue())
|
|
return getLhs();
|
|
|
|
return constFoldBinaryOp<IntegerAttr>(
|
|
operands, [](APInt a, APInt b) { return llvm::APIntOps::smin(a, b); });
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MinUIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult MinUIOp::fold(ArrayRef<Attribute> operands) {
|
|
assert(operands.size() == 2 && "binary operation takes two operands");
|
|
|
|
// minui(x,x) -> x
|
|
if (getLhs() == getRhs())
|
|
return getRhs();
|
|
|
|
APInt intValue;
|
|
// minui(x,MIN_INT) -> MIN_INT
|
|
if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
|
|
return getRhs();
|
|
|
|
// minui(x, MAX_INT) -> x
|
|
if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
|
|
return getLhs();
|
|
|
|
return constFoldBinaryOp<IntegerAttr>(
|
|
operands, [](APInt a, APInt b) { return llvm::APIntOps::umin(a, b); });
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MulFOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) {
|
|
return constFoldBinaryOp<FloatAttr>(
|
|
operands, [](APFloat a, APFloat b) { return a * b; });
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// DivFOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::DivFOp::fold(ArrayRef<Attribute> operands) {
|
|
return constFoldBinaryOp<FloatAttr>(
|
|
operands, [](APFloat a, APFloat b) { return a / b; });
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Utility functions for verifying cast ops
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
template <typename... Types>
|
|
using type_list = std::tuple<Types...> *;
|
|
|
|
/// Returns a non-null type only if the provided type is one of the allowed
|
|
/// types or one of the allowed shaped types of the allowed types. Returns the
|
|
/// element type if a valid shaped type is provided.
|
|
template <typename... ShapedTypes, typename... ElementTypes>
|
|
static Type getUnderlyingType(Type type, type_list<ShapedTypes...>,
|
|
type_list<ElementTypes...>) {
|
|
if (type.isa<ShapedType>() && !type.isa<ShapedTypes...>())
|
|
return {};
|
|
|
|
auto underlyingType = getElementTypeOrSelf(type);
|
|
if (!underlyingType.isa<ElementTypes...>())
|
|
return {};
|
|
|
|
return underlyingType;
|
|
}
|
|
|
|
/// Get allowed underlying types for vectors and tensors.
|
|
template <typename... ElementTypes>
|
|
static Type getTypeIfLike(Type type) {
|
|
return getUnderlyingType(type, type_list<VectorType, TensorType>(),
|
|
type_list<ElementTypes...>());
|
|
}
|
|
|
|
/// Get allowed underlying types for vectors, tensors, and memrefs.
|
|
template <typename... ElementTypes>
|
|
static Type getTypeIfLikeOrMemRef(Type type) {
|
|
return getUnderlyingType(type,
|
|
type_list<VectorType, TensorType, MemRefType>(),
|
|
type_list<ElementTypes...>());
|
|
}
|
|
|
|
static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) {
|
|
return inputs.size() == 1 && outputs.size() == 1 &&
|
|
succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Verifiers for integer and floating point extension/truncation ops
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Extend ops can only extend to a wider type.
|
|
template <typename ValType, typename Op>
|
|
static LogicalResult verifyExtOp(Op op) {
|
|
Type srcType = getElementTypeOrSelf(op.getIn().getType());
|
|
Type dstType = getElementTypeOrSelf(op.getType());
|
|
|
|
if (srcType.cast<ValType>().getWidth() >= dstType.cast<ValType>().getWidth())
|
|
return op.emitError("result type ")
|
|
<< dstType << " must be wider than operand type " << srcType;
|
|
|
|
return success();
|
|
}
|
|
|
|
// Truncate ops can only truncate to a shorter type.
|
|
template <typename ValType, typename Op>
|
|
static LogicalResult verifyTruncateOp(Op op) {
|
|
Type srcType = getElementTypeOrSelf(op.getIn().getType());
|
|
Type dstType = getElementTypeOrSelf(op.getType());
|
|
|
|
if (srcType.cast<ValType>().getWidth() <= dstType.cast<ValType>().getWidth())
|
|
return op.emitError("result type ")
|
|
<< dstType << " must be shorter than operand type " << srcType;
|
|
|
|
return success();
|
|
}
|
|
|
|
/// Validate a cast that changes the width of a type.
|
|
template <template <typename> class WidthComparator, typename... ElementTypes>
|
|
static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
|
|
if (!areValidCastInputsAndOutputs(inputs, outputs))
|
|
return false;
|
|
|
|
auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
|
|
auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
|
|
if (!srcType || !dstType)
|
|
return false;
|
|
|
|
return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
|
|
srcType.getIntOrFloatBitWidth());
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ExtUIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) {
|
|
if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>())
|
|
return IntegerAttr::get(
|
|
getType(), lhs.getValue().zext(getType().getIntOrFloatBitWidth()));
|
|
|
|
return {};
|
|
}
|
|
|
|
bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
|
return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ExtSIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) {
|
|
if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>())
|
|
return IntegerAttr::get(
|
|
getType(), lhs.getValue().sext(getType().getIntOrFloatBitWidth()));
|
|
|
|
return {};
|
|
}
|
|
|
|
bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
|
return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ExtFOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
|
return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TruncIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) {
|
|
// trunci(zexti(a)) -> a
|
|
// trunci(sexti(a)) -> a
|
|
if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
|
|
matchPattern(getOperand(), m_Op<arith::ExtSIOp>()))
|
|
return getOperand().getDefiningOp()->getOperand(0);
|
|
|
|
assert(operands.size() == 1 && "unary operation takes one operand");
|
|
|
|
if (!operands[0])
|
|
return {};
|
|
|
|
if (auto lhs = operands[0].dyn_cast<IntegerAttr>()) {
|
|
return IntegerAttr::get(
|
|
getType(), lhs.getValue().trunc(getType().getIntOrFloatBitWidth()));
|
|
}
|
|
|
|
return {};
|
|
}
|
|
|
|
bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
|
return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TruncFOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Perform safe const propagation for truncf, i.e. only propagate if FP value
|
|
/// can be represented without precision loss or rounding.
|
|
OpFoldResult arith::TruncFOp::fold(ArrayRef<Attribute> operands) {
|
|
assert(operands.size() == 1 && "unary operation takes one operand");
|
|
|
|
auto constOperand = operands.front();
|
|
if (!constOperand || !constOperand.isa<FloatAttr>())
|
|
return {};
|
|
|
|
// Convert to target type via 'double'.
|
|
double sourceValue =
|
|
constOperand.dyn_cast<FloatAttr>().getValue().convertToDouble();
|
|
auto targetAttr = FloatAttr::get(getType(), sourceValue);
|
|
|
|
// Propagate if constant's value does not change after truncation.
|
|
if (sourceValue == targetAttr.getValue().convertToDouble())
|
|
return targetAttr;
|
|
|
|
return {};
|
|
}
|
|
|
|
bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
|
return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Verifiers for casts between integers and floats.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
template <typename From, typename To>
|
|
static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {
|
|
if (!areValidCastInputsAndOutputs(inputs, outputs))
|
|
return false;
|
|
|
|
auto srcType = getTypeIfLike<From>(inputs.front());
|
|
auto dstType = getTypeIfLike<To>(outputs.back());
|
|
|
|
return srcType && dstType;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// UIToFPOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
|
return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// SIToFPOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
|
return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// FPToUIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
|
return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// FPToSIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
|
return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// IndexCastOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
|
|
TypeRange outputs) {
|
|
if (!areValidCastInputsAndOutputs(inputs, outputs))
|
|
return false;
|
|
|
|
auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
|
|
auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
|
|
if (!srcType || !dstType)
|
|
return false;
|
|
|
|
return (srcType.isIndex() && dstType.isSignlessInteger()) ||
|
|
(srcType.isSignlessInteger() && dstType.isIndex());
|
|
}
|
|
|
|
OpFoldResult arith::IndexCastOp::fold(ArrayRef<Attribute> operands) {
|
|
// index_cast(constant) -> constant
|
|
// A little hack because we go through int. Otherwise, the size of the
|
|
// constant might need to change.
|
|
if (auto value = operands[0].dyn_cast_or_null<IntegerAttr>())
|
|
return IntegerAttr::get(getType(), value.getInt());
|
|
|
|
return {};
|
|
}
|
|
|
|
void arith::IndexCastOp::getCanonicalizationPatterns(
|
|
OwningRewritePatternList &patterns, MLIRContext *context) {
|
|
patterns.insert<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// BitcastOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
|
if (!areValidCastInputsAndOutputs(inputs, outputs))
|
|
return false;
|
|
|
|
auto srcType =
|
|
getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front());
|
|
auto dstType =
|
|
getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front());
|
|
if (!srcType || !dstType)
|
|
return false;
|
|
|
|
return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
|
|
}
|
|
|
|
OpFoldResult arith::BitcastOp::fold(ArrayRef<Attribute> operands) {
|
|
assert(operands.size() == 1 && "bitcast op expects 1 operand");
|
|
|
|
auto resType = getType();
|
|
auto operand = operands[0];
|
|
if (!operand)
|
|
return {};
|
|
|
|
/// Bitcast dense elements.
|
|
if (auto denseAttr = operand.dyn_cast_or_null<DenseElementsAttr>())
|
|
return denseAttr.bitcast(resType.cast<ShapedType>().getElementType());
|
|
/// Other shaped types unhandled.
|
|
if (resType.isa<ShapedType>())
|
|
return {};
|
|
|
|
/// Bitcast integer or float to integer or float.
|
|
APInt bits = operand.isa<FloatAttr>()
|
|
? operand.cast<FloatAttr>().getValue().bitcastToAPInt()
|
|
: operand.cast<IntegerAttr>().getValue();
|
|
|
|
if (auto resFloatType = resType.dyn_cast<FloatType>())
|
|
return FloatAttr::get(resType,
|
|
APFloat(resFloatType.getFloatSemantics(), bits));
|
|
return IntegerAttr::get(resType, bits);
|
|
}
|
|
|
|
void arith::BitcastOp::getCanonicalizationPatterns(
|
|
OwningRewritePatternList &patterns, MLIRContext *context) {
|
|
patterns.insert<BitcastOfBitcast>(context);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Helpers for compare ops
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Return the type of the same shape (scalar, vector or tensor) containing i1.
|
|
static Type getI1SameShape(Type type) {
|
|
auto i1Type = IntegerType::get(type.getContext(), 1);
|
|
if (auto tensorType = type.dyn_cast<RankedTensorType>())
|
|
return RankedTensorType::get(tensorType.getShape(), i1Type);
|
|
if (type.isa<UnrankedTensorType>())
|
|
return UnrankedTensorType::get(i1Type);
|
|
if (auto vectorType = type.dyn_cast<VectorType>())
|
|
return VectorType::get(vectorType.getShape(), i1Type);
|
|
return i1Type;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CmpIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
|
|
/// comparison predicates.
|
|
bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
|
|
const APInt &lhs, const APInt &rhs) {
|
|
switch (predicate) {
|
|
case arith::CmpIPredicate::eq:
|
|
return lhs.eq(rhs);
|
|
case arith::CmpIPredicate::ne:
|
|
return lhs.ne(rhs);
|
|
case arith::CmpIPredicate::slt:
|
|
return lhs.slt(rhs);
|
|
case arith::CmpIPredicate::sle:
|
|
return lhs.sle(rhs);
|
|
case arith::CmpIPredicate::sgt:
|
|
return lhs.sgt(rhs);
|
|
case arith::CmpIPredicate::sge:
|
|
return lhs.sge(rhs);
|
|
case arith::CmpIPredicate::ult:
|
|
return lhs.ult(rhs);
|
|
case arith::CmpIPredicate::ule:
|
|
return lhs.ule(rhs);
|
|
case arith::CmpIPredicate::ugt:
|
|
return lhs.ugt(rhs);
|
|
case arith::CmpIPredicate::uge:
|
|
return lhs.uge(rhs);
|
|
}
|
|
llvm_unreachable("unknown cmpi predicate kind");
|
|
}
|
|
|
|
/// Returns true if the predicate is true for two equal operands.
|
|
static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
|
|
switch (predicate) {
|
|
case arith::CmpIPredicate::eq:
|
|
case arith::CmpIPredicate::sle:
|
|
case arith::CmpIPredicate::sge:
|
|
case arith::CmpIPredicate::ule:
|
|
case arith::CmpIPredicate::uge:
|
|
return true;
|
|
case arith::CmpIPredicate::ne:
|
|
case arith::CmpIPredicate::slt:
|
|
case arith::CmpIPredicate::sgt:
|
|
case arith::CmpIPredicate::ult:
|
|
case arith::CmpIPredicate::ugt:
|
|
return false;
|
|
}
|
|
llvm_unreachable("unknown cmpi predicate kind");
|
|
}
|
|
|
|
OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
|
|
assert(operands.size() == 2 && "cmpi takes two operands");
|
|
|
|
// cmpi(pred, x, x)
|
|
if (getLhs() == getRhs()) {
|
|
auto val = applyCmpPredicateToEqualOperands(getPredicate());
|
|
return BoolAttr::get(getContext(), val);
|
|
}
|
|
|
|
auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
|
|
auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
|
|
if (!lhs || !rhs)
|
|
return {};
|
|
|
|
auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
|
|
return BoolAttr::get(getContext(), val);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CmpFOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
|
|
/// comparison predicates.
|
|
bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
|
|
const APFloat &lhs, const APFloat &rhs) {
|
|
auto cmpResult = lhs.compare(rhs);
|
|
switch (predicate) {
|
|
case arith::CmpFPredicate::AlwaysFalse:
|
|
return false;
|
|
case arith::CmpFPredicate::OEQ:
|
|
return cmpResult == APFloat::cmpEqual;
|
|
case arith::CmpFPredicate::OGT:
|
|
return cmpResult == APFloat::cmpGreaterThan;
|
|
case arith::CmpFPredicate::OGE:
|
|
return cmpResult == APFloat::cmpGreaterThan ||
|
|
cmpResult == APFloat::cmpEqual;
|
|
case arith::CmpFPredicate::OLT:
|
|
return cmpResult == APFloat::cmpLessThan;
|
|
case arith::CmpFPredicate::OLE:
|
|
return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
|
|
case arith::CmpFPredicate::ONE:
|
|
return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
|
|
case arith::CmpFPredicate::ORD:
|
|
return cmpResult != APFloat::cmpUnordered;
|
|
case arith::CmpFPredicate::UEQ:
|
|
return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
|
|
case arith::CmpFPredicate::UGT:
|
|
return cmpResult == APFloat::cmpUnordered ||
|
|
cmpResult == APFloat::cmpGreaterThan;
|
|
case arith::CmpFPredicate::UGE:
|
|
return cmpResult == APFloat::cmpUnordered ||
|
|
cmpResult == APFloat::cmpGreaterThan ||
|
|
cmpResult == APFloat::cmpEqual;
|
|
case arith::CmpFPredicate::ULT:
|
|
return cmpResult == APFloat::cmpUnordered ||
|
|
cmpResult == APFloat::cmpLessThan;
|
|
case arith::CmpFPredicate::ULE:
|
|
return cmpResult == APFloat::cmpUnordered ||
|
|
cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
|
|
case arith::CmpFPredicate::UNE:
|
|
return cmpResult != APFloat::cmpEqual;
|
|
case arith::CmpFPredicate::UNO:
|
|
return cmpResult == APFloat::cmpUnordered;
|
|
case arith::CmpFPredicate::AlwaysTrue:
|
|
return true;
|
|
}
|
|
llvm_unreachable("unknown cmpf predicate kind");
|
|
}
|
|
|
|
OpFoldResult arith::CmpFOp::fold(ArrayRef<Attribute> operands) {
|
|
assert(operands.size() == 2 && "cmpf takes two operands");
|
|
|
|
auto lhs = operands.front().dyn_cast_or_null<FloatAttr>();
|
|
auto rhs = operands.back().dyn_cast_or_null<FloatAttr>();
|
|
|
|
if (!lhs || !rhs)
|
|
return {};
|
|
|
|
auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
|
|
return BoolAttr::get(getContext(), val);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TableGen'd op method definitions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#define GET_OP_CLASSES
|
|
#include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.cpp.inc"
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TableGen'd enum attribute definitions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.cpp.inc"
|