[mlir][affine|ValueBounds] Add transform to simplify affine min max ops with ValueBoundsOpInterface (#145068)
This commit makes the following changes:
- Expose `map` and `mapOperands` in
`ValueBoundsConstraintSet::Variable`, so that the class can be used by
subclasses of `ValueBoundsConstraintSet`. Otherwise subclasses cannot
access those members.
- Add `ValueBoundsConstraintSet::strongCompare`. This method is similar
to `ValueBoundsConstraintSet::compare` except that it returns false when
the inverse comparison holds, and `llvm::failure()` if neither the
relation nor its inverse relation could be proven.
- Add `simplifyAffineMinOp`, `simplifyAffineMaxOp`, and
`simplifyAffineMinMaxOps` to simplify those operations using
`ValueBoundsConstraintSet`.
- Adds the `SimplifyMinMaxAffineOpsOp` transform op that uses
`simplifyAffineMinMaxOps`.
- Add the `test.value_with_bounds` op to test unknown values with a min
max range using `ValueBoundsOpInterface`.
- Adds tests verifying the transform.
Example:
```mlir
func.func @overlapping_constraints() -> (index, index) {
%0 = test.value_with_bounds {min = 0 : index, max = 192 : index}
%1 = test.value_with_bounds {min = 128 : index, max = 384 : index}
%2 = test.value_with_bounds {min = 256 : index, max = 512 : index}
%r0 = affine.min affine_map<()[s0, s1, s2] -> (s0, s1, s2)>()[%0, %1, %2]
%r1 = affine.max affine_map<()[s0, s1, s2] -> (s0, s1, s2)>()[%0, %1, %2]
return %r0, %r1 : index, index
}
// Result of applying `simplifyAffineMinMaxOps` to `func.func`
#map1 = affine_map<()[s0, s1] -> (s1, s0)>
func.func @overlapping_constraints() -> (index, index) {
%0 = test.value_with_bounds {max = 192 : index, min = 0 : index}
%1 = test.value_with_bounds {max = 384 : index, min = 128 : index}
%2 = test.value_with_bounds {max = 512 : index, min = 256 : index}
%3 = affine.min #map1()[%0, %1]
%4 = affine.max #map1()[%1, %2]
return %3, %4 : index, index
}
```
---------
Co-authored-by: Nicolas Vasilache <Nico.Vasilache@amd.com>
This commit is contained in:
@@ -63,4 +63,35 @@ def SimplifyBoundedAffineOpsOp
|
||||
}];
|
||||
}
|
||||
|
||||
def SimplifyMinMaxAffineOpsOp :
|
||||
Op<Transform_Dialect, "affine.simplify_min_max_affine_ops", [
|
||||
DeclareOpInterfaceMethods<TransformOpInterface>,
|
||||
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
|
||||
]> {
|
||||
let description = [{
|
||||
Simplify the targeted `affine.min` / `affine.max` ops using the
|
||||
`mlir::affine::simplifyAffineMinMaxOps` transform.
|
||||
|
||||
Example:
|
||||
```
|
||||
%0 = transform.structured.match ops{["affine.max"]} in %arg1
|
||||
transform.affine.simplify_min_max_affine_ops %0 : !transform.any_op
|
||||
```
|
||||
|
||||
#### Return modes
|
||||
|
||||
This transform consumes the target handle and does not produce any results.
|
||||
This transforms definitely fails if any of the targeted operations is not an
|
||||
`affine.min` or `affine.max` operation, or if the canonicalization patterns
|
||||
failed to converge.
|
||||
This transform silently fails if none of the operations were simplified.
|
||||
Otherwise, it succeeds.
|
||||
}];
|
||||
let arguments = (ins TransformHandleTypeInterface:$target);
|
||||
let results = (outs);
|
||||
let assemblyFormat = [{
|
||||
$target attr-dict `:` type($target)
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // Affine_TRANSFORM_OPS
|
||||
|
||||
@@ -34,6 +34,8 @@ namespace affine {
|
||||
class AffineApplyOp;
|
||||
class AffineDelinearizeIndexOp;
|
||||
class AffineLinearizeIndexOp;
|
||||
class AffineMaxOp;
|
||||
class AffineMinOp;
|
||||
|
||||
/// Lowers `affine.delinearize_index` into a sequence of division and remainder
|
||||
/// operations.
|
||||
@@ -127,6 +129,37 @@ OpFoldResult materializeComputedBound(
|
||||
OpBuilder &b, Location loc, AffineMap boundMap,
|
||||
ArrayRef<std::pair<Value, std::optional<int64_t>>> mapOperands);
|
||||
|
||||
/// This transform tries to simplify the affine min operation `op`, by finding a
|
||||
/// common lower bound for a set of expressions in the affine map results. It
|
||||
/// returns whether the transform updated `op`'s affine map.
|
||||
///
|
||||
/// In concrete terms, given an operation like:
|
||||
/// `affine.min affine_map<(d0)[s0, s1] -> (d0, s1, s0, 128)>(%i)[%s0, %s1]`
|
||||
/// If `d0 < 128` and `128 < s1 < s0`, the transform will update `op` to:
|
||||
/// `affine.min affine_map<(d0)[s0, s1] -> (d0, 128)>(%i)[%s0, %s1]`.
|
||||
bool simplifyAffineMinOp(RewriterBase &rewriter, AffineMinOp op);
|
||||
|
||||
/// This transform tries to simplify the affine max operation `op`, by finding a
|
||||
/// common upper bound for a set of expressions in the affine map results. It
|
||||
/// returns whether the transform updated `op`'s affine map.
|
||||
///
|
||||
/// In concrete terms, given an operation like:
|
||||
/// `affine.max affine_map<(d0)[s0, s1] -> (d0, s1, s0, 128)>(%i)[%s0, %s1]`
|
||||
/// If `d0 > 128` and `s0 > s1 > 128`, the transform will update `op` to:
|
||||
/// `affine.max affine_map<(d0)[s0, s1] -> (d0, s0)>(%i)[%s0, %s1]`.
|
||||
bool simplifyAffineMaxOp(RewriterBase &rewriter, AffineMaxOp op);
|
||||
|
||||
/// This transform applies `simplifyAffineMinOp` and `simplifyAffineMaxOp` to
|
||||
/// all the `affine.min` or `affine.max` operations in `ops`. After
|
||||
/// simplification, it invokes the `affine.min/max` canonicalization patterns on
|
||||
/// `ops`.
|
||||
///
|
||||
/// This transform returns failure if the greedy pattern rewriter failed to
|
||||
/// converge during canonicalization, otherwise it returns success. If provided,
|
||||
/// `modified` is set to `true` if the IR was modified in any way.
|
||||
LogicalResult simplifyAffineMinMaxOps(RewriterBase &rewriter,
|
||||
ArrayRef<Operation *> ops,
|
||||
bool *modified = nullptr);
|
||||
} // namespace affine
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
@@ -135,10 +135,17 @@ public:
|
||||
|
||||
/// Construct a variable for a map and its operands.
|
||||
Variable(AffineMap map, ArrayRef<Variable> mapOperands);
|
||||
Variable(AffineMap map, ArrayRef<Value> mapOperands);
|
||||
Variable(AffineMap map, ValueRange mapOperands);
|
||||
|
||||
MLIRContext *getContext() const { return map.getContext(); }
|
||||
|
||||
/// Returns the affine map.
|
||||
AffineMap getMap() const { return map; }
|
||||
|
||||
/// Returns the map operands.
|
||||
ValueDimList &getOperands() { return mapOperands; }
|
||||
const ValueDimList &getOperands() const { return mapOperands; }
|
||||
|
||||
private:
|
||||
friend class ValueBoundsConstraintSet;
|
||||
AffineMap map;
|
||||
@@ -254,6 +261,12 @@ public:
|
||||
/// prove the relation or until it ran out of IR.
|
||||
static bool compare(const Variable &lhs, ComparisonOperator cmp,
|
||||
const Variable &rhs);
|
||||
/// This function is similar to `ValueBoundsConstraintSet::compare`, except
|
||||
/// that it returns false if `!(lhs cmp rhs)`, and `failure` if neither the
|
||||
/// relation nor its inverse relation could be proven.
|
||||
static llvm::FailureOr<bool> strongCompare(const Variable &lhs,
|
||||
ComparisonOperator cmp,
|
||||
const Variable &rhs);
|
||||
|
||||
/// Compute whether the given variables are equal. Return "failure" if
|
||||
/// equality could not be determined.
|
||||
@@ -327,6 +340,16 @@ protected:
|
||||
/// constraints.
|
||||
bool comparePos(int64_t lhsPos, ComparisonOperator cmp, int64_t rhsPos);
|
||||
|
||||
/// Return "true" if, based on the current state of the constraint system,
|
||||
/// "lhs cmp rhs" was proven to hold. It returns "false" if "!(lhs cmp rhs)"
|
||||
/// can be proven. Otherwise, it returns `failure` if neither the relation nor
|
||||
/// its inverse relation could be proven.
|
||||
///
|
||||
/// This function does not analyze any IR and does not populate any additional
|
||||
/// constraints.
|
||||
llvm::FailureOr<bool> strongComparePos(int64_t lhsPos, ComparisonOperator cmp,
|
||||
int64_t rhsPos);
|
||||
|
||||
/// Given an affine map with a single result (and map operands), add a new
|
||||
/// column to the constraint set that represents the result of the map.
|
||||
/// Traverse additional IR starting from the map operands as needed (as long
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
|
||||
#include "mlir/Dialect/Affine/LoopUtils.h"
|
||||
#include "mlir/Dialect/Affine/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
|
||||
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
@@ -112,7 +113,7 @@ SimplifyBoundedAffineOpsOp::apply(transform::TransformRewriter &rewriter,
|
||||
}
|
||||
if (boundedOps.contains(target)) {
|
||||
auto diag = emitDefiniteFailure()
|
||||
<< "target op result must not be constrainted";
|
||||
<< "target op result must not be constrained";
|
||||
diag.attachNote(target->getLoc()) << "target/constrained op";
|
||||
return diag;
|
||||
}
|
||||
@@ -148,6 +149,42 @@ void SimplifyBoundedAffineOpsOp::getEffects(
|
||||
modifiesPayload(effects);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SimplifyMinMaxAffineOpsOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
DiagnosedSilenceableFailure
|
||||
SimplifyMinMaxAffineOpsOp::apply(transform::TransformRewriter &rewriter,
|
||||
TransformResults &results,
|
||||
TransformState &state) {
|
||||
SmallVector<Operation *> targets;
|
||||
for (Operation *target : state.getPayloadOps(getTarget())) {
|
||||
if (!isa<AffineMinOp, AffineMaxOp>(target)) {
|
||||
auto diag = emitDefiniteFailure()
|
||||
<< "target must be affine.min or affine.max";
|
||||
diag.attachNote(target->getLoc()) << "target op";
|
||||
return diag;
|
||||
}
|
||||
targets.push_back(target);
|
||||
}
|
||||
bool modified = false;
|
||||
if (failed(mlir::affine::simplifyAffineMinMaxOps(rewriter, targets,
|
||||
&modified))) {
|
||||
return emitDefiniteFailure()
|
||||
<< "affine.min/max simplification did not converge";
|
||||
}
|
||||
if (!modified) {
|
||||
return emitSilenceableError()
|
||||
<< "the transform failed to simplify any of the target operations";
|
||||
}
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
void SimplifyMinMaxAffineOpsOp::getEffects(
|
||||
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
||||
consumesHandle(getTargetMutable(), effects);
|
||||
modifiesPayload(effects);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Transform op registration
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -17,6 +17,7 @@ add_mlir_dialect_library(MLIRAffineTransforms
|
||||
ReifyValueBounds.cpp
|
||||
SuperVectorize.cpp
|
||||
SimplifyAffineStructures.cpp
|
||||
SimplifyAffineMinMax.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Affine
|
||||
|
||||
174
mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp
Normal file
174
mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp
Normal file
@@ -0,0 +1,174 @@
|
||||
//===- SimplifyAffineMinMax.cpp - Simplify affine min/max ops -------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements a transform to simplify mix/max affine operations.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Affine/Transforms/Transforms.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "llvm/ADT/IntEqClasses.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
#define DEBUG_TYPE "affine-min-max"
|
||||
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::affine;
|
||||
|
||||
/// Simplifies an affine min/max operation by proving there's a lower or upper
|
||||
/// bound.
|
||||
template <typename AffineOp>
|
||||
static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
|
||||
using Variable = ValueBoundsConstraintSet::Variable;
|
||||
using ComparisonOperator = ValueBoundsConstraintSet::ComparisonOperator;
|
||||
|
||||
AffineMap affineMap = affineOp.getMap();
|
||||
ValueRange operands = affineOp.getOperands();
|
||||
static constexpr bool isMin = std::is_same_v<AffineOp, AffineMinOp>;
|
||||
|
||||
LLVM_DEBUG({ DBGS() << "analyzing value: `" << affineOp << "`\n"; });
|
||||
|
||||
// Create a `Variable` list with values corresponding to each of the results
|
||||
// in the affine affineMap.
|
||||
SmallVector<Variable> variables = llvm::map_to_vector(
|
||||
llvm::iota_range<unsigned>(0u, affineMap.getNumResults(), false),
|
||||
[&](unsigned i) {
|
||||
return Variable(affineMap.getSliceMap(i, 1), operands);
|
||||
});
|
||||
|
||||
// Get the comparison operation.
|
||||
ComparisonOperator cmpOp =
|
||||
isMin ? ComparisonOperator::LT : ComparisonOperator::GT;
|
||||
|
||||
// Find disjoint sets bounded by a common value.
|
||||
llvm::IntEqClasses boundedClasses(variables.size());
|
||||
DenseMap<unsigned, Variable *> bounds;
|
||||
for (auto &&[i, v] : llvm::enumerate(variables)) {
|
||||
unsigned eqClass = boundedClasses.findLeader(i);
|
||||
|
||||
// If the class already has a bound continue.
|
||||
if (bounds.contains(eqClass))
|
||||
continue;
|
||||
|
||||
// Initialize the bound.
|
||||
Variable *bound = &v;
|
||||
|
||||
LLVM_DEBUG({
|
||||
DBGS() << "- inspecting variable: #" << i << ", with map: `" << v.getMap()
|
||||
<< "`\n";
|
||||
});
|
||||
|
||||
// Check against the other variables.
|
||||
for (size_t j = i + 1; j < variables.size(); ++j) {
|
||||
unsigned jEqClass = boundedClasses.findLeader(j);
|
||||
// Skip if the class is the same.
|
||||
if (jEqClass == eqClass)
|
||||
continue;
|
||||
|
||||
// Get the bound of the equivalence class or itself.
|
||||
Variable *nv = bounds.lookup_or(jEqClass, &variables[j]);
|
||||
|
||||
LLVM_DEBUG({
|
||||
DBGS() << "- comparing with variable: #" << jEqClass
|
||||
<< ", with map: " << nv->getMap() << "\n";
|
||||
});
|
||||
|
||||
// Compare the variables.
|
||||
FailureOr<bool> cmpResult =
|
||||
ValueBoundsConstraintSet::strongCompare(*bound, cmpOp, *nv);
|
||||
|
||||
// The variables cannot be compared.
|
||||
if (failed(cmpResult)) {
|
||||
LLVM_DEBUG({
|
||||
DBGS() << "-- classes: #" << i << ", #" << jEqClass
|
||||
<< " cannot be merged\n";
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
// Join the equivalent classes and update the bound if necessary.
|
||||
LLVM_DEBUG({
|
||||
DBGS() << "-- merging classes: #" << i << ", #" << jEqClass
|
||||
<< ", is cmp(lhs, rhs): " << *cmpResult << "`\n";
|
||||
});
|
||||
if (*cmpResult) {
|
||||
boundedClasses.join(eqClass, jEqClass);
|
||||
} else {
|
||||
// In this case we have lhs > rhs if isMin == true, or lhs < rhs if
|
||||
// isMin == false.
|
||||
bound = nv;
|
||||
boundedClasses.join(eqClass, jEqClass);
|
||||
}
|
||||
}
|
||||
bounds[boundedClasses.findLeader(i)] = bound;
|
||||
}
|
||||
|
||||
// Return if there's no simplification.
|
||||
if (bounds.size() >= affineMap.getNumResults()) {
|
||||
LLVM_DEBUG(
|
||||
{ DBGS() << "- the affine operation couldn't get simplified\n"; });
|
||||
return false;
|
||||
}
|
||||
|
||||
// Construct the new affine affineMap.
|
||||
SmallVector<AffineExpr> results;
|
||||
results.reserve(bounds.size());
|
||||
for (auto [k, bound] : bounds)
|
||||
results.push_back(bound->getMap().getResult(0));
|
||||
|
||||
affineMap = AffineMap::get(affineMap.getNumDims(), affineMap.getNumSymbols(),
|
||||
results, rewriter.getContext());
|
||||
|
||||
// Update the affine op.
|
||||
rewriter.modifyOpInPlace(affineOp, [&]() { affineOp.setMap(affineMap); });
|
||||
LLVM_DEBUG({ DBGS() << "- simplified affine op: `" << affineOp << "`\n"; });
|
||||
return true;
|
||||
}
|
||||
|
||||
bool mlir::affine::simplifyAffineMinOp(RewriterBase &rewriter, AffineMinOp op) {
|
||||
return simplifyAffineMinMaxOp(rewriter, op);
|
||||
}
|
||||
|
||||
bool mlir::affine::simplifyAffineMaxOp(RewriterBase &rewriter, AffineMaxOp op) {
|
||||
return simplifyAffineMinMaxOp(rewriter, op);
|
||||
}
|
||||
|
||||
LogicalResult mlir::affine::simplifyAffineMinMaxOps(RewriterBase &rewriter,
|
||||
ArrayRef<Operation *> ops,
|
||||
bool *modified) {
|
||||
bool changed = false;
|
||||
for (Operation *op : ops) {
|
||||
if (auto minOp = dyn_cast<AffineMinOp>(op))
|
||||
changed = simplifyAffineMinOp(rewriter, minOp) || changed;
|
||||
else if (auto maxOp = cast<AffineMaxOp>(op))
|
||||
changed = simplifyAffineMaxOp(rewriter, maxOp) || changed;
|
||||
}
|
||||
RewritePatternSet patterns(rewriter.getContext());
|
||||
AffineMaxOp::getCanonicalizationPatterns(patterns, rewriter.getContext());
|
||||
AffineMinOp::getCanonicalizationPatterns(patterns, rewriter.getContext());
|
||||
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
|
||||
if (modified)
|
||||
*modified = changed;
|
||||
// Canonicalize to a fixpoint.
|
||||
if (failed(applyOpPatternsGreedily(
|
||||
ops, frozenPatterns,
|
||||
GreedyRewriteConfig()
|
||||
.setListener(
|
||||
static_cast<RewriterBase::Listener *>(rewriter.getListener()))
|
||||
.setStrictness(GreedyRewriteStrictness::ExistingAndNewOps),
|
||||
&changed))) {
|
||||
return failure();
|
||||
}
|
||||
if (modified)
|
||||
*modified = changed;
|
||||
return success();
|
||||
}
|
||||
@@ -146,7 +146,7 @@ ValueBoundsConstraintSet::Variable::Variable(AffineMap map,
|
||||
}
|
||||
|
||||
ValueBoundsConstraintSet::Variable::Variable(AffineMap map,
|
||||
ArrayRef<Value> mapOperands)
|
||||
ValueRange mapOperands)
|
||||
: Variable(map, llvm::map_to_vector(mapOperands,
|
||||
[](Value v) { return Variable(v); })) {}
|
||||
|
||||
@@ -736,6 +736,44 @@ bool ValueBoundsConstraintSet::comparePos(int64_t lhsPos,
|
||||
return isEmpty;
|
||||
}
|
||||
|
||||
FailureOr<bool> ValueBoundsConstraintSet::strongComparePos(
|
||||
int64_t lhsPos, ComparisonOperator cmp, int64_t rhsPos) {
|
||||
auto strongCmp = [&](ComparisonOperator cmp,
|
||||
ComparisonOperator negCmp) -> FailureOr<bool> {
|
||||
if (comparePos(lhsPos, cmp, rhsPos))
|
||||
return true;
|
||||
if (comparePos(lhsPos, negCmp, rhsPos))
|
||||
return false;
|
||||
return failure();
|
||||
};
|
||||
switch (cmp) {
|
||||
case ComparisonOperator::LT:
|
||||
return strongCmp(ComparisonOperator::LT, ComparisonOperator::GE);
|
||||
case ComparisonOperator::LE:
|
||||
return strongCmp(ComparisonOperator::LE, ComparisonOperator::GT);
|
||||
case ComparisonOperator::GT:
|
||||
return strongCmp(ComparisonOperator::GT, ComparisonOperator::LE);
|
||||
case ComparisonOperator::GE:
|
||||
return strongCmp(ComparisonOperator::GE, ComparisonOperator::LT);
|
||||
case ComparisonOperator::EQ: {
|
||||
std::optional<bool> le =
|
||||
strongComparePos(lhsPos, ComparisonOperator::LE, rhsPos);
|
||||
if (!le)
|
||||
return failure();
|
||||
if (!*le)
|
||||
return false;
|
||||
std::optional<bool> ge =
|
||||
strongComparePos(lhsPos, ComparisonOperator::GE, rhsPos);
|
||||
if (!ge)
|
||||
return failure();
|
||||
if (!*ge)
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
llvm_unreachable("invalid comparison operator");
|
||||
}
|
||||
|
||||
bool ValueBoundsConstraintSet::populateAndCompare(const Variable &lhs,
|
||||
ComparisonOperator cmp,
|
||||
const Variable &rhs) {
|
||||
@@ -763,14 +801,29 @@ bool ValueBoundsConstraintSet::compare(const Variable &lhs,
|
||||
return cstr.comparePos(lhsPos, cmp, rhsPos);
|
||||
}
|
||||
|
||||
FailureOr<bool> ValueBoundsConstraintSet::strongCompare(const Variable &lhs,
|
||||
ComparisonOperator cmp,
|
||||
const Variable &rhs) {
|
||||
int64_t lhsPos = -1, rhsPos = -1;
|
||||
auto stopCondition = [&](Value v, std::optional<int64_t> dim,
|
||||
ValueBoundsConstraintSet &cstr) {
|
||||
// Keep processing as long as lhs/rhs were not processed.
|
||||
if (size_t(lhsPos) >= cstr.positionToValueDim.size() ||
|
||||
size_t(rhsPos) >= cstr.positionToValueDim.size())
|
||||
return false;
|
||||
// Keep processing as long as the strong relation cannot be proven.
|
||||
FailureOr<bool> ordered = cstr.strongComparePos(lhsPos, cmp, rhsPos);
|
||||
return failed(ordered) ? true : false;
|
||||
};
|
||||
ValueBoundsConstraintSet cstr(lhs.getContext(), stopCondition);
|
||||
lhsPos = cstr.populateConstraints(lhs.map, lhs.mapOperands);
|
||||
rhsPos = cstr.populateConstraints(rhs.map, rhs.mapOperands);
|
||||
return cstr.strongComparePos(lhsPos, cmp, rhsPos);
|
||||
}
|
||||
|
||||
FailureOr<bool> ValueBoundsConstraintSet::areEqual(const Variable &var1,
|
||||
const Variable &var2) {
|
||||
if (ValueBoundsConstraintSet::compare(var1, ComparisonOperator::EQ, var2))
|
||||
return true;
|
||||
if (ValueBoundsConstraintSet::compare(var1, ComparisonOperator::LT, var2) ||
|
||||
ValueBoundsConstraintSet::compare(var1, ComparisonOperator::GT, var2))
|
||||
return false;
|
||||
return failure();
|
||||
return strongCompare(var1, ComparisonOperator::EQ, var2);
|
||||
}
|
||||
|
||||
FailureOr<bool>
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
|
||||
|
||||
// CHECK-DAG: #[[MAP_0:.*]] = affine_map<()[s0] -> (32, s0)>
|
||||
// CHECK-DAG: #[[MAP_1:.*]] = affine_map<()[s0, s1] -> (s1, s0)>
|
||||
// CHECK-DAG: #[[MAP_2:.*]] = affine_map<()[s0] -> (256, s0)>
|
||||
|
||||
// CHECK: @min_max_full_simplify
|
||||
func.func @min_max_full_simplify() -> (index, index) {
|
||||
%0 = test.value_with_bounds {max = 128 : index, min = 0 : index}
|
||||
%1 = test.value_with_bounds {max = 512 : index, min = 256 : index}
|
||||
// CHECK: %[[V0:.*]] = test.value_with_bounds {max = 128 : index, min = 0 : index}
|
||||
// CHECK: %[[V1:.*]] = test.value_with_bounds {max = 512 : index, min = 256 : index}
|
||||
// CHECK-NOT: affine.min
|
||||
// CHECK-NOT: affine.max
|
||||
// CHECK: return %[[V0]], %[[V1]]
|
||||
%r0 = affine.min affine_map<()[s0, s1] -> (s0, 192, s1)>()[%0, %1]
|
||||
%r1 = affine.max affine_map<()[s0, s1] -> (s0, 192, s1)>()[%0, %1]
|
||||
return %r0, %r1 : index, index
|
||||
}
|
||||
|
||||
// CHECK: @min_only_simplify
|
||||
func.func @min_only_simplify() -> (index, index) {
|
||||
// CHECK: %[[V0:.*]] = test.value_with_bounds {max = 512 : index, min = 0 : index}
|
||||
// CHECK: %[[V1:.*]] = test.value_with_bounds {max = 512 : index, min = 256 : index}
|
||||
// CHECK: affine.min #[[MAP_0]]()[%[[V0]]]
|
||||
// CHECK: affine.max #[[MAP_1]]()[%[[V0]], %[[V1]]]
|
||||
%0 = test.value_with_bounds {max = 512 : index, min = 0 : index}
|
||||
%1 = test.value_with_bounds {max = 512 : index, min = 256 : index}
|
||||
%r0 = affine.min affine_map<()[s0, s1] -> (s0, 32, s1)>()[%0, %1]
|
||||
%r1 = affine.max affine_map<()[s0, s1] -> (s0, 32, s1)>()[%0, %1]
|
||||
return %r0, %r1 : index, index
|
||||
}
|
||||
|
||||
// CHECK: @max_only_simplify
|
||||
func.func @max_only_simplify() -> (index, index) {
|
||||
// CHECK: %[[V0:.*]] = test.value_with_bounds {max = 128 : index, min = 0 : index}
|
||||
// CHECK: %[[V1:.*]] = test.value_with_bounds {max = 512 : index, min = 0 : index}
|
||||
// CHECK: affine.min #[[MAP_1]]()[%[[V0]], %[[V1]]]
|
||||
// CHECK: affine.max #[[MAP_2]]()[%[[V1]]]
|
||||
%0 = test.value_with_bounds {max = 128 : index, min = 0 : index}
|
||||
%1 = test.value_with_bounds {max = 512 : index, min = 0 : index}
|
||||
%r0 = affine.min affine_map<()[s0, s1] -> (s0, 256, s1)>()[%0, %1]
|
||||
%r1 = affine.max affine_map<()[s0, s1] -> (s0, 256, s1)>()[%0, %1]
|
||||
return %r0, %r1 : index, index
|
||||
}
|
||||
|
||||
// CHECK: @overlapping_constraints
|
||||
func.func @overlapping_constraints() -> (index, index) {
|
||||
%0 = test.value_with_bounds {max = 192 : index, min = 0 : index}
|
||||
%1 = test.value_with_bounds {max = 384 : index, min = 128 : index}
|
||||
%2 = test.value_with_bounds {max = 512 : index, min = 256 : index}
|
||||
// CHECK: %[[V0:.*]] = test.value_with_bounds {max = 192 : index, min = 0 : index}
|
||||
// CHECK: %[[V1:.*]] = test.value_with_bounds {max = 384 : index, min = 128 : index}
|
||||
// CHECK: %[[V2:.*]] = test.value_with_bounds {max = 512 : index, min = 256 : index}
|
||||
// CHECK: affine.min #[[MAP_1]]()[%[[V0]], %[[V1]]]
|
||||
// CHECK: affine.max #[[MAP_1]]()[%[[V1]], %[[V2]]]
|
||||
%r0 = affine.min affine_map<()[s0, s1, s2] -> (s0, s1, s2)>()[%0, %1, %2]
|
||||
%r1 = affine.max affine_map<()[s0, s1, s2] -> (s0, s1, s2)>()[%0, %1, %2]
|
||||
return %r0, %r1 : index, index
|
||||
}
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
|
||||
%0 = transform.structured.match ops{["affine.min", "affine.max"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
||||
transform.affine.simplify_min_max_affine_ops %0 : !transform.any_op
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
@@ -454,3 +454,38 @@ module attributes {transform.with_named_sequence} {
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// This test checks that by using `simplify_min_max_affine_ops` after padding
|
||||
// and tiling, it's possible to recover static tiled slices.
|
||||
|
||||
// CHECK-LABEL: @dyn_pad_tiling
|
||||
// CHECK: %[[LHS:.*]] = tensor.pad
|
||||
// CHECK: %[[RHS:.*]] = tensor.pad
|
||||
// CHECK: scf.for
|
||||
// CHECK-DAG: tensor.extract_slice %[[LHS]][0, %{{.*}}] [%{{.*}}, 32]
|
||||
// CHECK-DAG: tensor.extract_slice %[[RHS]][0, %{{.*}}] [%{{.*}}, 32]
|
||||
func.func @dyn_pad_tiling(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
%0 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
return %0 : tensor<?x?xf32>
|
||||
}
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
|
||||
%0 = transform.structured.match ops{["linalg.matmul_transpose_b"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
||||
%padded, %pad, %copy = transform.structured.pad %0 pad_to_multiple_of [32] use_prescribed_tensor_shapes {padding_dimensions = [2], padding_values = [0.000000e+00 : f32, 0.000000e+00 : f32, 0.000000e+00 : f32]} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
|
||||
%tiled_linalg_op, %loops = transform.structured.tile_using_for %padded tile_sizes [0, 0, 32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
|
||||
%1 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
||||
%2 = transform.apply_registered_pass "resolve-shaped-type-result-dims" to %1 : (!transform.any_op) -> !transform.any_op
|
||||
transform.apply_patterns to %2 {
|
||||
transform.apply_patterns.canonicalization
|
||||
} {apply_cse} : !transform.any_op
|
||||
%3 = transform.structured.match ops{["affine.min", "affine.max"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
||||
transform.affine.simplify_min_max_affine_ops %3 : !transform.any_op
|
||||
transform.apply_patterns to %2 {
|
||||
transform.apply_patterns.canonicalization
|
||||
} {apply_cse} : !transform.any_op
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -836,6 +836,16 @@ void ConversionFuncOp::print(OpAsmPrinter &p) {
|
||||
getArgAttrsAttrName(), getResAttrsAttrName());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TestValueWithBoundsOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void TestValueWithBoundsOp::populateBoundsForIndexValue(
|
||||
Value v, ValueBoundsConstraintSet &cstr) {
|
||||
cstr.bound(v) >= getMin().getSExtValue();
|
||||
cstr.bound(v) <= getMax().getSExtValue();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ReifyBoundOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -31,6 +31,7 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
|
||||
include "mlir/Interfaces/LoopLikeInterface.td"
|
||||
include "mlir/Interfaces/MemorySlotInterfaces.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/Interfaces/ValueBoundsOpInterface.td"
|
||||
include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
|
||||
|
||||
// Include the attribute definitions.
|
||||
@@ -2375,6 +2376,24 @@ def ForwardBufferOp : TEST_Op<"forward_buffer", [Pure]> {
|
||||
// Test ValueBoundsOpInterface
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def TestValueWithBoundsOp : TEST_Op<"value_with_bounds", [
|
||||
DeclareOpInterfaceMethods<ValueBoundsOpInterface, ["populateBoundsForIndexValue"]>
|
||||
]> {
|
||||
let description = [{
|
||||
Creates a value with specified [min, max] range for value bounds analysis.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%0 = test.value_with_bounds { min = 4 : index, max = 5 : index}
|
||||
```
|
||||
}];
|
||||
let arguments = (ins IndexAttr:$min, IndexAttr:$max);
|
||||
let results = (outs Index:$result);
|
||||
let assemblyFormat = "attr-dict";
|
||||
}
|
||||
|
||||
|
||||
def ReifyBoundOp : TEST_Op<"reify_bound", [Pure]> {
|
||||
let description = [{
|
||||
Reify a bound for the given index-typed value or dimension size of a shaped
|
||||
|
||||
Reference in New Issue
Block a user