Files
clang-p2996/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp
Mogball ab701975e7 [mlir] Swap integer range inference to the new framework
Integer range inference has been swapped to the new framework. The integer value range lattices automatically updates the corresponding constant value on update.

Depends on D127173

Reviewed By: krzysz00, rriddle

Differential Revision: https://reviews.llvm.org/D128866
2022-07-07 20:28:13 -07:00

153 lines
5.8 KiB
C++

//===- UnsignedWhenEquivalent.cpp - Pass to replace signed operations with
// unsigned
// ones when all their arguments and results are statically non-negative --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
using namespace mlir::arith;
using namespace mlir::dataflow;
/// Succeeds when a value is statically non-negative in that it has a lower
/// bound on its value (if it is treated as signed) and that bound is
/// non-negative.
static LogicalResult staticallyNonNegative(DataFlowSolver &solver, Value v) {
auto *result = solver.lookupState<IntegerValueRangeLattice>(v);
if (!result)
return failure();
const ConstantIntRanges &range = result->getValue().getValue();
return success(range.smin().isNonNegative());
}
/// Succeeds if an op can be converted to its unsigned equivalent without
/// changing its semantics. This is the case when none of its openands or
/// results can be below 0 when analyzed from a signed perspective.
static LogicalResult staticallyNonNegative(DataFlowSolver &solver,
Operation *op) {
auto nonNegativePred = [&solver](Value v) -> bool {
return succeeded(staticallyNonNegative(solver, v));
};
return success(llvm::all_of(op->getOperands(), nonNegativePred) &&
llvm::all_of(op->getResults(), nonNegativePred));
}
/// Succeeds when the comparison predicate is a signed operation and all the
/// operands are non-negative, indicating that the cmpi operation `op` can have
/// its predicate changed to an unsigned equivalent.
static LogicalResult isCmpIConvertable(DataFlowSolver &solver, CmpIOp op) {
CmpIPredicate pred = op.getPredicate();
switch (pred) {
case CmpIPredicate::sle:
case CmpIPredicate::slt:
case CmpIPredicate::sge:
case CmpIPredicate::sgt:
return success(llvm::all_of(op.getOperands(), [&solver](Value v) -> bool {
return succeeded(staticallyNonNegative(solver, v));
}));
default:
return failure();
}
}
/// Return the unsigned equivalent of a signed comparison predicate,
/// or the predicate itself if there is none.
static CmpIPredicate toUnsignedPred(CmpIPredicate pred) {
switch (pred) {
case CmpIPredicate::sle:
return CmpIPredicate::ule;
case CmpIPredicate::slt:
return CmpIPredicate::ult;
case CmpIPredicate::sge:
return CmpIPredicate::uge;
case CmpIPredicate::sgt:
return CmpIPredicate::ugt;
default:
return pred;
}
}
namespace {
template <typename Signed, typename Unsigned>
struct ConvertOpToUnsigned : OpConversionPattern<Signed> {
using OpConversionPattern<Signed>::OpConversionPattern;
LogicalResult matchAndRewrite(Signed op, typename Signed::Adaptor adaptor,
ConversionPatternRewriter &rw) const override {
rw.replaceOpWithNewOp<Unsigned>(op, op->getResultTypes(),
adaptor.getOperands(), op->getAttrs());
return success();
}
};
struct ConvertCmpIToUnsigned : OpConversionPattern<CmpIOp> {
using OpConversionPattern<CmpIOp>::OpConversionPattern;
LogicalResult matchAndRewrite(CmpIOp op, CmpIOpAdaptor adaptor,
ConversionPatternRewriter &rw) const override {
rw.replaceOpWithNewOp<CmpIOp>(op, toUnsignedPred(op.getPredicate()),
op.getLhs(), op.getRhs());
return success();
}
};
struct ArithmeticUnsignedWhenEquivalentPass
: public ArithmeticUnsignedWhenEquivalentBase<
ArithmeticUnsignedWhenEquivalentPass> {
/// Implementation structure: first find all equivalent ops and collect them,
/// then perform all the rewrites in a second pass over the target op. This
/// ensures that analysis results are not invalidated during rewriting.
void runOnOperation() override {
Operation *op = getOperation();
MLIRContext *ctx = op->getContext();
DataFlowSolver solver;
solver.load<DeadCodeAnalysis>();
solver.load<IntegerRangeAnalysis>();
if (failed(solver.initializeAndRun(op)))
return signalPassFailure();
ConversionTarget target(*ctx);
target.addLegalDialect<ArithmeticDialect>();
target
.addDynamicallyLegalOp<DivSIOp, CeilDivSIOp, CeilDivUIOp, FloorDivSIOp,
RemSIOp, MinSIOp, MaxSIOp, ExtSIOp>(
[&solver](Operation *op) -> Optional<bool> {
return failed(staticallyNonNegative(solver, op));
});
target.addDynamicallyLegalOp<CmpIOp>(
[&solver](CmpIOp op) -> Optional<bool> {
return failed(isCmpIConvertable(solver, op));
});
RewritePatternSet patterns(ctx);
patterns.add<ConvertOpToUnsigned<DivSIOp, DivUIOp>,
ConvertOpToUnsigned<CeilDivSIOp, CeilDivUIOp>,
ConvertOpToUnsigned<FloorDivSIOp, DivUIOp>,
ConvertOpToUnsigned<RemSIOp, RemUIOp>,
ConvertOpToUnsigned<MinSIOp, MinUIOp>,
ConvertOpToUnsigned<MaxSIOp, MaxUIOp>,
ConvertOpToUnsigned<ExtSIOp, ExtUIOp>, ConvertCmpIToUnsigned>(
ctx);
if (failed(applyPartialConversion(op, target, std::move(patterns)))) {
signalPassFailure();
}
}
};
} // end anonymous namespace
std::unique_ptr<Pass>
mlir::arith::createArithmeticUnsignedWhenEquivalentPass() {
return std::make_unique<ArithmeticUnsignedWhenEquivalentPass>();
}