[mlir][affine|ValueBounds] Add transform to simplify affine min max ops with ValueBoundsOpInterface (#145068)

This commit makes the following changes:

- Expose `map` and `mapOperands` in
`ValueBoundsConstraintSet::Variable`, so that the class can be used by
subclasses of `ValueBoundsConstraintSet`. Otherwise subclasses cannot
access those members.

- Add `ValueBoundsConstraintSet::strongCompare`. This method is similar
to `ValueBoundsConstraintSet::compare` except that it returns false when
the inverse comparison holds, and `llvm::failure()` if neither the
relation nor its inverse relation could be proven.

- Add `simplifyAffineMinOp`, `simplifyAffineMaxOp`, and
`simplifyAffineMinMaxOps` to simplify those operations using
`ValueBoundsConstraintSet`.

- Adds the `SimplifyMinMaxAffineOpsOp` transform op that uses
`simplifyAffineMinMaxOps`.

- Add the `test.value_with_bounds` op to test unknown values with a min
max range using `ValueBoundsOpInterface`.

- Adds tests verifying the transform.

Example:

```mlir
func.func @overlapping_constraints() -> (index, index) {
  %0 = test.value_with_bounds {min = 0 : index, max = 192 : index}
  %1 = test.value_with_bounds {min = 128 : index, max = 384 : index}
  %2 = test.value_with_bounds {min = 256 : index, max = 512 : index}
  %r0 = affine.min affine_map<()[s0, s1, s2] -> (s0, s1, s2)>()[%0, %1, %2]
  %r1 = affine.max affine_map<()[s0, s1, s2] -> (s0, s1, s2)>()[%0, %1, %2]
  return %r0, %r1 : index, index
}
// Result of applying `simplifyAffineMinMaxOps` to `func.func`
#map1 = affine_map<()[s0, s1] -> (s1, s0)>
func.func @overlapping_constraints() -> (index, index) {
  %0 = test.value_with_bounds {max = 192 : index, min = 0 : index}
  %1 = test.value_with_bounds {max = 384 : index, min = 128 : index}
  %2 = test.value_with_bounds {max = 512 : index, min = 256 : index}
  %3 = affine.min #map1()[%0, %1]
  %4 = affine.max #map1()[%1, %2]
  return %3, %4 : index, index
}
```

---------

Co-authored-by: Nicolas Vasilache <Nico.Vasilache@amd.com>
This commit is contained in:
Fabian Mora
2025-06-23 00:05:20 -04:00
committed by GitHub
parent 89c61449e6
commit c7165587e4
11 changed files with 493 additions and 9 deletions

View File

@@ -63,4 +63,35 @@ def SimplifyBoundedAffineOpsOp
}];
}
def SimplifyMinMaxAffineOpsOp :
Op<Transform_Dialect, "affine.simplify_min_max_affine_ops", [
DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
]> {
let description = [{
Simplify the targeted `affine.min` / `affine.max` ops using the
`mlir::affine::simplifyAffineMinMaxOps` transform.
Example:
```
%0 = transform.structured.match ops{["affine.max"]} in %arg1
transform.affine.simplify_min_max_affine_ops %0 : !transform.any_op
```
#### Return modes
This transform consumes the target handle and does not produce any results.
This transforms definitely fails if any of the targeted operations is not an
`affine.min` or `affine.max` operation, or if the canonicalization patterns
failed to converge.
This transform silently fails if none of the operations were simplified.
Otherwise, it succeeds.
}];
let arguments = (ins TransformHandleTypeInterface:$target);
let results = (outs);
let assemblyFormat = [{
$target attr-dict `:` type($target)
}];
}
#endif // Affine_TRANSFORM_OPS

View File

