[BPF] Undo transformation for LICM.cpp:hoistMinMax()

Extended BPFCheckAndAdjustIR pass with sinkMinMax() transformation
that undoes LICM hoistMinMax pass.

The undo transformation converts the following patterns:

    x < min(a, b) -> x < a && x < b
    x > min(a, b) -> x > a || x > b
    x < max(a, b) -> x < a || x < b
    x > max(a, b) -> x > a && x > b

Where 'a' or 'b' is a constant.
Also supports `sext min(...) ...` and `zext min(...) ...`.

Differential Revision: https://reviews.llvm.org/D147990
This commit is contained in:
Eduard Zingerman
2023-04-11 06:31:27 +03:00
parent 99b3b8e5b9
commit 09feee559a
2 changed files with 461 additions and 0 deletions

View File

@@ -18,8 +18,10 @@
#include "BPF.h"
#include "BPFCORE.h"
#include "BPFTargetMachine.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/IR/DebugInfoMetadata.h"
#include "llvm/IR/GlobalVariable.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Module.h"
@@ -41,12 +43,14 @@ class BPFCheckAndAdjustIR final : public ModulePass {
public:
static char ID;
BPFCheckAndAdjustIR() : ModulePass(ID) {}
virtual void getAnalysisUsage(AnalysisUsage &AU) const override;
private:
void checkIR(Module &M);
bool adjustIR(Module &M);
bool removePassThroughBuiltin(Module &M);
bool removeCompareBuiltin(Module &M);
bool sinkMinMax(Module &M);
};
} // End anonymous namespace
@@ -161,9 +165,208 @@ bool BPFCheckAndAdjustIR::removeCompareBuiltin(Module &M) {
return Changed;
}
struct MinMaxSinkInfo {
ICmpInst *ICmp;
Value *Other;
ICmpInst::Predicate Predicate;
CallInst *MinMax;
ZExtInst *ZExt;
SExtInst *SExt;
MinMaxSinkInfo(ICmpInst *ICmp, Value *Other, ICmpInst::Predicate Predicate)
: ICmp(ICmp), Other(Other), Predicate(Predicate), MinMax(nullptr),
ZExt(nullptr), SExt(nullptr) {}
};
static bool sinkMinMaxInBB(BasicBlock &BB,
const std::function<bool(Instruction *)> &Filter) {
// Check if V is:
// (fn %a %b) or (ext (fn %a %b))
// Where:
// ext := sext | zext
// fn := smin | umin | smax | umax
auto IsMinMaxCall = [=](Value *V, MinMaxSinkInfo &Info) {
if (auto *ZExt = dyn_cast<ZExtInst>(V)) {
V = ZExt->getOperand(0);
Info.ZExt = ZExt;
} else if (auto *SExt = dyn_cast<SExtInst>(V)) {
V = SExt->getOperand(0);
Info.SExt = SExt;
}
auto *Call = dyn_cast<CallInst>(V);
if (!Call)
return false;
auto *Called = dyn_cast<Function>(Call->getCalledOperand());
if (!Called)
return false;
switch (Called->getIntrinsicID()) {
case Intrinsic::smin:
case Intrinsic::umin:
case Intrinsic::smax:
case Intrinsic::umax:
break;
default:
return false;
}
if (!Filter(Call))
return false;
Info.MinMax = Call;
return true;
};
auto ZeroOrSignExtend = [](IRBuilder<> &Builder, Value *V,
MinMaxSinkInfo &Info) {
if (Info.SExt) {
if (Info.SExt->getType() == V->getType())
return V;
return Builder.CreateSExt(V, Info.SExt->getType());
}
if (Info.ZExt) {
if (Info.ZExt->getType() == V->getType())
return V;
return Builder.CreateZExt(V, Info.ZExt->getType());
}
return V;
};
bool Changed = false;
SmallVector<MinMaxSinkInfo, 2> SinkList;
// Check BB for instructions like:
// insn := (icmp %a (fn ...)) | (icmp (fn ...) %a)
//
// Where:
// fn := min | max | (sext (min ...)) | (sext (max ...))
//
// Put such instructions to SinkList.
for (Instruction &I : BB) {
ICmpInst *ICmp = dyn_cast<ICmpInst>(&I);
if (!ICmp)
continue;
if (!ICmp->isRelational())
continue;
MinMaxSinkInfo First(ICmp, ICmp->getOperand(1),
ICmpInst::getSwappedPredicate(ICmp->getPredicate()));
MinMaxSinkInfo Second(ICmp, ICmp->getOperand(0), ICmp->getPredicate());
bool FirstMinMax = IsMinMaxCall(ICmp->getOperand(0), First);
bool SecondMinMax = IsMinMaxCall(ICmp->getOperand(1), Second);
if (!(FirstMinMax ^ SecondMinMax))
continue;
SinkList.push_back(FirstMinMax ? First : Second);
}
// Iterate SinkList and replace each (icmp ...) with corresponding
// `x < a && x < b` or similar expression.
for (auto &Info : SinkList) {
ICmpInst *ICmp = Info.ICmp;
CallInst *MinMax = Info.MinMax;
Intrinsic::ID IID = MinMax->getCalledFunction()->getIntrinsicID();
ICmpInst::Predicate P = Info.Predicate;
if (ICmpInst::isSigned(P) && IID != Intrinsic::smin &&
IID != Intrinsic::smax)
continue;
IRBuilder<> Builder(ICmp);
Value *X = Info.Other;
Value *A = ZeroOrSignExtend(Builder, MinMax->getArgOperand(0), Info);
Value *B = ZeroOrSignExtend(Builder, MinMax->getArgOperand(1), Info);
bool IsMin = IID == Intrinsic::smin || IID == Intrinsic::umin;
bool IsMax = IID == Intrinsic::smax || IID == Intrinsic::umax;
bool IsLess = ICmpInst::isLE(P) || ICmpInst::isLT(P);
bool IsGreater = ICmpInst::isGE(P) || ICmpInst::isGT(P);
assert(IsMin ^ IsMax);
assert(IsLess ^ IsGreater);
Value *Replacement;
Value *LHS = Builder.CreateICmp(P, X, A);
Value *RHS = Builder.CreateICmp(P, X, B);
if ((IsLess && IsMin) || (IsGreater && IsMax))
// x < min(a, b) -> x < a && x < b
// x > max(a, b) -> x > a && x > b
Replacement = Builder.CreateLogicalAnd(LHS, RHS);
else
// x > min(a, b) -> x > a || x > b
// x < max(a, b) -> x < a || x < b
Replacement = Builder.CreateLogicalOr(LHS, RHS);
ICmp->replaceAllUsesWith(Replacement);
Instruction *ToRemove[] = {ICmp, Info.ZExt, Info.SExt, MinMax};
for (Instruction *I : ToRemove)
if (I && I->use_empty()) {
I->dropAllReferences();
I->removeFromParent();
}
Changed = true;
}
return Changed;
}
// Do the following transformation:
//
// x < min(a, b) -> x < a && x < b
// x > min(a, b) -> x > a || x > b
// x < max(a, b) -> x < a || x < b
// x > max(a, b) -> x > a && x > b
//
// Such patterns are introduced by LICM.cpp:hoistMinMax()
// transformation and might lead to BPF verification failures for
// older kernels.
//
// To minimize "collateral" changes only do it for icmp + min/max
// calls when icmp is inside a loop and min/max is outside of that
// loop.
//
// Verification failure happens when:
// - RHS operand of some `icmp LHS, RHS` is replaced by some RHS1;
// - verifier can recognize RHS as a constant scalar in some context;
// - verifier can't recognize RHS1 as a constant scalar in the same
// context;
//
// The "constant scalar" is not a compile time constant, but a register
// that holds a scalar value known to verifier at some point in time
// during abstract interpretation.
//
// See also:
// https://lore.kernel.org/bpf/20230406164505.1046801-1-yhs@fb.com/
bool BPFCheckAndAdjustIR::sinkMinMax(Module &M) {
bool Changed = false;
for (Function &F : M) {
if (F.isDeclaration())
continue;
LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>(F).getLoopInfo();
for (Loop *L : LI)
for (BasicBlock *BB : L->blocks()) {
// Filter out instructions coming from the same loop
Loop *BBLoop = LI.getLoopFor(BB);
auto OtherLoopFilter = [&](Instruction *I) {
return LI.getLoopFor(I->getParent()) != BBLoop;
};
Changed |= sinkMinMaxInBB(*BB, OtherLoopFilter);
}
}
return Changed;
}
void BPFCheckAndAdjustIR::getAnalysisUsage(AnalysisUsage &AU) const {
AU.addRequired<LoopInfoWrapperPass>();
}
bool BPFCheckAndAdjustIR::adjustIR(Module &M) {
bool Changed = removePassThroughBuiltin(M);
Changed = removeCompareBuiltin(M) || Changed;
Changed = sinkMinMax(M) || Changed;
return Changed;
}

