//=== AffineTransformOps.cpp - Implementation of Affine transformation ops ===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h" #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" #include "mlir/Dialect/Affine/Analysis/Utils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineValueMap.h" #include "mlir/Dialect/Affine/LoopUtils.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; using namespace mlir::affine; using namespace mlir::transform; //===----------------------------------------------------------------------===// // SimplifyBoundedAffineOpsOp //===----------------------------------------------------------------------===// LogicalResult SimplifyBoundedAffineOpsOp::verify() { if (getLowerBounds().size() != getBoundedValues().size()) return emitOpError() << "incorrect number of lower bounds, expected " << getBoundedValues().size() << " but found " << getLowerBounds().size(); if (getUpperBounds().size() != getBoundedValues().size()) return emitOpError() << "incorrect number of upper bounds, expected " << getBoundedValues().size() << " but found " << getUpperBounds().size(); return success(); } namespace { /// Simplify affine.min / affine.max ops with the given constraints. They are /// either rewritten to affine.apply or left unchanged. template struct SimplifyAffineMinMaxOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; SimplifyAffineMinMaxOp(MLIRContext *ctx, const FlatAffineValueConstraints &constraints, PatternBenefit benefit = 1) : OpRewritePattern(ctx, benefit), constraints(constraints) {} LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { FailureOr simplified = simplifyConstrainedMinMaxOp(op, constraints); if (failed(simplified)) return failure(); rewriter.replaceOpWithNewOp(op, simplified->getAffineMap(), simplified->getOperands()); return success(); } const FlatAffineValueConstraints &constraints; }; } // namespace DiagnosedSilenceableFailure SimplifyBoundedAffineOpsOp::apply(transform::TransformRewriter &rewriter, TransformResults &results, TransformState &state) { // Get constraints for bounded values. SmallVector lbs; SmallVector ubs; SmallVector boundedValues; DenseSet boundedOps; for (const auto &it : llvm::zip_equal(getBoundedValues(), getLowerBounds(), getUpperBounds())) { Value handle = std::get<0>(it); for (Operation *op : state.getPayloadOps(handle)) { if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) { auto diag = emitDefiniteFailure() << "expected bounded value handle to point to one or multiple " "single-result index-typed ops"; diag.attachNote(op->getLoc()) << "multiple/non-index result"; return diag; } boundedValues.push_back(op->getResult(0)); boundedOps.insert(op); lbs.push_back(std::get<1>(it)); ubs.push_back(std::get<2>(it)); } } // Build constraint set. FlatAffineValueConstraints cstr; for (const auto &it : llvm::zip(boundedValues, lbs, ubs)) { unsigned pos; if (!cstr.findVar(std::get<0>(it), &pos)) pos = cstr.appendSymbolVar(std::get<0>(it)); cstr.addBound(presburger::BoundType::LB, pos, std::get<1>(it)); // Note: addBound bounds are inclusive, but specified UB is exclusive. cstr.addBound(presburger::BoundType::UB, pos, std::get<2>(it) - 1); } // Transform all targets. SmallVector targets; for (Operation *target : state.getPayloadOps(getTarget())) { if (!isa(target)) { auto diag = emitDefiniteFailure() << "target must be affine.min or affine.max"; diag.attachNote(target->getLoc()) << "target op"; return diag; } if (boundedOps.contains(target)) { auto diag = emitDefiniteFailure() << "target op result must not be constrainted"; diag.attachNote(target->getLoc()) << "target/constrained op"; return diag; } targets.push_back(target); } SmallVector transformed; RewritePatternSet patterns(getContext()); // Canonicalization patterns are needed so that affine.apply ops are composed // with the remaining affine.min/max ops. AffineMaxOp::getCanonicalizationPatterns(patterns, getContext()); AffineMinOp::getCanonicalizationPatterns(patterns, getContext()); patterns.insert, SimplifyAffineMinMaxOp>(getContext(), cstr); FrozenRewritePatternSet frozenPatterns(std::move(patterns)); GreedyRewriteConfig config; config.listener = static_cast(rewriter.getListener()); config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps; // Apply the simplification pattern to a fixpoint. if (failed(applyOpPatternsAndFold(targets, frozenPatterns, config))) { auto diag = emitDefiniteFailure() << "affine.min/max simplification did not converge"; return diag; } return DiagnosedSilenceableFailure::success(); } void SimplifyBoundedAffineOpsOp::getEffects( SmallVectorImpl &effects) { consumesHandle(getTarget(), effects); for (Value v : getBoundedValues()) onlyReadsHandle(v, effects); modifiesPayload(effects); } //===----------------------------------------------------------------------===// // Transform op registration //===----------------------------------------------------------------------===// namespace { class AffineTransformDialectExtension : public transform::TransformDialectExtension< AffineTransformDialectExtension> { public: using Base::Base; void init() { declareGeneratedDialect(); registerTransformOps< #define GET_OP_LIST #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.cpp.inc" >(); } }; } // namespace #define GET_OP_CLASSES #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.cpp.inc" void mlir::affine::registerTransformDialectExtension( DialectRegistry ®istry) { registry.addExtensions(); }