Files
clang-p2996/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
Ivan Butygin f54cdc5d6e [mlir] IntegerRangeAnalysis: add support for vector type (#112292)
Treat integer range for vector type as union of ranges of individual
elements. With this semantics, most arith ops on vectors will work out
of the box, the only special handling needed for constants and vector
elements manipulation ops.

The end goal of these changes is to be able to optimize vectorized index
calculations.
2024-11-01 23:58:16 +03:00

357 lines
14 KiB
C++

//===- InferIntRangeInterfaceImpls.cpp - Integer range impls for arith -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
#include "llvm/Support/Debug.h"
#include <optional>
#define DEBUG_TYPE "int-range-analysis"
using namespace mlir;
using namespace mlir::arith;
using namespace mlir::intrange;
static intrange::OverflowFlags
convertArithOverflowFlags(arith::IntegerOverflowFlags flags) {
intrange::OverflowFlags retFlags = intrange::OverflowFlags::None;
if (bitEnumContainsAny(flags, arith::IntegerOverflowFlags::nsw))
retFlags |= intrange::OverflowFlags::Nsw;
if (bitEnumContainsAny(flags, arith::IntegerOverflowFlags::nuw))
retFlags |= intrange::OverflowFlags::Nuw;
return retFlags;
}
//===----------------------------------------------------------------------===//
// ConstantOp
//===----------------------------------------------------------------------===//
void arith::ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
if (auto scalarCstAttr = llvm::dyn_cast_or_null<IntegerAttr>(getValue())) {
const APInt &value = scalarCstAttr.getValue();
setResultRange(getResult(), ConstantIntRanges::constant(value));
return;
}
if (auto arrayCstAttr =
llvm::dyn_cast_or_null<DenseIntElementsAttr>(getValue())) {
std::optional<ConstantIntRanges> result;
for (const APInt &val : arrayCstAttr) {
auto range = ConstantIntRanges::constant(val);
result = (result ? result->rangeUnion(range) : range);
}
assert(result && "Zero-sized vectors are not allowed");
setResultRange(getResult(), *result);
return;
}
}
//===----------------------------------------------------------------------===//
// AddIOp
//===----------------------------------------------------------------------===//
void arith::AddIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
setResultRange(getResult(), inferAdd(argRanges, convertArithOverflowFlags(
getOverflowFlags())));
}
//===----------------------------------------------------------------------===//
// SubIOp
//===----------------------------------------------------------------------===//
void arith::SubIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
setResultRange(getResult(), inferSub(argRanges, convertArithOverflowFlags(
getOverflowFlags())));
}
//===----------------------------------------------------------------------===//
// MulIOp
//===----------------------------------------------------------------------===//
void arith::MulIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
setResultRange(getResult(), inferMul(argRanges, convertArithOverflowFlags(
getOverflowFlags())));
}
//===----------------------------------------------------------------------===//
// DivUIOp
//===----------------------------------------------------------------------===//
void arith::DivUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
setResultRange(getResult(), inferDivU(argRanges));
}
//===----------------------------------------------------------------------===//
// DivSIOp
//===----------------------------------------------------------------------===//
void arith::DivSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
setResultRange(getResult(), inferDivS(argRanges));
}
//===----------------------------------------------------------------------===//
// CeilDivUIOp
//===----------------------------------------------------------------------===//
void arith::CeilDivUIOp::inferResultRanges(
ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
setResultRange(getResult(), inferCeilDivU(argRanges));
}
//===----------------------------------------------------------------------===//
// CeilDivSIOp
//===----------------------------------------------------------------------===//
void arith::CeilDivSIOp::inferResultRanges(
ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
setResultRange(getResult(), inferCeilDivS(argRanges));
}
//===----------------------------------------------------------------------===//
// FloorDivSIOp
//===----------------------------------------------------------------------===//
void arith::FloorDivSIOp::inferResultRanges(
ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
return setResultRange(getResult(), inferFloorDivS(argRanges));
}
//===----------------------------------------------------------------------===//
// RemUIOp
//===----------------------------------------------------------------------===//
void arith::RemUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
setResultRange(getResult(), inferRemU(argRanges));
}
//===----------------------------------------------------------------------===//
// RemSIOp
//===----------------------------------------------------------------------===//
void arith::RemSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
setResultRange(getResult(), inferRemS(argRanges));
}
//===----------------------------------------------------------------------===//
// AndIOp
//===----------------------------------------------------------------------===//
void arith::AndIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
setResultRange(getResult(), inferAnd(argRanges));
}
//===----------------------------------------------------------------------===//
// OrIOp
//===----------------------------------------------------------------------===//
void arith::OrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
setResultRange(getResult(), inferOr(argRanges));
}
//===----------------------------------------------------------------------===//
// XOrIOp
//===----------------------------------------------------------------------===//
void arith::XOrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
setResultRange(getResult(), inferXor(argRanges));
}
//===----------------------------------------------------------------------===//
// MaxSIOp
//===----------------------------------------------------------------------===//
void arith::MaxSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
setResultRange(getResult(), inferMaxS(argRanges));
}
//===----------------------------------------------------------------------===//
// MaxUIOp
//===----------------------------------------------------------------------===//
void arith::MaxUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
setResultRange(getResult(), inferMaxU(argRanges));
}
//===----------------------------------------------------------------------===//
// MinSIOp
//===----------------------------------------------------------------------===//
void arith::MinSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
setResultRange(getResult(), inferMinS(argRanges));
}
//===----------------------------------------------------------------------===//
// MinUIOp
//===----------------------------------------------------------------------===//
void arith::MinUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
setResultRange(getResult(), inferMinU(argRanges));
}
//===----------------------------------------------------------------------===//
// ExtUIOp
//===----------------------------------------------------------------------===//
void arith::ExtUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
unsigned destWidth =
ConstantIntRanges::getStorageBitwidth(getResult().getType());
setResultRange(getResult(), extUIRange(argRanges[0], destWidth));
}
//===----------------------------------------------------------------------===//
// ExtSIOp
//===----------------------------------------------------------------------===//
void arith::ExtSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
unsigned destWidth =
ConstantIntRanges::getStorageBitwidth(getResult().getType());
setResultRange(getResult(), extSIRange(argRanges[0], destWidth));
}
//===----------------------------------------------------------------------===//
// TruncIOp
//===----------------------------------------------------------------------===//
void arith::TruncIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
unsigned destWidth =
ConstantIntRanges::getStorageBitwidth(getResult().getType());
setResultRange(getResult(), truncRange(argRanges[0], destWidth));
}
//===----------------------------------------------------------------------===//
// IndexCastOp
//===----------------------------------------------------------------------===//
void arith::IndexCastOp::inferResultRanges(
ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
Type sourceType = getOperand().getType();
Type destType = getResult().getType();
unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType);
unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
if (srcWidth < destWidth)
setResultRange(getResult(), extSIRange(argRanges[0], destWidth));
else if (srcWidth > destWidth)
setResultRange(getResult(), truncRange(argRanges[0], destWidth));
else
setResultRange(getResult(), argRanges[0]);
}
//===----------------------------------------------------------------------===//
// IndexCastUIOp
//===----------------------------------------------------------------------===//
void arith::IndexCastUIOp::inferResultRanges(
ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
Type sourceType = getOperand().getType();
Type destType = getResult().getType();
unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType);
unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
if (srcWidth < destWidth)
setResultRange(getResult(), extUIRange(argRanges[0], destWidth));
else if (srcWidth > destWidth)
setResultRange(getResult(), truncRange(argRanges[0], destWidth));
else
setResultRange(getResult(), argRanges[0]);
}
//===----------------------------------------------------------------------===//
// CmpIOp
//===----------------------------------------------------------------------===//
void arith::CmpIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
arith::CmpIPredicate arithPred = getPredicate();
intrange::CmpPredicate pred = static_cast<intrange::CmpPredicate>(arithPred);
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
APInt min = APInt::getZero(1);
APInt max = APInt::getAllOnes(1);
std::optional<bool> truthValue = intrange::evaluatePred(pred, lhs, rhs);
if (truthValue.has_value() && *truthValue)
min = max;
else if (truthValue.has_value() && !(*truthValue))
max = min;
setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max));
}
//===----------------------------------------------------------------------===//
// SelectOp
//===----------------------------------------------------------------------===//
void arith::SelectOp::inferResultRangesFromOptional(
ArrayRef<IntegerValueRange> argRanges, SetIntLatticeFn setResultRange) {
std::optional<APInt> mbCondVal =
argRanges[0].isUninitialized()
? std::nullopt
: argRanges[0].getValue().getConstantValue();
const IntegerValueRange &trueCase = argRanges[1];
const IntegerValueRange &falseCase = argRanges[2];
if (mbCondVal) {
if (mbCondVal->isZero())
setResultRange(getResult(), falseCase);
else
setResultRange(getResult(), trueCase);
return;
}
setResultRange(getResult(), IntegerValueRange::join(trueCase, falseCase));
}
//===----------------------------------------------------------------------===//
// ShLIOp
//===----------------------------------------------------------------------===//
void arith::ShLIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
setResultRange(getResult(), inferShl(argRanges, convertArithOverflowFlags(
getOverflowFlags())));
}
//===----------------------------------------------------------------------===//
// ShRUIOp
//===----------------------------------------------------------------------===//
void arith::ShRUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
setResultRange(getResult(), inferShrU(argRanges));
}
//===----------------------------------------------------------------------===//
// ShRSIOp
//===----------------------------------------------------------------------===//
void arith::ShRSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
setResultRange(getResult(), inferShrS(argRanges));
}