[BOLT][NFCI] Simplify DataAggregator using traces (#143289)

Consistently apply traces as defined in #127125 for branch profile
aggregation. This combines branches and fall-through records into one.

With large input binaries/profiles, the speed up in aggregation time
(`-time-aggr`, wall time):
- perf.data, pre-BOLT input: 154.5528s -> 144.0767s
- pre-aggregated data, pre-BOLT input: 15.1026s -> 9.0711s
- pre-aggregated data, BOLTed input: 15.4871s -> 10.0077s

Test Plan: NFC
This commit is contained in:
Amir Ayupov
2025-06-16 23:54:40 -07:00
committed by GitHub
parent 41b9d28327
commit 7e6c1bd3ed
2 changed files with 104 additions and 132 deletions

View File

@@ -99,24 +99,28 @@ private:
uint64_t Addr;
};
/// Container for the unit of branch data.
/// Backwards compatible with legacy use for branches and fall-throughs:
/// - if \p Branch is FT_ONLY or FT_EXTERNAL_ORIGIN, the trace only
/// contains fall-through data,
/// - if \p To is BR_ONLY, the trace only contains branch data.
struct Trace {
static constexpr const uint64_t EXTERNAL = 0ULL;
static constexpr const uint64_t BR_ONLY = -1ULL;
static constexpr const uint64_t FT_ONLY = -1ULL;
static constexpr const uint64_t FT_EXTERNAL_ORIGIN = -2ULL;
uint64_t Branch;
uint64_t From;
uint64_t To;
Trace(uint64_t From, uint64_t To) : From(From), To(To) {}
bool operator==(const Trace &Other) const {
return From == Other.From && To == Other.To;
}
auto tie() const { return std::tie(Branch, From, To); }
bool operator==(const Trace &Other) const { return tie() == Other.tie(); }
bool operator<(const Trace &Other) const { return tie() < Other.tie(); }
};
friend raw_ostream &operator<<(raw_ostream &OS, const Trace &);
struct TraceHash {
size_t operator()(const Trace &L) const {
return std::hash<uint64_t>()(L.From << 32 | L.To);
}
};
struct FTInfo {
uint64_t InternCount{0};
uint64_t ExternCount{0};
size_t operator()(const Trace &L) const { return hash_combine(L.tie()); }
};
struct TakenBranchInfo {
@@ -126,8 +130,8 @@ private:
/// Intermediate storage for profile data. We save the results of parsing
/// and use them later for processing and assigning profile.
std::unordered_map<Trace, TakenBranchInfo, TraceHash> BranchLBRs;
std::unordered_map<Trace, FTInfo, TraceHash> FallthroughLBRs;
std::unordered_map<Trace, TakenBranchInfo, TraceHash> TraceMap;
std::vector<std::pair<Trace, TakenBranchInfo>> Traces;
std::unordered_map<uint64_t, uint64_t> BasicSamples;
std::vector<PerfMemSample> MemSamples;
@@ -200,8 +204,8 @@ private:
/// Return a vector of offsets corresponding to a trace in a function
/// if the trace is valid, std::nullopt otherwise.
std::optional<SmallVector<std::pair<uint64_t, uint64_t>, 16>>
getFallthroughsInTrace(BinaryFunction &BF, const LBREntry &First,
const LBREntry &Second, uint64_t Count = 1) const;
getFallthroughsInTrace(BinaryFunction &BF, const Trace &Trace,
uint64_t Count) const;
/// Record external entry into the function \p BF.
///
@@ -265,8 +269,7 @@ private:
bool doBranch(uint64_t From, uint64_t To, uint64_t Count, uint64_t Mispreds);
/// Register a trace between two LBR entries supplied in execution order.
bool doTrace(const LBREntry &First, const LBREntry &Second,
uint64_t Count = 1);
bool doTrace(const Trace &Trace, uint64_t Count);
/// Parser helpers
/// Return false if we exhausted our parser buffer and finished parsing
@@ -516,6 +519,21 @@ inline raw_ostream &operator<<(raw_ostream &OS,
OS << formatv("{0:x} -> {1:x}/{2}", L.From, L.To, L.Mispred ? 'M' : 'P');
return OS;
}
inline raw_ostream &operator<<(raw_ostream &OS,
const DataAggregator::Trace &T) {
switch (T.Branch) {
case DataAggregator::Trace::FT_ONLY:
case DataAggregator::Trace::FT_EXTERNAL_ORIGIN:
break;
default:
OS << Twine::utohexstr(T.Branch) << " -> ";
}
OS << Twine::utohexstr(T.From);
if (T.To != DataAggregator::Trace::BR_ONLY)
OS << " ... " << Twine::utohexstr(T.To);
return OS;
}
} // namespace bolt
} // namespace llvm

