[mlir][Arith] Pass to switch signed ops for equivalent unsigned ones

If all the arguments to and results of an operation are known to be
non-negative when interpreted as signed (which also implies that all
computations producing those values did not experience signed
overflow), we can replace that operation with an equivalent one that
operates on unsigned values.

Such a replacement, when it is possible, can provide useful hints to
backends, such as by allowing LLVM to replace remainder with bitwise
operations in more cases.

Depends on D124022

Depends on D124023

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D124024
This commit is contained in:
Krzysztof Drewniak
2022-04-14 22:51:23 +00:00
parent 657e954939
commit b0b0043209
5 changed files with 255 additions and 0 deletions

View File

@@ -26,6 +26,10 @@ void populateArithmeticExpandOpsPatterns(RewritePatternSet &patterns);
/// Create a pass to legalize Arithmetic ops for LLVM lowering.
std::unique_ptr<Pass> createArithmeticExpandOpsPass();
/// Create a pass to replace signed ops with unsigned ones where they are proven
/// equivalent.
std::unique_ptr<Pass> createArithmeticUnsignedWhenEquivalentPass();
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//

View File

@@ -33,4 +33,20 @@ def ArithmeticExpandOps : Pass<"arith-expand"> {
let constructor = "mlir::arith::createArithmeticExpandOpsPass()";
}
def ArithmeticUnsignedWhenEquivalent : Pass<"arith-unsigned-when-equivalent"> {
let summary = "Replace signed ops with unsigned ones where they are proven equivalent";
let description = [{
Replace signed ops with their unsigned equivalents when integer range analysis
determines that their arguments and results are all guaranteed to be
non-negative when interpreted as signed integers. When this occurs,
we know that the semantics of the signed and unsigned operations are the same,
since they share the same behavior when their operands and results are in the
range [0, signed_max(type)].
The affect ops include division, remainder, shifts, min, max, and integer
comparisons.
}];
let constructor = "mlir::arith::createArithmeticUnsignedWhenEquivalentPass()";
}
#endif // MLIR_DIALECT_ARITHMETIC_TRANSFORMS_PASSES

View File