@@ -34,6 +34,8 @@ namespace affine {
class AffineApplyOp;
class AffineDelinearizeIndexOp;
class AffineLinearizeIndexOp;
class AffineMaxOp;
class AffineMinOp;
/// Lowers `affine.delinearize_index` into a sequence of division and remainder
/// operations.
@@ -127,6 +129,37 @@ OpFoldResult materializeComputedBound(
OpBuilder &b, Location loc, AffineMap boundMap,
ArrayRef<std::pair<Value, std::optional<int64_t>>> mapOperands);
/// This transform tries to simplify the affine min operation `op`, by finding a
/// common lower bound for a set of expressions in the affine map results. It
/// returns whether the transform updated `op`'s affine map.
///
/// In concrete terms, given an operation like:
/// `affine.min affine_map<(d0)[s0, s1] -> (d0, s1, s0, 128)>(%i)[%s0, %s1]`
/// If `d0 < 128` and `128 < s1 < s0`, the transform will update `op` to:
/// `affine.min affine_map<(d0)[s0, s1] -> (d0, 128)>(%i)[%s0, %s1]`.
bool simplifyAffineMinOp(RewriterBase &rewriter, AffineMinOp op);
/// This transform tries to simplify the affine max operation `op`, by finding a
/// common upper bound for a set of expressions in the affine map results. It
/// returns whether the transform updated `op`'s affine map.
///
/// In concrete terms, given an operation like:
/// `affine.max affine_map<(d0)[s0, s1] -> (d0, s1, s0, 128)>(%i)[%s0, %s1]`
/// If `d0 > 128` and `s0 > s1 > 128`, the transform will update `op` to:
/// `affine.max affine_map<(d0)[s0, s1] -> (d0, s0)>(%i)[%s0, %s1]`.
bool simplifyAffineMaxOp(RewriterBase &rewriter, AffineMaxOp op);
/// This transform applies `simplifyAffineMinOp` and `simplifyAffineMaxOp` to
/// all the `affine.min` or `affine.max` operations in `ops`. After
/// simplification, it invokes the `affine.min/max` canonicalization patterns on
/// `ops`.
///
/// This transform returns failure if the greedy pattern rewriter failed to
/// converge during canonicalization, otherwise it returns success. If provided,
/// `modified` is set to `true` if the IR was modified in any way.
LogicalResult simplifyAffineMinMaxOps(RewriterBase &rewriter,
ArrayRef<Operation *> ops,
bool *modified = nullptr);
} // namespace affine
} // namespace mlir

View File

@@ -135,10 +135,17 @@ public:
/// Construct a variable for a map and its operands.
Variable(AffineMap map, ArrayRef<Variable> mapOperands);
Variable(AffineMap map, ArrayRef<Value> mapOperands);
Variable(AffineMap map, ValueRange mapOperands);
MLIRContext *getContext() const { return map.getContext(); }
/// Returns the affine map.
AffineMap getMap() const { return map; }
/// Returns the map operands.
ValueDimList &getOperands() { return mapOperands; }
const ValueDimList &getOperands() const { return mapOperands; }
private:
friend class ValueBoundsConstraintSet;
AffineMap map;
@@ -254,6 +261,12 @@ public:
/// prove the relation or until it ran out of IR.
static bool compare(const Variable &lhs, ComparisonOperator cmp,
const Variable &rhs);
/// This function is similar to `ValueBoundsConstraintSet::compare`, except
/// that it returns false if `!(lhs cmp rhs)`, and `failure` if neither the
/// relation nor its inverse relation could be proven.
static llvm::FailureOr<bool> strongCompare(const Variable &lhs,
ComparisonOperator cmp,
const Variable &rhs);
/// Compute whether the given variables are equal. Return "failure" if
/// equality could not be determined.
@@ -327,6 +340,16 @@ protected:
/// constraints.
bool comparePos(int64_t lhsPos, ComparisonOperator cmp, int64_t rhsPos);
/// Return "true" if, based on the current state of the constraint system,
/// "lhs cmp rhs" was proven to hold. It returns "false" if "!(lhs cmp rhs)"
/// can be proven. Otherwise, it returns `failure` if neither the relation nor
/// its inverse relation could be proven.
///
/// This function does not analyze any IR and does not populate any additional
/// constraints.
llvm::FailureOr<bool> strongComparePos(int64_t lhsPos, ComparisonOperator cmp,
int64_t rhsPos);
/// Given an affine map with a single result (and map operands), add a new
/// column to the constraint set that represents the result of the map.
/// Traverse additional IR starting from the map operands as needed (as long