View File

@@ -0,0 +1,258 @@
; RUN: opt --bpf-check-and-opt-ir -S -mtriple=bpf-pc-linux %s | FileCheck %s
; Test plan:
; @test1: x < umin(i64 a, i64 b)
; @test2: x < umax(i64 a, i64 b)
; @test3: x >= umin(i64 a, i64 b)
; @test4: x >= umax(i64 a, i64 b)
; @test5: umin(i64 a, i64 b) >= x
; @test6: x < smin(i64 a, i64 b)
; @test7: x < umin(i32 a, i32 b)
; @test8: x < zext i64 umin(i32 a, i32 b)
; @test9: x < sext i64 umin(i32 a, i32 b)
; @test10: check that umin belonging to the same loop is not touched
; @test11: check that nested loops are processed
define i32 @test1(i64 %a, i64 %b, i64 %x) {
entry:
%min = tail call i64 @llvm.umin.i64(i64 %a, i64 %b)
br label %loop
loop:
%cmp = icmp ult i64 %x, %min
br i1 %cmp, label %loop, label %ret
ret: ret i32 0
}
; CHECK: @test1
; CHECK-NEXT: entry:
; CHECK-NEXT: br label %loop
; CHECK-EMPTY:
; CHECK-NEXT: loop:
; CHECK-NEXT: %0 = icmp ult i64 %x, %a
; CHECK-NEXT: %1 = icmp ult i64 %x, %b
; CHECK-NEXT: %2 = select i1 %0, i1 %1, i1 false
; CHECK-NEXT: br i1 %2, label %loop, label %ret
define i32 @test2(i64 %a, i64 %b, i64 %x) {
entry:
%max = tail call i64 @llvm.umax.i64(i64 %a, i64 %b)
br label %loop
loop:
%cmp = icmp ult i64 %x, %max
br i1 %cmp, label %loop, label %ret
ret: ret i32 0
}
; CHECK: @test2
; CHECK-NEXT: entry:
; CHECK-NEXT: br label %loop
; CHECK-EMPTY:
; CHECK-NEXT: loop:
; CHECK-NEXT: %0 = icmp ult i64 %x, %a
; CHECK-NEXT: %1 = icmp ult i64 %x, %b
; CHECK-NEXT: %2 = select i1 %0, i1 true, i1 %1
; CHECK-NEXT: br i1 %2, label %loop, label %ret
define i32 @test3(i64 %a, i64 %b, i64 %x) {
entry:
%min = tail call i64 @llvm.umin.i64(i64 %a, i64 %b)
br label %loop
loop:
%cmp = icmp uge i64 %x, %min
br i1 %cmp, label %loop, label %ret
ret: ret i32 0
}
; CHECK: @test3
; CHECK-NEXT: entry:
; CHECK-NEXT: br label %loop
; CHECK-EMPTY:
; CHECK-NEXT: loop:
; CHECK-NEXT: %0 = icmp uge i64 %x, %a
; CHECK-NEXT: %1 = icmp uge i64 %x, %b
; CHECK-NEXT: %2 = select i1 %0, i1 true, i1 %1
; CHECK-NEXT: br i1 %2, label %loop, label %ret
define i32 @test4(i64 %a, i64 %b, i64 %x) {
entry:
%max = tail call i64 @llvm.umax.i64(i64 %a, i64 %b)
br label %loop
loop:
%cmp = icmp uge i64 %x, %max
br i1 %cmp, label %loop, label %ret
ret: ret i32 0
}
; CHECK: @test4
; CHECK-NEXT: entry:
; CHECK-NEXT: br label %loop
; CHECK-EMPTY:
; CHECK-NEXT: loop:
; CHECK-NEXT: %0 = icmp uge i64 %x, %a
; CHECK-NEXT: %1 = icmp uge i64 %x, %b
; CHECK-NEXT: %2 = select i1 %0, i1 %1, i1 false
; CHECK-NEXT: br i1 %2, label %loop, label %ret
define i32 @test5(i64 %a, i64 %b, i64 %x) {
entry:
%min = tail call i64 @llvm.umin.i64(i64 %a, i64 %b)
br label %loop
loop:
%cmp = icmp uge i64 %min, %x
br i1 %cmp, label %loop, label %ret
ret: ret i32 0
}
; CHECK: @test5
; CHECK-NEXT: entry:
; CHECK-NEXT: br label %loop
; CHECK-EMPTY:
; CHECK-NEXT: loop:
; CHECK: %0 = icmp ule i64 %x, %a
; CHECK-NEXT: %1 = icmp ule i64 %x, %b
; CHECK-NEXT: %2 = select i1 %0, i1 %1, i1 false
; CHECK-NEXT: br i1 %2, label %loop, label %ret
define i32 @test6(i64 %a, i64 %b, i64 %x) {
entry:
%min = tail call i64 @llvm.smin.i64(i64 %a, i64 %b)
br label %loop
loop:
%cmp = icmp slt i64 %x, %min
br i1 %cmp, label %loop, label %ret
ret: ret i32 0
}
; CHECK: @test6
; CHECK-NEXT: entry:
; CHECK-NEXT: br label %loop
; CHECK-EMPTY:
; CHECK-NEXT: loop:
; CHECK: %0 = icmp slt i64 %x, %a
; CHECK-NEXT: %1 = icmp slt i64 %x, %b
; CHECK-NEXT: %2 = select i1 %0, i1 %1, i1 false
; CHECK-NEXT: br i1 %2, label %loop, label %ret
define i32 @test7(i32 %a, i32 %b, i32 %x) {
entry:
%min = tail call i32 @llvm.umin.i32(i32 %a, i32 %b)
br label %loop
loop:
%cmp = icmp ult i32 %x, %min
br i1 %cmp, label %loop, label %ret
ret: ret i32 0
}
; CHECK: @test7
; CHECK-NEXT: entry:
; CHECK-NEXT: br label %loop
; CHECK-EMPTY:
; CHECK-NEXT: loop:
; CHECK: %0 = icmp ult i32 %x, %a
; CHECK-NEXT: %1 = icmp ult i32 %x, %b
; CHECK-NEXT: %2 = select i1 %0, i1 %1, i1 false
; CHECK-NEXT: br i1 %2, label %loop, label %ret
define i32 @test8(i32 %a, i32 %b, i64 %x) {
entry:
%min = tail call i32 @llvm.umin.i32(i32 %a, i32 %b)
br label %loop
loop:
%ext = zext i32 %min to i64
%cmp = icmp ult i64 %x, %ext
br i1 %cmp, label %loop, label %ret
ret: ret i32 0
}
; CHECK: @test8
; CHECK-NEXT: entry:
; CHECK-NEXT: br label %loop
; CHECK-EMPTY:
; CHECK-NEXT: loop:
; CHECK-NEXT: %0 = zext i32 %a to i64
; CHECK-NEXT: %1 = zext i32 %b to i64
; CHECK-NEXT: %2 = icmp ult i64 %x, %0
; CHECK-NEXT: %3 = icmp ult i64 %x, %1
; CHECK-NEXT: %4 = select i1 %2, i1 %3, i1 false
; CHECK-NEXT: br i1 %4, label %loop, label %ret
define i32 @test9(i32 %a, i32 %b, i64 %x) {
entry:
%min = tail call i32 @llvm.umin.i32(i32 %a, i32 %b)
br label %loop
loop:
%ext = sext i32 %min to i64
%cmp = icmp ult i64 %x, %ext
br i1 %cmp, label %loop, label %ret
ret: ret i32 0
}
; CHECK: @test9
; CHECK-NEXT: entry:
; CHECK-NEXT: br label %loop
; CHECK-EMPTY:
; CHECK-NEXT: loop:
; CHECK-NEXT: %0 = sext i32 %a to i64
; CHECK-NEXT: %1 = sext i32 %b to i64
; CHECK-NEXT: %2 = icmp ult i64 %x, %0
; CHECK-NEXT: %3 = icmp ult i64 %x, %1
; CHECK-NEXT: %4 = select i1 %2, i1 %3, i1 false
; CHECK-NEXT: br i1 %4, label %loop, label %ret
; umin within the loop body is unchanged
define i32 @test10(i64 %a, i64 %b, i64 %x) {
entry:
br label %loop
loop:
%min = tail call i64 @llvm.umin.i64(i64 %a, i64 %b)
%cmp = icmp ult i64 %x, %min
br i1 %cmp, label %loop, label %ret
ret: ret i32 0
}
; CHECK: @test10
; CHECK-NEXT: entry:
; CHECK-NEXT: br label %loop
; CHECK-EMPTY:
; CHECK-NEXT: loop:
; CHECK-NEXT: %min = tail call i64 @llvm.umin.i64(i64 %a, i64 %b)
; CHECK-NEXT: %cmp = icmp ult i64 %x, %min
; CHECK-NEXT: br i1 %cmp, label %loop, label %ret
; umin from outer loop body is processed
define i32 @test11(i64 %a, i64 %b, i64 %x) {
entry:
br label %loop
loop:
%min = tail call i64 @llvm.umin.i64(i64 %a, i64 %b)
br label %nested.loop
nested.loop:
%cmp = icmp ult i64 %x, %min
br i1 %cmp, label %nested.loop, label %loop
ret: ret i32 0
}
; CHECK: @test11
; CHECK-NEXT: entry:
; CHECK-NEXT: br label %loop
; CHECK-EMPTY:
; CHECK-NEXT: loop:
; CHECK-NEXT: br label %nested.loop
; CHECK-EMPTY:
; CHECK-NEXT: nested.loop:
; CHECK-NEXT: %0 = icmp ult i64 %x, %a
; CHECK-NEXT: %1 = icmp ult i64 %x, %b
; CHECK-NEXT: %2 = select i1 %0, i1 %1, i1 false
; CHECK-NEXT: br i1 %2, label %nested.loop, label %loop
declare i64 @llvm.umin.i64(i64, i64)
declare i64 @llvm.smin.i64(i64, i64)
declare i64 @llvm.umax.i64(i64, i64)
declare i64 @llvm.smax.i64(i64, i64)
declare i32 @llvm.umin.i32(i32, i32)
declare i32 @llvm.smin.i32(i32, i32)
declare i32 @llvm.umax.i32(i32, i32)
declare i32 @llvm.smax.i32(i32, i32)