//===- IntRangeOptimizations.cpp - Optimizations based on integer ranges --===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include #include "mlir/Analysis/DataFlowFramework.h" #include "mlir/Dialect/Arith/Transforms/Passes.h" #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir::arith { #define GEN_PASS_DEF_ARITHINTRANGEOPTS #include "mlir/Dialect/Arith/Transforms/Passes.h.inc" } // namespace mlir::arith using namespace mlir; using namespace mlir::arith; using namespace mlir::dataflow; static std::optional getMaybeConstantValue(DataFlowSolver &solver, Value value) { auto *maybeInferredRange = solver.lookupState(value); if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized()) return std::nullopt; const ConstantIntRanges &inferredRange = maybeInferredRange->getValue().getValue(); return inferredRange.getConstantValue(); } /// Patterned after SCCP static LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver, PatternRewriter &rewriter, Value value) { if (value.use_empty()) return failure(); std::optional maybeConstValue = getMaybeConstantValue(solver, value); if (!maybeConstValue.has_value()) return failure(); Operation *maybeDefiningOp = value.getDefiningOp(); Dialect *valueDialect = maybeDefiningOp ? maybeDefiningOp->getDialect() : value.getParentRegion()->getParentOp()->getDialect(); Attribute constAttr = rewriter.getIntegerAttr(value.getType(), *maybeConstValue); Operation *constOp = valueDialect->materializeConstant( rewriter, constAttr, value.getType(), value.getLoc()); // Fall back to arith.constant if the dialect materializer doesn't know what // to do with an integer constant. if (!constOp) constOp = rewriter.getContext() ->getLoadedDialect() ->materializeConstant(rewriter, constAttr, value.getType(), value.getLoc()); if (!constOp) return failure(); rewriter.replaceAllUsesWith(value, constOp->getResult(0)); return success(); } namespace { class DataFlowListener : public RewriterBase::Listener { public: DataFlowListener(DataFlowSolver &s) : s(s) {} protected: void notifyOperationErased(Operation *op) override { s.eraseState(s.getProgramPointAfter(op)); for (Value res : op->getResults()) s.eraseState(res); } DataFlowSolver &s; }; /// Rewrite any results of `op` that were inferred to be constant integers to /// and replace their uses with that constant. Return success() if all results /// where thus replaced and the operation is erased. Also replace any block /// arguments with their constant values. struct MaterializeKnownConstantValues : public RewritePattern { MaterializeKnownConstantValues(MLIRContext *context, DataFlowSolver &s) : RewritePattern(Pattern::MatchAnyOpTypeTag(), /*benefit=*/1, context), solver(s) {} LogicalResult match(Operation *op) const override { if (matchPattern(op, m_Constant())) return failure(); auto needsReplacing = [&](Value v) { return getMaybeConstantValue(solver, v).has_value() && !v.use_empty(); }; bool hasConstantResults = llvm::any_of(op->getResults(), needsReplacing); if (op->getNumRegions() == 0) return success(hasConstantResults); bool hasConstantRegionArgs = false; for (Region ®ion : op->getRegions()) { for (Block &block : region.getBlocks()) { hasConstantRegionArgs |= llvm::any_of(block.getArguments(), needsReplacing); } } return success(hasConstantResults || hasConstantRegionArgs); } void rewrite(Operation *op, PatternRewriter &rewriter) const override { bool replacedAll = (op->getNumResults() != 0); for (Value v : op->getResults()) replacedAll &= (succeeded(maybeReplaceWithConstant(solver, rewriter, v)) || v.use_empty()); if (replacedAll && isOpTriviallyDead(op)) { rewriter.eraseOp(op); return; } PatternRewriter::InsertionGuard guard(rewriter); for (Region ®ion : op->getRegions()) { for (Block &block : region.getBlocks()) { rewriter.setInsertionPointToStart(&block); for (BlockArgument &arg : block.getArguments()) { (void)maybeReplaceWithConstant(solver, rewriter, arg); } } } } private: DataFlowSolver &solver; }; template struct DeleteTrivialRem : public OpRewritePattern { DeleteTrivialRem(MLIRContext *context, DataFlowSolver &s) : OpRewritePattern(context), solver(s) {} LogicalResult matchAndRewrite(RemOp op, PatternRewriter &rewriter) const override { Value lhs = op.getOperand(0); Value rhs = op.getOperand(1); auto maybeModulus = getConstantIntValue(rhs); if (!maybeModulus.has_value()) return failure(); int64_t modulus = *maybeModulus; if (modulus <= 0) return failure(); auto *maybeLhsRange = solver.lookupState(lhs); if (!maybeLhsRange || maybeLhsRange->getValue().isUninitialized()) return failure(); const ConstantIntRanges &lhsRange = maybeLhsRange->getValue().getValue(); const APInt &min = isa(op) ? lhsRange.umin() : lhsRange.smin(); const APInt &max = isa(op) ? lhsRange.umax() : lhsRange.smax(); // The minima and maxima here are given as closed ranges, we must be // strictly less than the modulus. if (min.isNegative() || min.uge(modulus)) return failure(); if (max.isNegative() || max.uge(modulus)) return failure(); if (!min.ule(max)) return failure(); // With all those conditions out of the way, we know thas this invocation of // a remainder is a noop because the input is strictly within the range // [0, modulus), so get rid of it. rewriter.replaceOp(op, ValueRange{lhs}); return success(); } private: DataFlowSolver &solver; }; struct IntRangeOptimizationsPass : public arith::impl::ArithIntRangeOptsBase { void runOnOperation() override { Operation *op = getOperation(); MLIRContext *ctx = op->getContext(); DataFlowSolver solver; solver.load(); solver.load(); if (failed(solver.initializeAndRun(op))) return signalPassFailure(); DataFlowListener listener(solver); RewritePatternSet patterns(ctx); populateIntRangeOptimizationsPatterns(patterns, solver); GreedyRewriteConfig config; config.listener = &listener; if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config))) signalPassFailure(); } }; } // namespace void mlir::arith::populateIntRangeOptimizationsPatterns( RewritePatternSet &patterns, DataFlowSolver &solver) { patterns.add, DeleteTrivialRem>(patterns.getContext(), solver); } std::unique_ptr mlir::arith::createIntRangeOptimizationsPass() { return std::make_unique(); }