Files
clang-p2996/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
Fabian Mora c7165587e4 [mlir][affine|ValueBounds] Add transform to simplify affine min max ops with ValueBoundsOpInterface (#145068)
This commit makes the following changes:

- Expose `map` and `mapOperands` in
`ValueBoundsConstraintSet::Variable`, so that the class can be used by
subclasses of `ValueBoundsConstraintSet`. Otherwise subclasses cannot
access those members.

- Add `ValueBoundsConstraintSet::strongCompare`. This method is similar
to `ValueBoundsConstraintSet::compare` except that it returns false when
the inverse comparison holds, and `llvm::failure()` if neither the
relation nor its inverse relation could be proven.

- Add `simplifyAffineMinOp`, `simplifyAffineMaxOp`, and
`simplifyAffineMinMaxOps` to simplify those operations using
`ValueBoundsConstraintSet`.

- Adds the `SimplifyMinMaxAffineOpsOp` transform op that uses
`simplifyAffineMinMaxOps`.

- Add the `test.value_with_bounds` op to test unknown values with a min
max range using `ValueBoundsOpInterface`.

- Adds tests verifying the transform.

Example:

```mlir
func.func @overlapping_constraints() -> (index, index) {
  %0 = test.value_with_bounds {min = 0 : index, max = 192 : index}
  %1 = test.value_with_bounds {min = 128 : index, max = 384 : index}
  %2 = test.value_with_bounds {min = 256 : index, max = 512 : index}
  %r0 = affine.min affine_map<()[s0, s1, s2] -> (s0, s1, s2)>()[%0, %1, %2]
  %r1 = affine.max affine_map<()[s0, s1, s2] -> (s0, s1, s2)>()[%0, %1, %2]
  return %r0, %r1 : index, index
}
// Result of applying `simplifyAffineMinMaxOps` to `func.func`
#map1 = affine_map<()[s0, s1] -> (s1, s0)>
func.func @overlapping_constraints() -> (index, index) {
  %0 = test.value_with_bounds {max = 192 : index, min = 0 : index}
  %1 = test.value_with_bounds {max = 384 : index, min = 128 : index}
  %2 = test.value_with_bounds {max = 512 : index, min = 256 : index}
  %3 = affine.min #map1()[%0, %1]
  %4 = affine.max #map1()[%1, %2]
  return %3, %4 : index, index
}
```

---------

Co-authored-by: Nicolas Vasilache <Nico.Vasilache@amd.com>
2025-06-23 06:05:20 +02:00

219 lines
8.5 KiB
C++

//=== AffineTransformOps.cpp - Implementation of Affine transformation ops ===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
#include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
#include "mlir/Dialect/Affine/Analysis/Utils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
#include "mlir/Dialect/Affine/LoopUtils.h"
#include "mlir/Dialect/Affine/Transforms/Transforms.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
using namespace mlir::affine;
using namespace mlir::transform;
//===----------------------------------------------------------------------===//
// SimplifyBoundedAffineOpsOp
//===----------------------------------------------------------------------===//
LogicalResult SimplifyBoundedAffineOpsOp::verify() {
if (getLowerBounds().size() != getBoundedValues().size())
return emitOpError() << "incorrect number of lower bounds, expected "
<< getBoundedValues().size() << " but found "
<< getLowerBounds().size();
if (getUpperBounds().size() != getBoundedValues().size())
return emitOpError() << "incorrect number of upper bounds, expected "
<< getBoundedValues().size() << " but found "
<< getUpperBounds().size();
return success();
}
namespace {
/// Simplify affine.min / affine.max ops with the given constraints. They are
/// either rewritten to affine.apply or left unchanged.
template <typename OpTy>
struct SimplifyAffineMinMaxOp : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
SimplifyAffineMinMaxOp(MLIRContext *ctx,
const FlatAffineValueConstraints &constraints,
PatternBenefit benefit = 1)
: OpRewritePattern<OpTy>(ctx, benefit), constraints(constraints) {}
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
FailureOr<AffineValueMap> simplified =
simplifyConstrainedMinMaxOp(op, constraints);
if (failed(simplified))
return failure();
rewriter.replaceOpWithNewOp<AffineApplyOp>(op, simplified->getAffineMap(),
simplified->getOperands());
return success();
}
const FlatAffineValueConstraints &constraints;
};
} // namespace
DiagnosedSilenceableFailure
SimplifyBoundedAffineOpsOp::apply(transform::TransformRewriter &rewriter,
TransformResults &results,
TransformState &state) {
// Get constraints for bounded values.
SmallVector<int64_t> lbs;
SmallVector<int64_t> ubs;
SmallVector<Value> boundedValues;
DenseSet<Operation *> boundedOps;
for (const auto &it : llvm::zip_equal(getBoundedValues(), getLowerBounds(),
getUpperBounds())) {
Value handle = std::get<0>(it);
for (Operation *op : state.getPayloadOps(handle)) {
if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
auto diag =
emitDefiniteFailure()
<< "expected bounded value handle to point to one or multiple "
"single-result index-typed ops";
diag.attachNote(op->getLoc()) << "multiple/non-index result";
return diag;
}
boundedValues.push_back(op->getResult(0));
boundedOps.insert(op);
lbs.push_back(std::get<1>(it));
ubs.push_back(std::get<2>(it));
}
}
// Build constraint set.
FlatAffineValueConstraints cstr;
for (const auto &it : llvm::zip(boundedValues, lbs, ubs)) {
unsigned pos;
if (!cstr.findVar(std::get<0>(it), &pos))
pos = cstr.appendSymbolVar(std::get<0>(it));
cstr.addBound(presburger::BoundType::LB, pos, std::get<1>(it));
// Note: addBound bounds are inclusive, but specified UB is exclusive.
cstr.addBound(presburger::BoundType::UB, pos, std::get<2>(it) - 1);
}
// Transform all targets.
SmallVector<Operation *> targets;
for (Operation *target : state.getPayloadOps(getTarget())) {
if (!isa<AffineMinOp, AffineMaxOp>(target)) {
auto diag = emitDefiniteFailure()
<< "target must be affine.min or affine.max";
diag.attachNote(target->getLoc()) << "target op";
return diag;
}
if (boundedOps.contains(target)) {
auto diag = emitDefiniteFailure()
<< "target op result must not be constrained";
diag.attachNote(target->getLoc()) << "target/constrained op";
return diag;
}
targets.push_back(target);
}
RewritePatternSet patterns(getContext());
// Canonicalization patterns are needed so that affine.apply ops are composed
// with the remaining affine.min/max ops.
AffineMaxOp::getCanonicalizationPatterns(patterns, getContext());
AffineMinOp::getCanonicalizationPatterns(patterns, getContext());
patterns.insert<SimplifyAffineMinMaxOp<AffineMinOp>,
SimplifyAffineMinMaxOp<AffineMaxOp>>(getContext(), cstr);
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
// Apply the simplification pattern to a fixpoint.
if (failed(applyOpPatternsGreedily(
targets, frozenPatterns,
GreedyRewriteConfig()
.setListener(
static_cast<RewriterBase::Listener *>(rewriter.getListener()))
.setStrictness(GreedyRewriteStrictness::ExistingAndNewOps)))) {
auto diag = emitDefiniteFailure()
<< "affine.min/max simplification did not converge";
return diag;
}
return DiagnosedSilenceableFailure::success();
}
void SimplifyBoundedAffineOpsOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
consumesHandle(getTargetMutable(), effects);
for (OpOperand &operand : getBoundedValuesMutable())
onlyReadsHandle(operand, effects);
modifiesPayload(effects);
}
//===----------------------------------------------------------------------===//
// SimplifyMinMaxAffineOpsOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
SimplifyMinMaxAffineOpsOp::apply(transform::TransformRewriter &rewriter,
TransformResults &results,
TransformState &state) {
SmallVector<Operation *> targets;
for (Operation *target : state.getPayloadOps(getTarget())) {
if (!isa<AffineMinOp, AffineMaxOp>(target)) {
auto diag = emitDefiniteFailure()
<< "target must be affine.min or affine.max";
diag.attachNote(target->getLoc()) << "target op";
return diag;
}
targets.push_back(target);
}
bool modified = false;
if (failed(mlir::affine::simplifyAffineMinMaxOps(rewriter, targets,
&modified))) {
return emitDefiniteFailure()
<< "affine.min/max simplification did not converge";
}
if (!modified) {
return emitSilenceableError()
<< "the transform failed to simplify any of the target operations";
}
return DiagnosedSilenceableFailure::success();
}
void SimplifyMinMaxAffineOpsOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
consumesHandle(getTargetMutable(), effects);
modifiesPayload(effects);
}
//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
namespace {
class AffineTransformDialectExtension
: public transform::TransformDialectExtension<
AffineTransformDialectExtension> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AffineTransformDialectExtension)
using Base::Base;
void init() {
declareGeneratedDialect<AffineDialect>();
registerTransformOps<
#define GET_OP_LIST
#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.cpp.inc"
>();
}
};
} // namespace
#define GET_OP_CLASSES
#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.cpp.inc"
void mlir::affine::registerTransformDialectExtension(
DialectRegistry &registry) {
registry.addExtensions<AffineTransformDialectExtension>();
}