[FuncSpec] Improve the accuracy of the cost model.

Instead of blindly traversing the use-def chain of constant arguments,
compute known constants along the way. Stop as soon as a user cannot
be replaced by a constant. Keep it light-weight by handling some basic
instruction types.

Differential Revision: https://reviews.llvm.org/D150464
This commit is contained in:
Alexandros Lamprineas
2023-05-12 00:07:49 +01:00
parent 4447b82b89
commit ced90d1ff6
4 changed files with 525 additions and 44 deletions

View File

@@ -52,6 +52,7 @@
#include "llvm/Analysis/CodeMetrics.h"
#include "llvm/Analysis/InlineCost.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/IR/InstVisitor.h"
#include "llvm/Transforms/Scalar/SCCP.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "llvm/Transforms/Utils/SCCPSolver.h"
@@ -69,6 +70,9 @@ using SpecMap = DenseMap<Function *, std::pair<unsigned, unsigned>>;
// Just a shorter abbreviation to improve indentation.
using Cost = InstructionCost;
// Map of known constants found during the specialization bonus estimation.
using ConstMap = DenseMap<Value *, Constant *>;
// Specialization signature, used to uniquely designate a specialization within
// a function.
struct SpecSig {
@@ -115,6 +119,39 @@ struct Spec {
: F(F), Sig(S), Score(Score) {}
};
class InstCostVisitor : public InstVisitor<InstCostVisitor, Constant *> {
const DataLayout &DL;
BlockFrequencyInfo &BFI;
TargetTransformInfo &TTI;
SCCPSolver &Solver;
ConstMap KnownConstants;
ConstMap::iterator LastVisited;
public:
InstCostVisitor(const DataLayout &DL, BlockFrequencyInfo &BFI,
TargetTransformInfo &TTI, SCCPSolver &Solver)
: DL(DL), BFI(BFI), TTI(TTI), Solver(Solver) {}
Cost getUserBonus(Instruction *User, Value *Use, Constant *C);
private:
friend class InstVisitor<InstCostVisitor, Constant *>;
Cost estimateSwitchInst(SwitchInst &I);
Cost estimateBranchInst(BranchInst &I);
Constant *visitInstruction(Instruction &I) { return nullptr; }
Constant *visitLoadInst(LoadInst &I);
Constant *visitGetElementPtrInst(GetElementPtrInst &I);
Constant *visitSelectInst(SelectInst &I);
Constant *visitCastInst(CastInst &I);
Constant *visitCmpInst(CmpInst &I);
Constant *visitUnaryOperator(UnaryOperator &I);
Constant *visitBinaryOperator(BinaryOperator &I);
};
class FunctionSpecializer {
/// The IPSCCP Solver.
@@ -151,6 +188,16 @@ public:
bool run();
InstCostVisitor getInstCostVisitorFor(Function *F) {
auto &BFI = (GetBFI)(*F);
auto &TTI = (GetTTI)(*F);
return InstCostVisitor(M.getDataLayout(), BFI, TTI, Solver);
}
/// Compute a bonus for replacing argument \p A with constant \p C.
Cost getSpecializationBonus(Argument *A, Constant *C,
InstCostVisitor &Visitor);
private:
Constant *getPromotableAlloca(AllocaInst *Alloca, CallInst *Call);
@@ -194,9 +241,6 @@ private:
/// Compute and return the cost of specializing function \p F.
Cost getSpecializationCost(Function *F);
/// Compute a bonus for replacing argument \p A with constant \p C.
Cost getSpecializationBonus(Argument *A, Constant *C);
/// Determine if it is possible to specialise the function for constant values
/// of the formal parameter \p A.
bool isArgumentInteresting(Argument *A);

View File

@@ -48,11 +48,14 @@
#include "llvm/Transforms/IPO/FunctionSpecialization.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/CodeMetrics.h"
#include "llvm/Analysis/ConstantFolding.h"
#include "llvm/Analysis/InlineCost.h"
#include "llvm/Analysis/InstructionSimplify.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/Analysis/ValueLattice.h"
#include "llvm/Analysis/ValueLatticeUtils.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/ConstantFold.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/Transforms/Scalar/SCCP.h"
#include "llvm/Transforms/Utils/Cloning.h"
@@ -94,6 +97,210 @@ static cl::opt<bool> SpecializeLiteralConstant(
"Enable specialization of functions that take a literal constant as an "
"argument"));
// Estimates the instruction cost of all the basic blocks in \p WorkList.
// The successors of such blocks are added to the list as long as they are
// executable and they have a unique predecessor. \p WorkList represents
// the basic blocks of a specialization which become dead once we replace
// instructions that are known to be constants. The aim here is to estimate
// the combination of size and latency savings in comparison to the non
// specialized version of the function.
static Cost estimateBasicBlocks(SmallVectorImpl<BasicBlock *> &WorkList,
ConstMap &KnownConstants, SCCPSolver &Solver,
BlockFrequencyInfo &BFI,
TargetTransformInfo &TTI) {
Cost Bonus = 0;
// Accumulate the instruction cost of each basic block weighted by frequency.
while (!WorkList.empty()) {
BasicBlock *BB = WorkList.pop_back_val();
uint64_t Weight = BFI.getBlockFreq(BB).getFrequency() /
BFI.getEntryFreq();
if (!Weight)
continue;
for (Instruction &I : *BB) {
// Disregard SSA copies.
if (auto *II = dyn_cast<IntrinsicInst>(&I))
if (II->getIntrinsicID() == Intrinsic::ssa_copy)
continue;
// If it's a known constant we have already accounted for it.
if (KnownConstants.contains(&I))
continue;
Bonus += Weight *
TTI.getInstructionCost(&I, TargetTransformInfo::TCK_SizeAndLatency);
LLVM_DEBUG(dbgs() << "FnSpecialization: Bonus " << Bonus
<< " after user " << I << "\n");
}
// Keep adding dead successors to the list as long as they are
// executable and they have a unique predecessor.
for (BasicBlock *SuccBB : successors(BB))
if (Solver.isBlockExecutable(SuccBB) &&
SuccBB->getUniquePredecessor() == BB)
WorkList.push_back(SuccBB);
}
return Bonus;
}
static Constant *findConstantFor(Value *V, ConstMap &KnownConstants) {
if (auto It = KnownConstants.find(V); It != KnownConstants.end())
return It->second;
return nullptr;
}
Cost InstCostVisitor::getUserBonus(Instruction *User, Value *Use, Constant *C) {
// Cache the iterator before visiting.
LastVisited = KnownConstants.insert({Use, C}).first;
if (auto *I = dyn_cast<SwitchInst>(User))
return estimateSwitchInst(*I);
if (auto *I = dyn_cast<BranchInst>(User))
return estimateBranchInst(*I);
C = visit(*User);
if (!C)
return 0;
KnownConstants.insert({User, C});
uint64_t Weight = BFI.getBlockFreq(User->getParent()).getFrequency() /
BFI.getEntryFreq();
if (!Weight)
return 0;
Cost Bonus = Weight *
TTI.getInstructionCost(User, TargetTransformInfo::TCK_SizeAndLatency);
LLVM_DEBUG(dbgs() << "FnSpecialization: Bonus " << Bonus
<< " for user " << *User << "\n");
for (auto *U : User->users())
if (auto *UI = dyn_cast<Instruction>(U))
if (Solver.isBlockExecutable(UI->getParent()))
Bonus += getUserBonus(UI, User, C);
return Bonus;
}
Cost InstCostVisitor::estimateSwitchInst(SwitchInst &I) {
if (I.getCondition() != LastVisited->first)
return 0;
auto *C = cast<ConstantInt>(LastVisited->second);
BasicBlock *Succ = I.findCaseValue(C)->getCaseSuccessor();
// Initialize the worklist with the dead basic blocks. These are the
// destination labels which are different from the one corresponding
// to \p C. They should be executable and have a unique predecessor.
SmallVector<BasicBlock *> WorkList;
for (const auto &Case : I.cases()) {
BasicBlock *BB = Case.getCaseSuccessor();
if (BB == Succ || !Solver.isBlockExecutable(BB) ||
BB->getUniquePredecessor() != I.getParent())
continue;
WorkList.push_back(BB);
}
return estimateBasicBlocks(WorkList, KnownConstants, Solver, BFI, TTI);
}
Cost InstCostVisitor::estimateBranchInst(BranchInst &I) {
if (I.getCondition() != LastVisited->first)
return 0;
BasicBlock *Succ = I.getSuccessor(LastVisited->second->isOneValue());
// Initialize the worklist with the dead successor as long as
// it is executable and has a unique predecessor.
SmallVector<BasicBlock *> WorkList;
if (Solver.isBlockExecutable(Succ) &&
Succ->getUniquePredecessor() == I.getParent())
WorkList.push_back(Succ);
return estimateBasicBlocks(WorkList, KnownConstants, Solver, BFI, TTI);
}
Constant *InstCostVisitor::visitLoadInst(LoadInst &I) {
if (isa<ConstantPointerNull>(LastVisited->second))
return nullptr;
return ConstantFoldLoadFromConstPtr(LastVisited->second, I.getType(), DL);
}
Constant *InstCostVisitor::visitGetElementPtrInst(GetElementPtrInst &I) {
SmallVector<Value *, 8> Operands;
Operands.reserve(I.getNumOperands());
for (unsigned Idx = 0, E = I.getNumOperands(); Idx != E; ++Idx) {
Value *V = I.getOperand(Idx);
auto *C = dyn_cast<Constant>(V);
if (!C)
C = findConstantFor(V, KnownConstants);
if (!C)
return nullptr;
Operands.push_back(C);
}
auto *Ptr = cast<Constant>(Operands[0]);
auto Ops = ArrayRef(Operands.begin() + 1, Operands.end());
return ConstantFoldGetElementPtr(I.getSourceElementType(), Ptr,
I.isInBounds(), std::nullopt, Ops);
}
Constant *InstCostVisitor::visitSelectInst(SelectInst &I) {
if (I.getCondition() != LastVisited->first)
return nullptr;
Value *V = LastVisited->second->isZeroValue() ? I.getFalseValue()
: I.getTrueValue();
auto *C = dyn_cast<Constant>(V);
if (!C)
C = findConstantFor(V, KnownConstants);
return C;
}
Constant *InstCostVisitor::visitCastInst(CastInst &I) {
return ConstantFoldCastOperand(I.getOpcode(), LastVisited->second,
I.getType(), DL);
}
Constant *InstCostVisitor::visitCmpInst(CmpInst &I) {
bool Swap = I.getOperand(1) == LastVisited->first;
Value *V = Swap ? I.getOperand(0) : I.getOperand(1);
auto *Other = dyn_cast<Constant>(V);
if (!Other)
Other = findConstantFor(V, KnownConstants);
if (!Other)
return nullptr;
Constant *Const = LastVisited->second;
return Swap ?
ConstantFoldCompareInstOperands(I.getPredicate(), Other, Const, DL)
: ConstantFoldCompareInstOperands(I.getPredicate(), Const, Other, DL);
}
Constant *InstCostVisitor::visitUnaryOperator(UnaryOperator &I) {
return ConstantFoldUnaryOpOperand(I.getOpcode(), LastVisited->second, DL);
}
Constant *InstCostVisitor::visitBinaryOperator(BinaryOperator &I) {
bool Swap = I.getOperand(1) == LastVisited->first;
Value *V = Swap ? I.getOperand(0) : I.getOperand(1);
auto *Other = dyn_cast<Constant>(V);
if (!Other)
Other = findConstantFor(V, KnownConstants);
if (!Other)
return nullptr;
Constant *Const = LastVisited->second;
return dyn_cast_or_null<Constant>(Swap ?
simplifyBinOp(I.getOpcode(), Other, Const, SimplifyQuery(DL))
: simplifyBinOp(I.getOpcode(), Const, Other, SimplifyQuery(DL)));
}
Constant *FunctionSpecializer::getPromotableAlloca(AllocaInst *Alloca,
CallInst *Call) {
Value *StoreValue = nullptr;
@@ -412,10 +619,6 @@ CodeMetrics &FunctionSpecializer::analyzeFunction(Function *F) {
CodeMetrics::collectEphemeralValues(F, &(GetAC)(*F), EphValues);
for (BasicBlock &BB : *F)
Metrics.analyzeBasicBlock(&BB, (GetTTI)(*F), EphValues);
LLVM_DEBUG(dbgs() << "FnSpecialization: Code size of function "
<< F->getName() << " is " << Metrics.NumInsts
<< " instructions\n");
}
return Metrics;
}
@@ -496,8 +699,9 @@ bool FunctionSpecializer::findSpecializations(Function *F, Cost SpecCost,
} else {
// Calculate the specialisation gain.
Cost Score = 0 - SpecCost;
InstCostVisitor Visitor = getInstCostVisitorFor(F);
for (ArgInfo &A : S.Args)
Score += getSpecializationBonus(A.Formal, A.Actual);
Score += getSpecializationBonus(A.Formal, A.Actual, Visitor);
// Discard unprofitable specialisations.
if (!ForceSpecialization && Score <= 0)
@@ -584,49 +788,23 @@ Cost FunctionSpecializer::getSpecializationCost(Function *F) {
// Otherwise, set the specialization cost to be the cost of all the
// instructions in the function.
return Metrics.NumInsts * InlineConstants::getInstrCost();
}
static Cost getUserBonus(User *U, TargetTransformInfo &TTI,
BlockFrequencyInfo &BFI) {
auto *I = dyn_cast_or_null<Instruction>(U);
// If not an instruction we do not know how to evaluate.
// Keep minimum possible cost for now so that it doesnt affect
// specialization.
if (!I)
return 0;
uint64_t Weight = BFI.getBlockFreq(I->getParent()).getFrequency() /
BFI.getEntryFreq();
if (!Weight)
return 0;
Cost Bonus = Weight *
TTI.getInstructionCost(U, TargetTransformInfo::TCK_SizeAndLatency);
// Traverse recursively if there are more uses.
// TODO: Any other instructions to be added here?
if (I->mayReadFromMemory() || I->isCast())
for (auto *User : I->users())
Bonus += getUserBonus(User, TTI, BFI);
return Bonus;
return Metrics.NumInsts;
}
/// Compute a bonus for replacing argument \p A with constant \p C.
Cost FunctionSpecializer::getSpecializationBonus(Argument *A, Constant *C) {
Function *F = A->getParent();
auto &TTI = (GetTTI)(*F);
auto &BFI = (GetBFI)(*F);
Cost FunctionSpecializer::getSpecializationBonus(Argument *A, Constant *C,
InstCostVisitor &Visitor) {
LLVM_DEBUG(dbgs() << "FnSpecialization: Analysing bonus for constant: "
<< C->getNameOrAsOperand() << "\n");
Cost TotalCost = 0;
for (auto *U : A->users()) {
TotalCost += getUserBonus(U, TTI, BFI);
LLVM_DEBUG(dbgs() << "FnSpecialization: User cost ";
TotalCost.print(dbgs()); dbgs() << " for: " << *U << "\n");
}
for (auto *U : A->users())
if (auto *UI = dyn_cast<Instruction>(U))
if (Solver.isBlockExecutable(UI->getParent()))
TotalCost += Visitor.getUserBonus(UI, A, C);
LLVM_DEBUG(dbgs() << "FnSpecialization: Accumulated user bonus "
<< TotalCost << " for argument " << *A << "\n");
// The below heuristic is only concerned with exposing inlining
// opportunities via indirect call promotion. If the argument is not a

View File

@@ -12,6 +12,7 @@ add_llvm_unittest(IPOTests
LowerTypeTests.cpp
WholeProgramDevirt.cpp
AttributorTest.cpp
FunctionSpecializationTest.cpp
)
set_property(TARGET IPOTests PROPERTY FOLDER "Tests/UnitTests/TransformsTests")

View File

@@ -0,0 +1,258 @@
//===- FunctionSpecializationTest.cpp - Cost model unit tests -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/BlockFrequencyInfo.h"
#include "llvm/Analysis/BranchProbabilityInfo.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/PostDominators.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/IR/Constants.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Transforms/IPO/FunctionSpecialization.h"
#include "llvm/Transforms/Utils/SCCPSolver.h"
#include "gtest/gtest.h"
#include <memory>
namespace llvm {
class FunctionSpecializationTest : public testing::Test {
protected:
LLVMContext Ctx;
FunctionAnalysisManager FAM;
std::unique_ptr<Module> M;
std::unique_ptr<SCCPSolver> Solver;
FunctionSpecializationTest() {
FAM.registerPass([&] { return TargetLibraryAnalysis(); });
FAM.registerPass([&] { return TargetIRAnalysis(); });
FAM.registerPass([&] { return BlockFrequencyAnalysis(); });
FAM.registerPass([&] { return BranchProbabilityAnalysis(); });
FAM.registerPass([&] { return LoopAnalysis(); });
FAM.registerPass([&] { return AssumptionAnalysis(); });
FAM.registerPass([&] { return DominatorTreeAnalysis(); });
FAM.registerPass([&] { return PostDominatorTreeAnalysis(); });
FAM.registerPass([&] { return PassInstrumentationAnalysis(); });
}
Module &parseModule(const char *ModuleString) {
SMDiagnostic Err;
M = parseAssemblyString(ModuleString, Err, Ctx);
EXPECT_TRUE(M);
return *M;
}
FunctionSpecializer getSpecializerFor(Function *F) {
auto GetTLI = [this](Function &F) -> const TargetLibraryInfo & {
return FAM.getResult<TargetLibraryAnalysis>(F);
};
auto GetTTI = [this](Function &F) -> TargetTransformInfo & {
return FAM.getResult<TargetIRAnalysis>(F);
};
auto GetBFI = [this](Function &F) -> BlockFrequencyInfo & {
return FAM.getResult<BlockFrequencyAnalysis>(F);
};
auto GetAC = [this](Function &F) -> AssumptionCache & {
return FAM.getResult<AssumptionAnalysis>(F);
};
auto GetAnalysis = [this](Function &F) -> AnalysisResultsForFn {
DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F);
return { std::make_unique<PredicateInfo>(F, DT,
FAM.getResult<AssumptionAnalysis>(F)),
&DT, FAM.getCachedResult<PostDominatorTreeAnalysis>(F) };
};
Solver = std::make_unique<SCCPSolver>(M->getDataLayout(), GetTLI, Ctx);
Solver->addAnalysis(*F, GetAnalysis(*F));
Solver->markBlockExecutable(&F->front());
for (Argument &Arg : F->args())
Solver->markOverdefined(&Arg);
Solver->solveWhileResolvedUndefsIn(*M);
return FunctionSpecializer(*Solver, *M, &FAM, GetBFI, GetTLI, GetTTI,
GetAC);
}
Cost getInstCost(Instruction &I) {
auto &TTI = FAM.getResult<TargetIRAnalysis>(*I.getFunction());
auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(*I.getFunction());
return BFI.getBlockFreq(I.getParent()).getFrequency() / BFI.getEntryFreq() *
TTI.getInstructionCost(&I, TargetTransformInfo::TCK_SizeAndLatency);
}
};
} // namespace llvm
using namespace llvm;
TEST_F(FunctionSpecializationTest, SwitchInst) {
const char *ModuleString = R"(
define void @foo(i32 %a, i32 %b, i32 %i) {
entry:
switch i32 %i, label %default
[ i32 1, label %case1
i32 2, label %case2 ]
case1:
%0 = mul i32 %a, 2
%1 = sub i32 6, 5
br label %bb1
case2:
%2 = and i32 %b, 3
%3 = sdiv i32 8, 2
br label %bb2
bb1:
%4 = add i32 %0, %b
br label %default
bb2:
%5 = or i32 %2, %a
br label %default
default:
ret void
}
)";
Module &M = parseModule(ModuleString);
Function *F = M.getFunction("foo");
FunctionSpecializer Specializer = getSpecializerFor(F);
InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
auto FuncIter = F->begin();
BasicBlock &Case1 = *++FuncIter;
BasicBlock &Case2 = *++FuncIter;
BasicBlock &BB1 = *++FuncIter;
BasicBlock &BB2 = *++FuncIter;
Instruction &Mul = Case1.front();
Instruction &And = Case2.front();
Instruction &Sdiv = *++Case2.begin();
Instruction &BrBB2 = Case2.back();
Instruction &Add = BB1.front();
Instruction &Or = BB2.front();
Instruction &BrDefault = BB2.back();
// mul
Cost Ref = getInstCost(Mul);
Cost Bonus = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor);
EXPECT_EQ(Bonus, Ref);
// and + or + add
Ref = getInstCost(And) + getInstCost(Or) + getInstCost(Add);
Bonus = Specializer.getSpecializationBonus(F->getArg(1), One, Visitor);
EXPECT_EQ(Bonus, Ref);
// sdiv + br + br
Ref = getInstCost(Sdiv) + getInstCost(BrBB2) + getInstCost(BrDefault);
Bonus = Specializer.getSpecializationBonus(F->getArg(2), One, Visitor);
EXPECT_EQ(Bonus, Ref);
}
TEST_F(FunctionSpecializationTest, BranchInst) {
const char *ModuleString = R"(
define void @foo(i32 %a, i32 %b, i1 %cond) {
entry:
br i1 %cond, label %bb0, label %bb2
bb0:
%0 = mul i32 %a, 2
%1 = sub i32 6, 5
br label %bb1
bb1:
%2 = add i32 %0, %b
%3 = sdiv i32 8, 2
br label %bb2
bb2:
ret void
}
)";
Module &M = parseModule(ModuleString);
Function *F = M.getFunction("foo");
FunctionSpecializer Specializer = getSpecializerFor(F);
InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
Constant *False = ConstantInt::getFalse(M.getContext());
auto FuncIter = F->begin();
BasicBlock &BB0 = *++FuncIter;
BasicBlock &BB1 = *++FuncIter;
Instruction &Mul = BB0.front();
Instruction &Sub = *++BB0.begin();
Instruction &BrBB1 = BB0.back();
Instruction &Add = BB1.front();
Instruction &Sdiv = *++BB1.begin();
Instruction &BrBB2 = BB1.back();
// mul
Cost Ref = getInstCost(Mul);
Cost Bonus = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor);
EXPECT_EQ(Bonus, Ref);
// add
Ref = getInstCost(Add);
Bonus = Specializer.getSpecializationBonus(F->getArg(1), One, Visitor);
EXPECT_EQ(Bonus, Ref);
// sub + br + sdiv + br
Ref = getInstCost(Sub) + getInstCost(BrBB1) + getInstCost(Sdiv) +
getInstCost(BrBB2);
Bonus = Specializer.getSpecializationBonus(F->getArg(2), False, Visitor);
EXPECT_EQ(Bonus, Ref);
}
TEST_F(FunctionSpecializationTest, Misc) {
const char *ModuleString = R"(
@g = constant [2 x i32] zeroinitializer, align 4
define i32 @foo(i8 %a, i1 %cond, ptr %b) {
%cmp = icmp eq i8 %a, 10
%ext = zext i1 %cmp to i32
%sel = select i1 %cond, i32 %ext, i32 1
%gep = getelementptr i32, ptr %b, i32 %sel
%ld = load i32, ptr %gep
ret i32 %ld
}
)";
Module &M = parseModule(ModuleString);
Function *F = M.getFunction("foo");
FunctionSpecializer Specializer = getSpecializerFor(F);
InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
GlobalVariable *GV = M.getGlobalVariable("g");
Constant *One = ConstantInt::get(IntegerType::getInt8Ty(M.getContext()), 1);
Constant *True = ConstantInt::getTrue(M.getContext());
auto BlockIter = F->front().begin();
Instruction &Icmp = *BlockIter++;
Instruction &Zext = *BlockIter++;
Instruction &Select = *BlockIter++;
Instruction &Gep = *BlockIter++;
Instruction &Load = *BlockIter++;
// icmp + zext
Cost Ref = getInstCost(Icmp) + getInstCost(Zext);
Cost Bonus = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor);
EXPECT_EQ(Bonus, Ref);
// select
Ref = getInstCost(Select);
Bonus = Specializer.getSpecializationBonus(F->getArg(1), True, Visitor);
EXPECT_EQ(Bonus, Ref);
// gep + load
Ref = getInstCost(Gep) + getInstCost(Load);
Bonus = Specializer.getSpecializationBonus(F->getArg(2), GV, Visitor);
EXPECT_EQ(Bonus, Ref);
}