//===- IntRangeAnalysis.cpp - Infer Ranges Interfaces --*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file defines the dataflow analysis class for integer range inference // which is used in transformations over the `arith` dialect such as // branch elimination or signed->unsigned rewriting // //===----------------------------------------------------------------------===// #include "mlir/Analysis/IntRangeAnalysis.h" #include "mlir/Analysis/DataFlowAnalysis.h" #include "mlir/Interfaces/InferIntRangeInterface.h" #include "mlir/Interfaces/LoopLikeInterface.h" #include "llvm/Support/Debug.h" #define DEBUG_TYPE "int-range-analysis" using namespace mlir; namespace { /// A wrapper around ConstantIntRanges that provides the lattice functions /// expected by dataflow analysis. struct IntRangeLattice { IntRangeLattice(const ConstantIntRanges &value) : value(value){}; IntRangeLattice(ConstantIntRanges &&value) : value(value){}; bool operator==(const IntRangeLattice &other) const { return value == other.value; } /// wrapper around rangeUnion() static IntRangeLattice join(const IntRangeLattice &a, const IntRangeLattice &b) { return a.value.rangeUnion(b.value); } /// Creates a range with bitwidth 0 to represent that we don't know if the /// value being marked overdefined is even an integer. static IntRangeLattice getPessimisticValueState(MLIRContext *context) { APInt noIntValue = APInt::getZeroWidth(); return ConstantIntRanges::range(noIntValue, noIntValue); } /// Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)]) /// range that is used to mark the value v as unable to be analyzed further, /// where t is the type of v. static IntRangeLattice getPessimisticValueState(Value v) { unsigned int width = ConstantIntRanges::getStorageBitwidth(v.getType()); APInt umin = APInt::getMinValue(width); APInt umax = APInt::getMaxValue(width); APInt smin = width != 0 ? APInt::getSignedMinValue(width) : umin; APInt smax = width != 0 ? APInt::getSignedMaxValue(width) : umax; return ConstantIntRanges{umin, umax, smin, smax}; } ConstantIntRanges value; }; } // end anonymous namespace namespace mlir { namespace detail { class IntRangeAnalysisImpl : public ForwardDataFlowAnalysis { using ForwardDataFlowAnalysis::ForwardDataFlowAnalysis; public: /// Define bounds on the results or block arguments of the operation /// based on the bounds on the arguments given in `operands` ChangeResult visitOperation(Operation *op, ArrayRef *> operands) final; /// Skip regions of branch ops when we can statically infer constant /// values for operands to the branch op and said op tells us it's safe to do /// so. LogicalResult getSuccessorsForOperands(BranchOpInterface branch, ArrayRef *> operands, SmallVectorImpl &successors) final; /// Skip regions of branch or loop ops when we can statically infer constant /// values for operands to the branch op and said op tells us it's safe to do /// so. void getSuccessorsForOperands(RegionBranchOpInterface branch, Optional sourceIndex, ArrayRef *> operands, SmallVectorImpl &successors) final; /// Call the InferIntRangeInterface implementation for region-using ops /// that implement it, and infer the bounds of loop induction variables /// for ops that implement LoopLikeOPInterface. ChangeResult visitNonControlFlowArguments( Operation *op, const RegionSuccessor ®ion, ArrayRef *> operands) final; }; } // end namespace detail } // end namespace mlir /// Given the results of getConstant{Lower,Upper}Bound() /// or getConstantStep() on a LoopLikeInterface return the lower/upper bound for /// that result if possible. static APInt getLoopBoundFromFold(Optional loopBound, Type boundType, detail::IntRangeAnalysisImpl &analysis, bool getUpper) { unsigned int width = ConstantIntRanges::getStorageBitwidth(boundType); if (loopBound.hasValue()) { if (loopBound->is()) { if (auto bound = loopBound->get().dyn_cast_or_null()) return bound.getValue(); } else if (loopBound->is()) { LatticeElement *lattice = analysis.lookupLatticeElement(loopBound->get()); if (lattice != nullptr) return getUpper ? lattice->getValue().value.smax() : lattice->getValue().value.smin(); } } return getUpper ? APInt::getSignedMaxValue(width) : APInt::getSignedMinValue(width); } ChangeResult detail::IntRangeAnalysisImpl::visitOperation( Operation *op, ArrayRef *> operands) { ChangeResult result = ChangeResult::NoChange; // Ignore non-integer outputs - return early if the op has no scalar // integer results bool hasIntegerResult = false; for (Value v : op->getResults()) { if (v.getType().isIntOrIndex()) hasIntegerResult = true; else result |= markAllPessimisticFixpoint(v); } if (!hasIntegerResult) return result; if (auto inferrable = dyn_cast(op)) { LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for "); LLVM_DEBUG(inferrable->print(llvm::dbgs())); LLVM_DEBUG(llvm::dbgs() << "\n"); SmallVector argRanges( llvm::map_range(operands, [](LatticeElement *val) { return val->getValue().value; })); auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) { LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n"); LatticeElement &lattice = getLatticeElement(v); Optional oldRange; if (!lattice.isUninitialized()) oldRange = lattice.getValue(); result |= lattice.join(IntRangeLattice(attrs)); // Catch loop results with loop variant bounds and conservatively make // them [-inf, inf] so we don't circle around infinitely often (because // the dataflow analysis in MLIR doesn't attempt to work out trip counts // and often can't). bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) { return op->hasTrait(); }); if (isYieldedResult && oldRange.hasValue() && !(lattice.getValue() == *oldRange)) { LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n"); result |= lattice.markPessimisticFixpoint(); } }; inferrable.inferResultRanges(argRanges, joinCallback); for (Value opResult : op->getResults()) { LatticeElement &lattice = getLatticeElement(opResult); // setResultRange() not called, make pessimistic. if (lattice.isUninitialized()) result |= lattice.markPessimisticFixpoint(); } } else if (op->getNumRegions() == 0) { // No regions + no result inference method -> unbounded results (ex. memory // ops) result |= markAllPessimisticFixpoint(op->getResults()); } return result; } LogicalResult detail::IntRangeAnalysisImpl::getSuccessorsForOperands( BranchOpInterface branch, ArrayRef *> operands, SmallVectorImpl &successors) { auto toConstantAttr = [&branch](auto enumPair) -> Attribute { Optional maybeConstValue = enumPair.value()->getValue().value.getConstantValue(); if (maybeConstValue) { return IntegerAttr::get(branch->getOperand(enumPair.index()).getType(), *maybeConstValue); } return {}; }; SmallVector inferredConsts( llvm::map_range(llvm::enumerate(operands), toConstantAttr)); if (Block *singleSucc = branch.getSuccessorForOperands(inferredConsts)) { successors.push_back(singleSucc); return success(); } return failure(); } void detail::IntRangeAnalysisImpl::getSuccessorsForOperands( RegionBranchOpInterface branch, Optional sourceIndex, ArrayRef *> operands, SmallVectorImpl &successors) { auto toConstantAttr = [&branch](auto enumPair) -> Attribute { Optional maybeConstValue = enumPair.value()->getValue().value.getConstantValue(); if (maybeConstValue) { return IntegerAttr::get(branch->getOperand(enumPair.index()).getType(), *maybeConstValue); } return {}; }; SmallVector inferredConsts( llvm::map_range(llvm::enumerate(operands), toConstantAttr)); branch.getSuccessorRegions(sourceIndex, inferredConsts, successors); } ChangeResult detail::IntRangeAnalysisImpl::visitNonControlFlowArguments( Operation *op, const RegionSuccessor ®ion, ArrayRef *> operands) { if (auto inferrable = dyn_cast(op)) { LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for "); LLVM_DEBUG(inferrable->print(llvm::dbgs())); LLVM_DEBUG(llvm::dbgs() << "\n"); SmallVector argRanges( llvm::map_range(operands, [](LatticeElement *val) { return val->getValue().value; })); ChangeResult result = ChangeResult::NoChange; auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) { LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n"); LatticeElement &lattice = getLatticeElement(v); Optional oldRange; if (!lattice.isUninitialized()) oldRange = lattice.getValue(); result |= lattice.join(IntRangeLattice(attrs)); // Catch loop results with loop variant bounds and conservatively make // them [-inf, inf] so we don't circle around infinitely often (because // the dataflow analysis in MLIR doesn't attempt to work out trip counts // and often can't). bool isYieldedValue = llvm::any_of(v.getUsers(), [](Operation *op) { return op->hasTrait(); }); if (isYieldedValue && oldRange.hasValue() && !(lattice.getValue() == *oldRange)) { LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n"); result |= lattice.markPessimisticFixpoint(); } }; inferrable.inferResultRanges(argRanges, joinCallback); for (Value regionArg : region.getSuccessor()->getArguments()) { LatticeElement &lattice = getLatticeElement(regionArg); // setResultRange() not called, make pessimistic. if (lattice.isUninitialized()) result |= lattice.markPessimisticFixpoint(); } return result; } // Infer bounds for loop arguments that have static bounds if (auto loop = dyn_cast(op)) { Optional iv = loop.getSingleInductionVar(); if (!iv.hasValue()) { return ForwardDataFlowAnalysis< IntRangeLattice>::visitNonControlFlowArguments(op, region, operands); } Optional lowerBound = loop.getSingleLowerBound(); Optional upperBound = loop.getSingleUpperBound(); Optional step = loop.getSingleStep(); APInt min = getLoopBoundFromFold(lowerBound, iv->getType(), *this, /*getUpper=*/false); APInt max = getLoopBoundFromFold(upperBound, iv->getType(), *this, /*getUpper=*/true); // Assume positivity for uniscoverable steps by way of getUpper = true. APInt stepVal = getLoopBoundFromFold(step, iv->getType(), *this, /*getUpper=*/true); if (stepVal.isNegative()) { std::swap(min, max); } else { // Correct the upper bound by subtracting 1 so that it becomes a <= bound, // because loops do not generally include their upper bound. max -= 1; } LatticeElement &ivEntry = getLatticeElement(*iv); return ivEntry.join(ConstantIntRanges::fromSigned(min, max)); } return ForwardDataFlowAnalysis::visitNonControlFlowArguments( op, region, operands); } IntRangeAnalysis::IntRangeAnalysis(Operation *topLevelOperation) { impl = std::make_unique( topLevelOperation->getContext()); impl->run(topLevelOperation); } IntRangeAnalysis::~IntRangeAnalysis() = default; IntRangeAnalysis::IntRangeAnalysis(IntRangeAnalysis &&other) = default; Optional IntRangeAnalysis::getResult(Value v) { LatticeElement *result = impl->lookupLatticeElement(v); if (result == nullptr || result->isUninitialized()) return llvm::None; return result->getValue().value; }