View File

@@ -523,6 +523,10 @@ Error DataAggregator::preprocessProfile(BinaryContext &BC) {
deleteTempFiles();
heatmap:
// Sort parsed traces for faster processing.
if (!opts::BasicAggregation)
llvm::sort(Traces, llvm::less_first());
if (!opts::HeatmapMode)
return Error::success();
@@ -598,8 +602,7 @@ void DataAggregator::processProfile(BinaryContext &BC) {
llvm::stable_sort(MemEvents.second.Data);
// Release intermediate storage.
clear(BranchLBRs);
clear(FallthroughLBRs);
clear(Traces);
clear(BasicSamples);
clear(MemSamples);
}
@@ -780,37 +783,19 @@ bool DataAggregator::doBranch(uint64_t From, uint64_t To, uint64_t Count,
return doInterBranch(FromFunc, ToFunc, From, To, Count, Mispreds);
}
bool DataAggregator::doTrace(const LBREntry &First, const LBREntry &Second,
uint64_t Count) {
BinaryFunction *FromFunc = getBinaryFunctionContainingAddress(First.To);
BinaryFunction *ToFunc = getBinaryFunctionContainingAddress(Second.From);
bool DataAggregator::doTrace(const Trace &Trace, uint64_t Count) {
const uint64_t From = Trace.From, To = Trace.To;
BinaryFunction *FromFunc = getBinaryFunctionContainingAddress(From);
BinaryFunction *ToFunc = getBinaryFunctionContainingAddress(To);
NumTraces += Count;
if (!FromFunc || !ToFunc) {
LLVM_DEBUG({
dbgs() << "Out of range trace starting in ";
if (FromFunc)
dbgs() << formatv("{0} @ {1:x}", *FromFunc,
First.To - FromFunc->getAddress());
else
dbgs() << Twine::utohexstr(First.To);
dbgs() << " and ending in ";
if (ToFunc)
dbgs() << formatv("{0} @ {1:x}", *ToFunc,
Second.From - ToFunc->getAddress());
else
dbgs() << Twine::utohexstr(Second.From);
dbgs() << '\n';
});
LLVM_DEBUG(dbgs() << "Out of range trace " << Trace << '\n');
NumLongRangeTraces += Count;
return false;
}
if (FromFunc != ToFunc) {
LLVM_DEBUG(dbgs() << "Invalid trace " << Trace << '\n');
NumInvalidTraces += Count;
LLVM_DEBUG({
dbgs() << "Invalid trace starting in " << FromFunc->getPrintName()
<< formatv(" @ {0:x}", First.To - FromFunc->getAddress())
<< " and ending in " << ToFunc->getPrintName()
<< formatv(" @ {0:x}\n", Second.From - ToFunc->getAddress());
});
return false;
}
@@ -818,28 +803,21 @@ bool DataAggregator::doTrace(const LBREntry &First, const LBREntry &Second,
BinaryFunction *ParentFunc = getBATParentFunction(*FromFunc);
if (!ParentFunc)
ParentFunc = FromFunc;
ParentFunc->SampleCountInBytes += Count * (Second.From - First.To);
ParentFunc->SampleCountInBytes += Count * (To - From);
const uint64_t FuncAddress = FromFunc->getAddress();
std::optional<BoltAddressTranslation::FallthroughListTy> FTs =
BAT && BAT->isBATFunction(FuncAddress)
? BAT->getFallthroughsInTrace(FuncAddress, First.To, Second.From)
: getFallthroughsInTrace(*FromFunc, First, Second, Count);
? BAT->getFallthroughsInTrace(FuncAddress, From, To)
: getFallthroughsInTrace(*FromFunc, Trace, Count);
if (!FTs) {
LLVM_DEBUG(
dbgs() << "Invalid trace starting in " << FromFunc->getPrintName()
<< " @ " << Twine::utohexstr(First.To - FromFunc->getAddress())
<< " and ending in " << ToFunc->getPrintName() << " @ "
<< ToFunc->getPrintName() << " @ "
<< Twine::utohexstr(Second.From - ToFunc->getAddress()) << '\n');
LLVM_DEBUG(dbgs() << "Invalid trace " << Trace << '\n');
NumInvalidTraces += Count;
return false;
}
LLVM_DEBUG(dbgs() << "Processing " << FTs->size() << " fallthroughs for "
<< FromFunc->getPrintName() << ":"
<< Twine::utohexstr(First.To) << " to "
<< Twine::utohexstr(Second.From) << ".\n");
<< FromFunc->getPrintName() << ":" << Trace << '\n');
for (auto [From, To] : *FTs) {
if (BAT) {
From = BAT->translate(FromFunc->getAddress(), From, /*IsBranchSrc=*/true);
@@ -852,17 +830,15 @@ bool DataAggregator::doTrace(const LBREntry &First, const LBREntry &Second,
}
std::optional<SmallVector<std::pair<uint64_t, uint64_t>, 16>>
DataAggregator::getFallthroughsInTrace(BinaryFunction &BF,
const LBREntry &FirstLBR,
const LBREntry &SecondLBR,
DataAggregator::getFallthroughsInTrace(BinaryFunction &BF, const Trace &Trace,
uint64_t Count) const {
SmallVector<std::pair<uint64_t, uint64_t>, 16> Branches;
BinaryContext &BC = BF.getBinaryContext();
// Offsets of the trace within this function.
const uint64_t From = FirstLBR.To - BF.getAddress();
const uint64_t To = SecondLBR.From - BF.getAddress();
const uint64_t From = Trace.From - BF.getAddress();
const uint64_t To = Trace.To - BF.getAddress();
if (From > To)
return std::nullopt;
@@ -889,8 +865,9 @@ DataAggregator::getFallthroughsInTrace(BinaryFunction &BF,
// Adjust FromBB if the first LBR is a return from the last instruction in
// the previous block (that instruction should be a call).
if (From == FromBB->getOffset() && !BF.containsAddress(FirstLBR.From) &&
!FromBB->isEntryPoint() && !FromBB->isLandingPad()) {
if (Trace.Branch != Trace::FT_ONLY && !BF.containsAddress(Trace.Branch) &&
From == FromBB->getOffset() && !FromBB->isEntryPoint() &&
!FromBB->isLandingPad()) {
const BinaryBasicBlock *PrevBB =
BF.getLayout().getBlock(FromBB->getIndex() - 1);
if (PrevBB->getSuccessor(FromBB->getLabel())) {
@@ -898,10 +875,9 @@ DataAggregator::getFallthroughsInTrace(BinaryFunction &BF,
if (Instr && BC.MIB->isCall(*Instr))
FromBB = PrevBB;
else
LLVM_DEBUG(dbgs() << "invalid incoming LBR (no call): " << FirstLBR
<< '\n');
LLVM_DEBUG(dbgs() << "invalid trace (no call): " << Trace << '\n');
} else {
LLVM_DEBUG(dbgs() << "invalid incoming LBR: " << FirstLBR << '\n');
LLVM_DEBUG(dbgs() << "invalid trace: " << Trace << '\n');
}
}
@@ -920,9 +896,7 @@ DataAggregator::getFallthroughsInTrace(BinaryFunction &BF,
// Check for bad LBRs.
if (!BB->getSuccessor(NextBB->getLabel())) {
LLVM_DEBUG(dbgs() << "no fall-through for the trace:\n"
<< " " << FirstLBR << '\n'
<< " " << SecondLBR << '\n');
LLVM_DEBUG(dbgs() << "no fall-through for the trace: " << Trace << '\n');
return std::nullopt;
}
@@ -1227,14 +1201,15 @@ std::error_code DataAggregator::parseAggregatedLBREntry() {
FT_EXTERNAL_ORIGIN // f
} Type = INVALID;
// The number of fields to parse, set based on Type.
/// The number of fields to parse, set based on \p Type.
int AddrNum = 0;
int CounterNum = 0;
// Storage for parsed fields.
/// Storage for parsed fields.
StringRef EventName;
std::optional<Location> Addr[3];
int64_t Counters[2] = {0};
/// Parse strings: record type and optionally an event name.
while (Type == INVALID || Type == EVENT_NAME) {
while (checkAndConsumeFS()) {
}
@@ -1268,6 +1243,7 @@ std::error_code DataAggregator::parseAggregatedLBREntry() {
CounterNum = SSI(Str).Case("B", 2).Case("E", 0).Default(1);
}
/// Parse locations depending on entry type, recording them in \p Addr array.
for (int I = 0; I < AddrNum; ++I) {
while (checkAndConsumeFS()) {
}
@@ -1277,6 +1253,7 @@ std::error_code DataAggregator::parseAggregatedLBREntry() {
Addr[I] = AddrOrErr.get();
}
/// Parse counters depending on entry type.
for (int I = 0; I < CounterNum; ++I) {
while (checkAndConsumeFS()) {
}
@@ -1287,11 +1264,13 @@ std::error_code DataAggregator::parseAggregatedLBREntry() {
Counters[I] = CountOrErr.get();
}
/// Expect end of line here.
if (!checkAndConsumeNewLine()) {
reportError("expected end of line");
return make_error_code(llvm::errc::io_error);
}
/// Record event name into \p EventNames and return.
if (Type == EVENT_NAME) {
EventNames.insert(EventName);
return std::error_code();
@@ -1305,6 +1284,7 @@ std::error_code DataAggregator::parseAggregatedLBREntry() {
int64_t Count = Counters[0];
int64_t Mispreds = Counters[1];
/// Record basic IP sample into \p BasicSamples and return.
if (Type == SAMPLE) {
BasicSamples[FromOffset] += Count;
NumTotalSamples += Count;
@@ -1316,29 +1296,25 @@ std::error_code DataAggregator::parseAggregatedLBREntry() {
if (ToFunc)
ToFunc->setHasProfileAvailable();
Trace Trace(FromOffset, ToOffset);
// Taken trace
if (Type == TRACE || Type == BRANCH) {
TakenBranchInfo &Info = BranchLBRs[Trace];
Info.TakenCount += Count;
Info.MispredCount += Mispreds;
/// For legacy fall-through types, adjust locations to match Trace container.
if (Type == FT || Type == FT_EXTERNAL_ORIGIN) {
Addr[2] = Location(Addr[1]->Offset); // Trace To
Addr[1] = Location(Addr[0]->Offset); // Trace From
// Put a magic value into Trace Branch to differentiate from a full trace.
Addr[0] = Location(Type == FT ? Trace::FT_ONLY : Trace::FT_EXTERNAL_ORIGIN);
}
NumTotalSamples += Count;
/// For legacy branch type, mark Trace To to differentite from a full trace.
if (Type == BRANCH) {
Addr[2] = Location(Trace::BR_ONLY);
}
// Construct fallthrough part of the trace
if (Type == TRACE) {
const uint64_t TraceFtEndOffset = Addr[2]->Offset;
Trace.From = ToOffset;
Trace.To = TraceFtEndOffset;
Type = FromFunc == ToFunc ? FT : FT_EXTERNAL_ORIGIN;
}
// Add fallthrough trace
if (Type != BRANCH) {
FTInfo &Info = FallthroughLBRs[Trace];
(Type == FT ? Info.InternCount : Info.ExternCount) += Count;
NumTraces += Count;
}
/// Record a trace.
Trace T{Addr[0]->Offset, Addr[1]->Offset, Addr[2]->Offset};
TakenBranchInfo TI{(uint64_t)Count, (uint64_t)Mispreds};
Traces.emplace_back(T, TI);
NumTotalSamples += Count;
return std::error_code();
}
@@ -1350,7 +1326,7 @@ bool DataAggregator::ignoreKernelInterrupt(LBREntry &LBR) const {
std::error_code DataAggregator::printLBRHeatMap() {
outs() << "PERF2BOLT: parse branch events...\n";
NamedRegionTimer T("parseBranch", "Parsing branch events", TimerGroupName,
NamedRegionTimer T("buildHeatmap", "Building heatmap", TimerGroupName,
TimerGroupDesc, opts::TimeAggregator);
if (BC->IsLinuxKernel) {
@@ -1386,12 +1362,9 @@ std::error_code DataAggregator::printLBRHeatMap() {
// Register basic samples and perf LBR addresses not covered by fallthroughs.
for (const auto &[PC, Hits] : BasicSamples)
HM.registerAddress(PC, Hits);
for (const auto &LBR : FallthroughLBRs) {
const Trace &Trace = LBR.first;
const FTInfo &Info = LBR.second;
HM.registerAddressRange(Trace.From, Trace.To,
Info.InternCount + Info.ExternCount);
}
for (const auto &[Trace, Info] : Traces)
if (Trace.To != Trace::BR_ONLY)
HM.registerAddressRange(Trace.From, Trace.To, Info.TakenCount);
if (HM.getNumInvalidRanges())
outs() << "HEATMAP: invalid traces: " << HM.getNumInvalidRanges() << '\n';
@@ -1437,22 +1410,10 @@ void DataAggregator::parseLBRSample(const PerfBranchSample &Sample,
// chronological order)
if (NeedsSkylakeFix && NumEntry <= 2)
continue;
if (NextLBR) {
// Record fall-through trace.
const uint64_t TraceFrom = LBR.To;
const uint64_t TraceTo = NextLBR->From;
const BinaryFunction *TraceBF =
getBinaryFunctionContainingAddress(TraceFrom);
FTInfo &Info = FallthroughLBRs[Trace(TraceFrom, TraceTo)];
if (TraceBF && TraceBF->containsAddress(LBR.From))
++Info.InternCount;
else
++Info.ExternCount;
++NumTraces;
}
uint64_t TraceTo = NextLBR ? NextLBR->From : Trace::BR_ONLY;
NextLBR = &LBR;
TakenBranchInfo &Info = BranchLBRs[Trace(LBR.From, LBR.To)];
TakenBranchInfo &Info = TraceMap[Trace{LBR.From, LBR.To, TraceTo}];
++Info.TakenCount;
Info.MispredCount += LBR.Mispred;
}
@@ -1563,10 +1524,14 @@ std::error_code DataAggregator::parseBranchEvents() {
parseLBRSample(Sample, NeedsSkylakeFix);
}
for (const Trace &Trace : llvm::make_first_range(BranchLBRs))
for (const uint64_t Addr : {Trace.From, Trace.To})
Traces.reserve(TraceMap.size());
for (const auto &[Trace, Info] : TraceMap) {
Traces.emplace_back(Trace, Info);
for (const uint64_t Addr : {Trace.Branch, Trace.From})
if (BinaryFunction *BF = getBinaryFunctionContainingAddress(Addr))
BF->setHasProfileAvailable();
}
clear(TraceMap);
outs() << "PERF2BOLT: read " << NumSamples << " samples and " << NumEntries
<< " LBR entries\n";
@@ -1591,23 +1556,12 @@ void DataAggregator::processBranchEvents() {
NamedRegionTimer T("processBranch", "Processing branch events",
TimerGroupName, TimerGroupDesc, opts::TimeAggregator);
for (const auto &AggrLBR : FallthroughLBRs) {
const Trace &Loc = AggrLBR.first;
const FTInfo &Info = AggrLBR.second;
LBREntry First{Loc.From, Loc.From, false};
LBREntry Second{Loc.To, Loc.To, false};
if (Info.InternCount)
doTrace(First, Second, Info.InternCount);
if (Info.ExternCount) {
First.From = 0;
doTrace(First, Second, Info.ExternCount);
}
}
for (const auto &AggrLBR : BranchLBRs) {
const Trace &Loc = AggrLBR.first;
const TakenBranchInfo &Info = AggrLBR.second;
doBranch(Loc.From, Loc.To, Info.TakenCount, Info.MispredCount);
for (const auto &[Trace, Info] : Traces) {
if (Trace.Branch != Trace::FT_ONLY &&
Trace.Branch != Trace::FT_EXTERNAL_ORIGIN)
doBranch(Trace.Branch, Trace.From, Info.TakenCount, Info.MispredCount);
if (Trace.To != Trace::BR_ONLY)
doTrace(Trace, Info.TakenCount);
}
printBranchSamplesDiagnostics();
}