mlir/Presburger: reinstate use of LogicalResult (#97415)
Follow up on a desire post-landingd0fee98(mlir/Presburger: strip dependency on MLIRSupport) to reinstate the use of LogicalResult in Presburger. Sincedb791b2(mlir/LogicalResult: move into llvm), LogicalResult is in LLVM, and fulfilling this desire is possible while still maintaining the goal of stripping the Presburger library of mlir dependencies.
This commit is contained in:
committed by
GitHub
parent
915ee0b823
commit
f819302a09
@@ -21,13 +21,17 @@
|
||||
#include "mlir/Analysis/Presburger/Utils.h"
|
||||
#include "llvm/ADT/DynamicAPInt.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/LogicalResult.h"
|
||||
#include <optional>
|
||||
|
||||
namespace mlir {
|
||||
namespace presburger {
|
||||
using llvm::DynamicAPInt;
|
||||
using llvm::failure;
|
||||
using llvm::int64fromDynamicAPInt;
|
||||
using llvm::LogicalResult;
|
||||
using llvm::SmallVectorImpl;
|
||||
using llvm::success;
|
||||
|
||||
class IntegerRelation;
|
||||
class IntegerPolyhedron;
|
||||
@@ -478,7 +482,7 @@ public:
|
||||
/// equality detection; if successful, the constant is substituted for the
|
||||
/// variable everywhere in the constraint system and then removed from the
|
||||
/// system.
|
||||
bool constantFoldVar(unsigned pos);
|
||||
LogicalResult constantFoldVar(unsigned pos);
|
||||
|
||||
/// This method calls `constantFoldVar` for the specified range of variables,
|
||||
/// `num` variables starting at position `pos`.
|
||||
@@ -501,7 +505,7 @@ public:
|
||||
/// 3) this = {0 <= d0 <= 5, 1 <= d1 <= 9}
|
||||
/// other = {2 <= d0 <= 6, 5 <= d1 <= 15},
|
||||
/// output = {0 <= d0 <= 6, 1 <= d1 <= 15}
|
||||
bool unionBoundingBox(const IntegerRelation &other);
|
||||
LogicalResult unionBoundingBox(const IntegerRelation &other);
|
||||
|
||||
/// Returns the smallest known constant bound for the extent of the specified
|
||||
/// variable (pos^th), i.e., the smallest known constant that is greater
|
||||
@@ -774,8 +778,8 @@ protected:
|
||||
/// Eliminates a single variable at `position` from equality and inequality
|
||||
/// constraints. Returns `success` if the variable was eliminated, and
|
||||
/// `failure` otherwise.
|
||||
inline bool gaussianEliminateVar(unsigned position) {
|
||||
return gaussianEliminateVars(position, position + 1) == 1;
|
||||
inline LogicalResult gaussianEliminateVar(unsigned position) {
|
||||
return success(gaussianEliminateVars(position, position + 1) == 1);
|
||||
}
|
||||
|
||||
/// Removes local variables using equalities. Each equality is checked if it
|
||||
|
||||
@@ -445,7 +445,7 @@ protected:
|
||||
/// lexicopositivity of the basis transform. The row must have a non-positive
|
||||
/// sample value. If this is not possible, return failure. This occurs when
|
||||
/// the constraints have no solution or the sample value is zero.
|
||||
bool moveRowUnknownToColumn(unsigned row);
|
||||
LogicalResult moveRowUnknownToColumn(unsigned row);
|
||||
|
||||
/// Given a row that has a non-integer sample value, add an inequality to cut
|
||||
/// away this fractional sample value from the polytope without removing any
|
||||
@@ -459,7 +459,7 @@ protected:
|
||||
///
|
||||
/// Return failure if the tableau became empty, and success if it didn't.
|
||||
/// Failure status indicates that the polytope was integer empty.
|
||||
bool addCut(unsigned row);
|
||||
LogicalResult addCut(unsigned row);
|
||||
|
||||
/// Undo the addition of the last constraint. This is only called while
|
||||
/// rolling back.
|
||||
@@ -511,7 +511,7 @@ private:
|
||||
MaybeOptimum<SmallVector<Fraction, 8>> getRationalSample() const;
|
||||
|
||||
/// Make the tableau configuration consistent.
|
||||
bool restoreRationalConsistency();
|
||||
LogicalResult restoreRationalConsistency();
|
||||
|
||||
/// Return whether the specified row is violated;
|
||||
bool rowIsViolated(unsigned row) const;
|
||||
@@ -626,7 +626,7 @@ private:
|
||||
/// Return failure if the tableau became empty, indicating that the polytope
|
||||
/// is always integer empty in the current symbol domain.
|
||||
/// Return success otherwise.
|
||||
bool doNonBranchingPivots();
|
||||
LogicalResult doNonBranchingPivots();
|
||||
|
||||
/// Get a row that is always violated in the current domain, if one exists.
|
||||
std::optional<unsigned> maybeGetAlwaysViolatedRow();
|
||||
@@ -647,7 +647,7 @@ private:
|
||||
/// at the time of the call. (This function may modify the symbol domain, but
|
||||
/// failure statu indicates that the polytope was empty for all symbol values
|
||||
/// in the initial domain.)
|
||||
bool addSymbolicCut(unsigned row);
|
||||
LogicalResult addSymbolicCut(unsigned row);
|
||||
|
||||
/// Get the numerator of the symbolic sample of the specific row.
|
||||
/// This is an affine expression in the symbols with integer coefficients.
|
||||
@@ -820,7 +820,7 @@ private:
|
||||
///
|
||||
/// Returns success if the unknown was successfully restored to a non-negative
|
||||
/// sample value, failure otherwise.
|
||||
bool restoreRow(Unknown &u);
|
||||
LogicalResult restoreRow(Unknown &u);
|
||||
|
||||
/// Find a pivot to change the sample value of row in the specified
|
||||
/// direction while preserving tableau consistency, except that if the
|
||||
|
||||
@@ -1247,10 +1247,10 @@ LogicalResult FlatLinearValueConstraints::unionBoundingBox(
|
||||
if (!areVarsAligned(*this, otherCst)) {
|
||||
FlatLinearValueConstraints otherCopy(otherCst);
|
||||
mergeAndAlignVars(/*offset=*/getNumDimVars(), this, &otherCopy);
|
||||
return success(IntegerPolyhedron::unionBoundingBox(otherCopy));
|
||||
return IntegerPolyhedron::unionBoundingBox(otherCopy);
|
||||
}
|
||||
|
||||
return success(IntegerPolyhedron::unionBoundingBox(otherCst));
|
||||
return IntegerPolyhedron::unionBoundingBox(otherCst);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -26,6 +26,7 @@
|
||||
#include "llvm/ADT/Sequence.h"
|
||||
#include "llvm/ADT/SmallBitVector.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/LogicalResult.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
@@ -1552,22 +1553,22 @@ static int findEqualityToConstant(const IntegerRelation &cst, unsigned pos,
|
||||
return -1;
|
||||
}
|
||||
|
||||
bool IntegerRelation::constantFoldVar(unsigned pos) {
|
||||
LogicalResult IntegerRelation::constantFoldVar(unsigned pos) {
|
||||
assert(pos < getNumVars() && "invalid position");
|
||||
int rowIdx;
|
||||
if ((rowIdx = findEqualityToConstant(*this, pos)) == -1)
|
||||
return false;
|
||||
return failure();
|
||||
|
||||
// atEq(rowIdx, pos) is either -1 or 1.
|
||||
assert(atEq(rowIdx, pos) * atEq(rowIdx, pos) == 1);
|
||||
DynamicAPInt constVal = -atEq(rowIdx, getNumCols() - 1) / atEq(rowIdx, pos);
|
||||
setAndEliminate(pos, constVal);
|
||||
return true;
|
||||
return success();
|
||||
}
|
||||
|
||||
void IntegerRelation::constantFoldVarRange(unsigned pos, unsigned num) {
|
||||
for (unsigned s = pos, t = pos, e = pos + num; s < e; s++) {
|
||||
if (!constantFoldVar(t))
|
||||
if (constantFoldVar(t).failed())
|
||||
t++;
|
||||
}
|
||||
}
|
||||
@@ -1944,9 +1945,9 @@ void IntegerRelation::fourierMotzkinEliminate(unsigned pos, bool darkShadow,
|
||||
for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
|
||||
if (atEq(r, pos) != 0) {
|
||||
// Use Gaussian elimination here (since we have an equality).
|
||||
bool ret = gaussianEliminateVar(pos);
|
||||
LogicalResult ret = gaussianEliminateVar(pos);
|
||||
(void)ret;
|
||||
assert(ret && "Gaussian elimination guaranteed to succeed");
|
||||
assert(ret.succeeded() && "Gaussian elimination guaranteed to succeed");
|
||||
LLVM_DEBUG(llvm::dbgs() << "FM output (through Gaussian elimination):\n");
|
||||
LLVM_DEBUG(dump());
|
||||
return;
|
||||
@@ -2173,7 +2174,8 @@ static void getCommonConstraints(const IntegerRelation &a,
|
||||
|
||||
// Computes the bounding box with respect to 'other' by finding the min of the
|
||||
// lower bounds and the max of the upper bounds along each of the dimensions.
|
||||
bool IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) {
|
||||
LogicalResult
|
||||
IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) {
|
||||
assert(space.isEqual(otherCst.getSpace()) && "Spaces should match.");
|
||||
assert(getNumLocalVars() == 0 && "local ids not supported yet here");
|
||||
|
||||
@@ -2201,13 +2203,13 @@ bool IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) {
|
||||
if (!extent.has_value())
|
||||
// TODO: symbolic extents when necessary.
|
||||
// TODO: handle union if a dimension is unbounded.
|
||||
return false;
|
||||
return failure();
|
||||
|
||||
auto otherExtent = otherCst.getConstantBoundOnDimSize(
|
||||
d, &otherLb, &otherLbFloorDivisor, &otherUb);
|
||||
if (!otherExtent.has_value() || lbFloorDivisor != otherLbFloorDivisor)
|
||||
// TODO: symbolic extents when necessary.
|
||||
return false;
|
||||
return failure();
|
||||
|
||||
assert(lbFloorDivisor > 0 && "divisor always expected to be positive");
|
||||
|
||||
@@ -2227,7 +2229,7 @@ bool IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) {
|
||||
auto constLb = getConstantBound(BoundType::LB, d);
|
||||
auto constOtherLb = otherCst.getConstantBound(BoundType::LB, d);
|
||||
if (!constLb.has_value() || !constOtherLb.has_value())
|
||||
return false;
|
||||
return failure();
|
||||
std::fill(minLb.begin(), minLb.end(), 0);
|
||||
minLb.back() = std::min(*constLb, *constOtherLb);
|
||||
}
|
||||
@@ -2243,7 +2245,7 @@ bool IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) {
|
||||
auto constUb = getConstantBound(BoundType::UB, d);
|
||||
auto constOtherUb = otherCst.getConstantBound(BoundType::UB, d);
|
||||
if (!constUb.has_value() || !constOtherUb.has_value())
|
||||
return false;
|
||||
return failure();
|
||||
std::fill(maxUb.begin(), maxUb.end(), 0);
|
||||
maxUb.back() = std::max(*constUb, *constOtherUb);
|
||||
}
|
||||
@@ -2281,7 +2283,7 @@ bool IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) {
|
||||
// union (since the above are just the union along dimensions); we shouldn't
|
||||
// be discarding any other constraints on the symbols.
|
||||
|
||||
return true;
|
||||
return success();
|
||||
}
|
||||
|
||||
bool IntegerRelation::isColZero(unsigned pos) const {
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallBitVector.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/LogicalResult.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include <cassert>
|
||||
#include <functional>
|
||||
@@ -753,18 +754,18 @@ private:
|
||||
/// \___\|/ \_____/
|
||||
///
|
||||
///
|
||||
bool coalescePairCutCase(unsigned i, unsigned j);
|
||||
LogicalResult coalescePairCutCase(unsigned i, unsigned j);
|
||||
|
||||
/// Types the inequality `ineq` according to its `IneqType` for `simp` into
|
||||
/// `redundantIneqsB` and `cuttingIneqsB`. Returns success, if no separate
|
||||
/// inequalities were encountered. Otherwise, returns failure.
|
||||
bool typeInequality(ArrayRef<DynamicAPInt> ineq, Simplex &simp);
|
||||
LogicalResult typeInequality(ArrayRef<DynamicAPInt> ineq, Simplex &simp);
|
||||
|
||||
/// Types the equality `eq`, i.e. for `eq` == 0, types both `eq` >= 0 and
|
||||
/// -`eq` >= 0 according to their `IneqType` for `simp` into
|
||||
/// `redundantIneqsB` and `cuttingIneqsB`. Returns success, if no separate
|
||||
/// inequalities were encountered. Otherwise, returns failure.
|
||||
bool typeEquality(ArrayRef<DynamicAPInt> eq, Simplex &simp);
|
||||
LogicalResult typeEquality(ArrayRef<DynamicAPInt> eq, Simplex &simp);
|
||||
|
||||
/// Replaces the element at position `i` with the last element and erases
|
||||
/// the last element for both `disjuncts` and `simplices`.
|
||||
@@ -775,7 +776,7 @@ private:
|
||||
/// successfully coalesced. The simplices in `simplices` need to be the ones
|
||||
/// constructed from `disjuncts`. At this point, there are no empty
|
||||
/// disjuncts in `disjuncts` left.
|
||||
bool coalescePair(unsigned i, unsigned j);
|
||||
LogicalResult coalescePair(unsigned i, unsigned j);
|
||||
};
|
||||
|
||||
/// Constructs a `SetCoalescer` from a `PresburgerRelation`. Only adds non-empty
|
||||
@@ -818,7 +819,7 @@ PresburgerRelation SetCoalescer::coalesce() {
|
||||
cuttingIneqsB.clear();
|
||||
if (i == j)
|
||||
continue;
|
||||
if (coalescePair(i, j)) {
|
||||
if (coalescePair(i, j).succeeded()) {
|
||||
broken = true;
|
||||
break;
|
||||
}
|
||||
@@ -902,7 +903,7 @@ void SetCoalescer::addCoalescedDisjunct(unsigned i, unsigned j,
|
||||
/// \___\|/ \_____/
|
||||
///
|
||||
///
|
||||
bool SetCoalescer::coalescePairCutCase(unsigned i, unsigned j) {
|
||||
LogicalResult SetCoalescer::coalescePairCutCase(unsigned i, unsigned j) {
|
||||
/// All inequalities of `b` need to be redundant. We already know that the
|
||||
/// redundant ones are, so only the cutting ones remain to be checked.
|
||||
Simplex &simp = simplices[i];
|
||||
@@ -910,7 +911,7 @@ bool SetCoalescer::coalescePairCutCase(unsigned i, unsigned j) {
|
||||
if (llvm::any_of(cuttingIneqsA, [this, &simp](ArrayRef<DynamicAPInt> curr) {
|
||||
return !isFacetContained(curr, simp);
|
||||
}))
|
||||
return false;
|
||||
return failure();
|
||||
IntegerRelation newSet(disjunct.getSpace());
|
||||
|
||||
for (ArrayRef<DynamicAPInt> curr : redundantIneqsA)
|
||||
@@ -920,23 +921,25 @@ bool SetCoalescer::coalescePairCutCase(unsigned i, unsigned j) {
|
||||
newSet.addInequality(curr);
|
||||
|
||||
addCoalescedDisjunct(i, j, newSet);
|
||||
return true;
|
||||
return success();
|
||||
}
|
||||
|
||||
bool SetCoalescer::typeInequality(ArrayRef<DynamicAPInt> ineq, Simplex &simp) {
|
||||
LogicalResult SetCoalescer::typeInequality(ArrayRef<DynamicAPInt> ineq,
|
||||
Simplex &simp) {
|
||||
Simplex::IneqType type = simp.findIneqType(ineq);
|
||||
if (type == Simplex::IneqType::Redundant)
|
||||
redundantIneqsB.push_back(ineq);
|
||||
else if (type == Simplex::IneqType::Cut)
|
||||
cuttingIneqsB.push_back(ineq);
|
||||
else
|
||||
return false;
|
||||
return true;
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
bool SetCoalescer::typeEquality(ArrayRef<DynamicAPInt> eq, Simplex &simp) {
|
||||
if (!typeInequality(eq, simp))
|
||||
return false;
|
||||
LogicalResult SetCoalescer::typeEquality(ArrayRef<DynamicAPInt> eq,
|
||||
Simplex &simp) {
|
||||
if (typeInequality(eq, simp).failed())
|
||||
return failure();
|
||||
negEqs.push_back(getNegatedCoeffs(eq));
|
||||
ArrayRef<DynamicAPInt> inv(negEqs.back());
|
||||
return typeInequality(inv, simp);
|
||||
@@ -951,7 +954,7 @@ void SetCoalescer::eraseDisjunct(unsigned i) {
|
||||
simplices.pop_back();
|
||||
}
|
||||
|
||||
bool SetCoalescer::coalescePair(unsigned i, unsigned j) {
|
||||
LogicalResult SetCoalescer::coalescePair(unsigned i, unsigned j) {
|
||||
|
||||
IntegerRelation &a = disjuncts[i];
|
||||
IntegerRelation &b = disjuncts[j];
|
||||
@@ -959,7 +962,7 @@ bool SetCoalescer::coalescePair(unsigned i, unsigned j) {
|
||||
/// skipped.
|
||||
/// TODO: implement local id support.
|
||||
if (a.getNumLocalVars() != 0 || b.getNumLocalVars() != 0)
|
||||
return false;
|
||||
return failure();
|
||||
Simplex &simpA = simplices[i];
|
||||
Simplex &simpB = simplices[j];
|
||||
|
||||
@@ -969,34 +972,34 @@ bool SetCoalescer::coalescePair(unsigned i, unsigned j) {
|
||||
// inequality is encountered during typing, the two IntegerRelations
|
||||
// cannot be coalesced.
|
||||
for (int k = 0, e = a.getNumInequalities(); k < e; ++k)
|
||||
if (!typeInequality(a.getInequality(k), simpB))
|
||||
return false;
|
||||
if (typeInequality(a.getInequality(k), simpB).failed())
|
||||
return failure();
|
||||
|
||||
for (int k = 0, e = a.getNumEqualities(); k < e; ++k)
|
||||
if (!typeEquality(a.getEquality(k), simpB))
|
||||
return false;
|
||||
if (typeEquality(a.getEquality(k), simpB).failed())
|
||||
return failure();
|
||||
|
||||
std::swap(redundantIneqsA, redundantIneqsB);
|
||||
std::swap(cuttingIneqsA, cuttingIneqsB);
|
||||
|
||||
for (int k = 0, e = b.getNumInequalities(); k < e; ++k)
|
||||
if (!typeInequality(b.getInequality(k), simpA))
|
||||
return false;
|
||||
if (typeInequality(b.getInequality(k), simpA).failed())
|
||||
return failure();
|
||||
|
||||
for (int k = 0, e = b.getNumEqualities(); k < e; ++k)
|
||||
if (!typeEquality(b.getEquality(k), simpA))
|
||||
return false;
|
||||
if (typeEquality(b.getEquality(k), simpA).failed())
|
||||
return failure();
|
||||
|
||||
// If there are no cutting inequalities of `a`, `b` is contained
|
||||
// within `a`.
|
||||
if (cuttingIneqsA.empty()) {
|
||||
eraseDisjunct(j);
|
||||
return true;
|
||||
return success();
|
||||
}
|
||||
|
||||
// Try to apply the cut case
|
||||
if (coalescePairCutCase(i, j))
|
||||
return true;
|
||||
if (coalescePairCutCase(i, j).succeeded())
|
||||
return success();
|
||||
|
||||
// Swap the vectors to compare the pair (j,i) instead of (i,j).
|
||||
std::swap(redundantIneqsA, redundantIneqsB);
|
||||
@@ -1006,7 +1009,7 @@ bool SetCoalescer::coalescePair(unsigned i, unsigned j) {
|
||||
// within `a`.
|
||||
if (cuttingIneqsA.empty()) {
|
||||
eraseDisjunct(i);
|
||||
return true;
|
||||
return success();
|
||||
}
|
||||
|
||||
// Try to apply the cut case
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/Compiler.h"
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
#include "llvm/Support/LogicalResult.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include <cassert>
|
||||
#include <functional>
|
||||
@@ -229,7 +230,7 @@ Direction flippedDirection(Direction direction) {
|
||||
/// add these to the set of ignored columns and continue to the next row. If we
|
||||
/// run out of rows, then A*y is zero and we are done.
|
||||
MaybeOptimum<SmallVector<Fraction, 8>> LexSimplex::findRationalLexMin() {
|
||||
if (!restoreRationalConsistency()) {
|
||||
if (restoreRationalConsistency().failed()) {
|
||||
markEmpty();
|
||||
return OptimumKind::Empty;
|
||||
}
|
||||
@@ -274,7 +275,7 @@ MaybeOptimum<SmallVector<Fraction, 8>> LexSimplex::findRationalLexMin() {
|
||||
///
|
||||
/// The constraint is violated when added (it would be useless otherwise)
|
||||
/// so we immediately try to move it to a column.
|
||||
bool LexSimplexBase::addCut(unsigned row) {
|
||||
LogicalResult LexSimplexBase::addCut(unsigned row) {
|
||||
DynamicAPInt d = tableau(row, 0);
|
||||
unsigned cutRow = addZeroRow(/*makeRestricted=*/true);
|
||||
tableau(cutRow, 0) = d;
|
||||
@@ -301,7 +302,7 @@ std::optional<unsigned> LexSimplex::maybeGetNonIntegralVarRow() const {
|
||||
|
||||
MaybeOptimum<SmallVector<DynamicAPInt, 8>> LexSimplex::findIntegerLexMin() {
|
||||
// We first try to make the tableau consistent.
|
||||
if (!restoreRationalConsistency())
|
||||
if (restoreRationalConsistency().failed())
|
||||
return OptimumKind::Empty;
|
||||
|
||||
// Then, if the sample value is integral, we are done.
|
||||
@@ -316,9 +317,9 @@ MaybeOptimum<SmallVector<DynamicAPInt, 8>> LexSimplex::findIntegerLexMin() {
|
||||
//
|
||||
// Failure indicates that the tableau became empty, which occurs when the
|
||||
// polytope is integer empty.
|
||||
if (!addCut(*maybeRow))
|
||||
if (addCut(*maybeRow).failed())
|
||||
return OptimumKind::Empty;
|
||||
if (!restoreRationalConsistency())
|
||||
if (restoreRationalConsistency().failed())
|
||||
return OptimumKind::Empty;
|
||||
}
|
||||
|
||||
@@ -411,7 +412,7 @@ bool SymbolicLexSimplex::isSymbolicSampleIntegral(unsigned row) const {
|
||||
/// (sum_i (b_i%d)y_i - (-c%d) - sum_i (-a_i%d)s_i + q*d)/d >= 0
|
||||
/// This constraint is violated when added so we immediately try to move it to a
|
||||
/// column.
|
||||
bool SymbolicLexSimplex::addSymbolicCut(unsigned row) {
|
||||
LogicalResult SymbolicLexSimplex::addSymbolicCut(unsigned row) {
|
||||
DynamicAPInt d = tableau(row, 0);
|
||||
if (isRangeDivisibleBy(tableau.getRow(row).slice(3, nSymbol), d)) {
|
||||
// The coefficients of symbols in the symbol numerator are divisible
|
||||
@@ -523,11 +524,11 @@ std::optional<unsigned> SymbolicLexSimplex::maybeGetNonIntegralVarRow() {
|
||||
|
||||
/// The non-branching pivots are just the ones moving the rows
|
||||
/// that are always violated in the symbol domain.
|
||||
bool SymbolicLexSimplex::doNonBranchingPivots() {
|
||||
LogicalResult SymbolicLexSimplex::doNonBranchingPivots() {
|
||||
while (std::optional<unsigned> row = maybeGetAlwaysViolatedRow())
|
||||
if (!moveRowUnknownToColumn(*row))
|
||||
return false;
|
||||
return true;
|
||||
if (moveRowUnknownToColumn(*row).failed())
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
SymbolicLexOpt SymbolicLexSimplex::computeSymbolicIntegerLexMin() {
|
||||
@@ -567,7 +568,7 @@ SymbolicLexOpt SymbolicLexSimplex::computeSymbolicIntegerLexMin() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!doNonBranchingPivots()) {
|
||||
if (doNonBranchingPivots().failed()) {
|
||||
// Could not find pivots for violated constraints; return.
|
||||
--level;
|
||||
continue;
|
||||
@@ -627,7 +628,7 @@ SymbolicLexOpt SymbolicLexSimplex::computeSymbolicIntegerLexMin() {
|
||||
// The tableau is rationally consistent for the current domain.
|
||||
// Now we look for non-integral sample values and add cuts for them.
|
||||
if (std::optional<unsigned> row = maybeGetNonIntegralVarRow()) {
|
||||
if (!addSymbolicCut(*row)) {
|
||||
if (addSymbolicCut(*row).failed()) {
|
||||
// No integral points; return.
|
||||
--level;
|
||||
continue;
|
||||
@@ -661,7 +662,7 @@ SymbolicLexOpt SymbolicLexSimplex::computeSymbolicIntegerLexMin() {
|
||||
SmallVector<DynamicAPInt, 8> splitIneq =
|
||||
getComplementIneq(getSymbolicSampleIneq(u.pos));
|
||||
normalizeRange(splitIneq);
|
||||
if (!moveRowUnknownToColumn(u.pos)) {
|
||||
if (moveRowUnknownToColumn(u.pos).failed()) {
|
||||
// The unknown can't be made non-negative; return.
|
||||
--level;
|
||||
continue;
|
||||
@@ -699,13 +700,13 @@ std::optional<unsigned> LexSimplex::maybeGetViolatedRow() const {
|
||||
/// We simply look for violated rows and keep trying to move them to column
|
||||
/// orientation, which always succeeds unless the constraints have no solution
|
||||
/// in which case we just give up and return.
|
||||
bool LexSimplex::restoreRationalConsistency() {
|
||||
LogicalResult LexSimplex::restoreRationalConsistency() {
|
||||
if (empty)
|
||||
return false;
|
||||
return failure();
|
||||
while (std::optional<unsigned> maybeViolatedRow = maybeGetViolatedRow())
|
||||
if (!moveRowUnknownToColumn(*maybeViolatedRow))
|
||||
return false;
|
||||
return true;
|
||||
if (moveRowUnknownToColumn(*maybeViolatedRow).failed())
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
// Move the row unknown to column orientation while preserving lexicopositivity
|
||||
@@ -770,7 +771,7 @@ bool LexSimplex::restoreRationalConsistency() {
|
||||
// which is in contradiction to the fact that B.col(j) / B(i,j) must be
|
||||
// lexicographically smaller than B.col(k) / B(i,k), since it lexicographically
|
||||
// minimizes the change in sample value.
|
||||
bool LexSimplexBase::moveRowUnknownToColumn(unsigned row) {
|
||||
LogicalResult LexSimplexBase::moveRowUnknownToColumn(unsigned row) {
|
||||
std::optional<unsigned> maybeColumn;
|
||||
for (unsigned col = 3 + nSymbol, e = getNumColumns(); col < e; ++col) {
|
||||
if (tableau(row, col) <= 0)
|
||||
@@ -780,10 +781,10 @@ bool LexSimplexBase::moveRowUnknownToColumn(unsigned row) {
|
||||
}
|
||||
|
||||
if (!maybeColumn)
|
||||
return false;
|
||||
return failure();
|
||||
|
||||
pivot(row, *maybeColumn);
|
||||
return true;
|
||||
return success();
|
||||
}
|
||||
|
||||
unsigned LexSimplexBase::getLexMinPivotColumn(unsigned row, unsigned colA,
|
||||
@@ -986,7 +987,7 @@ void SimplexBase::pivot(unsigned pivotRow, unsigned pivotCol) {
|
||||
/// Perform pivots until the unknown has a non-negative sample value or until
|
||||
/// no more upward pivots can be performed. Return success if we were able to
|
||||
/// bring the row to a non-negative sample value, and failure otherwise.
|
||||
bool Simplex::restoreRow(Unknown &u) {
|
||||
LogicalResult Simplex::restoreRow(Unknown &u) {
|
||||
assert(u.orientation == Orientation::Row &&
|
||||
"unknown should be in row position");
|
||||
|
||||
@@ -997,9 +998,9 @@ bool Simplex::restoreRow(Unknown &u) {
|
||||
|
||||
pivot(*maybePivot);
|
||||
if (u.orientation == Orientation::Column)
|
||||
return true; // the unknown is unbounded above.
|
||||
return success(); // the unknown is unbounded above.
|
||||
}
|
||||
return tableau(u.pos, 1) >= 0;
|
||||
return success(tableau(u.pos, 1) >= 0);
|
||||
}
|
||||
|
||||
/// Find a row that can be used to pivot the column in the specified direction.
|
||||
@@ -1105,8 +1106,8 @@ void SimplexBase::markEmpty() {
|
||||
/// empty and we mark it as such.
|
||||
void Simplex::addInequality(ArrayRef<DynamicAPInt> coeffs) {
|
||||
unsigned conIndex = addRow(coeffs, /*makeRestricted=*/true);
|
||||
bool result = restoreRow(con[conIndex]);
|
||||
if (!result)
|
||||
LogicalResult result = restoreRow(con[conIndex]);
|
||||
if (result.failed())
|
||||
markEmpty();
|
||||
}
|
||||
|
||||
@@ -1384,7 +1385,7 @@ MaybeOptimum<Fraction> Simplex::computeOptimum(Direction direction,
|
||||
MaybeOptimum<Fraction> optimum = computeRowOptimum(direction, row);
|
||||
if (u.restricted && direction == Direction::Down &&
|
||||
(optimum.isUnbounded() || *optimum < Fraction(0, 1))) {
|
||||
if (!restoreRow(u))
|
||||
if (restoreRow(u).failed())
|
||||
llvm_unreachable("Could not restore row!");
|
||||
}
|
||||
return optimum;
|
||||
@@ -1453,7 +1454,7 @@ void Simplex::detectRedundant(unsigned offset, unsigned count) {
|
||||
if (minimum.isUnbounded() || *minimum < Fraction(0, 1)) {
|
||||
// Constraint is unbounded below or can attain negative sample values and
|
||||
// hence is not redundant.
|
||||
if (!restoreRow(u))
|
||||
if (restoreRow(u).failed())
|
||||
llvm_unreachable("Could not restore non-redundant row!");
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
#include "mlir/Analysis/Presburger/PresburgerSpace.h"
|
||||
#include "llvm/ADT/STLFunctionalExtras.h"
|
||||
#include "llvm/ADT/SmallBitVector.h"
|
||||
#include "llvm/Support/LogicalResult.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include <cassert>
|
||||
#include <cstdint>
|
||||
@@ -95,10 +96,10 @@ static void normalizeDivisionByGCD(MutableArrayRef<DynamicAPInt> dividend,
|
||||
/// If successful, `expr` is set to dividend of the division and `divisor` is
|
||||
/// set to the denominator of the division, which will be positive.
|
||||
/// The final division expression is normalized by GCD.
|
||||
static bool getDivRepr(const IntegerRelation &cst, unsigned pos,
|
||||
unsigned ubIneq, unsigned lbIneq,
|
||||
MutableArrayRef<DynamicAPInt> expr,
|
||||
DynamicAPInt &divisor) {
|
||||
static LogicalResult getDivRepr(const IntegerRelation &cst, unsigned pos,
|
||||
unsigned ubIneq, unsigned lbIneq,
|
||||
MutableArrayRef<DynamicAPInt> expr,
|
||||
DynamicAPInt &divisor) {
|
||||
|
||||
assert(pos <= cst.getNumVars() && "Invalid variable position");
|
||||
assert(ubIneq <= cst.getNumInequalities() &&
|
||||
@@ -120,7 +121,7 @@ static bool getDivRepr(const IntegerRelation &cst, unsigned pos,
|
||||
break;
|
||||
|
||||
if (i < e)
|
||||
return false;
|
||||
return failure();
|
||||
|
||||
// Then, check if the constant term is of the proper form.
|
||||
// Due to the form of the upper/lower bound inequalities, the sum of their
|
||||
@@ -132,7 +133,7 @@ static bool getDivRepr(const IntegerRelation &cst, unsigned pos,
|
||||
// Check if `c` satisfies the condition `0 <= c <= divisor - 1`.
|
||||
// This also implictly checks that `divisor` is positive.
|
||||
if (!(0 <= c && c <= divisor - 1)) // NOLINT
|
||||
return false;
|
||||
return failure();
|
||||
|
||||
// The inequality pair can be used to extract the division.
|
||||
// Set `expr` to the dividend of the division except the constant term, which
|
||||
@@ -147,7 +148,7 @@ static bool getDivRepr(const IntegerRelation &cst, unsigned pos,
|
||||
expr.back() = cst.atIneq(ubIneq, cst.getNumCols() - 1) + c;
|
||||
normalizeDivisionByGCD(expr, divisor);
|
||||
|
||||
return true;
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Check if the pos^th variable can be represented as a division using
|
||||
@@ -161,9 +162,10 @@ static bool getDivRepr(const IntegerRelation &cst, unsigned pos,
|
||||
/// If successful, `expr` is set to dividend of the division and `divisor` is
|
||||
/// set to the denominator of the division. The final division expression is
|
||||
/// normalized by GCD.
|
||||
static bool getDivRepr(const IntegerRelation &cst, unsigned pos, unsigned eqInd,
|
||||
MutableArrayRef<DynamicAPInt> expr,
|
||||
DynamicAPInt &divisor) {
|
||||
static LogicalResult getDivRepr(const IntegerRelation &cst, unsigned pos,
|
||||
unsigned eqInd,
|
||||
MutableArrayRef<DynamicAPInt> expr,
|
||||
DynamicAPInt &divisor) {
|
||||
|
||||
assert(pos <= cst.getNumVars() && "Invalid variable position");
|
||||
assert(eqInd <= cst.getNumEqualities() && "Invalid equality position");
|
||||
@@ -174,7 +176,7 @@ static bool getDivRepr(const IntegerRelation &cst, unsigned pos, unsigned eqInd,
|
||||
// Equality must involve the pos-th variable and hence `tempDiv` != 0.
|
||||
DynamicAPInt tempDiv = cst.atEq(eqInd, pos);
|
||||
if (tempDiv == 0)
|
||||
return false;
|
||||
return failure();
|
||||
int signDiv = tempDiv < 0 ? -1 : 1;
|
||||
|
||||
// The divisor is always a positive integer.
|
||||
@@ -187,7 +189,7 @@ static bool getDivRepr(const IntegerRelation &cst, unsigned pos, unsigned eqInd,
|
||||
expr.back() = -signDiv * cst.atEq(eqInd, cst.getNumCols() - 1);
|
||||
normalizeDivisionByGCD(expr, divisor);
|
||||
|
||||
return true;
|
||||
return success();
|
||||
}
|
||||
|
||||
// Returns `false` if the constraints depends on a variable for which an
|
||||
@@ -238,7 +240,7 @@ MaybeLocalRepr presburger::computeSingleVarRepr(
|
||||
for (unsigned ubPos : ubIndices) {
|
||||
for (unsigned lbPos : lbIndices) {
|
||||
// Attempt to get divison representation from ubPos, lbPos.
|
||||
if (!getDivRepr(cst, pos, ubPos, lbPos, dividend, divisor))
|
||||
if (getDivRepr(cst, pos, ubPos, lbPos, dividend, divisor).failed())
|
||||
continue;
|
||||
|
||||
if (!checkExplicitRepresentation(cst, foundRepr, dividend, pos))
|
||||
@@ -251,7 +253,7 @@ MaybeLocalRepr presburger::computeSingleVarRepr(
|
||||
}
|
||||
for (unsigned eqPos : eqIndices) {
|
||||
// Attempt to get divison representation from eqPos.
|
||||
if (!getDivRepr(cst, pos, eqPos, dividend, divisor))
|
||||
if (getDivRepr(cst, pos, eqPos, dividend, divisor).failed())
|
||||
continue;
|
||||
|
||||
if (!checkExplicitRepresentation(cst, foundRepr, dividend, pos))
|
||||
|
||||
Reference in New Issue
Block a user