When the integer range analysis was first develop, a pass that did
integer range-based constant folding was developed and used as a test
pass. There was an intent to add such a folding to SCCP, but that hasn't
happened.
Meanwhile, -int-range-optimizations was added to the arith dialect's
transformations. The cmpi simplification in that pass is a strict subset
of the constant folding that lived in
-test-int-range-inference.
This commit moves the former test pass into -int-range-optimizaitons,
subsuming its previous contents. It also adds an optimization from
rocMLIR where `rem{s,u}i` operations that are noops are replaced by
their left operands.
222 lines
7.8 KiB
C++
222 lines
7.8 KiB
C++
//===- IntRangeOptimizations.cpp - Optimizations based on integer ranges --===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include <utility>
|
|
|
|
#include "mlir/Analysis/DataFlowFramework.h"
|
|
#include "mlir/Dialect/Arith/Transforms/Passes.h"
|
|
|
|
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
|
|
#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/Utils/StaticValueUtils.h"
|
|
#include "mlir/IR/Matchers.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
|
#include "mlir/Transforms/FoldUtils.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
|
|
namespace mlir::arith {
|
|
#define GEN_PASS_DEF_ARITHINTRANGEOPTS
|
|
#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
|
|
} // namespace mlir::arith
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::arith;
|
|
using namespace mlir::dataflow;
|
|
|
|
static std::optional<APInt> getMaybeConstantValue(DataFlowSolver &solver,
|
|
Value value) {
|
|
auto *maybeInferredRange =
|
|
solver.lookupState<IntegerValueRangeLattice>(value);
|
|
if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
|
|
return std::nullopt;
|
|
const ConstantIntRanges &inferredRange =
|
|
maybeInferredRange->getValue().getValue();
|
|
return inferredRange.getConstantValue();
|
|
}
|
|
|
|
/// Patterned after SCCP
|
|
static LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver,
|
|
PatternRewriter &rewriter,
|
|
Value value) {
|
|
if (value.use_empty())
|
|
return failure();
|
|
std::optional<APInt> maybeConstValue = getMaybeConstantValue(solver, value);
|
|
if (!maybeConstValue.has_value())
|
|
return failure();
|
|
|
|
Operation *maybeDefiningOp = value.getDefiningOp();
|
|
Dialect *valueDialect =
|
|
maybeDefiningOp ? maybeDefiningOp->getDialect()
|
|
: value.getParentRegion()->getParentOp()->getDialect();
|
|
Attribute constAttr =
|
|
rewriter.getIntegerAttr(value.getType(), *maybeConstValue);
|
|
Operation *constOp = valueDialect->materializeConstant(
|
|
rewriter, constAttr, value.getType(), value.getLoc());
|
|
// Fall back to arith.constant if the dialect materializer doesn't know what
|
|
// to do with an integer constant.
|
|
if (!constOp)
|
|
constOp = rewriter.getContext()
|
|
->getLoadedDialect<ArithDialect>()
|
|
->materializeConstant(rewriter, constAttr, value.getType(),
|
|
value.getLoc());
|
|
if (!constOp)
|
|
return failure();
|
|
|
|
rewriter.replaceAllUsesWith(value, constOp->getResult(0));
|
|
return success();
|
|
}
|
|
|
|
namespace {
|
|
class DataFlowListener : public RewriterBase::Listener {
|
|
public:
|
|
DataFlowListener(DataFlowSolver &s) : s(s) {}
|
|
|
|
protected:
|
|
void notifyOperationErased(Operation *op) override {
|
|
s.eraseState(op);
|
|
for (Value res : op->getResults())
|
|
s.eraseState(res);
|
|
}
|
|
|
|
DataFlowSolver &s;
|
|
};
|
|
|
|
/// Rewrite any results of `op` that were inferred to be constant integers to
|
|
/// and replace their uses with that constant. Return success() if all results
|
|
/// where thus replaced and the operation is erased. Also replace any block
|
|
/// arguments with their constant values.
|
|
struct MaterializeKnownConstantValues : public RewritePattern {
|
|
MaterializeKnownConstantValues(MLIRContext *context, DataFlowSolver &s)
|
|
: RewritePattern(Pattern::MatchAnyOpTypeTag(), /*benefit=*/1, context),
|
|
solver(s) {}
|
|
|
|
LogicalResult match(Operation *op) const override {
|
|
if (matchPattern(op, m_Constant()))
|
|
return failure();
|
|
|
|
auto needsReplacing = [&](Value v) {
|
|
return getMaybeConstantValue(solver, v).has_value() && !v.use_empty();
|
|
};
|
|
bool hasConstantResults = llvm::any_of(op->getResults(), needsReplacing);
|
|
if (op->getNumRegions() == 0)
|
|
return success(hasConstantResults);
|
|
bool hasConstantRegionArgs = false;
|
|
for (Region ®ion : op->getRegions()) {
|
|
for (Block &block : region.getBlocks()) {
|
|
hasConstantRegionArgs |=
|
|
llvm::any_of(block.getArguments(), needsReplacing);
|
|
}
|
|
}
|
|
return success(hasConstantResults || hasConstantRegionArgs);
|
|
}
|
|
|
|
void rewrite(Operation *op, PatternRewriter &rewriter) const override {
|
|
bool replacedAll = (op->getNumResults() != 0);
|
|
for (Value v : op->getResults())
|
|
replacedAll &=
|
|
(succeeded(maybeReplaceWithConstant(solver, rewriter, v)) ||
|
|
v.use_empty());
|
|
if (replacedAll && isOpTriviallyDead(op)) {
|
|
rewriter.eraseOp(op);
|
|
return;
|
|
}
|
|
|
|
PatternRewriter::InsertionGuard guard(rewriter);
|
|
for (Region ®ion : op->getRegions()) {
|
|
for (Block &block : region.getBlocks()) {
|
|
rewriter.setInsertionPointToStart(&block);
|
|
for (BlockArgument &arg : block.getArguments()) {
|
|
(void)maybeReplaceWithConstant(solver, rewriter, arg);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
private:
|
|
DataFlowSolver &solver;
|
|
};
|
|
|
|
template <typename RemOp>
|
|
struct DeleteTrivialRem : public OpRewritePattern<RemOp> {
|
|
DeleteTrivialRem(MLIRContext *context, DataFlowSolver &s)
|
|
: OpRewritePattern<RemOp>(context), solver(s) {}
|
|
|
|
LogicalResult matchAndRewrite(RemOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
Value lhs = op.getOperand(0);
|
|
Value rhs = op.getOperand(1);
|
|
auto maybeModulus = getConstantIntValue(rhs);
|
|
if (!maybeModulus.has_value())
|
|
return failure();
|
|
int64_t modulus = *maybeModulus;
|
|
if (modulus <= 0)
|
|
return failure();
|
|
auto *maybeLhsRange = solver.lookupState<IntegerValueRangeLattice>(lhs);
|
|
if (!maybeLhsRange || maybeLhsRange->getValue().isUninitialized())
|
|
return failure();
|
|
const ConstantIntRanges &lhsRange = maybeLhsRange->getValue().getValue();
|
|
const APInt &min = isa<RemUIOp>(op) ? lhsRange.umin() : lhsRange.smin();
|
|
const APInt &max = isa<RemUIOp>(op) ? lhsRange.umax() : lhsRange.smax();
|
|
// The minima and maxima here are given as closed ranges, we must be
|
|
// strictly less than the modulus.
|
|
if (min.isNegative() || min.uge(modulus))
|
|
return failure();
|
|
if (max.isNegative() || max.uge(modulus))
|
|
return failure();
|
|
if (!min.ule(max))
|
|
return failure();
|
|
|
|
// With all those conditions out of the way, we know thas this invocation of
|
|
// a remainder is a noop because the input is strictly within the range
|
|
// [0, modulus), so get rid of it.
|
|
rewriter.replaceOp(op, ValueRange{lhs});
|
|
return success();
|
|
}
|
|
|
|
private:
|
|
DataFlowSolver &solver;
|
|
};
|
|
|
|
struct IntRangeOptimizationsPass
|
|
: public arith::impl::ArithIntRangeOptsBase<IntRangeOptimizationsPass> {
|
|
|
|
void runOnOperation() override {
|
|
Operation *op = getOperation();
|
|
MLIRContext *ctx = op->getContext();
|
|
DataFlowSolver solver;
|
|
solver.load<DeadCodeAnalysis>();
|
|
solver.load<IntegerRangeAnalysis>();
|
|
if (failed(solver.initializeAndRun(op)))
|
|
return signalPassFailure();
|
|
|
|
DataFlowListener listener(solver);
|
|
|
|
RewritePatternSet patterns(ctx);
|
|
populateIntRangeOptimizationsPatterns(patterns, solver);
|
|
|
|
GreedyRewriteConfig config;
|
|
config.listener = &listener;
|
|
|
|
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
|
|
signalPassFailure();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
void mlir::arith::populateIntRangeOptimizationsPatterns(
|
|
RewritePatternSet &patterns, DataFlowSolver &solver) {
|
|
patterns.add<MaterializeKnownConstantValues, DeleteTrivialRem<RemSIOp>,
|
|
DeleteTrivialRem<RemUIOp>>(patterns.getContext(), solver);
|
|
}
|
|
|
|
std::unique_ptr<Pass> mlir::arith::createIntRangeOptimizationsPass() {
|
|
return std::make_unique<IntRangeOptimizationsPass>();
|
|
}
|