Files
clang-p2996/mlir/lib/Interfaces/InferIntRangeInterface.cpp
Krzysztof Drewniak 1350c9887d [mlir] Add integer range inference analysis
This commit defines a dataflow analysis for integer ranges, which
uses a newly-added InferIntRangeInterface to compute the lower and
upper bounds on the results of an operation from the bounds on the
arguments. The range inference is a flow-insensitive dataflow analysis
that can be used to simplify code, such as by statically identifying
bounds checks that cannot fail in order to eliminate them.

The InferIntRangeInterface has one method, inferResultRanges(), which
takes a vector of inferred ranges for each argument to an op
implementing the interface and a callback allowing the implementation
to define the ranges for each result. These ranges are stored as
ConstantIntRanges, which hold the lower and upper bounds for a
value. Bounds are tracked separately for the signed and unsigned
interpretations of a value, which ensures that the impact of
arithmetic overflows is correctly tracked during the analysis.

The commit also adds a -test-int-range-inference pass to test the
analysis until it is integrated into SCCP or otherwise exposed.

Finally, this commit fixes some bugs relating to the handling of
region iteration arguments and terminators in the data flow analysis
framework.

Depends on D124020

Depends on D124021

Reviewed By: rriddle, Mogball

Differential Revision: https://reviews.llvm.org/D124023
2022-06-02 20:24:11 +00:00

100 lines
3.5 KiB
C++

//===- InferIntRangeInterface.cpp - Integer range inference interface ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Interfaces/InferIntRangeInterface.cpp.inc"
using namespace mlir;
bool ConstantIntRanges::operator==(const ConstantIntRanges &other) const {
return umin().getBitWidth() == other.umin().getBitWidth() &&
umin() == other.umin() && umax() == other.umax() &&
smin() == other.smin() && smax() == other.smax();
}
const APInt &ConstantIntRanges::umin() const { return uminVal; }
const APInt &ConstantIntRanges::umax() const { return umaxVal; }
const APInt &ConstantIntRanges::smin() const { return sminVal; }
const APInt &ConstantIntRanges::smax() const { return smaxVal; }
unsigned ConstantIntRanges::getStorageBitwidth(Type type) {
if (type.isIndex())
return IndexType::kInternalStorageBitWidth;
if (auto integerType = type.dyn_cast<IntegerType>())
return integerType.getWidth();
// Non-integer types have their bounds stored in width 0 `APInt`s.
return 0;
}
ConstantIntRanges ConstantIntRanges::range(const APInt &min, const APInt &max) {
return {min, max, min, max};
}
ConstantIntRanges ConstantIntRanges::fromSigned(const APInt &smin,
const APInt &smax) {
unsigned int width = smin.getBitWidth();
APInt umin, umax;
if (smin.isNonNegative() == smax.isNonNegative()) {
umin = smin.ult(smax) ? smin : smax;
umax = smin.ugt(smax) ? smin : smax;
} else {
umin = APInt::getMinValue(width);
umax = APInt::getMaxValue(width);
}
return {umin, umax, smin, smax};
}
ConstantIntRanges ConstantIntRanges::fromUnsigned(const APInt &umin,
const APInt &umax) {
unsigned int width = umin.getBitWidth();
APInt smin, smax;
if (umin.isNonNegative() == umax.isNonNegative()) {
smin = umin.slt(umax) ? umin : umax;
smax = umin.sgt(umax) ? umin : umax;
} else {
smin = APInt::getSignedMinValue(width);
smax = APInt::getSignedMaxValue(width);
}
return {umin, umax, smin, smax};
}
ConstantIntRanges
ConstantIntRanges::rangeUnion(const ConstantIntRanges &other) const {
// "Not an integer" poisons everything and also cannot be fed to comparison
// operators.
if (umin().getBitWidth() == 0)
return *this;
if (other.umin().getBitWidth() == 0)
return other;
const APInt &uminUnion = umin().ult(other.umin()) ? umin() : other.umin();
const APInt &umaxUnion = umax().ugt(other.umax()) ? umax() : other.umax();
const APInt &sminUnion = smin().slt(other.smin()) ? smin() : other.smin();
const APInt &smaxUnion = smax().sgt(other.smax()) ? smax() : other.smax();
return {uminUnion, umaxUnion, sminUnion, smaxUnion};
}
Optional<APInt> ConstantIntRanges::getConstantValue() const {
// Note: we need to exclude the trivially-equal width 0 values here.
if (umin() == umax() && umin().getBitWidth() != 0)
return umin();
if (smin() == smax() && smin().getBitWidth() != 0)
return smin();
return None;
}
raw_ostream &mlir::operator<<(raw_ostream &os, const ConstantIntRanges &range) {
return os << "unsigned : [" << range.umin() << ", " << range.umax()
<< "] signed : [" << range.smin() << ", " << range.smax() << "]";
}