//===- PWMAFunction.cpp - MLIR PWMAFunction Class -------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Analysis/Presburger/PWMAFunction.h" #include "mlir/Analysis/Presburger/Simplex.h" using namespace mlir; using namespace presburger; void MultiAffineFunction::assertIsConsistent() const { assert(space.getNumVars() - space.getNumRangeVars() + 1 == output.getNumColumns() && "Inconsistent number of output columns"); assert(space.getNumDomainVars() + space.getNumSymbolVars() == divs.getNumNonDivs() && "Inconsistent number of non-division variables in divs"); assert(space.getNumRangeVars() == output.getNumRows() && "Inconsistent number of output rows"); assert(space.getNumLocalVars() == divs.getNumDivs() && "Inconsistent number of divisions."); assert(divs.hasAllReprs() && "All divisions should have a representation"); } // Return the result of subtracting the two given vectors pointwise. // The vectors must be of the same size. // e.g., [3, 4, 6] - [2, 5, 1] = [1, -1, 5]. static SmallVector subtractExprs(ArrayRef vecA, ArrayRef vecB) { assert(vecA.size() == vecB.size() && "Cannot subtract vectors of differing lengths!"); SmallVector result; result.reserve(vecA.size()); for (unsigned i = 0, e = vecA.size(); i < e; ++i) result.push_back(vecA[i] - vecB[i]); return result; } PresburgerSet PWMAFunction::getDomain() const { PresburgerSet domain = PresburgerSet::getEmpty(getDomainSpace()); for (const Piece &piece : pieces) domain.unionInPlace(piece.domain); return domain; } void MultiAffineFunction::print(raw_ostream &os) const { space.print(os); os << "Division Representation:\n"; divs.print(os); os << "Output:\n"; output.print(os); } SmallVector MultiAffineFunction::valueAt(ArrayRef point) const { assert(point.size() == getNumDomainVars() + getNumSymbolVars() && "Point has incorrect dimensionality!"); SmallVector pointHomogenous{llvm::to_vector(point)}; // Get the division values at this point. SmallVector, 8> divValues = divs.divValuesAt(point); // The given point didn't include the values of the divs which the output is a // function of; we have computed one possible set of values and use them here. pointHomogenous.reserve(pointHomogenous.size() + divValues.size()); for (const Optional &divVal : divValues) pointHomogenous.push_back(*divVal); // The matrix `output` has an affine expression in the ith row, corresponding // to the expression for the ith value in the output vector. The last column // of the matrix contains the constant term. Let v be the input point with // a 1 appended at the end. We can see that output * v gives the desired // output vector. pointHomogenous.emplace_back(1); SmallVector result = output.postMultiplyWithColumn(pointHomogenous); assert(result.size() == getNumOutputs()); return result; } bool MultiAffineFunction::isEqual(const MultiAffineFunction &other) const { assert(space.isCompatible(other.space) && "Spaces should be compatible for equality check."); return getAsRelation().isEqual(other.getAsRelation()); } bool MultiAffineFunction::isEqual(const MultiAffineFunction &other, const IntegerPolyhedron &domain) const { assert(space.isCompatible(other.space) && "Spaces should be compatible for equality check."); IntegerRelation restrictedThis = getAsRelation(); restrictedThis.intersectDomain(domain); IntegerRelation restrictedOther = other.getAsRelation(); restrictedOther.intersectDomain(domain); return restrictedThis.isEqual(restrictedOther); } bool MultiAffineFunction::isEqual(const MultiAffineFunction &other, const PresburgerSet &domain) const { assert(space.isCompatible(other.space) && "Spaces should be compatible for equality check."); return llvm::all_of(domain.getAllDisjuncts(), [&](const IntegerRelation &disjunct) { return isEqual(other, IntegerPolyhedron(disjunct)); }); } void MultiAffineFunction::removeOutputs(unsigned start, unsigned end) { assert(end <= getNumOutputs() && "Invalid range"); if (start >= end) return; space.removeVarRange(VarKind::Range, start, end); output.removeRows(start, end - start); } void MultiAffineFunction::mergeDivs(MultiAffineFunction &other) { assert(space.isCompatible(other.space) && "Functions should be compatible"); unsigned nDivs = getNumDivs(); unsigned divOffset = divs.getDivOffset(); other.divs.insertDiv(0, nDivs); SmallVector div(other.divs.getNumVars() + 1); for (unsigned i = 0; i < nDivs; ++i) { // Zero fill. std::fill(div.begin(), div.end(), 0); // Fill div with dividend from `divs`. Do not fill the constant. std::copy(divs.getDividend(i).begin(), divs.getDividend(i).end() - 1, div.begin()); // Fill constant. div.back() = divs.getDividend(i).back(); other.divs.setDiv(i, div, divs.getDenom(i)); } other.space.insertVar(VarKind::Local, 0, nDivs); other.output.insertColumns(divOffset, nDivs); auto merge = [&](unsigned i, unsigned j) { // We only merge from local at pos j to local at pos i, where j > i. if (i >= j) return false; // If i < nDivs, we are trying to merge duplicate divs in `this`. Since we // do not want to merge duplicates in `this`, we ignore this call. if (j < nDivs) return false; // Merge things in space and output. other.space.removeVarRange(VarKind::Local, j, j + 1); other.output.addToColumn(divOffset + i, divOffset + j, 1); other.output.removeColumn(divOffset + j); return true; }; other.divs.removeDuplicateDivs(merge); unsigned newDivs = other.divs.getNumDivs() - nDivs; space.insertVar(VarKind::Local, nDivs, newDivs); output.insertColumns(divOffset + nDivs, newDivs); divs = other.divs; // Check consistency. assertIsConsistent(); other.assertIsConsistent(); } /// Two PWMAFunctions are equal if they have the same dimensionalities, /// the same domain, and take the same value at every point in the domain. bool PWMAFunction::isEqual(const PWMAFunction &other) const { if (!space.isCompatible(other.space)) return false; if (!this->getDomain().isEqual(other.getDomain())) return false; // Check if, whenever the domains of a piece of `this` and a piece of `other` // overlap, they take the same output value. If `this` and `other` have the // same domain (checked above), then this check passes iff the two functions // have the same output at every point in the domain. return llvm::all_of(this->pieces, [&other](const Piece &pieceA) { return llvm::all_of(other.pieces, [&pieceA](const Piece &pieceB) { PresburgerSet commonDomain = pieceA.domain.intersect(pieceB.domain); return pieceA.output.isEqual(pieceB.output, commonDomain); }); }); } void PWMAFunction::addPiece(const Piece &piece) { assert(piece.isConsistent() && "Piece should be consistent"); pieces.push_back(piece); } void PWMAFunction::print(raw_ostream &os) const { space.print(os); os << getNumPieces() << " pieces:\n"; for (const Piece &piece : pieces) { os << "Domain of piece:\n"; piece.domain.print(os); os << "Output of piece\n"; piece.output.print(os); } } void PWMAFunction::dump() const { print(llvm::errs()); } PWMAFunction PWMAFunction::unionFunction( const PWMAFunction &func, llvm::function_ref tiebreak) const { assert(getNumOutputs() == func.getNumOutputs() && "Ranges of functions should be same."); assert(getSpace().isCompatible(func.getSpace()) && "Space is not compatible."); // The algorithm used here is as follows: // - Add the output of pieceB for the part of the domain where both pieceA and // pieceB are defined, and `tiebreak` chooses the output of pieceB. // - Add the output of pieceA, where pieceB is not defined or `tiebreak` // chooses // pieceA over pieceB. // - Add the output of pieceB, where pieceA is not defined. // Add parts of the common domain where pieceB's output is used. Also // add all the parts where pieceA's output is used, both common and // non-common. PWMAFunction result(getSpace()); for (const Piece &pieceA : pieces) { PresburgerSet dom(pieceA.domain); for (const Piece &pieceB : func.pieces) { PresburgerSet better = tiebreak(pieceB, pieceA); // Add the output of pieceB, where it is better than output of pieceA. // The disjuncts in "better" will be disjoint as tiebreak should gurantee // that. result.addPiece({better, pieceB.output}); dom = dom.subtract(better); } // Add output of pieceA, where it is better than pieceB, or pieceB is not // defined. // // `dom` here is guranteed to be disjoint from already added pieces // because because the pieces added before are either: // - Subsets of the domain of other MAFs in `this`, which are guranteed // to be disjoint from `dom`, or // - They are one of the pieces added for `pieceB`, and we have been // subtracting all such pieces from `dom`, so `dom` is disjoint from those // pieces as well. result.addPiece({dom, pieceA.output}); } // Add parts of pieceB which are not shared with pieceA. PresburgerSet dom = getDomain(); for (const Piece &pieceB : func.pieces) result.addPiece({pieceB.domain.subtract(dom), pieceB.output}); return result; } /// A tiebreak function which breaks ties by comparing the outputs /// lexicographically. If `lexMin` is true, then the ties are broken by /// taking the lexicographically smaller output and otherwise, by taking the /// lexicographically larger output. template static PresburgerSet tiebreakLex(const PWMAFunction::Piece &pieceA, const PWMAFunction::Piece &pieceB) { // TODO: Support local variables here. assert(pieceA.output.getSpace().isCompatible(pieceB.output.getSpace()) && "Pieces should be compatible"); assert(pieceA.domain.getSpace().getNumLocalVars() == 0 && "Local variables are not supported yet."); PresburgerSpace compatibleSpace = pieceA.domain.getSpace(); const PresburgerSpace &space = pieceA.domain.getSpace(); // We first create the set `result`, corresponding to the set where output // of pieceA is lexicographically larger/smaller than pieceB. This is done by // creating a PresburgerSet with the following constraints: // // (outA[0] > outB[0]) U // (outA[0] = outB[0], outA[1] > outA[1]) U // (outA[0] = outB[0], outA[1] = outA[1], outA[2] > outA[2]) U // ... // (outA[0] = outB[0], ..., outA[n-2] = outB[n-2], outA[n-1] > outB[n-1]) // // where `n` is the number of outputs. // If `lexMin` is set, the complement inequality is used: // // (outA[0] < outB[0]) U // (outA[0] = outB[0], outA[1] < outA[1]) U // (outA[0] = outB[0], outA[1] = outA[1], outA[2] < outA[2]) U // ... // (outA[0] = outB[0], ..., outA[n-2] = outB[n-2], outA[n-1] < outB[n-1]) PresburgerSet result = PresburgerSet::getEmpty(compatibleSpace); IntegerPolyhedron levelSet( /*numReservedInequalities=*/1, /*numReservedEqualities=*/pieceA.output.getNumOutputs(), /*numReservedCols=*/space.getNumVars() + 1, space); for (unsigned level = 0; level < pieceA.output.getNumOutputs(); ++level) { // Create the expression `outA - outB` for this level. SmallVector subExpr = subtractExprs( pieceA.output.getOutputExpr(level), pieceB.output.getOutputExpr(level)); if (lexMin) { // For lexMin, we add an upper bound of -1: // outA - outB <= -1 // outA <= outB - 1 // outA < outB levelSet.addBound(IntegerPolyhedron::BoundType::UB, subExpr, MPInt(-1)); } else { // For lexMax, we add a lower bound of 1: // outA - outB >= 1 // outA > outB + 1 // outA > outB levelSet.addBound(IntegerPolyhedron::BoundType::LB, subExpr, MPInt(1)); } // Union the set with the result. result.unionInPlace(levelSet); // There is only 1 inequality in `levelSet`, so the index is always 0. levelSet.removeInequality(0); // Add equality `outA - outB == 0` for this level for next iteration. levelSet.addEquality(subExpr); } // We then intersect `result` with the domain of pieceA and pieceB, to only // tiebreak on the domain where both are defined. result = result.intersect(pieceA.domain).intersect(pieceB.domain); return result; } PWMAFunction PWMAFunction::unionLexMin(const PWMAFunction &func) { return unionFunction(func, tiebreakLex); } PWMAFunction PWMAFunction::unionLexMax(const PWMAFunction &func) { return unionFunction(func, tiebreakLex); } void MultiAffineFunction::subtract(const MultiAffineFunction &other) { assert(space.isCompatible(other.space) && "Spaces should be compatible for subtraction."); MultiAffineFunction copyOther = other; mergeDivs(copyOther); for (unsigned i = 0, e = getNumOutputs(); i < e; ++i) output.addToRow(i, copyOther.getOutputExpr(i), MPInt(-1)); // Check consistency. assertIsConsistent(); } /// Adds division constraints corresponding to local variables, given a /// relation and division representations of the local variables in the /// relation. static void addDivisionConstraints(IntegerRelation &rel, const DivisionRepr &divs) { assert(divs.hasAllReprs() && "All divisions in divs should have a representation"); assert(rel.getNumVars() == divs.getNumVars() && "Relation and divs should have the same number of vars"); assert(rel.getNumLocalVars() == divs.getNumDivs() && "Relation and divs should have the same number of local vars"); for (unsigned i = 0, e = divs.getNumDivs(); i < e; ++i) { rel.addInequality(getDivUpperBound(divs.getDividend(i), divs.getDenom(i), divs.getDivOffset() + i)); rel.addInequality(getDivLowerBound(divs.getDividend(i), divs.getDenom(i), divs.getDivOffset() + i)); } } IntegerRelation MultiAffineFunction::getAsRelation() const { // Create a relation corressponding to the input space plus the divisions // used in outputs. IntegerRelation result(PresburgerSpace::getRelationSpace( space.getNumDomainVars(), 0, space.getNumSymbolVars(), space.getNumLocalVars())); // Add division constraints corresponding to divisions used in outputs. addDivisionConstraints(result, divs); // The outputs are represented as range variables in the relation. We add // range variables for the outputs. result.insertVar(VarKind::Range, 0, getNumOutputs()); // Add equalities such that the i^th range variable is equal to the i^th // output expression. SmallVector eq(result.getNumCols()); for (unsigned i = 0, e = getNumOutputs(); i < e; ++i) { // TODO: Add functions to get VarKind offsets in output in MAF and use them // here. // The output expression does not contain range variables, while the // equality does. So, we need to copy all variables and mark all range // variables as 0 in the equality. ArrayRef expr = getOutputExpr(i); // Copy domain variables in `expr` to domain variables in `eq`. std::copy(expr.begin(), expr.begin() + getNumDomainVars(), eq.begin()); // Fill the range variables in `eq` as zero. std::fill(eq.begin() + result.getVarKindOffset(VarKind::Range), eq.begin() + result.getVarKindEnd(VarKind::Range), 0); // Copy remaining variables in `expr` to the remaining variables in `eq`. std::copy(expr.begin() + getNumDomainVars(), expr.end(), eq.begin() + result.getVarKindEnd(VarKind::Range)); // Set the i^th range var to -1 in `eq` to equate the output expression to // this range var. eq[result.getVarKindOffset(VarKind::Range) + i] = -1; // Add the equality `rangeVar_i = output[i]`. result.addEquality(eq); } return result; } void PWMAFunction::removeOutputs(unsigned start, unsigned end) { space.removeVarRange(VarKind::Range, start, end); for (Piece &piece : pieces) piece.output.removeOutputs(start, end); } Optional> PWMAFunction::valueAt(ArrayRef point) const { assert(point.size() == getNumDomainVars() + getNumSymbolVars()); for (const Piece &piece : pieces) if (piece.domain.containsPoint(point)) return piece.output.valueAt(point); return std::nullopt; }