@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRArithmeticTransforms
BufferizableOpInterfaceImpl.cpp
Bufferize.cpp
ExpandOps.cpp
UnsignedWhenEquivalent.cpp
ADDITIONAL_HEADER_DIRS
{$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Arithmetic/Transforms
@@ -10,9 +11,11 @@ add_mlir_dialect_library(MLIRArithmeticTransforms
MLIRArithmeticTransformsIncGen
LINK_LIBS PUBLIC
MLIRAnalysis
MLIRArithmeticDialect
MLIRBufferizationDialect
MLIRBufferizationTransforms
MLIRInferIntRangeInterface
MLIRIR
MLIRMemRefDialect
MLIRPass

View File

@@ -0,0 +1,144 @@
//===- UnsignedWhenEquivalent.cpp - Pass to replace signed operations with
// unsigned
// ones when all their arguments and results are statically non-negative --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Analysis/IntRangeAnalysis.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
using namespace mlir;
using namespace mlir::arith;
using OpList = llvm::SmallVector<Operation *>;
/// Returns true when a value is statically non-negative in that it has a lower
/// bound on its value (if it is treated as signed) and that bound is
/// non-negative.
static bool staticallyNonNegative(IntRangeAnalysis &analysis, Value v) {
Optional<ConstantIntRanges> result = analysis.getResult(v);
if (!result.hasValue())
return false;
const ConstantIntRanges &range = result.getValue();
return (range.smin().isNonNegative());
}
/// Identify all operations in a block that have signed equivalents and have
/// operands and results that are statically non-negative.
template <typename... Ts>
static void getConvertableOps(Operation *root, OpList &toRewrite,
IntRangeAnalysis &analysis) {
auto nonNegativePred = [&analysis](Value v) -> bool {
return staticallyNonNegative(analysis, v);
};
root->walk([&nonNegativePred, &toRewrite](Operation *orig) {
if (isa<Ts...>(orig) &&
llvm::all_of(orig->getOperands(), nonNegativePred) &&
llvm::all_of(orig->getResults(), nonNegativePred)) {
toRewrite.push_back(orig);
}
});
}
static CmpIPredicate toUnsignedPred(CmpIPredicate pred) {
switch (pred) {
case CmpIPredicate::sle:
return CmpIPredicate::ule;
case CmpIPredicate::slt:
return CmpIPredicate::ult;
case CmpIPredicate::sge:
return CmpIPredicate::uge;
case CmpIPredicate::sgt:
return CmpIPredicate::ugt;
default:
return pred;
}
}
/// Find all cmpi ops that can be replaced by their unsigned equivalents.
static void getConvertableCmpi(Operation *root, OpList &toRewrite,
IntRangeAnalysis &analysis) {
auto nonNegativePred = [&analysis](Value v) -> bool {
return staticallyNonNegative(analysis, v);
};
root->walk([&nonNegativePred, &toRewrite](arith::CmpIOp orig) {
CmpIPredicate pred = orig.getPredicate();
if (toUnsignedPred(pred) != pred &&
// i1 will spuriously and trivially show up as pontentially negative,
// so don't check the results
llvm::all_of(orig->getOperands(), nonNegativePred)) {
toRewrite.push_back(orig.getOperation());
}
});
}
/// Return ops to be replaced in the order they should be rewritten.
static OpList getMatching(Operation *root, IntRangeAnalysis &analysis) {
OpList ret;
getConvertableOps<DivSIOp, CeilDivSIOp, FloorDivSIOp, RemSIOp, MinSIOp,
MaxSIOp, ExtSIOp>(root, ret, analysis);
// Since these are in-place changes, they don't need to be topological order
// like the others.
getConvertableCmpi(root, ret, analysis);
return ret;
}
template <typename T, typename U>
static void rewriteOp(Operation *op, OpBuilder &b) {
if (isa<T>(op)) {
OpBuilder::InsertionGuard guard(b);
b.setInsertionPoint(op);
Operation *newOp = b.create<U>(op->getLoc(), op->getResultTypes(),
op->getOperands(), op->getAttrs());
op->replaceAllUsesWith(newOp->getResults());
op->erase();
}
}
static void rewriteCmpI(Operation *op, OpBuilder &b) {
if (auto cmpOp = dyn_cast<CmpIOp>(op)) {
cmpOp.setPredicateAttr(CmpIPredicateAttr::get(
b.getContext(), toUnsignedPred(cmpOp.getPredicate())));
}
}
static void rewrite(Operation *root, const OpList &toReplace) {
OpBuilder b(root->getContext());
b.setInsertionPoint(root);
for (Operation *op : toReplace) {
rewriteOp<DivSIOp, DivUIOp>(op, b);
rewriteOp<CeilDivSIOp, CeilDivUIOp>(op, b);
rewriteOp<FloorDivSIOp, DivUIOp>(op, b);
rewriteOp<RemSIOp, RemUIOp>(op, b);
rewriteOp<MinSIOp, MinUIOp>(op, b);
rewriteOp<MaxSIOp, MaxUIOp>(op, b);
rewriteOp<ExtSIOp, ExtUIOp>(op, b);
rewriteCmpI(op, b);
}
}
namespace {
struct ArithmeticUnsignedWhenEquivalentPass
: public ArithmeticUnsignedWhenEquivalentBase<
ArithmeticUnsignedWhenEquivalentPass> {
/// Implementation structure: first find all equivalent ops and collect them,
/// then perform all the rewrites in a second pass over the target op. This
/// ensures that analysis results are not invalidated during rewriting.
void runOnOperation() override {
Operation *op = getOperation();
IntRangeAnalysis analysis(op);
rewrite(op, getMatching(op, analysis));
}
};
} // end anonymous namespace
std::unique_ptr<Pass>
mlir::arith::createArithmeticUnsignedWhenEquivalentPass() {
return std::make_unique<ArithmeticUnsignedWhenEquivalentPass>();
}

View File

@@ -0,0 +1,88 @@
// RUN: mlir-opt -arith-unsigned-when-equivalent %s | FileCheck %s
// CHECK-LABEL func @not_with_maybe_overflow
// CHECK: arith.divsi
// CHECK: arith.ceildivsi
// CHECK: arith.floordivsi
// CHECK: arith.remsi
// CHECK: arith.minsi
// CHECK: arith.maxsi
// CHECK: arith.extsi
// CHECK: arith.cmpi sle
// CHECK: arith.cmpi slt
// CHECK: arith.cmpi sge
// CHECK: arith.cmpi sgt
func.func @not_with_maybe_overflow(%arg0 : i32) {
%ci32_smax = arith.constant 0x7fffffff : i32
%c1 = arith.constant 1 : i32
%c4 = arith.constant 4 : i32
%0 = arith.minui %arg0, %ci32_smax : i32
%1 = arith.addi %0, %c1 : i32
%2 = arith.divsi %1, %c4 : i32
%3 = arith.ceildivsi %1, %c4 : i32
%4 = arith.floordivsi %1, %c4 : i32
%5 = arith.remsi %1, %c4 : i32
%6 = arith.minsi %1, %c4 : i32
%7 = arith.maxsi %1, %c4 : i32
%8 = arith.extsi %1 : i32 to i64
%9 = arith.cmpi sle, %1, %c4 : i32
%10 = arith.cmpi slt, %1, %c4 : i32
%11 = arith.cmpi sge, %1, %c4 : i32
%12 = arith.cmpi sgt, %1, %c4 : i32
func.return
}
// CHECK-LABEL func @yes_with_no_overflow
// CHECK: arith.divui
// CHECK: arith.ceildivui
// CHECK: arith.divui
// CHECK: arith.remui
// CHECK: arith.minui
// CHECK: arith.maxui
// CHECK: arith.extui
// CHECK: arith.cmpi ule
// CHECK: arith.cmpi ult
// CHECK: arith.cmpi uge
// CHECK: arith.cmpi ugt
func.func @yes_with_no_overflow(%arg0 : i32) {
%ci32_almost_smax = arith.constant 0x7ffffffe : i32
%c1 = arith.constant 1 : i32
%c4 = arith.constant 4 : i32
%0 = arith.minui %arg0, %ci32_almost_smax : i32
%1 = arith.addi %0, %c1 : i32
%2 = arith.divsi %1, %c4 : i32
%3 = arith.ceildivsi %1, %c4 : i32
%4 = arith.floordivsi %1, %c4 : i32
%5 = arith.remsi %1, %c4 : i32
%6 = arith.minsi %1, %c4 : i32
%7 = arith.maxsi %1, %c4 : i32
%8 = arith.extsi %1 : i32 to i64
%9 = arith.cmpi sle, %1, %c4 : i32
%10 = arith.cmpi slt, %1, %c4 : i32
%11 = arith.cmpi sge, %1, %c4 : i32
%12 = arith.cmpi sgt, %1, %c4 : i32
func.return
}
// CHECK-LABEL: func @preserves_structure
// CHECK: scf.for %[[arg1:.*]] =
// CHECK: %[[v:.*]] = arith.remui %[[arg1]]
// CHECK: %[[w:.*]] = arith.addi %[[v]], %[[v]]
// CHECK: %[[test:.*]] = arith.cmpi ule, %[[w]]
// CHECK: scf.if %[[test]]
// CHECK: memref.store %[[w]]
func.func @preserves_structure(%arg0 : memref<8xindex>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
%c8 = arith.constant 8 : index
scf.for %arg1 = %c0 to %c8 step %c1 {
%v = arith.remsi %arg1, %c4 : index
%w = arith.addi %v, %v : index
%test = arith.cmpi sle, %w, %c4 : index
scf.if %test {
memref.store %w, %arg0[%arg1] : memref<8xindex>
}
}
func.return
}