View File

@@ -12,6 +12,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
#include "mlir/Dialect/Affine/LoopUtils.h"
#include "mlir/Dialect/Affine/Transforms/Transforms.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -112,7 +113,7 @@ SimplifyBoundedAffineOpsOp::apply(transform::TransformRewriter &rewriter,
}
if (boundedOps.contains(target)) {
auto diag = emitDefiniteFailure()
<< "target op result must not be constrainted";
<< "target op result must not be constrained";
diag.attachNote(target->getLoc()) << "target/constrained op";
return diag;
}
@@ -148,6 +149,42 @@ void SimplifyBoundedAffineOpsOp::getEffects(
modifiesPayload(effects);
}
//===----------------------------------------------------------------------===//
// SimplifyMinMaxAffineOpsOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
SimplifyMinMaxAffineOpsOp::apply(transform::TransformRewriter &rewriter,
TransformResults &results,
TransformState &state) {
SmallVector<Operation *> targets;
for (Operation *target : state.getPayloadOps(getTarget())) {
if (!isa<AffineMinOp, AffineMaxOp>(target)) {
auto diag = emitDefiniteFailure()
<< "target must be affine.min or affine.max";
diag.attachNote(target->getLoc()) << "target op";
return diag;
}
targets.push_back(target);
}
bool modified = false;
if (failed(mlir::affine::simplifyAffineMinMaxOps(rewriter, targets,
&modified))) {
return emitDefiniteFailure()
<< "affine.min/max simplification did not converge";
}
if (!modified) {
return emitSilenceableError()
<< "the transform failed to simplify any of the target operations";
}
return DiagnosedSilenceableFailure::success();
}
void SimplifyMinMaxAffineOpsOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
consumesHandle(getTargetMutable(), effects);
modifiesPayload(effects);
}
//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//

View File

