[mlir][Arith] Generalize and improve -int-range-optimizations (#94712)

When the integer range analysis was first develop, a pass that did
integer range-based constant folding was developed and used as a test
pass. There was an intent to add such a folding to SCCP, but that hasn't
happened.

Meanwhile, -int-range-optimizations was added to the arith dialect's
transformations. The cmpi simplification in that pass is a strict subset
of the constant folding that lived in
-test-int-range-inference.

This commit moves the former test pass into -int-range-optimizaitons,
subsuming its previous contents. It also adds an optimization from
rocMLIR where `rem{s,u}i` operations that are noops are replaced by
their left operands.
This commit is contained in:
Krzysztof Drewniak
2024-06-10 07:56:33 -07:00
committed by GitHub
parent 3e39328b62
commit 472291111d
10 changed files with 184 additions and 257 deletions

View File

@@ -40,9 +40,14 @@ def ArithIntRangeOpts : Pass<"int-range-optimizations"> {
let summary = "Do optimizations based on integer range analysis";
let description = [{
This pass runs integer range analysis and apllies optimizations based on its
results. e.g. replace arith.cmpi with const if it can be inferred from
args ranges.
results. It replaces operations with known-constant results with said constants,
rewrites `(0 <= %x < D) mod D` to `%x`.
}];
// Explicitly depend on "arith" because this pass could create operations in
// `arith` out of thin air in some cases.
let dependentDialects = [
"::mlir::arith::ArithDialect"
];
}
def ArithEmulateUnsupportedFloats : Pass<"arith-emulate-unsupported-floats"> {

View File

@@ -8,11 +8,17 @@
#include <utility>
#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir::arith {
@@ -24,88 +30,50 @@ using namespace mlir;
using namespace mlir::arith;
using namespace mlir::dataflow;
/// Returns true if 2 integer ranges have intersection.
static bool intersects(const ConstantIntRanges &lhs,
const ConstantIntRanges &rhs) {
return !((lhs.smax().slt(rhs.smin()) || lhs.smin().sgt(rhs.smax())) &&
(lhs.umax().ult(rhs.umin()) || lhs.umin().ugt(rhs.umax())));
static std::optional<APInt> getMaybeConstantValue(DataFlowSolver &solver,
Value value) {
auto *maybeInferredRange =
solver.lookupState<IntegerValueRangeLattice>(value);
if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
return std::nullopt;
const ConstantIntRanges &inferredRange =
maybeInferredRange->getValue().getValue();
return inferredRange.getConstantValue();
}
static FailureOr<bool> handleEq(ConstantIntRanges lhs, ConstantIntRanges rhs) {
if (!intersects(lhs, rhs))
return false;
/// Patterned after SCCP
static LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver,
PatternRewriter &rewriter,
Value value) {
if (value.use_empty())
return failure();
std::optional<APInt> maybeConstValue = getMaybeConstantValue(solver, value);
if (!maybeConstValue.has_value())
return failure();
return failure();
}
Operation *maybeDefiningOp = value.getDefiningOp();
Dialect *valueDialect =
maybeDefiningOp ? maybeDefiningOp->getDialect()
: value.getParentRegion()->getParentOp()->getDialect();
Attribute constAttr =
rewriter.getIntegerAttr(value.getType(), *maybeConstValue);
Operation *constOp = valueDialect->materializeConstant(
rewriter, constAttr, value.getType(), value.getLoc());
// Fall back to arith.constant if the dialect materializer doesn't know what
// to do with an integer constant.
if (!constOp)
constOp = rewriter.getContext()
->getLoadedDialect<ArithDialect>()
->materializeConstant(rewriter, constAttr, value.getType(),
value.getLoc());
if (!constOp)
return failure();
static FailureOr<bool> handleNe(ConstantIntRanges lhs, ConstantIntRanges rhs) {
if (!intersects(lhs, rhs))
return true;
return failure();
}
static FailureOr<bool> handleSlt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
if (lhs.smax().slt(rhs.smin()))
return true;
if (lhs.smin().sge(rhs.smax()))
return false;
return failure();
}
static FailureOr<bool> handleSle(ConstantIntRanges lhs, ConstantIntRanges rhs) {
if (lhs.smax().sle(rhs.smin()))
return true;
if (lhs.smin().sgt(rhs.smax()))
return false;
return failure();
}
static FailureOr<bool> handleSgt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
return handleSlt(std::move(rhs), std::move(lhs));
}
static FailureOr<bool> handleSge(ConstantIntRanges lhs, ConstantIntRanges rhs) {
return handleSle(std::move(rhs), std::move(lhs));
}
static FailureOr<bool> handleUlt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
if (lhs.umax().ult(rhs.umin()))
return true;
if (lhs.umin().uge(rhs.umax()))
return false;
return failure();
}
static FailureOr<bool> handleUle(ConstantIntRanges lhs, ConstantIntRanges rhs) {
if (lhs.umax().ule(rhs.umin()))
return true;
if (lhs.umin().ugt(rhs.umax()))
return false;
return failure();
}
static FailureOr<bool> handleUgt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
return handleUlt(std::move(rhs), std::move(lhs));
}
static FailureOr<bool> handleUge(ConstantIntRanges lhs, ConstantIntRanges rhs) {
return handleUle(std::move(rhs), std::move(lhs));
rewriter.replaceAllUsesWith(value, constOp->getResult(0));
return success();
}
namespace {
/// This class listens on IR transformations performed during a pass relying on
/// information from a `DataflowSolver`. It erases state associated with the
/// erased operation and its results from the `DataFlowSolver` so that Patterns
/// do not accidentally query old state information for newly created Ops.
class DataFlowListener : public RewriterBase::Listener {
public:
DataFlowListener(DataFlowSolver &s) : s(s) {}
@@ -120,52 +88,95 @@ protected:
DataFlowSolver &s;
};
struct ConvertCmpOp : public OpRewritePattern<arith::CmpIOp> {
/// Rewrite any results of `op` that were inferred to be constant integers to
/// and replace their uses with that constant. Return success() if all results
/// where thus replaced and the operation is erased. Also replace any block
/// arguments with their constant values.
struct MaterializeKnownConstantValues : public RewritePattern {
MaterializeKnownConstantValues(MLIRContext *context, DataFlowSolver &s)
: RewritePattern(Pattern::MatchAnyOpTypeTag(), /*benefit=*/1, context),
solver(s) {}
ConvertCmpOp(MLIRContext *context, DataFlowSolver &s)
: OpRewritePattern<arith::CmpIOp>(context), solver(s) {}
LogicalResult match(Operation *op) const override {
if (matchPattern(op, m_Constant()))
return failure();
LogicalResult matchAndRewrite(arith::CmpIOp op,
auto needsReplacing = [&](Value v) {
return getMaybeConstantValue(solver, v).has_value() && !v.use_empty();
};
bool hasConstantResults = llvm::any_of(op->getResults(), needsReplacing);
if (op->getNumRegions() == 0)
return success(hasConstantResults);
bool hasConstantRegionArgs = false;
for (Region &region : op->getRegions()) {
for (Block &block : region.getBlocks()) {
hasConstantRegionArgs |=
llvm::any_of(block.getArguments(), needsReplacing);
}
}
return success(hasConstantResults || hasConstantRegionArgs);
}
void rewrite(Operation *op, PatternRewriter &rewriter) const override {
bool replacedAll = (op->getNumResults() != 0);
for (Value v : op->getResults())
replacedAll &=
(succeeded(maybeReplaceWithConstant(solver, rewriter, v)) ||
v.use_empty());
if (replacedAll && isOpTriviallyDead(op)) {
rewriter.eraseOp(op);
return;
}
PatternRewriter::InsertionGuard guard(rewriter);
for (Region &region : op->getRegions()) {
for (Block &block : region.getBlocks()) {
rewriter.setInsertionPointToStart(&block);
for (BlockArgument &arg : block.getArguments()) {
(void)maybeReplaceWithConstant(solver, rewriter, arg);
}
}
}
}
private:
DataFlowSolver &solver;
};
template <typename RemOp>
struct DeleteTrivialRem : public OpRewritePattern<RemOp> {
DeleteTrivialRem(MLIRContext *context, DataFlowSolver &s)
: OpRewritePattern<RemOp>(context), solver(s) {}
LogicalResult matchAndRewrite(RemOp op,
PatternRewriter &rewriter) const override {
auto *lhsResult =
solver.lookupState<dataflow::IntegerValueRangeLattice>(op.getLhs());
if (!lhsResult || lhsResult->getValue().isUninitialized())
Value lhs = op.getOperand(0);
Value rhs = op.getOperand(1);
auto maybeModulus = getConstantIntValue(rhs);
if (!maybeModulus.has_value())
return failure();
int64_t modulus = *maybeModulus;
if (modulus <= 0)
return failure();
auto *maybeLhsRange = solver.lookupState<IntegerValueRangeLattice>(lhs);
if (!maybeLhsRange || maybeLhsRange->getValue().isUninitialized())
return failure();
const ConstantIntRanges &lhsRange = maybeLhsRange->getValue().getValue();
const APInt &min = isa<RemUIOp>(op) ? lhsRange.umin() : lhsRange.smin();
const APInt &max = isa<RemUIOp>(op) ? lhsRange.umax() : lhsRange.smax();
// The minima and maxima here are given as closed ranges, we must be
// strictly less than the modulus.
if (min.isNegative() || min.uge(modulus))
return failure();
if (max.isNegative() || max.uge(modulus))
return failure();
if (!min.ule(max))
return failure();
auto *rhsResult =
solver.lookupState<dataflow::IntegerValueRangeLattice>(op.getRhs());
if (!rhsResult || rhsResult->getValue().isUninitialized())
return failure();
using HandlerFunc =
FailureOr<bool> (*)(ConstantIntRanges, ConstantIntRanges);
std::array<HandlerFunc, arith::getMaxEnumValForCmpIPredicate() + 1>
handlers{};
using Pred = arith::CmpIPredicate;
handlers[static_cast<size_t>(Pred::eq)] = &handleEq;
handlers[static_cast<size_t>(Pred::ne)] = &handleNe;
handlers[static_cast<size_t>(Pred::slt)] = &handleSlt;
handlers[static_cast<size_t>(Pred::sle)] = &handleSle;
handlers[static_cast<size_t>(Pred::sgt)] = &handleSgt;
handlers[static_cast<size_t>(Pred::sge)] = &handleSge;
handlers[static_cast<size_t>(Pred::ult)] = &handleUlt;
handlers[static_cast<size_t>(Pred::ule)] = &handleUle;
handlers[static_cast<size_t>(Pred::ugt)] = &handleUgt;
handlers[static_cast<size_t>(Pred::uge)] = &handleUge;
HandlerFunc handler = handlers[static_cast<size_t>(op.getPredicate())];
if (!handler)
return failure();
ConstantIntRanges lhsValue = lhsResult->getValue().getValue();
ConstantIntRanges rhsValue = rhsResult->getValue().getValue();
FailureOr<bool> result = handler(lhsValue, rhsValue);
if (failed(result))
return failure();
rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(
op, static_cast<int64_t>(*result), /*width*/ 1);
// With all those conditions out of the way, we know thas this invocation of
// a remainder is a noop because the input is strictly within the range
// [0, modulus), so get rid of it.
rewriter.replaceOp(op, ValueRange{lhs});
return success();
}
@@ -201,7 +212,8 @@ struct IntRangeOptimizationsPass
void mlir::arith::populateIntRangeOptimizationsPatterns(
RewritePatternSet &patterns, DataFlowSolver &solver) {
patterns.add<ConvertCmpOp>(patterns.getContext(), solver);
patterns.add<MaterializeKnownConstantValues, DeleteTrivialRem<RemSIOp>,
DeleteTrivialRem<RemUIOp>>(patterns.getContext(), solver);
}
std::unique_ptr<Pass> mlir::arith::createIntRangeOptimizationsPass() {

View File

@@ -1,4 +1,4 @@
// RUN: mlir-opt -test-int-range-inference -canonicalize %s | FileCheck %s
// RUN: mlir-opt -int-range-optimizations -canonicalize %s | FileCheck %s
// CHECK-LABEL: func @add_min_max
// CHECK: %[[c3:.*]] = arith.constant 3 : index

View File

@@ -96,3 +96,39 @@ func.func @test() -> i8 {
return %1: i8
}
// -----
// CHECK-LABEL: func @trivial_rem
// CHECK: [[val:%.+]] = test.with_bounds
// CHECK: return [[val]]
func.func @trivial_rem() -> i8 {
%c64 = arith.constant 64 : i8
%val = test.with_bounds { umin = 0 : ui8, umax = 63 : ui8, smin = 0 : si8, smax = 63 : si8 } : i8
%mod = arith.remsi %val, %c64 : i8
return %mod : i8
}
// -----
// CHECK-LABEL: func @non_const_rhs
// CHECK: [[mod:%.+]] = arith.remui
// CHECK: return [[mod]]
func.func @non_const_rhs() -> i8 {
%c64 = arith.constant 64 : i8
%val = test.with_bounds { umin = 0 : ui8, umax = 2 : ui8, smin = 0 : si8, smax = 2 : si8 } : i8
%rhs = test.with_bounds { umin = 63 : ui8, umax = 64 : ui8, smin = 63 : si8, smax = 64 : si8 } : i8
%mod = arith.remui %val, %rhs : i8
return %mod : i8
}
// -----
// CHECK-LABEL: func @wraps
// CHECK: [[mod:%.+]] = arith.remsi
// CHECK: return [[mod]]
func.func @wraps() -> i8 {
%c64 = arith.constant 64 : i8
%val = test.with_bounds { umin = 63 : ui8, umax = 65 : ui8, smin = 63 : si8, smax = 65 : si8 } : i8
%mod = arith.remsi %val, %c64 : i8
return %mod : i8
}

View File

@@ -1,4 +1,4 @@
// RUN: mlir-opt -test-int-range-inference -split-input-file %s | FileCheck %s
// RUN: mlir-opt -int-range-optimizations -split-input-file %s | FileCheck %s
// CHECK-LABEL: func @launch_func
func.func @launch_func(%arg0 : index) {

View File

@@ -1,4 +1,4 @@
// RUN: mlir-opt -test-int-range-inference -canonicalize %s | FileCheck %s
// RUN: mlir-opt -int-range-optimizations -canonicalize %s | FileCheck %s
// Most operations are covered by the `arith` tests, which use the same code
// Here, we add a few tests to ensure the "index can be 32- or 64-bit" handling

View File

@@ -1,4 +1,4 @@
// RUN: mlir-opt -test-int-range-inference %s | FileCheck %s
// RUN: mlir-opt -int-range-optimizations %s | FileCheck %s
// CHECK-LABEL: func @constant
// CHECK: %[[cst:.*]] = "test.constant"() <{value = 3 : index}
@@ -103,13 +103,11 @@ func.func @func_args_unbound(%arg0 : index) -> index {
// CHECK-LABEL: func @propagate_across_while_loop_false()
func.func @propagate_across_while_loop_false() -> index {
// CHECK-DAG: %[[C0:.*]] = "test.constant"() <{value = 0
// CHECK-DAG: %[[C1:.*]] = "test.constant"() <{value = 1
// CHECK: %[[C1:.*]] = "test.constant"() <{value = 1
%0 = test.with_bounds { umin = 0 : index, umax = 0 : index,
smin = 0 : index, smax = 0 : index } : index
%1 = scf.while : () -> index {
%false = arith.constant false
// CHECK: scf.condition(%{{.*}}) %[[C0]]
scf.condition(%false) %0 : index
} do {
^bb0(%i1: index):
@@ -122,12 +120,10 @@ func.func @propagate_across_while_loop_false() -> index {
// CHECK-LABEL: func @propagate_across_while_loop
func.func @propagate_across_while_loop(%arg0 : i1) -> index {
// CHECK-DAG: %[[C0:.*]] = "test.constant"() <{value = 0
// CHECK-DAG: %[[C1:.*]] = "test.constant"() <{value = 1
// CHECK: %[[C1:.*]] = "test.constant"() <{value = 1
%0 = test.with_bounds { umin = 0 : index, umax = 0 : index,
smin = 0 : index, smax = 0 : index } : index
%1 = scf.while : () -> index {
// CHECK: scf.condition(%{{.*}}) %[[C0]]
scf.condition(%arg0) %0 : index
} do {
^bb0(%i1: index):

View File

@@ -24,7 +24,6 @@ add_mlir_library(MLIRTestTransforms
TestConstantFold.cpp
TestControlFlowSink.cpp
TestInlining.cpp
TestIntRangeInference.cpp
TestMakeIsolatedFromAbove.cpp
${MLIRTestTransformsPDLSrc}

View File

@@ -1,125 +0,0 @@
//===- TestIntRangeInference.cpp - Create consts from range inference ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// TODO: This pass is needed to test integer range inference until that
// functionality has been integrated into SCCP.
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Support/TypeID.h"
#include "mlir/Transforms/FoldUtils.h"
#include <optional>
using namespace mlir;
using namespace mlir::dataflow;
/// Patterned after SCCP
static LogicalResult replaceWithConstant(DataFlowSolver &solver, OpBuilder &b,
OperationFolder &folder, Value value) {
auto *maybeInferredRange =
solver.lookupState<IntegerValueRangeLattice>(value);
if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
return failure();
const ConstantIntRanges &inferredRange =
maybeInferredRange->getValue().getValue();
std::optional<APInt> maybeConstValue = inferredRange.getConstantValue();
if (!maybeConstValue.has_value())
return failure();
Operation *maybeDefiningOp = value.getDefiningOp();
Dialect *valueDialect =
maybeDefiningOp ? maybeDefiningOp->getDialect()
: value.getParentRegion()->getParentOp()->getDialect();
Attribute constAttr = b.getIntegerAttr(value.getType(), *maybeConstValue);
Value constant = folder.getOrCreateConstant(
b.getInsertionBlock(), valueDialect, constAttr, value.getType());
if (!constant)
return failure();
value.replaceAllUsesWith(constant);
return success();
}
static void rewrite(DataFlowSolver &solver, MLIRContext *context,
MutableArrayRef<Region> initialRegions) {
SmallVector<Block *> worklist;
auto addToWorklist = [&](MutableArrayRef<Region> regions) {
for (Region &region : regions)
for (Block &block : llvm::reverse(region))
worklist.push_back(&block);
};
OpBuilder builder(context);
OperationFolder folder(context);
addToWorklist(initialRegions);
while (!worklist.empty()) {
Block *block = worklist.pop_back_val();
for (Operation &op : llvm::make_early_inc_range(*block)) {
builder.setInsertionPoint(&op);
// Replace any result with constants.
bool replacedAll = op.getNumResults() != 0;
for (Value res : op.getResults())
replacedAll &=
succeeded(replaceWithConstant(solver, builder, folder, res));
// If all of the results of the operation were replaced, try to erase
// the operation completely.
if (replacedAll && wouldOpBeTriviallyDead(&op)) {
assert(op.use_empty() && "expected all uses to be replaced");
op.erase();
continue;
}
// Add any the regions of this operation to the worklist.
addToWorklist(op.getRegions());
}
// Replace any block arguments with constants.
builder.setInsertionPointToStart(block);
for (BlockArgument arg : block->getArguments())
(void)replaceWithConstant(solver, builder, folder, arg);
}
}
namespace {
struct TestIntRangeInference
: PassWrapper<TestIntRangeInference, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestIntRangeInference)
StringRef getArgument() const final { return "test-int-range-inference"; }
StringRef getDescription() const final {
return "Test integer range inference analysis";
}
void runOnOperation() override {
Operation *op = getOperation();
DataFlowSolver solver;
solver.load<DeadCodeAnalysis>();
solver.load<SparseConstantPropagation>();
solver.load<IntegerRangeAnalysis>();
if (failed(solver.initializeAndRun(op)))
return signalPassFailure();
rewrite(solver, op->getContext(), op->getRegions());
}
};
} // end anonymous namespace
namespace mlir {
namespace test {
void registerTestIntRangeInference() {
PassRegistration<TestIntRangeInference>();
}
} // end namespace test
} // end namespace mlir

View File

@@ -97,9 +97,11 @@ void registerTestDynamicPipelinePass();
void registerTestEmulateNarrowTypePass();
void registerTestExpandMathPass();
void registerTestFooAnalysisPass();
void registerTestComposeSubView();
void registerTestMultiBuffering();
void registerTestIRVisitorsPass();
void registerTestGenericIRVisitorsPass();
void registerTestInterfaces();
void registerTestIntRangeInference();
void registerTestIRVisitorsPass();
void registerTestLastModifiedPass();
void registerTestLinalgDecomposeOps();
@@ -226,9 +228,11 @@ void registerTestPasses() {
mlir::test::registerTestEmulateNarrowTypePass();
mlir::test::registerTestExpandMathPass();
mlir::test::registerTestFooAnalysisPass();
mlir::test::registerTestComposeSubView();
mlir::test::registerTestMultiBuffering();
mlir::test::registerTestIRVisitorsPass();
mlir::test::registerTestGenericIRVisitorsPass();
mlir::test::registerTestInterfaces();
mlir::test::registerTestIntRangeInference();
mlir::test::registerTestIRVisitorsPass();
mlir::test::registerTestLastModifiedPass();
mlir::test::registerTestLinalgDecomposeOps();