This commit makes the following changes:
- Expose `map` and `mapOperands` in
`ValueBoundsConstraintSet::Variable`, so that the class can be used by
subclasses of `ValueBoundsConstraintSet`. Otherwise subclasses cannot
access those members.
- Add `ValueBoundsConstraintSet::strongCompare`. This method is similar
to `ValueBoundsConstraintSet::compare` except that it returns false when
the inverse comparison holds, and `llvm::failure()` if neither the
relation nor its inverse relation could be proven.
- Add `simplifyAffineMinOp`, `simplifyAffineMaxOp`, and
`simplifyAffineMinMaxOps` to simplify those operations using
`ValueBoundsConstraintSet`.
- Adds the `SimplifyMinMaxAffineOpsOp` transform op that uses
`simplifyAffineMinMaxOps`.
- Add the `test.value_with_bounds` op to test unknown values with a min
max range using `ValueBoundsOpInterface`.
- Adds tests verifying the transform.
Example:
```mlir
func.func @overlapping_constraints() -> (index, index) {
%0 = test.value_with_bounds {min = 0 : index, max = 192 : index}
%1 = test.value_with_bounds {min = 128 : index, max = 384 : index}
%2 = test.value_with_bounds {min = 256 : index, max = 512 : index}
%r0 = affine.min affine_map<()[s0, s1, s2] -> (s0, s1, s2)>()[%0, %1, %2]
%r1 = affine.max affine_map<()[s0, s1, s2] -> (s0, s1, s2)>()[%0, %1, %2]
return %r0, %r1 : index, index
}
// Result of applying `simplifyAffineMinMaxOps` to `func.func`
#map1 = affine_map<()[s0, s1] -> (s1, s0)>
func.func @overlapping_constraints() -> (index, index) {
%0 = test.value_with_bounds {max = 192 : index, min = 0 : index}
%1 = test.value_with_bounds {max = 384 : index, min = 128 : index}
%2 = test.value_with_bounds {max = 512 : index, min = 256 : index}
%3 = affine.min #map1()[%0, %1]
%4 = affine.max #map1()[%1, %2]
return %3, %4 : index, index
}
```
---------
Co-authored-by: Nicolas Vasilache <Nico.Vasilache@amd.com>
219 lines
8.5 KiB
C++
219 lines
8.5 KiB
C++
//=== AffineTransformOps.cpp - Implementation of Affine transformation ops ===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
|
|
#include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
|
|
#include "mlir/Dialect/Affine/Analysis/Utils.h"
|
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
|
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
|
|
#include "mlir/Dialect/Affine/LoopUtils.h"
|
|
#include "mlir/Dialect/Affine/Transforms/Transforms.h"
|
|
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
|
|
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::affine;
|
|
using namespace mlir::transform;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// SimplifyBoundedAffineOpsOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult SimplifyBoundedAffineOpsOp::verify() {
|
|
if (getLowerBounds().size() != getBoundedValues().size())
|
|
return emitOpError() << "incorrect number of lower bounds, expected "
|
|
<< getBoundedValues().size() << " but found "
|
|
<< getLowerBounds().size();
|
|
if (getUpperBounds().size() != getBoundedValues().size())
|
|
return emitOpError() << "incorrect number of upper bounds, expected "
|
|
<< getBoundedValues().size() << " but found "
|
|
<< getUpperBounds().size();
|
|
return success();
|
|
}
|
|
|
|
namespace {
|
|
/// Simplify affine.min / affine.max ops with the given constraints. They are
|
|
/// either rewritten to affine.apply or left unchanged.
|
|
template <typename OpTy>
|
|
struct SimplifyAffineMinMaxOp : public OpRewritePattern<OpTy> {
|
|
using OpRewritePattern<OpTy>::OpRewritePattern;
|
|
SimplifyAffineMinMaxOp(MLIRContext *ctx,
|
|
const FlatAffineValueConstraints &constraints,
|
|
PatternBenefit benefit = 1)
|
|
: OpRewritePattern<OpTy>(ctx, benefit), constraints(constraints) {}
|
|
|
|
LogicalResult matchAndRewrite(OpTy op,
|
|
PatternRewriter &rewriter) const override {
|
|
FailureOr<AffineValueMap> simplified =
|
|
simplifyConstrainedMinMaxOp(op, constraints);
|
|
if (failed(simplified))
|
|
return failure();
|
|
rewriter.replaceOpWithNewOp<AffineApplyOp>(op, simplified->getAffineMap(),
|
|
simplified->getOperands());
|
|
return success();
|
|
}
|
|
|
|
const FlatAffineValueConstraints &constraints;
|
|
};
|
|
} // namespace
|
|
|
|
DiagnosedSilenceableFailure
|
|
SimplifyBoundedAffineOpsOp::apply(transform::TransformRewriter &rewriter,
|
|
TransformResults &results,
|
|
TransformState &state) {
|
|
// Get constraints for bounded values.
|
|
SmallVector<int64_t> lbs;
|
|
SmallVector<int64_t> ubs;
|
|
SmallVector<Value> boundedValues;
|
|
DenseSet<Operation *> boundedOps;
|
|
for (const auto &it : llvm::zip_equal(getBoundedValues(), getLowerBounds(),
|
|
getUpperBounds())) {
|
|
Value handle = std::get<0>(it);
|
|
for (Operation *op : state.getPayloadOps(handle)) {
|
|
if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
|
|
auto diag =
|
|
emitDefiniteFailure()
|
|
<< "expected bounded value handle to point to one or multiple "
|
|
"single-result index-typed ops";
|
|
diag.attachNote(op->getLoc()) << "multiple/non-index result";
|
|
return diag;
|
|
}
|
|
boundedValues.push_back(op->getResult(0));
|
|
boundedOps.insert(op);
|
|
lbs.push_back(std::get<1>(it));
|
|
ubs.push_back(std::get<2>(it));
|
|
}
|
|
}
|
|
|
|
// Build constraint set.
|
|
FlatAffineValueConstraints cstr;
|
|
for (const auto &it : llvm::zip(boundedValues, lbs, ubs)) {
|
|
unsigned pos;
|
|
if (!cstr.findVar(std::get<0>(it), &pos))
|
|
pos = cstr.appendSymbolVar(std::get<0>(it));
|
|
cstr.addBound(presburger::BoundType::LB, pos, std::get<1>(it));
|
|
// Note: addBound bounds are inclusive, but specified UB is exclusive.
|
|
cstr.addBound(presburger::BoundType::UB, pos, std::get<2>(it) - 1);
|
|
}
|
|
|
|
// Transform all targets.
|
|
SmallVector<Operation *> targets;
|
|
for (Operation *target : state.getPayloadOps(getTarget())) {
|
|
if (!isa<AffineMinOp, AffineMaxOp>(target)) {
|
|
auto diag = emitDefiniteFailure()
|
|
<< "target must be affine.min or affine.max";
|
|
diag.attachNote(target->getLoc()) << "target op";
|
|
return diag;
|
|
}
|
|
if (boundedOps.contains(target)) {
|
|
auto diag = emitDefiniteFailure()
|
|
<< "target op result must not be constrained";
|
|
diag.attachNote(target->getLoc()) << "target/constrained op";
|
|
return diag;
|
|
}
|
|
targets.push_back(target);
|
|
}
|
|
RewritePatternSet patterns(getContext());
|
|
// Canonicalization patterns are needed so that affine.apply ops are composed
|
|
// with the remaining affine.min/max ops.
|
|
AffineMaxOp::getCanonicalizationPatterns(patterns, getContext());
|
|
AffineMinOp::getCanonicalizationPatterns(patterns, getContext());
|
|
patterns.insert<SimplifyAffineMinMaxOp<AffineMinOp>,
|
|
SimplifyAffineMinMaxOp<AffineMaxOp>>(getContext(), cstr);
|
|
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
|
|
// Apply the simplification pattern to a fixpoint.
|
|
if (failed(applyOpPatternsGreedily(
|
|
targets, frozenPatterns,
|
|
GreedyRewriteConfig()
|
|
.setListener(
|
|
static_cast<RewriterBase::Listener *>(rewriter.getListener()))
|
|
.setStrictness(GreedyRewriteStrictness::ExistingAndNewOps)))) {
|
|
auto diag = emitDefiniteFailure()
|
|
<< "affine.min/max simplification did not converge";
|
|
return diag;
|
|
}
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void SimplifyBoundedAffineOpsOp::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
consumesHandle(getTargetMutable(), effects);
|
|
for (OpOperand &operand : getBoundedValuesMutable())
|
|
onlyReadsHandle(operand, effects);
|
|
modifiesPayload(effects);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// SimplifyMinMaxAffineOpsOp
|
|
//===----------------------------------------------------------------------===//
|
|
DiagnosedSilenceableFailure
|
|
SimplifyMinMaxAffineOpsOp::apply(transform::TransformRewriter &rewriter,
|
|
TransformResults &results,
|
|
TransformState &state) {
|
|
SmallVector<Operation *> targets;
|
|
for (Operation *target : state.getPayloadOps(getTarget())) {
|
|
if (!isa<AffineMinOp, AffineMaxOp>(target)) {
|
|
auto diag = emitDefiniteFailure()
|
|
<< "target must be affine.min or affine.max";
|
|
diag.attachNote(target->getLoc()) << "target op";
|
|
return diag;
|
|
}
|
|
targets.push_back(target);
|
|
}
|
|
bool modified = false;
|
|
if (failed(mlir::affine::simplifyAffineMinMaxOps(rewriter, targets,
|
|
&modified))) {
|
|
return emitDefiniteFailure()
|
|
<< "affine.min/max simplification did not converge";
|
|
}
|
|
if (!modified) {
|
|
return emitSilenceableError()
|
|
<< "the transform failed to simplify any of the target operations";
|
|
}
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void SimplifyMinMaxAffineOpsOp::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
consumesHandle(getTargetMutable(), effects);
|
|
modifiesPayload(effects);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Transform op registration
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
class AffineTransformDialectExtension
|
|
: public transform::TransformDialectExtension<
|
|
AffineTransformDialectExtension> {
|
|
public:
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AffineTransformDialectExtension)
|
|
|
|
using Base::Base;
|
|
|
|
void init() {
|
|
declareGeneratedDialect<AffineDialect>();
|
|
|
|
registerTransformOps<
|
|
#define GET_OP_LIST
|
|
#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.cpp.inc"
|
|
>();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
#define GET_OP_CLASSES
|
|
#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.cpp.inc"
|
|
|
|
void mlir::affine::registerTransformDialectExtension(
|
|
DialectRegistry ®istry) {
|
|
registry.addExtensions<AffineTransformDialectExtension>();
|
|
}
|