@@ -17,6 +17,7 @@ add_mlir_dialect_library(MLIRAffineTransforms
ReifyValueBounds.cpp
SuperVectorize.cpp
SimplifyAffineStructures.cpp
SimplifyAffineMinMax.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Affine

View File

@@ -0,0 +1,174 @@
//===- SimplifyAffineMinMax.cpp - Simplify affine min/max ops -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a transform to simplify mix/max affine operations.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/IntEqClasses.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "affine-min-max"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
using namespace mlir;
using namespace mlir::affine;
/// Simplifies an affine min/max operation by proving there's a lower or upper
/// bound.
template <typename AffineOp>
static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
using Variable = ValueBoundsConstraintSet::Variable;
using ComparisonOperator = ValueBoundsConstraintSet::ComparisonOperator;
AffineMap affineMap = affineOp.getMap();
ValueRange operands = affineOp.getOperands();
static constexpr bool isMin = std::is_same_v<AffineOp, AffineMinOp>;
LLVM_DEBUG({ DBGS() << "analyzing value: `" << affineOp << "`\n"; });
// Create a `Variable` list with values corresponding to each of the results
// in the affine affineMap.
SmallVector<Variable> variables = llvm::map_to_vector(
llvm::iota_range<unsigned>(0u, affineMap.getNumResults(), false),
[&](unsigned i) {
return Variable(affineMap.getSliceMap(i, 1), operands);
});
// Get the comparison operation.
ComparisonOperator cmpOp =
isMin ? ComparisonOperator::LT : ComparisonOperator::GT;
// Find disjoint sets bounded by a common value.
llvm::IntEqClasses boundedClasses(variables.size());
DenseMap<unsigned, Variable *> bounds;
for (auto &&[i, v] : llvm::enumerate(variables)) {
unsigned eqClass = boundedClasses.findLeader(i);
// If the class already has a bound continue.
if (bounds.contains(eqClass))
continue;
// Initialize the bound.
Variable *bound = &v;
LLVM_DEBUG({
DBGS() << "- inspecting variable: #" << i << ", with map: `" << v.getMap()
<< "`\n";
});
// Check against the other variables.
for (size_t j = i + 1; j < variables.size(); ++j) {
unsigned jEqClass = boundedClasses.findLeader(j);
// Skip if the class is the same.
if (jEqClass == eqClass)
continue;
// Get the bound of the equivalence class or itself.
Variable *nv = bounds.lookup_or(jEqClass, &variables[j]);
LLVM_DEBUG({
DBGS() << "- comparing with variable: #" << jEqClass
<< ", with map: " << nv->getMap() << "\n";
});
// Compare the variables.
FailureOr<bool> cmpResult =
ValueBoundsConstraintSet::strongCompare(*bound, cmpOp, *nv);
// The variables cannot be compared.
if (failed(cmpResult)) {
LLVM_DEBUG({
DBGS() << "-- classes: #" << i << ", #" << jEqClass
<< " cannot be merged\n";
});
continue;
}
// Join the equivalent classes and update the bound if necessary.
LLVM_DEBUG({
DBGS() << "-- merging classes: #" << i << ", #" << jEqClass
<< ", is cmp(lhs, rhs): " << *cmpResult << "`\n";
});
if (*cmpResult) {
boundedClasses.join(eqClass, jEqClass);
} else {
// In this case we have lhs > rhs if isMin == true, or lhs < rhs if
// isMin == false.
bound = nv;
boundedClasses.join(eqClass, jEqClass);
}
}
bounds[boundedClasses.findLeader(i)] = bound;
}
// Return if there's no simplification.
if (bounds.size() >= affineMap.getNumResults()) {
LLVM_DEBUG(
{ DBGS() << "- the affine operation couldn't get simplified\n"; });
return false;
}
// Construct the new affine affineMap.
SmallVector<AffineExpr> results;
results.reserve(bounds.size());
for (auto [k, bound] : bounds)
results.push_back(bound->getMap().getResult(0));
affineMap = AffineMap::get(affineMap.getNumDims(), affineMap.getNumSymbols(),
results, rewriter.getContext());
// Update the affine op.
rewriter.modifyOpInPlace(affineOp, [&]() { affineOp.setMap(affineMap); });
LLVM_DEBUG({ DBGS() << "- simplified affine op: `" << affineOp << "`\n"; });
return true;
}
bool mlir::affine::simplifyAffineMinOp(RewriterBase &rewriter, AffineMinOp op) {
return simplifyAffineMinMaxOp(rewriter, op);
}
bool mlir::affine::simplifyAffineMaxOp(RewriterBase &rewriter, AffineMaxOp op) {
return simplifyAffineMinMaxOp(rewriter, op);
}
LogicalResult mlir::affine::simplifyAffineMinMaxOps(RewriterBase &rewriter,
ArrayRef<Operation *> ops,
bool *modified) {
bool changed = false;
for (Operation *op : ops) {
if (auto minOp = dyn_cast<AffineMinOp>(op))
changed = simplifyAffineMinOp(rewriter, minOp) || changed;
else if (auto maxOp = cast<AffineMaxOp>(op))
changed = simplifyAffineMaxOp(rewriter, maxOp) || changed;
}
RewritePatternSet patterns(rewriter.getContext());
AffineMaxOp::getCanonicalizationPatterns(patterns, rewriter.getContext());
AffineMinOp::getCanonicalizationPatterns(patterns, rewriter.getContext());
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
if (modified)
*modified = changed;
// Canonicalize to a fixpoint.
if (failed(applyOpPatternsGreedily(
ops, frozenPatterns,
GreedyRewriteConfig()
.setListener(
static_cast<RewriterBase::Listener *>(rewriter.getListener()))
.setStrictness(GreedyRewriteStrictness::ExistingAndNewOps),
&changed))) {
return failure();
}
if (modified)
*modified = changed;
return success();
}

