[FuncSpec] Improve the accuracy of the cost model.
Instead of blindly traversing the use-def chain of constant arguments, compute known constants along the way. Stop as soon as a user cannot be replaced by a constant. Keep it light-weight by handling some basic instruction types. Differential Revision: https://reviews.llvm.org/D150464
This commit is contained in:
@@ -52,6 +52,7 @@
|
||||
#include "llvm/Analysis/CodeMetrics.h"
|
||||
#include "llvm/Analysis/InlineCost.h"
|
||||
#include "llvm/Analysis/TargetTransformInfo.h"
|
||||
#include "llvm/IR/InstVisitor.h"
|
||||
#include "llvm/Transforms/Scalar/SCCP.h"
|
||||
#include "llvm/Transforms/Utils/Cloning.h"
|
||||
#include "llvm/Transforms/Utils/SCCPSolver.h"
|
||||
@@ -69,6 +70,9 @@ using SpecMap = DenseMap<Function *, std::pair<unsigned, unsigned>>;
|
||||
// Just a shorter abbreviation to improve indentation.
|
||||
using Cost = InstructionCost;
|
||||
|
||||
// Map of known constants found during the specialization bonus estimation.
|
||||
using ConstMap = DenseMap<Value *, Constant *>;
|
||||
|
||||
// Specialization signature, used to uniquely designate a specialization within
|
||||
// a function.
|
||||
struct SpecSig {
|
||||
@@ -115,6 +119,39 @@ struct Spec {
|
||||
: F(F), Sig(S), Score(Score) {}
|
||||
};
|
||||
|
||||
class InstCostVisitor : public InstVisitor<InstCostVisitor, Constant *> {
|
||||
const DataLayout &DL;
|
||||
BlockFrequencyInfo &BFI;
|
||||
TargetTransformInfo &TTI;
|
||||
SCCPSolver &Solver;
|
||||
|
||||
ConstMap KnownConstants;
|
||||
|
||||
ConstMap::iterator LastVisited;
|
||||
|
||||
public:
|
||||
InstCostVisitor(const DataLayout &DL, BlockFrequencyInfo &BFI,
|
||||
TargetTransformInfo &TTI, SCCPSolver &Solver)
|
||||
: DL(DL), BFI(BFI), TTI(TTI), Solver(Solver) {}
|
||||
|
||||
Cost getUserBonus(Instruction *User, Value *Use, Constant *C);
|
||||
|
||||
private:
|
||||
friend class InstVisitor<InstCostVisitor, Constant *>;
|
||||
|
||||
Cost estimateSwitchInst(SwitchInst &I);
|
||||
Cost estimateBranchInst(BranchInst &I);
|
||||
|
||||
Constant *visitInstruction(Instruction &I) { return nullptr; }
|
||||
Constant *visitLoadInst(LoadInst &I);
|
||||
Constant *visitGetElementPtrInst(GetElementPtrInst &I);
|
||||
Constant *visitSelectInst(SelectInst &I);
|
||||
Constant *visitCastInst(CastInst &I);
|
||||
Constant *visitCmpInst(CmpInst &I);
|
||||
Constant *visitUnaryOperator(UnaryOperator &I);
|
||||
Constant *visitBinaryOperator(BinaryOperator &I);
|
||||
};
|
||||
|
||||
class FunctionSpecializer {
|
||||
|
||||
/// The IPSCCP Solver.
|
||||
@@ -151,6 +188,16 @@ public:
|
||||
|
||||
bool run();
|
||||
|
||||
InstCostVisitor getInstCostVisitorFor(Function *F) {
|
||||
auto &BFI = (GetBFI)(*F);
|
||||
auto &TTI = (GetTTI)(*F);
|
||||
return InstCostVisitor(M.getDataLayout(), BFI, TTI, Solver);
|
||||
}
|
||||
|
||||
/// Compute a bonus for replacing argument \p A with constant \p C.
|
||||
Cost getSpecializationBonus(Argument *A, Constant *C,
|
||||
InstCostVisitor &Visitor);
|
||||
|
||||
private:
|
||||
Constant *getPromotableAlloca(AllocaInst *Alloca, CallInst *Call);
|
||||
|
||||
@@ -194,9 +241,6 @@ private:
|
||||
/// Compute and return the cost of specializing function \p F.
|
||||
Cost getSpecializationCost(Function *F);
|
||||
|
||||
/// Compute a bonus for replacing argument \p A with constant \p C.
|
||||
Cost getSpecializationBonus(Argument *A, Constant *C);
|
||||
|
||||
/// Determine if it is possible to specialise the function for constant values
|
||||
/// of the formal parameter \p A.
|
||||
bool isArgumentInteresting(Argument *A);
|
||||
|
||||
@@ -48,11 +48,14 @@
|
||||
#include "llvm/Transforms/IPO/FunctionSpecialization.h"
|
||||
#include "llvm/ADT/Statistic.h"
|
||||
#include "llvm/Analysis/CodeMetrics.h"
|
||||
#include "llvm/Analysis/ConstantFolding.h"
|
||||
#include "llvm/Analysis/InlineCost.h"
|
||||
#include "llvm/Analysis/InstructionSimplify.h"
|
||||
#include "llvm/Analysis/TargetTransformInfo.h"
|
||||
#include "llvm/Analysis/ValueLattice.h"
|
||||
#include "llvm/Analysis/ValueLatticeUtils.h"
|
||||
#include "llvm/Analysis/ValueTracking.h"
|
||||
#include "llvm/IR/ConstantFold.h"
|
||||
#include "llvm/IR/IntrinsicInst.h"
|
||||
#include "llvm/Transforms/Scalar/SCCP.h"
|
||||
#include "llvm/Transforms/Utils/Cloning.h"
|
||||
@@ -94,6 +97,210 @@ static cl::opt<bool> SpecializeLiteralConstant(
|
||||
"Enable specialization of functions that take a literal constant as an "
|
||||
"argument"));
|
||||
|
||||
// Estimates the instruction cost of all the basic blocks in \p WorkList.
|
||||
// The successors of such blocks are added to the list as long as they are
|
||||
// executable and they have a unique predecessor. \p WorkList represents
|
||||
// the basic blocks of a specialization which become dead once we replace
|
||||
// instructions that are known to be constants. The aim here is to estimate
|
||||
// the combination of size and latency savings in comparison to the non
|
||||
// specialized version of the function.
|
||||
static Cost estimateBasicBlocks(SmallVectorImpl<BasicBlock *> &WorkList,
|
||||
ConstMap &KnownConstants, SCCPSolver &Solver,
|
||||
BlockFrequencyInfo &BFI,
|
||||
TargetTransformInfo &TTI) {
|
||||
Cost Bonus = 0;
|
||||
|
||||
// Accumulate the instruction cost of each basic block weighted by frequency.
|
||||
while (!WorkList.empty()) {
|
||||
BasicBlock *BB = WorkList.pop_back_val();
|
||||
|
||||
uint64_t Weight = BFI.getBlockFreq(BB).getFrequency() /
|
||||
BFI.getEntryFreq();
|
||||
if (!Weight)
|
||||
continue;
|
||||
|
||||
for (Instruction &I : *BB) {
|
||||
// Disregard SSA copies.
|
||||
if (auto *II = dyn_cast<IntrinsicInst>(&I))
|
||||
if (II->getIntrinsicID() == Intrinsic::ssa_copy)
|
||||
continue;
|
||||
// If it's a known constant we have already accounted for it.
|
||||
if (KnownConstants.contains(&I))
|
||||
continue;
|
||||
|
||||
Bonus += Weight *
|
||||
TTI.getInstructionCost(&I, TargetTransformInfo::TCK_SizeAndLatency);
|
||||
|
||||
LLVM_DEBUG(dbgs() << "FnSpecialization: Bonus " << Bonus
|
||||
<< " after user " << I << "\n");
|
||||
}
|
||||
|
||||
// Keep adding dead successors to the list as long as they are
|
||||
// executable and they have a unique predecessor.
|
||||
for (BasicBlock *SuccBB : successors(BB))
|
||||
if (Solver.isBlockExecutable(SuccBB) &&
|
||||
SuccBB->getUniquePredecessor() == BB)
|
||||
WorkList.push_back(SuccBB);
|
||||
}
|
||||
return Bonus;
|
||||
}
|
||||
|
||||
static Constant *findConstantFor(Value *V, ConstMap &KnownConstants) {
|
||||
if (auto It = KnownConstants.find(V); It != KnownConstants.end())
|
||||
return It->second;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Cost InstCostVisitor::getUserBonus(Instruction *User, Value *Use, Constant *C) {
|
||||
// Cache the iterator before visiting.
|
||||
LastVisited = KnownConstants.insert({Use, C}).first;
|
||||
|
||||
if (auto *I = dyn_cast<SwitchInst>(User))
|
||||
return estimateSwitchInst(*I);
|
||||
|
||||
if (auto *I = dyn_cast<BranchInst>(User))
|
||||
return estimateBranchInst(*I);
|
||||
|
||||
C = visit(*User);
|
||||
if (!C)
|
||||
return 0;
|
||||
|
||||
KnownConstants.insert({User, C});
|
||||
|
||||
uint64_t Weight = BFI.getBlockFreq(User->getParent()).getFrequency() /
|
||||
BFI.getEntryFreq();
|
||||
if (!Weight)
|
||||
return 0;
|
||||
|
||||
Cost Bonus = Weight *
|
||||
TTI.getInstructionCost(User, TargetTransformInfo::TCK_SizeAndLatency);
|
||||
|
||||
LLVM_DEBUG(dbgs() << "FnSpecialization: Bonus " << Bonus
|
||||
<< " for user " << *User << "\n");
|
||||
|
||||
for (auto *U : User->users())
|
||||
if (auto *UI = dyn_cast<Instruction>(U))
|
||||
if (Solver.isBlockExecutable(UI->getParent()))
|
||||
Bonus += getUserBonus(UI, User, C);
|
||||
|
||||
return Bonus;
|
||||
}
|
||||
|
||||
Cost InstCostVisitor::estimateSwitchInst(SwitchInst &I) {
|
||||
if (I.getCondition() != LastVisited->first)
|
||||
return 0;
|
||||
|
||||
auto *C = cast<ConstantInt>(LastVisited->second);
|
||||
BasicBlock *Succ = I.findCaseValue(C)->getCaseSuccessor();
|
||||
// Initialize the worklist with the dead basic blocks. These are the
|
||||
// destination labels which are different from the one corresponding
|
||||
// to \p C. They should be executable and have a unique predecessor.
|
||||
SmallVector<BasicBlock *> WorkList;
|
||||
for (const auto &Case : I.cases()) {
|
||||
BasicBlock *BB = Case.getCaseSuccessor();
|
||||
if (BB == Succ || !Solver.isBlockExecutable(BB) ||
|
||||
BB->getUniquePredecessor() != I.getParent())
|
||||
continue;
|
||||
WorkList.push_back(BB);
|
||||
}
|
||||
|
||||
return estimateBasicBlocks(WorkList, KnownConstants, Solver, BFI, TTI);
|
||||
}
|
||||
|
||||
Cost InstCostVisitor::estimateBranchInst(BranchInst &I) {
|
||||
if (I.getCondition() != LastVisited->first)
|
||||
return 0;
|
||||
|
||||
BasicBlock *Succ = I.getSuccessor(LastVisited->second->isOneValue());
|
||||
// Initialize the worklist with the dead successor as long as
|
||||
// it is executable and has a unique predecessor.
|
||||
SmallVector<BasicBlock *> WorkList;
|
||||
if (Solver.isBlockExecutable(Succ) &&
|
||||
Succ->getUniquePredecessor() == I.getParent())
|
||||
WorkList.push_back(Succ);
|
||||
|
||||
return estimateBasicBlocks(WorkList, KnownConstants, Solver, BFI, TTI);
|
||||
}
|
||||
|
||||
Constant *InstCostVisitor::visitLoadInst(LoadInst &I) {
|
||||
if (isa<ConstantPointerNull>(LastVisited->second))
|
||||
return nullptr;
|
||||
return ConstantFoldLoadFromConstPtr(LastVisited->second, I.getType(), DL);
|
||||
}
|
||||
|
||||
Constant *InstCostVisitor::visitGetElementPtrInst(GetElementPtrInst &I) {
|
||||
SmallVector<Value *, 8> Operands;
|
||||
Operands.reserve(I.getNumOperands());
|
||||
|
||||
for (unsigned Idx = 0, E = I.getNumOperands(); Idx != E; ++Idx) {
|
||||
Value *V = I.getOperand(Idx);
|
||||
auto *C = dyn_cast<Constant>(V);
|
||||
if (!C)
|
||||
C = findConstantFor(V, KnownConstants);
|
||||
if (!C)
|
||||
return nullptr;
|
||||
Operands.push_back(C);
|
||||
}
|
||||
|
||||
auto *Ptr = cast<Constant>(Operands[0]);
|
||||
auto Ops = ArrayRef(Operands.begin() + 1, Operands.end());
|
||||
return ConstantFoldGetElementPtr(I.getSourceElementType(), Ptr,
|
||||
I.isInBounds(), std::nullopt, Ops);
|
||||
}
|
||||
|
||||
Constant *InstCostVisitor::visitSelectInst(SelectInst &I) {
|
||||
if (I.getCondition() != LastVisited->first)
|
||||
return nullptr;
|
||||
|
||||
Value *V = LastVisited->second->isZeroValue() ? I.getFalseValue()
|
||||
: I.getTrueValue();
|
||||
auto *C = dyn_cast<Constant>(V);
|
||||
if (!C)
|
||||
C = findConstantFor(V, KnownConstants);
|
||||
return C;
|
||||
}
|
||||
|
||||
Constant *InstCostVisitor::visitCastInst(CastInst &I) {
|
||||
return ConstantFoldCastOperand(I.getOpcode(), LastVisited->second,
|
||||
I.getType(), DL);
|
||||
}
|
||||
|
||||
Constant *InstCostVisitor::visitCmpInst(CmpInst &I) {
|
||||
bool Swap = I.getOperand(1) == LastVisited->first;
|
||||
Value *V = Swap ? I.getOperand(0) : I.getOperand(1);
|
||||
auto *Other = dyn_cast<Constant>(V);
|
||||
if (!Other)
|
||||
Other = findConstantFor(V, KnownConstants);
|
||||
|
||||
if (!Other)
|
||||
return nullptr;
|
||||
|
||||
Constant *Const = LastVisited->second;
|
||||
return Swap ?
|
||||
ConstantFoldCompareInstOperands(I.getPredicate(), Other, Const, DL)
|
||||
: ConstantFoldCompareInstOperands(I.getPredicate(), Const, Other, DL);
|
||||
}
|
||||
|
||||
Constant *InstCostVisitor::visitUnaryOperator(UnaryOperator &I) {
|
||||
return ConstantFoldUnaryOpOperand(I.getOpcode(), LastVisited->second, DL);
|
||||
}
|
||||
|
||||
Constant *InstCostVisitor::visitBinaryOperator(BinaryOperator &I) {
|
||||
bool Swap = I.getOperand(1) == LastVisited->first;
|
||||
Value *V = Swap ? I.getOperand(0) : I.getOperand(1);
|
||||
auto *Other = dyn_cast<Constant>(V);
|
||||
if (!Other)
|
||||
Other = findConstantFor(V, KnownConstants);
|
||||
|
||||
if (!Other)
|
||||
return nullptr;
|
||||
|
||||
Constant *Const = LastVisited->second;
|
||||
return dyn_cast_or_null<Constant>(Swap ?
|
||||
simplifyBinOp(I.getOpcode(), Other, Const, SimplifyQuery(DL))
|
||||
: simplifyBinOp(I.getOpcode(), Const, Other, SimplifyQuery(DL)));
|
||||
}
|
||||
|
||||
Constant *FunctionSpecializer::getPromotableAlloca(AllocaInst *Alloca,
|
||||
CallInst *Call) {
|
||||
Value *StoreValue = nullptr;
|
||||
@@ -412,10 +619,6 @@ CodeMetrics &FunctionSpecializer::analyzeFunction(Function *F) {
|
||||
CodeMetrics::collectEphemeralValues(F, &(GetAC)(*F), EphValues);
|
||||
for (BasicBlock &BB : *F)
|
||||
Metrics.analyzeBasicBlock(&BB, (GetTTI)(*F), EphValues);
|
||||
|
||||
LLVM_DEBUG(dbgs() << "FnSpecialization: Code size of function "
|
||||
<< F->getName() << " is " << Metrics.NumInsts
|
||||
<< " instructions\n");
|
||||
}
|
||||
return Metrics;
|
||||
}
|
||||
@@ -496,8 +699,9 @@ bool FunctionSpecializer::findSpecializations(Function *F, Cost SpecCost,
|
||||
} else {
|
||||
// Calculate the specialisation gain.
|
||||
Cost Score = 0 - SpecCost;
|
||||
InstCostVisitor Visitor = getInstCostVisitorFor(F);
|
||||
for (ArgInfo &A : S.Args)
|
||||
Score += getSpecializationBonus(A.Formal, A.Actual);
|
||||
Score += getSpecializationBonus(A.Formal, A.Actual, Visitor);
|
||||
|
||||
// Discard unprofitable specialisations.
|
||||
if (!ForceSpecialization && Score <= 0)
|
||||
@@ -584,49 +788,23 @@ Cost FunctionSpecializer::getSpecializationCost(Function *F) {
|
||||
|
||||
// Otherwise, set the specialization cost to be the cost of all the
|
||||
// instructions in the function.
|
||||
return Metrics.NumInsts * InlineConstants::getInstrCost();
|
||||
}
|
||||
|
||||
static Cost getUserBonus(User *U, TargetTransformInfo &TTI,
|
||||
BlockFrequencyInfo &BFI) {
|
||||
auto *I = dyn_cast_or_null<Instruction>(U);
|
||||
// If not an instruction we do not know how to evaluate.
|
||||
// Keep minimum possible cost for now so that it doesnt affect
|
||||
// specialization.
|
||||
if (!I)
|
||||
return 0;
|
||||
|
||||
uint64_t Weight = BFI.getBlockFreq(I->getParent()).getFrequency() /
|
||||
BFI.getEntryFreq();
|
||||
if (!Weight)
|
||||
return 0;
|
||||
|
||||
Cost Bonus = Weight *
|
||||
TTI.getInstructionCost(U, TargetTransformInfo::TCK_SizeAndLatency);
|
||||
|
||||
// Traverse recursively if there are more uses.
|
||||
// TODO: Any other instructions to be added here?
|
||||
if (I->mayReadFromMemory() || I->isCast())
|
||||
for (auto *User : I->users())
|
||||
Bonus += getUserBonus(User, TTI, BFI);
|
||||
|
||||
return Bonus;
|
||||
return Metrics.NumInsts;
|
||||
}
|
||||
|
||||
/// Compute a bonus for replacing argument \p A with constant \p C.
|
||||
Cost FunctionSpecializer::getSpecializationBonus(Argument *A, Constant *C) {
|
||||
Function *F = A->getParent();
|
||||
auto &TTI = (GetTTI)(*F);
|
||||
auto &BFI = (GetBFI)(*F);
|
||||
Cost FunctionSpecializer::getSpecializationBonus(Argument *A, Constant *C,
|
||||
InstCostVisitor &Visitor) {
|
||||
LLVM_DEBUG(dbgs() << "FnSpecialization: Analysing bonus for constant: "
|
||||
<< C->getNameOrAsOperand() << "\n");
|
||||
|
||||
Cost TotalCost = 0;
|
||||
for (auto *U : A->users()) {
|
||||
TotalCost += getUserBonus(U, TTI, BFI);
|
||||
LLVM_DEBUG(dbgs() << "FnSpecialization: User cost ";
|
||||
TotalCost.print(dbgs()); dbgs() << " for: " << *U << "\n");
|
||||
}
|
||||
for (auto *U : A->users())
|
||||
if (auto *UI = dyn_cast<Instruction>(U))
|
||||
if (Solver.isBlockExecutable(UI->getParent()))
|
||||
TotalCost += Visitor.getUserBonus(UI, A, C);
|
||||
|
||||
LLVM_DEBUG(dbgs() << "FnSpecialization: Accumulated user bonus "
|
||||
<< TotalCost << " for argument " << *A << "\n");
|
||||
|
||||
// The below heuristic is only concerned with exposing inlining
|
||||
// opportunities via indirect call promotion. If the argument is not a
|
||||
|
||||
@@ -12,6 +12,7 @@ add_llvm_unittest(IPOTests
|
||||
LowerTypeTests.cpp
|
||||
WholeProgramDevirt.cpp
|
||||
AttributorTest.cpp
|
||||
FunctionSpecializationTest.cpp
|
||||
)
|
||||
|
||||
set_property(TARGET IPOTests PROPERTY FOLDER "Tests/UnitTests/TransformsTests")
|
||||
|
||||
258
llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp
Normal file
258
llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp
Normal file
@@ -0,0 +1,258 @@
|
||||
//===- FunctionSpecializationTest.cpp - Cost model unit tests -------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "llvm/Analysis/AssumptionCache.h"
|
||||
#include "llvm/Analysis/BlockFrequencyInfo.h"
|
||||
#include "llvm/Analysis/BranchProbabilityInfo.h"
|
||||
#include "llvm/Analysis/LoopInfo.h"
|
||||
#include "llvm/Analysis/PostDominators.h"
|
||||
#include "llvm/Analysis/TargetLibraryInfo.h"
|
||||
#include "llvm/Analysis/TargetTransformInfo.h"
|
||||
#include "llvm/AsmParser/Parser.h"
|
||||
#include "llvm/IR/Constants.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "llvm/Transforms/IPO/FunctionSpecialization.h"
|
||||
#include "llvm/Transforms/Utils/SCCPSolver.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include <memory>
|
||||
|
||||
namespace llvm {
|
||||
|
||||
class FunctionSpecializationTest : public testing::Test {
|
||||
protected:
|
||||
LLVMContext Ctx;
|
||||
FunctionAnalysisManager FAM;
|
||||
std::unique_ptr<Module> M;
|
||||
std::unique_ptr<SCCPSolver> Solver;
|
||||
|
||||
FunctionSpecializationTest() {
|
||||
FAM.registerPass([&] { return TargetLibraryAnalysis(); });
|
||||
FAM.registerPass([&] { return TargetIRAnalysis(); });
|
||||
FAM.registerPass([&] { return BlockFrequencyAnalysis(); });
|
||||
FAM.registerPass([&] { return BranchProbabilityAnalysis(); });
|
||||
FAM.registerPass([&] { return LoopAnalysis(); });
|
||||
FAM.registerPass([&] { return AssumptionAnalysis(); });
|
||||
FAM.registerPass([&] { return DominatorTreeAnalysis(); });
|
||||
FAM.registerPass([&] { return PostDominatorTreeAnalysis(); });
|
||||
FAM.registerPass([&] { return PassInstrumentationAnalysis(); });
|
||||
}
|
||||
|
||||
Module &parseModule(const char *ModuleString) {
|
||||
SMDiagnostic Err;
|
||||
M = parseAssemblyString(ModuleString, Err, Ctx);
|
||||
EXPECT_TRUE(M);
|
||||
return *M;
|
||||
}
|
||||
|
||||
FunctionSpecializer getSpecializerFor(Function *F) {
|
||||
auto GetTLI = [this](Function &F) -> const TargetLibraryInfo & {
|
||||
return FAM.getResult<TargetLibraryAnalysis>(F);
|
||||
};
|
||||
auto GetTTI = [this](Function &F) -> TargetTransformInfo & {
|
||||
return FAM.getResult<TargetIRAnalysis>(F);
|
||||
};
|
||||
auto GetBFI = [this](Function &F) -> BlockFrequencyInfo & {
|
||||
return FAM.getResult<BlockFrequencyAnalysis>(F);
|
||||
};
|
||||
auto GetAC = [this](Function &F) -> AssumptionCache & {
|
||||
return FAM.getResult<AssumptionAnalysis>(F);
|
||||
};
|
||||
auto GetAnalysis = [this](Function &F) -> AnalysisResultsForFn {
|
||||
DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F);
|
||||
return { std::make_unique<PredicateInfo>(F, DT,
|
||||
FAM.getResult<AssumptionAnalysis>(F)),
|
||||
&DT, FAM.getCachedResult<PostDominatorTreeAnalysis>(F) };
|
||||
};
|
||||
|
||||
Solver = std::make_unique<SCCPSolver>(M->getDataLayout(), GetTLI, Ctx);
|
||||
|
||||
Solver->addAnalysis(*F, GetAnalysis(*F));
|
||||
Solver->markBlockExecutable(&F->front());
|
||||
for (Argument &Arg : F->args())
|
||||
Solver->markOverdefined(&Arg);
|
||||
Solver->solveWhileResolvedUndefsIn(*M);
|
||||
|
||||
return FunctionSpecializer(*Solver, *M, &FAM, GetBFI, GetTLI, GetTTI,
|
||||
GetAC);
|
||||
}
|
||||
|
||||
Cost getInstCost(Instruction &I) {
|
||||
auto &TTI = FAM.getResult<TargetIRAnalysis>(*I.getFunction());
|
||||
auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(*I.getFunction());
|
||||
|
||||
return BFI.getBlockFreq(I.getParent()).getFrequency() / BFI.getEntryFreq() *
|
||||
TTI.getInstructionCost(&I, TargetTransformInfo::TCK_SizeAndLatency);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace llvm
|
||||
|
||||
using namespace llvm;
|
||||
|
||||
TEST_F(FunctionSpecializationTest, SwitchInst) {
|
||||
const char *ModuleString = R"(
|
||||
define void @foo(i32 %a, i32 %b, i32 %i) {
|
||||
entry:
|
||||
switch i32 %i, label %default
|
||||
[ i32 1, label %case1
|
||||
i32 2, label %case2 ]
|
||||
case1:
|
||||
%0 = mul i32 %a, 2
|
||||
%1 = sub i32 6, 5
|
||||
br label %bb1
|
||||
case2:
|
||||
%2 = and i32 %b, 3
|
||||
%3 = sdiv i32 8, 2
|
||||
br label %bb2
|
||||
bb1:
|
||||
%4 = add i32 %0, %b
|
||||
br label %default
|
||||
bb2:
|
||||
%5 = or i32 %2, %a
|
||||
br label %default
|
||||
default:
|
||||
ret void
|
||||
}
|
||||
)";
|
||||
|
||||
Module &M = parseModule(ModuleString);
|
||||
Function *F = M.getFunction("foo");
|
||||
FunctionSpecializer Specializer = getSpecializerFor(F);
|
||||
InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
|
||||
|
||||
Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
|
||||
|
||||
auto FuncIter = F->begin();
|
||||
BasicBlock &Case1 = *++FuncIter;
|
||||
BasicBlock &Case2 = *++FuncIter;
|
||||
BasicBlock &BB1 = *++FuncIter;
|
||||
BasicBlock &BB2 = *++FuncIter;
|
||||
|
||||
Instruction &Mul = Case1.front();
|
||||
Instruction &And = Case2.front();
|
||||
Instruction &Sdiv = *++Case2.begin();
|
||||
Instruction &BrBB2 = Case2.back();
|
||||
Instruction &Add = BB1.front();
|
||||
Instruction &Or = BB2.front();
|
||||
Instruction &BrDefault = BB2.back();
|
||||
|
||||
// mul
|
||||
Cost Ref = getInstCost(Mul);
|
||||
Cost Bonus = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor);
|
||||
EXPECT_EQ(Bonus, Ref);
|
||||
|
||||
// and + or + add
|
||||
Ref = getInstCost(And) + getInstCost(Or) + getInstCost(Add);
|
||||
Bonus = Specializer.getSpecializationBonus(F->getArg(1), One, Visitor);
|
||||
EXPECT_EQ(Bonus, Ref);
|
||||
|
||||
// sdiv + br + br
|
||||
Ref = getInstCost(Sdiv) + getInstCost(BrBB2) + getInstCost(BrDefault);
|
||||
Bonus = Specializer.getSpecializationBonus(F->getArg(2), One, Visitor);
|
||||
EXPECT_EQ(Bonus, Ref);
|
||||
}
|
||||
|
||||
TEST_F(FunctionSpecializationTest, BranchInst) {
|
||||
const char *ModuleString = R"(
|
||||
define void @foo(i32 %a, i32 %b, i1 %cond) {
|
||||
entry:
|
||||
br i1 %cond, label %bb0, label %bb2
|
||||
bb0:
|
||||
%0 = mul i32 %a, 2
|
||||
%1 = sub i32 6, 5
|
||||
br label %bb1
|
||||
bb1:
|
||||
%2 = add i32 %0, %b
|
||||
%3 = sdiv i32 8, 2
|
||||
br label %bb2
|
||||
bb2:
|
||||
ret void
|
||||
}
|
||||
)";
|
||||
|
||||
Module &M = parseModule(ModuleString);
|
||||
Function *F = M.getFunction("foo");
|
||||
FunctionSpecializer Specializer = getSpecializerFor(F);
|
||||
InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
|
||||
|
||||
Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
|
||||
Constant *False = ConstantInt::getFalse(M.getContext());
|
||||
|
||||
auto FuncIter = F->begin();
|
||||
BasicBlock &BB0 = *++FuncIter;
|
||||
BasicBlock &BB1 = *++FuncIter;
|
||||
|
||||
Instruction &Mul = BB0.front();
|
||||
Instruction &Sub = *++BB0.begin();
|
||||
Instruction &BrBB1 = BB0.back();
|
||||
Instruction &Add = BB1.front();
|
||||
Instruction &Sdiv = *++BB1.begin();
|
||||
Instruction &BrBB2 = BB1.back();
|
||||
|
||||
// mul
|
||||
Cost Ref = getInstCost(Mul);
|
||||
Cost Bonus = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor);
|
||||
EXPECT_EQ(Bonus, Ref);
|
||||
|
||||
// add
|
||||
Ref = getInstCost(Add);
|
||||
Bonus = Specializer.getSpecializationBonus(F->getArg(1), One, Visitor);
|
||||
EXPECT_EQ(Bonus, Ref);
|
||||
|
||||
// sub + br + sdiv + br
|
||||
Ref = getInstCost(Sub) + getInstCost(BrBB1) + getInstCost(Sdiv) +
|
||||
getInstCost(BrBB2);
|
||||
Bonus = Specializer.getSpecializationBonus(F->getArg(2), False, Visitor);
|
||||
EXPECT_EQ(Bonus, Ref);
|
||||
}
|
||||
|
||||
TEST_F(FunctionSpecializationTest, Misc) {
|
||||
const char *ModuleString = R"(
|
||||
@g = constant [2 x i32] zeroinitializer, align 4
|
||||
|
||||
define i32 @foo(i8 %a, i1 %cond, ptr %b) {
|
||||
%cmp = icmp eq i8 %a, 10
|
||||
%ext = zext i1 %cmp to i32
|
||||
%sel = select i1 %cond, i32 %ext, i32 1
|
||||
%gep = getelementptr i32, ptr %b, i32 %sel
|
||||
%ld = load i32, ptr %gep
|
||||
ret i32 %ld
|
||||
}
|
||||
)";
|
||||
|
||||
Module &M = parseModule(ModuleString);
|
||||
Function *F = M.getFunction("foo");
|
||||
FunctionSpecializer Specializer = getSpecializerFor(F);
|
||||
InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
|
||||
|
||||
GlobalVariable *GV = M.getGlobalVariable("g");
|
||||
Constant *One = ConstantInt::get(IntegerType::getInt8Ty(M.getContext()), 1);
|
||||
Constant *True = ConstantInt::getTrue(M.getContext());
|
||||
|
||||
auto BlockIter = F->front().begin();
|
||||
Instruction &Icmp = *BlockIter++;
|
||||
Instruction &Zext = *BlockIter++;
|
||||
Instruction &Select = *BlockIter++;
|
||||
Instruction &Gep = *BlockIter++;
|
||||
Instruction &Load = *BlockIter++;
|
||||
|
||||
// icmp + zext
|
||||
Cost Ref = getInstCost(Icmp) + getInstCost(Zext);
|
||||
Cost Bonus = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor);
|
||||
EXPECT_EQ(Bonus, Ref);
|
||||
|
||||
// select
|
||||
Ref = getInstCost(Select);
|
||||
Bonus = Specializer.getSpecializationBonus(F->getArg(1), True, Visitor);
|
||||
EXPECT_EQ(Bonus, Ref);
|
||||
|
||||
// gep + load
|
||||
Ref = getInstCost(Gep) + getInstCost(Load);
|
||||
Bonus = Specializer.getSpecializationBonus(F->getArg(2), GV, Visitor);
|
||||
EXPECT_EQ(Bonus, Ref);
|
||||
}
|
||||
Reference in New Issue
Block a user