From daa2a587cc01c5656deecda7f768fed0afc1e515 Mon Sep 17 00:00:00 2001 From: Mircea Trofin Date: Mon, 23 Jun 2025 08:07:31 -0700 Subject: [PATCH] [TRE] Adjust function entry count when using instrumented profiles (#143987) The entry count of a function needs to be updated after a callsite is elided by TRE: before elision, the entry count accounted for the recursive call at that callsite. After TRE, we need to remove that callsite's contribution. This patch enables this for instrumented profiling cases because, there, we know the function entry count captured entries before TRE. We cannot currently address this for sample-based (because we don't know whether this function was TRE-ed in the binary that donated samples) --- llvm/include/llvm/Passes/PassBuilder.h | 2 + .../Scalar/TailRecursionElimination.h | 7 +- llvm/lib/Passes/PassBuilderPipelines.cpp | 14 +- .../Scalar/TailRecursionElimination.cpp | 68 +++++++++- .../TailCallElim/entry-count-adjustment.ll | 120 ++++++++++++++++++ 5 files changed, 200 insertions(+), 11 deletions(-) create mode 100644 llvm/test/Transforms/TailCallElim/entry-count-adjustment.ll diff --git a/llvm/include/llvm/Passes/PassBuilder.h b/llvm/include/llvm/Passes/PassBuilder.h index f13b5c678a89..9cdb7ca7dbc9 100644 --- a/llvm/include/llvm/Passes/PassBuilder.h +++ b/llvm/include/llvm/Passes/PassBuilder.h @@ -773,6 +773,8 @@ private: IntrusiveRefCntPtr FS); void addPostPGOLoopRotation(ModulePassManager &MPM, OptimizationLevel Level); + bool isInstrumentedPGOUse() const; + // Extension Point callbacks SmallVector, 2> PeepholeEPCallbacks; diff --git a/llvm/include/llvm/Transforms/Scalar/TailRecursionElimination.h b/llvm/include/llvm/Transforms/Scalar/TailRecursionElimination.h index 57b1ed9bf4fe..22a70cd66865 100644 --- a/llvm/include/llvm/Transforms/Scalar/TailRecursionElimination.h +++ b/llvm/include/llvm/Transforms/Scalar/TailRecursionElimination.h @@ -58,7 +58,12 @@ namespace llvm { class Function; -struct TailCallElimPass : PassInfoMixin { +class TailCallElimPass : public PassInfoMixin { + const bool UpdateFunctionEntryCount; + +public: + TailCallElimPass(bool UpdateFunctionEntryCount = true) + : UpdateFunctionEntryCount(UpdateFunctionEntryCount) {} PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); }; } diff --git a/llvm/lib/Passes/PassBuilderPipelines.cpp b/llvm/lib/Passes/PassBuilderPipelines.cpp index b0cdd1b94e56..c83d2dc1f151 100644 --- a/llvm/lib/Passes/PassBuilderPipelines.cpp +++ b/llvm/lib/Passes/PassBuilderPipelines.cpp @@ -625,7 +625,8 @@ PassBuilder::buildFunctionSimplificationPipeline(OptimizationLevel Level, !Level.isOptimizingForSize()) FPM.addPass(PGOMemOPSizeOpt()); - FPM.addPass(TailCallElimPass()); + FPM.addPass(TailCallElimPass(/*UpdateFunctionEntryCount=*/ + isInstrumentedPGOUse())); FPM.addPass( SimplifyCFGPass(SimplifyCFGOptions().convertSwitchRangeToICmp(true))); @@ -1578,7 +1579,8 @@ PassBuilder::buildModuleOptimizationPipeline(OptimizationLevel Level, OptimizePM.addPass(DivRemPairsPass()); // Try to annotate calls that were created during optimization. - OptimizePM.addPass(TailCallElimPass()); + OptimizePM.addPass( + TailCallElimPass(/*UpdateFunctionEntryCount=*/isInstrumentedPGOUse())); // LoopSink (and other loop passes since the last simplifyCFG) might have // resulted in single-entry-single-exit or empty blocks. Clean up the CFG. @@ -2066,7 +2068,8 @@ PassBuilder::buildLTODefaultPipeline(OptimizationLevel Level, // LTO provides additional opportunities for tailcall elimination due to // link-time inlining, and visibility of nocapture attribute. - FPM.addPass(TailCallElimPass()); + FPM.addPass( + TailCallElimPass(/*UpdateFunctionEntryCount=*/isInstrumentedPGOUse())); // Run a few AA driver optimizations here and now to cleanup the code. MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM), @@ -2347,3 +2350,8 @@ AAManager PassBuilder::buildDefaultAAPipeline() { return AA; } + +bool PassBuilder::isInstrumentedPGOUse() const { + return (PGOOpt && PGOOpt->Action == PGOOptions::IRUse) || + !UseCtxProfile.empty(); +} \ No newline at end of file diff --git a/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp b/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp index e7d989a43840..7828571123bc 100644 --- a/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp +++ b/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp @@ -53,6 +53,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/InstructionSimplify.h" @@ -75,10 +76,12 @@ #include "llvm/IR/Module.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" +#include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include using namespace llvm; #define DEBUG_TYPE "tailcallelim" @@ -87,6 +90,11 @@ STATISTIC(NumEliminated, "Number of tail calls removed"); STATISTIC(NumRetDuped, "Number of return duplicated"); STATISTIC(NumAccumAdded, "Number of accumulators introduced"); +static cl::opt ForceDisableBFI( + "tre-disable-entrycount-recompute", cl::init(false), cl::Hidden, + cl::desc("Force disabling recomputing of function entry count, on " + "successful tail recursion elimination.")); + /// Scan the specified function for alloca instructions. /// If it contains any dynamic allocas, returns false. static bool canTRE(Function &F) { @@ -399,6 +407,9 @@ class TailRecursionEliminator { AliasAnalysis *AA; OptimizationRemarkEmitter *ORE; DomTreeUpdater &DTU; + BlockFrequencyInfo *const BFI; + const uint64_t OrigEntryBBFreq; + const uint64_t OrigEntryCount; // The below are shared state we want to have available when eliminating any // calls in the function. There values should be populated by @@ -428,8 +439,19 @@ class TailRecursionEliminator { TailRecursionEliminator(Function &F, const TargetTransformInfo *TTI, AliasAnalysis *AA, OptimizationRemarkEmitter *ORE, - DomTreeUpdater &DTU) - : F(F), TTI(TTI), AA(AA), ORE(ORE), DTU(DTU) {} + DomTreeUpdater &DTU, BlockFrequencyInfo *BFI) + : F(F), TTI(TTI), AA(AA), ORE(ORE), DTU(DTU), BFI(BFI), + OrigEntryBBFreq( + BFI ? BFI->getBlockFreq(&F.getEntryBlock()).getFrequency() : 0U), + OrigEntryCount(F.getEntryCount() ? F.getEntryCount()->getCount() : 0) { + if (BFI) { + // The assert is meant as API documentation for the caller. + assert((OrigEntryCount != 0 && OrigEntryBBFreq != 0) && + "If a BFI was provided, the function should have both an entry " + "count that is non-zero and an entry basic block with a non-zero " + "frequency."); + } + } CallInst *findTRECandidate(BasicBlock *BB); @@ -450,7 +472,7 @@ class TailRecursionEliminator { public: static bool eliminate(Function &F, const TargetTransformInfo *TTI, AliasAnalysis *AA, OptimizationRemarkEmitter *ORE, - DomTreeUpdater &DTU); + DomTreeUpdater &DTU, BlockFrequencyInfo *BFI); }; } // namespace @@ -735,6 +757,28 @@ bool TailRecursionEliminator::eliminateCall(CallInst *CI) { CI->eraseFromParent(); // Remove call. DTU.applyUpdates({{DominatorTree::Insert, BB, HeaderBB}}); ++NumEliminated; + if (OrigEntryBBFreq) { + assert(F.getEntryCount().has_value()); + // This pass is not expected to remove BBs, only add an entry BB. For that + // reason, and because the BB here isn't the new entry BB, the BFI lookup is + // expected to succeed. + assert(&F.getEntryBlock() != BB); + auto RelativeBBFreq = + static_cast(BFI->getBlockFreq(BB).getFrequency()) / + static_cast(OrigEntryBBFreq); + auto ToSubtract = + static_cast(std::round(RelativeBBFreq * OrigEntryCount)); + auto OldEntryCount = F.getEntryCount()->getCount(); + if (OldEntryCount <= ToSubtract) { + LLVM_DEBUG( + errs() << "[TRE] The entrycount attributable to the recursive call, " + << ToSubtract + << ", should be strictly lower than the function entry count, " + << OldEntryCount << "\n"); + } else { + F.setEntryCount(OldEntryCount - ToSubtract, F.getEntryCount()->getType()); + } + } return true; } @@ -861,7 +905,8 @@ bool TailRecursionEliminator::eliminate(Function &F, const TargetTransformInfo *TTI, AliasAnalysis *AA, OptimizationRemarkEmitter *ORE, - DomTreeUpdater &DTU) { + DomTreeUpdater &DTU, + BlockFrequencyInfo *BFI) { if (F.getFnAttribute("disable-tail-calls").getValueAsBool()) return false; @@ -877,7 +922,7 @@ bool TailRecursionEliminator::eliminate(Function &F, return MadeChange; // Change any tail recursive calls to loops. - TailRecursionEliminator TRE(F, TTI, AA, ORE, DTU); + TailRecursionEliminator TRE(F, TTI, AA, ORE, DTU, BFI); for (BasicBlock &BB : F) MadeChange |= TRE.processBlock(BB); @@ -919,7 +964,8 @@ struct TailCallElim : public FunctionPass { return TailRecursionEliminator::eliminate( F, &getAnalysis().getTTI(F), &getAnalysis().getAAResults(), - &getAnalysis().getORE(), DTU); + &getAnalysis().getORE(), DTU, + /*BFI=*/nullptr); } }; } @@ -942,6 +988,13 @@ PreservedAnalyses TailCallElimPass::run(Function &F, TargetTransformInfo &TTI = AM.getResult(F); AliasAnalysis &AA = AM.getResult(F); + // This must come first. It needs the 2 analyses, meaning, if it came after + // the lines asking for the cached result, should they be nullptr (which, in + // the case of the PDT, is likely), updates to the trees would be missed. + auto *BFI = (!ForceDisableBFI && UpdateFunctionEntryCount && + F.getEntryCount().has_value() && F.getEntryCount()->getCount()) + ? &AM.getResult(F) + : nullptr; auto &ORE = AM.getResult(F); auto *DT = AM.getCachedResult(F); auto *PDT = AM.getCachedResult(F); @@ -949,7 +1002,8 @@ PreservedAnalyses TailCallElimPass::run(Function &F, // UpdateStrategy based on some test results. It is feasible to switch the // UpdateStrategy to Lazy if we find it profitable later. DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Eager); - bool Changed = TailRecursionEliminator::eliminate(F, &TTI, &AA, &ORE, DTU); + bool Changed = + TailRecursionEliminator::eliminate(F, &TTI, &AA, &ORE, DTU, BFI); if (!Changed) return PreservedAnalyses::all(); diff --git a/llvm/test/Transforms/TailCallElim/entry-count-adjustment.ll b/llvm/test/Transforms/TailCallElim/entry-count-adjustment.ll new file mode 100644 index 000000000000..6001e6040a74 --- /dev/null +++ b/llvm/test/Transforms/TailCallElim/entry-count-adjustment.ll @@ -0,0 +1,120 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-globals +; RUN: opt -passes=tailcallelim -S %s -o - | FileCheck %s --check-prefixes=CHECK,ENABLED +; RUN: opt -passes=tailcallelim -tre-disable-entrycount-recompute -S %s -o - | FileCheck %s --check-prefixes=CHECK,DISABLED + +; Test that tail call elimination correctly adjusts function entry counts +; when eliminating tail recursive calls. + +; Basic test: eliminate a tail call and adjust entry count +define i32 @test_basic_entry_count_adjustment(i32 %n) !prof !0 { +; CHECK-LABEL: @test_basic_entry_count_adjustment( +; CHECK-NEXT: entry: +; CHECK-NEXT: br label [[TAILRECURSE:%.*]] +; CHECK: tailrecurse: +; CHECK-NEXT: [[N_TR:%.*]] = phi i32 [ [[N:%.*]], [[ENTRY:%.*]] ], [ [[SUB:%.*]], [[IF_THEN:%.*]] ] +; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i32 [[N_TR]], 0 +; CHECK-NEXT: br i1 [[CMP]], label [[IF_THEN]], label [[IF_ELSE:%.*]], !prof [[PROF1:![0-9]+]] +; CHECK: if.then: +; CHECK-NEXT: [[SUB]] = sub i32 [[N_TR]], 1 +; CHECK-NEXT: br label [[TAILRECURSE]] +; CHECK: if.else: +; CHECK-NEXT: ret i32 0 +; +entry: + %cmp = icmp sgt i32 %n, 0 + br i1 %cmp, label %if.then, label %if.else, !prof !1 + +if.then: ; preds = %entry + %sub = sub i32 %n, 1 + %call = tail call i32 @test_basic_entry_count_adjustment(i32 %sub) + ret i32 %call + +if.else: ; preds = %entry + ret i32 0 +} + +; Test multiple tail calls in different blocks with different frequencies +define i32 @test_multiple_blocks_entry_count(i32 %n, i32 %flag) !prof !2 { +; CHECK-LABEL: @test_multiple_blocks_entry_count( +; CHECK-NEXT: entry: +; CHECK-NEXT: br label [[TAILRECURSE:%.*]] +; CHECK: tailrecurse: +; CHECK-NEXT: [[N_TR:%.*]] = phi i32 [ [[N:%.*]], [[ENTRY:%.*]] ], [ [[SUB1:%.*]], [[BLOCK1:%.*]] ], [ [[SUB2:%.*]], [[BLOCK2:%.*]] ] +; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i32 [[N_TR]], 0 +; CHECK-NEXT: br i1 [[CMP]], label [[CHECK_FLAG:%.*]], label [[BASE_CASE:%.*]], !prof [[PROF3:![0-9]+]] +; CHECK: check.flag: +; CHECK-NEXT: [[CMP_FLAG:%.*]] = icmp eq i32 [[FLAG:%.*]], 1 +; CHECK-NEXT: br i1 [[CMP_FLAG]], label [[BLOCK1]], label [[BLOCK2]], !prof [[PROF4:![0-9]+]] +; CHECK: block1: +; CHECK-NEXT: [[SUB1]] = sub i32 [[N_TR]], 1 +; CHECK-NEXT: br label [[TAILRECURSE]] +; CHECK: block2: +; CHECK-NEXT: [[SUB2]] = sub i32 [[N_TR]], 2 +; CHECK-NEXT: br label [[TAILRECURSE]] +; CHECK: base.case: +; CHECK-NEXT: ret i32 1 +; +entry: + %cmp = icmp sgt i32 %n, 0 + br i1 %cmp, label %check.flag, label %base.case, !prof !3 +check.flag: + %cmp.flag = icmp eq i32 %flag, 1 + br i1 %cmp.flag, label %block1, label %block2, !prof !4 +block1: ; preds = %check.flag + %sub1 = sub i32 %n, 1 + %call1 = tail call i32 @test_multiple_blocks_entry_count(i32 %sub1, i32 %flag) + ret i32 %call1 +block2: ; preds = %check.flag + %sub2 = sub i32 %n, 2 + %call2 = tail call i32 @test_multiple_blocks_entry_count(i32 %sub2, i32 %flag) + ret i32 %call2 +base.case: ; preds = %entry + ret i32 1 +} + +define i32 @test_no_entry_count(i32 %n) { +; CHECK-LABEL: @test_no_entry_count( +; CHECK-NEXT: entry: +; CHECK-NEXT: br label [[TAILRECURSE:%.*]] +; CHECK: tailrecurse: +; CHECK-NEXT: [[N_TR:%.*]] = phi i32 [ [[N:%.*]], [[ENTRY:%.*]] ], [ [[SUB:%.*]], [[IF_THEN:%.*]] ] +; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i32 [[N_TR]], 0 +; CHECK-NEXT: br i1 [[CMP]], label [[IF_THEN]], label [[IF_ELSE:%.*]] +; CHECK: if.then: +; CHECK-NEXT: [[SUB]] = sub i32 [[N_TR]], 1 +; CHECK-NEXT: br label [[TAILRECURSE]] +; CHECK: if.else: +; CHECK-NEXT: ret i32 0 +; +entry: + %cmp = icmp sgt i32 %n, 0 + br i1 %cmp, label %if.then, label %if.else + +if.then: ; preds = %entry + %sub = sub i32 %n, 1 + %call = tail call i32 @test_no_entry_count(i32 %sub) + ret i32 %call + +if.else: ; preds = %entry + ret i32 0 +} + +; Function entry count metadata +!0 = !{!"function_entry_count", i64 1000} +!1 = !{!"branch_weights", i32 800, i32 200} +!2 = !{!"function_entry_count", i64 2000} +!3 = !{!"branch_weights", i32 3, i32 1} +!4 = !{!"branch_weights", i32 100, i32 900} +;. +; ENABLED: [[META0:![0-9]+]] = !{!"function_entry_count", i64 200} +; ENABLED: [[PROF1]] = !{!"branch_weights", i32 800, i32 200} +; ENABLED: [[META2:![0-9]+]] = !{!"function_entry_count", i64 500} +; ENABLED: [[PROF3]] = !{!"branch_weights", i32 3, i32 1} +; ENABLED: [[PROF4]] = !{!"branch_weights", i32 100, i32 900} +;. +; DISABLED: [[META0:![0-9]+]] = !{!"function_entry_count", i64 1000} +; DISABLED: [[PROF1]] = !{!"branch_weights", i32 800, i32 200} +; DISABLED: [[META2:![0-9]+]] = !{!"function_entry_count", i64 2000} +; DISABLED: [[PROF3]] = !{!"branch_weights", i32 3, i32 1} +; DISABLED: [[PROF4]] = !{!"branch_weights", i32 100, i32 900} +;.