[mlir][Arith] Generalize and improve -int-range-optimizations (#94712)
When the integer range analysis was first develop, a pass that did
integer range-based constant folding was developed and used as a test
pass. There was an intent to add such a folding to SCCP, but that hasn't
happened.
Meanwhile, -int-range-optimizations was added to the arith dialect's
transformations. The cmpi simplification in that pass is a strict subset
of the constant folding that lived in
-test-int-range-inference.
This commit moves the former test pass into -int-range-optimizaitons,
subsuming its previous contents. It also adds an optimization from
rocMLIR where `rem{s,u}i` operations that are noops are replaced by
their left operands.
This commit is contained in:
committed by
GitHub
parent
3e39328b62
commit
472291111d
@@ -40,9 +40,14 @@ def ArithIntRangeOpts : Pass<"int-range-optimizations"> {
|
||||
let summary = "Do optimizations based on integer range analysis";
|
||||
let description = [{
|
||||
This pass runs integer range analysis and apllies optimizations based on its
|
||||
results. e.g. replace arith.cmpi with const if it can be inferred from
|
||||
args ranges.
|
||||
results. It replaces operations with known-constant results with said constants,
|
||||
rewrites `(0 <= %x < D) mod D` to `%x`.
|
||||
}];
|
||||
// Explicitly depend on "arith" because this pass could create operations in
|
||||
// `arith` out of thin air in some cases.
|
||||
let dependentDialects = [
|
||||
"::mlir::arith::ArithDialect"
|
||||
];
|
||||
}
|
||||
|
||||
def ArithEmulateUnsupportedFloats : Pass<"arith-emulate-unsupported-floats"> {
|
||||
|
||||
@@ -8,11 +8,17 @@
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "mlir/Analysis/DataFlowFramework.h"
|
||||
#include "mlir/Dialect/Arith/Transforms/Passes.h"
|
||||
|
||||
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
|
||||
#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Utils/StaticValueUtils.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
#include "mlir/Transforms/FoldUtils.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
namespace mlir::arith {
|
||||
@@ -24,88 +30,50 @@ using namespace mlir;
|
||||
using namespace mlir::arith;
|
||||
using namespace mlir::dataflow;
|
||||
|
||||
/// Returns true if 2 integer ranges have intersection.
|
||||
static bool intersects(const ConstantIntRanges &lhs,
|
||||
const ConstantIntRanges &rhs) {
|
||||
return !((lhs.smax().slt(rhs.smin()) || lhs.smin().sgt(rhs.smax())) &&
|
||||
(lhs.umax().ult(rhs.umin()) || lhs.umin().ugt(rhs.umax())));
|
||||
static std::optional<APInt> getMaybeConstantValue(DataFlowSolver &solver,
|
||||
Value value) {
|
||||
auto *maybeInferredRange =
|
||||
solver.lookupState<IntegerValueRangeLattice>(value);
|
||||
if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
|
||||
return std::nullopt;
|
||||
const ConstantIntRanges &inferredRange =
|
||||
maybeInferredRange->getValue().getValue();
|
||||
return inferredRange.getConstantValue();
|
||||
}
|
||||
|
||||
static FailureOr<bool> handleEq(ConstantIntRanges lhs, ConstantIntRanges rhs) {
|
||||
if (!intersects(lhs, rhs))
|
||||
return false;
|
||||
/// Patterned after SCCP
|
||||
static LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver,
|
||||
PatternRewriter &rewriter,
|
||||
Value value) {
|
||||
if (value.use_empty())
|
||||
return failure();
|
||||
std::optional<APInt> maybeConstValue = getMaybeConstantValue(solver, value);
|
||||
if (!maybeConstValue.has_value())
|
||||
return failure();
|
||||
|
||||
return failure();
|
||||
}
|
||||
Operation *maybeDefiningOp = value.getDefiningOp();
|
||||
Dialect *valueDialect =
|
||||
maybeDefiningOp ? maybeDefiningOp->getDialect()
|
||||
: value.getParentRegion()->getParentOp()->getDialect();
|
||||
Attribute constAttr =
|
||||
rewriter.getIntegerAttr(value.getType(), *maybeConstValue);
|
||||
Operation *constOp = valueDialect->materializeConstant(
|
||||
rewriter, constAttr, value.getType(), value.getLoc());
|
||||
// Fall back to arith.constant if the dialect materializer doesn't know what
|
||||
// to do with an integer constant.
|
||||
if (!constOp)
|
||||
constOp = rewriter.getContext()
|
||||
->getLoadedDialect<ArithDialect>()
|
||||
->materializeConstant(rewriter, constAttr, value.getType(),
|
||||
value.getLoc());
|
||||
if (!constOp)
|
||||
return failure();
|
||||
|
||||
static FailureOr<bool> handleNe(ConstantIntRanges lhs, ConstantIntRanges rhs) {
|
||||
if (!intersects(lhs, rhs))
|
||||
return true;
|
||||
|
||||
return failure();
|
||||
}
|
||||
|
||||
static FailureOr<bool> handleSlt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
|
||||
if (lhs.smax().slt(rhs.smin()))
|
||||
return true;
|
||||
|
||||
if (lhs.smin().sge(rhs.smax()))
|
||||
return false;
|
||||
|
||||
return failure();
|
||||
}
|
||||
|
||||
static FailureOr<bool> handleSle(ConstantIntRanges lhs, ConstantIntRanges rhs) {
|
||||
if (lhs.smax().sle(rhs.smin()))
|
||||
return true;
|
||||
|
||||
if (lhs.smin().sgt(rhs.smax()))
|
||||
return false;
|
||||
|
||||
return failure();
|
||||
}
|
||||
|
||||
static FailureOr<bool> handleSgt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
|
||||
return handleSlt(std::move(rhs), std::move(lhs));
|
||||
}
|
||||
|
||||
static FailureOr<bool> handleSge(ConstantIntRanges lhs, ConstantIntRanges rhs) {
|
||||
return handleSle(std::move(rhs), std::move(lhs));
|
||||
}
|
||||
|
||||
static FailureOr<bool> handleUlt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
|
||||
if (lhs.umax().ult(rhs.umin()))
|
||||
return true;
|
||||
|
||||
if (lhs.umin().uge(rhs.umax()))
|
||||
return false;
|
||||
|
||||
return failure();
|
||||
}
|
||||
|
||||
static FailureOr<bool> handleUle(ConstantIntRanges lhs, ConstantIntRanges rhs) {
|
||||
if (lhs.umax().ule(rhs.umin()))
|
||||
return true;
|
||||
|
||||
if (lhs.umin().ugt(rhs.umax()))
|
||||
return false;
|
||||
|
||||
return failure();
|
||||
}
|
||||
|
||||
static FailureOr<bool> handleUgt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
|
||||
return handleUlt(std::move(rhs), std::move(lhs));
|
||||
}
|
||||
|
||||
static FailureOr<bool> handleUge(ConstantIntRanges lhs, ConstantIntRanges rhs) {
|
||||
return handleUle(std::move(rhs), std::move(lhs));
|
||||
rewriter.replaceAllUsesWith(value, constOp->getResult(0));
|
||||
return success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// This class listens on IR transformations performed during a pass relying on
|
||||
/// information from a `DataflowSolver`. It erases state associated with the
|
||||
/// erased operation and its results from the `DataFlowSolver` so that Patterns
|
||||
/// do not accidentally query old state information for newly created Ops.
|
||||
class DataFlowListener : public RewriterBase::Listener {
|
||||
public:
|
||||
DataFlowListener(DataFlowSolver &s) : s(s) {}
|
||||
@@ -120,52 +88,95 @@ protected:
|
||||
DataFlowSolver &s;
|
||||
};
|
||||
|
||||
struct ConvertCmpOp : public OpRewritePattern<arith::CmpIOp> {
|
||||
/// Rewrite any results of `op` that were inferred to be constant integers to
|
||||
/// and replace their uses with that constant. Return success() if all results
|
||||
/// where thus replaced and the operation is erased. Also replace any block
|
||||
/// arguments with their constant values.
|
||||
struct MaterializeKnownConstantValues : public RewritePattern {
|
||||
MaterializeKnownConstantValues(MLIRContext *context, DataFlowSolver &s)
|
||||
: RewritePattern(Pattern::MatchAnyOpTypeTag(), /*benefit=*/1, context),
|
||||
solver(s) {}
|
||||
|
||||
ConvertCmpOp(MLIRContext *context, DataFlowSolver &s)
|
||||
: OpRewritePattern<arith::CmpIOp>(context), solver(s) {}
|
||||
LogicalResult match(Operation *op) const override {
|
||||
if (matchPattern(op, m_Constant()))
|
||||
return failure();
|
||||
|
||||
LogicalResult matchAndRewrite(arith::CmpIOp op,
|
||||
auto needsReplacing = [&](Value v) {
|
||||
return getMaybeConstantValue(solver, v).has_value() && !v.use_empty();
|
||||
};
|
||||
bool hasConstantResults = llvm::any_of(op->getResults(), needsReplacing);
|
||||
if (op->getNumRegions() == 0)
|
||||
return success(hasConstantResults);
|
||||
bool hasConstantRegionArgs = false;
|
||||
for (Region ®ion : op->getRegions()) {
|
||||
for (Block &block : region.getBlocks()) {
|
||||
hasConstantRegionArgs |=
|
||||
llvm::any_of(block.getArguments(), needsReplacing);
|
||||
}
|
||||
}
|
||||
return success(hasConstantResults || hasConstantRegionArgs);
|
||||
}
|
||||
|
||||
void rewrite(Operation *op, PatternRewriter &rewriter) const override {
|
||||
bool replacedAll = (op->getNumResults() != 0);
|
||||
for (Value v : op->getResults())
|
||||
replacedAll &=
|
||||
(succeeded(maybeReplaceWithConstant(solver, rewriter, v)) ||
|
||||
v.use_empty());
|
||||
if (replacedAll && isOpTriviallyDead(op)) {
|
||||
rewriter.eraseOp(op);
|
||||
return;
|
||||
}
|
||||
|
||||
PatternRewriter::InsertionGuard guard(rewriter);
|
||||
for (Region ®ion : op->getRegions()) {
|
||||
for (Block &block : region.getBlocks()) {
|
||||
rewriter.setInsertionPointToStart(&block);
|
||||
for (BlockArgument &arg : block.getArguments()) {
|
||||
(void)maybeReplaceWithConstant(solver, rewriter, arg);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
DataFlowSolver &solver;
|
||||
};
|
||||
|
||||
template <typename RemOp>
|
||||
struct DeleteTrivialRem : public OpRewritePattern<RemOp> {
|
||||
DeleteTrivialRem(MLIRContext *context, DataFlowSolver &s)
|
||||
: OpRewritePattern<RemOp>(context), solver(s) {}
|
||||
|
||||
LogicalResult matchAndRewrite(RemOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto *lhsResult =
|
||||
solver.lookupState<dataflow::IntegerValueRangeLattice>(op.getLhs());
|
||||
if (!lhsResult || lhsResult->getValue().isUninitialized())
|
||||
Value lhs = op.getOperand(0);
|
||||
Value rhs = op.getOperand(1);
|
||||
auto maybeModulus = getConstantIntValue(rhs);
|
||||
if (!maybeModulus.has_value())
|
||||
return failure();
|
||||
int64_t modulus = *maybeModulus;
|
||||
if (modulus <= 0)
|
||||
return failure();
|
||||
auto *maybeLhsRange = solver.lookupState<IntegerValueRangeLattice>(lhs);
|
||||
if (!maybeLhsRange || maybeLhsRange->getValue().isUninitialized())
|
||||
return failure();
|
||||
const ConstantIntRanges &lhsRange = maybeLhsRange->getValue().getValue();
|
||||
const APInt &min = isa<RemUIOp>(op) ? lhsRange.umin() : lhsRange.smin();
|
||||
const APInt &max = isa<RemUIOp>(op) ? lhsRange.umax() : lhsRange.smax();
|
||||
// The minima and maxima here are given as closed ranges, we must be
|
||||
// strictly less than the modulus.
|
||||
if (min.isNegative() || min.uge(modulus))
|
||||
return failure();
|
||||
if (max.isNegative() || max.uge(modulus))
|
||||
return failure();
|
||||
if (!min.ule(max))
|
||||
return failure();
|
||||
|
||||
auto *rhsResult =
|
||||
solver.lookupState<dataflow::IntegerValueRangeLattice>(op.getRhs());
|
||||
if (!rhsResult || rhsResult->getValue().isUninitialized())
|
||||
return failure();
|
||||
|
||||
using HandlerFunc =
|
||||
FailureOr<bool> (*)(ConstantIntRanges, ConstantIntRanges);
|
||||
std::array<HandlerFunc, arith::getMaxEnumValForCmpIPredicate() + 1>
|
||||
handlers{};
|
||||
using Pred = arith::CmpIPredicate;
|
||||
handlers[static_cast<size_t>(Pred::eq)] = &handleEq;
|
||||
handlers[static_cast<size_t>(Pred::ne)] = &handleNe;
|
||||
handlers[static_cast<size_t>(Pred::slt)] = &handleSlt;
|
||||
handlers[static_cast<size_t>(Pred::sle)] = &handleSle;
|
||||
handlers[static_cast<size_t>(Pred::sgt)] = &handleSgt;
|
||||
handlers[static_cast<size_t>(Pred::sge)] = &handleSge;
|
||||
handlers[static_cast<size_t>(Pred::ult)] = &handleUlt;
|
||||
handlers[static_cast<size_t>(Pred::ule)] = &handleUle;
|
||||
handlers[static_cast<size_t>(Pred::ugt)] = &handleUgt;
|
||||
handlers[static_cast<size_t>(Pred::uge)] = &handleUge;
|
||||
|
||||
HandlerFunc handler = handlers[static_cast<size_t>(op.getPredicate())];
|
||||
if (!handler)
|
||||
return failure();
|
||||
|
||||
ConstantIntRanges lhsValue = lhsResult->getValue().getValue();
|
||||
ConstantIntRanges rhsValue = rhsResult->getValue().getValue();
|
||||
FailureOr<bool> result = handler(lhsValue, rhsValue);
|
||||
|
||||
if (failed(result))
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(
|
||||
op, static_cast<int64_t>(*result), /*width*/ 1);
|
||||
// With all those conditions out of the way, we know thas this invocation of
|
||||
// a remainder is a noop because the input is strictly within the range
|
||||
// [0, modulus), so get rid of it.
|
||||
rewriter.replaceOp(op, ValueRange{lhs});
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -201,7 +212,8 @@ struct IntRangeOptimizationsPass
|
||||
|
||||
void mlir::arith::populateIntRangeOptimizationsPatterns(
|
||||
RewritePatternSet &patterns, DataFlowSolver &solver) {
|
||||
patterns.add<ConvertCmpOp>(patterns.getContext(), solver);
|
||||
patterns.add<MaterializeKnownConstantValues, DeleteTrivialRem<RemSIOp>,
|
||||
DeleteTrivialRem<RemUIOp>>(patterns.getContext(), solver);
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> mlir::arith::createIntRangeOptimizationsPass() {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: mlir-opt -test-int-range-inference -canonicalize %s | FileCheck %s
|
||||
// RUN: mlir-opt -int-range-optimizations -canonicalize %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @add_min_max
|
||||
// CHECK: %[[c3:.*]] = arith.constant 3 : index
|
||||
|
||||
@@ -96,3 +96,39 @@ func.func @test() -> i8 {
|
||||
return %1: i8
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @trivial_rem
|
||||
// CHECK: [[val:%.+]] = test.with_bounds
|
||||
// CHECK: return [[val]]
|
||||
func.func @trivial_rem() -> i8 {
|
||||
%c64 = arith.constant 64 : i8
|
||||
%val = test.with_bounds { umin = 0 : ui8, umax = 63 : ui8, smin = 0 : si8, smax = 63 : si8 } : i8
|
||||
%mod = arith.remsi %val, %c64 : i8
|
||||
return %mod : i8
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @non_const_rhs
|
||||
// CHECK: [[mod:%.+]] = arith.remui
|
||||
// CHECK: return [[mod]]
|
||||
func.func @non_const_rhs() -> i8 {
|
||||
%c64 = arith.constant 64 : i8
|
||||
%val = test.with_bounds { umin = 0 : ui8, umax = 2 : ui8, smin = 0 : si8, smax = 2 : si8 } : i8
|
||||
%rhs = test.with_bounds { umin = 63 : ui8, umax = 64 : ui8, smin = 63 : si8, smax = 64 : si8 } : i8
|
||||
%mod = arith.remui %val, %rhs : i8
|
||||
return %mod : i8
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @wraps
|
||||
// CHECK: [[mod:%.+]] = arith.remsi
|
||||
// CHECK: return [[mod]]
|
||||
func.func @wraps() -> i8 {
|
||||
%c64 = arith.constant 64 : i8
|
||||
%val = test.with_bounds { umin = 63 : ui8, umax = 65 : ui8, smin = 63 : si8, smax = 65 : si8 } : i8
|
||||
%mod = arith.remsi %val, %c64 : i8
|
||||
return %mod : i8
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: mlir-opt -test-int-range-inference -split-input-file %s | FileCheck %s
|
||||
// RUN: mlir-opt -int-range-optimizations -split-input-file %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @launch_func
|
||||
func.func @launch_func(%arg0 : index) {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: mlir-opt -test-int-range-inference -canonicalize %s | FileCheck %s
|
||||
// RUN: mlir-opt -int-range-optimizations -canonicalize %s | FileCheck %s
|
||||
|
||||
// Most operations are covered by the `arith` tests, which use the same code
|
||||
// Here, we add a few tests to ensure the "index can be 32- or 64-bit" handling
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: mlir-opt -test-int-range-inference %s | FileCheck %s
|
||||
// RUN: mlir-opt -int-range-optimizations %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @constant
|
||||
// CHECK: %[[cst:.*]] = "test.constant"() <{value = 3 : index}
|
||||
@@ -103,13 +103,11 @@ func.func @func_args_unbound(%arg0 : index) -> index {
|
||||
|
||||
// CHECK-LABEL: func @propagate_across_while_loop_false()
|
||||
func.func @propagate_across_while_loop_false() -> index {
|
||||
// CHECK-DAG: %[[C0:.*]] = "test.constant"() <{value = 0
|
||||
// CHECK-DAG: %[[C1:.*]] = "test.constant"() <{value = 1
|
||||
// CHECK: %[[C1:.*]] = "test.constant"() <{value = 1
|
||||
%0 = test.with_bounds { umin = 0 : index, umax = 0 : index,
|
||||
smin = 0 : index, smax = 0 : index } : index
|
||||
%1 = scf.while : () -> index {
|
||||
%false = arith.constant false
|
||||
// CHECK: scf.condition(%{{.*}}) %[[C0]]
|
||||
scf.condition(%false) %0 : index
|
||||
} do {
|
||||
^bb0(%i1: index):
|
||||
@@ -122,12 +120,10 @@ func.func @propagate_across_while_loop_false() -> index {
|
||||
|
||||
// CHECK-LABEL: func @propagate_across_while_loop
|
||||
func.func @propagate_across_while_loop(%arg0 : i1) -> index {
|
||||
// CHECK-DAG: %[[C0:.*]] = "test.constant"() <{value = 0
|
||||
// CHECK-DAG: %[[C1:.*]] = "test.constant"() <{value = 1
|
||||
// CHECK: %[[C1:.*]] = "test.constant"() <{value = 1
|
||||
%0 = test.with_bounds { umin = 0 : index, umax = 0 : index,
|
||||
smin = 0 : index, smax = 0 : index } : index
|
||||
%1 = scf.while : () -> index {
|
||||
// CHECK: scf.condition(%{{.*}}) %[[C0]]
|
||||
scf.condition(%arg0) %0 : index
|
||||
} do {
|
||||
^bb0(%i1: index):
|
||||
|
||||
@@ -24,7 +24,6 @@ add_mlir_library(MLIRTestTransforms
|
||||
TestConstantFold.cpp
|
||||
TestControlFlowSink.cpp
|
||||
TestInlining.cpp
|
||||
TestIntRangeInference.cpp
|
||||
TestMakeIsolatedFromAbove.cpp
|
||||
${MLIRTestTransformsPDLSrc}
|
||||
|
||||
|
||||
@@ -1,125 +0,0 @@
|
||||
//===- TestIntRangeInference.cpp - Create consts from range inference ---===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TODO: This pass is needed to test integer range inference until that
|
||||
// functionality has been integrated into SCCP.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
|
||||
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
|
||||
#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassRegistry.h"
|
||||
#include "mlir/Support/TypeID.h"
|
||||
#include "mlir/Transforms/FoldUtils.h"
|
||||
#include <optional>
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::dataflow;
|
||||
|
||||
/// Patterned after SCCP
|
||||
static LogicalResult replaceWithConstant(DataFlowSolver &solver, OpBuilder &b,
|
||||
OperationFolder &folder, Value value) {
|
||||
auto *maybeInferredRange =
|
||||
solver.lookupState<IntegerValueRangeLattice>(value);
|
||||
if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
|
||||
return failure();
|
||||
const ConstantIntRanges &inferredRange =
|
||||
maybeInferredRange->getValue().getValue();
|
||||
std::optional<APInt> maybeConstValue = inferredRange.getConstantValue();
|
||||
if (!maybeConstValue.has_value())
|
||||
return failure();
|
||||
|
||||
Operation *maybeDefiningOp = value.getDefiningOp();
|
||||
Dialect *valueDialect =
|
||||
maybeDefiningOp ? maybeDefiningOp->getDialect()
|
||||
: value.getParentRegion()->getParentOp()->getDialect();
|
||||
Attribute constAttr = b.getIntegerAttr(value.getType(), *maybeConstValue);
|
||||
Value constant = folder.getOrCreateConstant(
|
||||
b.getInsertionBlock(), valueDialect, constAttr, value.getType());
|
||||
if (!constant)
|
||||
return failure();
|
||||
|
||||
value.replaceAllUsesWith(constant);
|
||||
return success();
|
||||
}
|
||||
|
||||
static void rewrite(DataFlowSolver &solver, MLIRContext *context,
|
||||
MutableArrayRef<Region> initialRegions) {
|
||||
SmallVector<Block *> worklist;
|
||||
auto addToWorklist = [&](MutableArrayRef<Region> regions) {
|
||||
for (Region ®ion : regions)
|
||||
for (Block &block : llvm::reverse(region))
|
||||
worklist.push_back(&block);
|
||||
};
|
||||
|
||||
OpBuilder builder(context);
|
||||
OperationFolder folder(context);
|
||||
|
||||
addToWorklist(initialRegions);
|
||||
while (!worklist.empty()) {
|
||||
Block *block = worklist.pop_back_val();
|
||||
|
||||
for (Operation &op : llvm::make_early_inc_range(*block)) {
|
||||
builder.setInsertionPoint(&op);
|
||||
|
||||
// Replace any result with constants.
|
||||
bool replacedAll = op.getNumResults() != 0;
|
||||
for (Value res : op.getResults())
|
||||
replacedAll &=
|
||||
succeeded(replaceWithConstant(solver, builder, folder, res));
|
||||
|
||||
// If all of the results of the operation were replaced, try to erase
|
||||
// the operation completely.
|
||||
if (replacedAll && wouldOpBeTriviallyDead(&op)) {
|
||||
assert(op.use_empty() && "expected all uses to be replaced");
|
||||
op.erase();
|
||||
continue;
|
||||
}
|
||||
|
||||
// Add any the regions of this operation to the worklist.
|
||||
addToWorklist(op.getRegions());
|
||||
}
|
||||
|
||||
// Replace any block arguments with constants.
|
||||
builder.setInsertionPointToStart(block);
|
||||
for (BlockArgument arg : block->getArguments())
|
||||
(void)replaceWithConstant(solver, builder, folder, arg);
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
struct TestIntRangeInference
|
||||
: PassWrapper<TestIntRangeInference, OperationPass<>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestIntRangeInference)
|
||||
|
||||
StringRef getArgument() const final { return "test-int-range-inference"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Test integer range inference analysis";
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
Operation *op = getOperation();
|
||||
DataFlowSolver solver;
|
||||
solver.load<DeadCodeAnalysis>();
|
||||
solver.load<SparseConstantPropagation>();
|
||||
solver.load<IntegerRangeAnalysis>();
|
||||
if (failed(solver.initializeAndRun(op)))
|
||||
return signalPassFailure();
|
||||
rewrite(solver, op->getContext(), op->getRegions());
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
namespace mlir {
|
||||
namespace test {
|
||||
void registerTestIntRangeInference() {
|
||||
PassRegistration<TestIntRangeInference>();
|
||||
}
|
||||
} // end namespace test
|
||||
} // end namespace mlir
|
||||
@@ -97,9 +97,11 @@ void registerTestDynamicPipelinePass();
|
||||
void registerTestEmulateNarrowTypePass();
|
||||
void registerTestExpandMathPass();
|
||||
void registerTestFooAnalysisPass();
|
||||
void registerTestComposeSubView();
|
||||
void registerTestMultiBuffering();
|
||||
void registerTestIRVisitorsPass();
|
||||
void registerTestGenericIRVisitorsPass();
|
||||
void registerTestInterfaces();
|
||||
void registerTestIntRangeInference();
|
||||
void registerTestIRVisitorsPass();
|
||||
void registerTestLastModifiedPass();
|
||||
void registerTestLinalgDecomposeOps();
|
||||
@@ -226,9 +228,11 @@ void registerTestPasses() {
|
||||
mlir::test::registerTestEmulateNarrowTypePass();
|
||||
mlir::test::registerTestExpandMathPass();
|
||||
mlir::test::registerTestFooAnalysisPass();
|
||||
mlir::test::registerTestComposeSubView();
|
||||
mlir::test::registerTestMultiBuffering();
|
||||
mlir::test::registerTestIRVisitorsPass();
|
||||
mlir::test::registerTestGenericIRVisitorsPass();
|
||||
mlir::test::registerTestInterfaces();
|
||||
mlir::test::registerTestIntRangeInference();
|
||||
mlir::test::registerTestIRVisitorsPass();
|
||||
mlir::test::registerTestLastModifiedPass();
|
||||
mlir::test::registerTestLinalgDecomposeOps();
|
||||
|
||||
Reference in New Issue
Block a user