View File

@@ -146,7 +146,7 @@ ValueBoundsConstraintSet::Variable::Variable(AffineMap map,
}
ValueBoundsConstraintSet::Variable::Variable(AffineMap map,
ArrayRef<Value> mapOperands)
ValueRange mapOperands)
: Variable(map, llvm::map_to_vector(mapOperands,
[](Value v) { return Variable(v); })) {}
@@ -736,6 +736,44 @@ bool ValueBoundsConstraintSet::comparePos(int64_t lhsPos,
return isEmpty;
}
FailureOr<bool> ValueBoundsConstraintSet::strongComparePos(
int64_t lhsPos, ComparisonOperator cmp, int64_t rhsPos) {
auto strongCmp = [&](ComparisonOperator cmp,
ComparisonOperator negCmp) -> FailureOr<bool> {
if (comparePos(lhsPos, cmp, rhsPos))
return true;
if (comparePos(lhsPos, negCmp, rhsPos))
return false;
return failure();
};
switch (cmp) {
case ComparisonOperator::LT:
return strongCmp(ComparisonOperator::LT, ComparisonOperator::GE);
case ComparisonOperator::LE:
return strongCmp(ComparisonOperator::LE, ComparisonOperator::GT);
case ComparisonOperator::GT:
return strongCmp(ComparisonOperator::GT, ComparisonOperator::LE);
case ComparisonOperator::GE:
return strongCmp(ComparisonOperator::GE, ComparisonOperator::LT);
case ComparisonOperator::EQ: {
std::optional<bool> le =
strongComparePos(lhsPos, ComparisonOperator::LE, rhsPos);
if (!le)
return failure();
if (!*le)
return false;
std::optional<bool> ge =
strongComparePos(lhsPos, ComparisonOperator::GE, rhsPos);
if (!ge)
return failure();
if (!*ge)
return false;
return true;
}
}
llvm_unreachable("invalid comparison operator");
}
bool ValueBoundsConstraintSet::populateAndCompare(const Variable &lhs,
ComparisonOperator cmp,
const Variable &rhs) {
@@ -763,14 +801,29 @@ bool ValueBoundsConstraintSet::compare(const Variable &lhs,
return cstr.comparePos(lhsPos, cmp, rhsPos);
}
FailureOr<bool> ValueBoundsConstraintSet::strongCompare(const Variable &lhs,
ComparisonOperator cmp,
const Variable &rhs) {
int64_t lhsPos = -1, rhsPos = -1;
auto stopCondition = [&](Value v, std::optional<int64_t> dim,
ValueBoundsConstraintSet &cstr) {
// Keep processing as long as lhs/rhs were not processed.
if (size_t(lhsPos) >= cstr.positionToValueDim.size() ||
size_t(rhsPos) >= cstr.positionToValueDim.size())
return false;
// Keep processing as long as the strong relation cannot be proven.
FailureOr<bool> ordered = cstr.strongComparePos(lhsPos, cmp, rhsPos);
return failed(ordered) ? true : false;
};
ValueBoundsConstraintSet cstr(lhs.getContext(), stopCondition);
lhsPos = cstr.populateConstraints(lhs.map, lhs.mapOperands);
rhsPos = cstr.populateConstraints(rhs.map, rhs.mapOperands);
return cstr.strongComparePos(lhsPos, cmp, rhsPos);
}
FailureOr<bool> ValueBoundsConstraintSet::areEqual(const Variable &var1,
const Variable &var2) {
if (ValueBoundsConstraintSet::compare(var1, ComparisonOperator::EQ, var2))
return true;
if (ValueBoundsConstraintSet::compare(var1, ComparisonOperator::LT, var2) ||
ValueBoundsConstraintSet::compare(var1, ComparisonOperator::GT, var2))
return false;
return failure();
return strongCompare(var1, ComparisonOperator::EQ, var2);
}
FailureOr<bool>

View File

@@ -0,0 +1,68 @@
// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
// CHECK-DAG: #[[MAP_0:.*]] = affine_map<()[s0] -> (32, s0)>
// CHECK-DAG: #[[MAP_1:.*]] = affine_map<()[s0, s1] -> (s1, s0)>
// CHECK-DAG: #[[MAP_2:.*]] = affine_map<()[s0] -> (256, s0)>
// CHECK: @min_max_full_simplify
func.func @min_max_full_simplify() -> (index, index) {
%0 = test.value_with_bounds {max = 128 : index, min = 0 : index}
%1 = test.value_with_bounds {max = 512 : index, min = 256 : index}
// CHECK: %[[V0:.*]] = test.value_with_bounds {max = 128 : index, min = 0 : index}
// CHECK: %[[V1:.*]] = test.value_with_bounds {max = 512 : index, min = 256 : index}
// CHECK-NOT: affine.min
// CHECK-NOT: affine.max
// CHECK: return %[[V0]], %[[V1]]
%r0 = affine.min affine_map<()[s0, s1] -> (s0, 192, s1)>()[%0, %1]
%r1 = affine.max affine_map<()[s0, s1] -> (s0, 192, s1)>()[%0, %1]
return %r0, %r1 : index, index
}
// CHECK: @min_only_simplify
func.func @min_only_simplify() -> (index, index) {
// CHECK: %[[V0:.*]] = test.value_with_bounds {max = 512 : index, min = 0 : index}
// CHECK: %[[V1:.*]] = test.value_with_bounds {max = 512 : index, min = 256 : index}
// CHECK: affine.min #[[MAP_0]]()[%[[V0]]]
// CHECK: affine.max #[[MAP_1]]()[%[[V0]], %[[V1]]]
%0 = test.value_with_bounds {max = 512 : index, min = 0 : index}
%1 = test.value_with_bounds {max = 512 : index, min = 256 : index}
%r0 = affine.min affine_map<()[s0, s1] -> (s0, 32, s1)>()[%0, %1]
%r1 = affine.max affine_map<()[s0, s1] -> (s0, 32, s1)>()[%0, %1]
return %r0, %r1 : index, index
}
// CHECK: @max_only_simplify
func.func @max_only_simplify() -> (index, index) {
// CHECK: %[[V0:.*]] = test.value_with_bounds {max = 128 : index, min = 0 : index}
// CHECK: %[[V1:.*]] = test.value_with_bounds {max = 512 : index, min = 0 : index}
// CHECK: affine.min #[[MAP_1]]()[%[[V0]], %[[V1]]]
// CHECK: affine.max #[[MAP_2]]()[%[[V1]]]
%0 = test.value_with_bounds {max = 128 : index, min = 0 : index}
%1 = test.value_with_bounds {max = 512 : index, min = 0 : index}
%r0 = affine.min affine_map<()[s0, s1] -> (s0, 256, s1)>()[%0, %1]
%r1 = affine.max affine_map<()[s0, s1] -> (s0, 256, s1)>()[%0, %1]
return %r0, %r1 : index, index
}
// CHECK: @overlapping_constraints
func.func @overlapping_constraints() -> (index, index) {
%0 = test.value_with_bounds {max = 192 : index, min = 0 : index}
%1 = test.value_with_bounds {max = 384 : index, min = 128 : index}
%2 = test.value_with_bounds {max = 512 : index, min = 256 : index}
// CHECK: %[[V0:.*]] = test.value_with_bounds {max = 192 : index, min = 0 : index}
// CHECK: %[[V1:.*]] = test.value_with_bounds {max = 384 : index, min = 128 : index}
// CHECK: %[[V2:.*]] = test.value_with_bounds {max = 512 : index, min = 256 : index}
// CHECK: affine.min #[[MAP_1]]()[%[[V0]], %[[V1]]]
// CHECK: affine.max #[[MAP_1]]()[%[[V1]], %[[V2]]]
%r0 = affine.min affine_map<()[s0, s1, s2] -> (s0, s1, s2)>()[%0, %1, %2]
%r1 = affine.max affine_map<()[s0, s1, s2] -> (s0, s1, s2)>()[%0, %1, %2]
return %r0, %r1 : index, index
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["affine.min", "affine.max"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.affine.simplify_min_max_affine_ops %0 : !transform.any_op
transform.yield
}
}

View File

@@ -454,3 +454,38 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
// -----
// This test checks that by using `simplify_min_max_affine_ops` after padding
// and tiling, it's possible to recover static tiled slices.
// CHECK-LABEL: @dyn_pad_tiling
// CHECK: %[[LHS:.*]] = tensor.pad
// CHECK: %[[RHS:.*]] = tensor.pad
// CHECK: scf.for
// CHECK-DAG: tensor.extract_slice %[[LHS]][0, %{{.*}}] [%{{.*}}, 32]
// CHECK-DAG: tensor.extract_slice %[[RHS]][0, %{{.*}}] [%{{.*}}, 32]
func.func @dyn_pad_tiling(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.matmul_transpose_b"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%padded, %pad, %copy = transform.structured.pad %0 pad_to_multiple_of [32] use_prescribed_tensor_shapes {padding_dimensions = [2], padding_values = [0.000000e+00 : f32, 0.000000e+00 : f32, 0.000000e+00 : f32]} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
%tiled_linalg_op, %loops = transform.structured.tile_using_for %padded tile_sizes [0, 0, 32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%1 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%2 = transform.apply_registered_pass "resolve-shaped-type-result-dims" to %1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %2 {
transform.apply_patterns.canonicalization
} {apply_cse} : !transform.any_op
%3 = transform.structured.match ops{["affine.min", "affine.max"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.affine.simplify_min_max_affine_ops %3 : !transform.any_op
transform.apply_patterns to %2 {
transform.apply_patterns.canonicalization
} {apply_cse} : !transform.any_op
transform.yield
}
}

View File

@@ -836,6 +836,16 @@ void ConversionFuncOp::print(OpAsmPrinter &p) {
getArgAttrsAttrName(), getResAttrsAttrName());
}
//===----------------------------------------------------------------------===//
// TestValueWithBoundsOp
//===----------------------------------------------------------------------===//
void TestValueWithBoundsOp::populateBoundsForIndexValue(
Value v, ValueBoundsConstraintSet &cstr) {
cstr.bound(v) >= getMin().getSExtValue();
cstr.bound(v) <= getMax().getSExtValue();
}
//===----------------------------------------------------------------------===//
// ReifyBoundOp
//===----------------------------------------------------------------------===//

View File

@@ -31,6 +31,7 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/Interfaces/MemorySlotInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ValueBoundsOpInterface.td"
include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
// Include the attribute definitions.
@@ -2375,6 +2376,24 @@ def ForwardBufferOp : TEST_Op<"forward_buffer", [Pure]> {
// Test ValueBoundsOpInterface
//===----------------------------------------------------------------------===//
def TestValueWithBoundsOp : TEST_Op<"value_with_bounds", [
DeclareOpInterfaceMethods<ValueBoundsOpInterface, ["populateBoundsForIndexValue"]>
]> {
let description = [{
Creates a value with specified [min, max] range for value bounds analysis.
Example:
```mlir
%0 = test.value_with_bounds { min = 4 : index, max = 5 : index}
```
}];
let arguments = (ins IndexAttr:$min, IndexAttr:$max);
let results = (outs Index:$result);
let assemblyFormat = "attr-dict";
}
def ReifyBoundOp : TEST_Op<"reify_bound", [Pure]> {
let description = [{
Reify a bound for the given index-typed value or dimension size of a shaped