Skip to content

Commit 73c3b73

Browse files
authored
[ctx_prof] Add support for ICP (llvm#105469)
An overload of `llvm::promoteCallWithIfThenElse` that updates the contextual profile. High-level, this is very simple: after creating the `if... then (direct call) else (indirect call)` structure, we instrument the new callsites and BBs (the instrumentation will help with tracking for other IPO transformations, and, ultimately, to match counter values before flattening to `MD_prof`). In more detail: - move the callsite instrumentation of the indirect call to the `else` BB, before the indirect call - create a new callsite instrumentation for the direct call - create instrumentation for both the `then` and `else` BBs - we could instrument just one (MST-style) but we're not running the binary with this instrumentation, and at most this would save some space (less counters tracked). For simplicity instrumenting both at this point - update each context belonging to the caller by updating the counters, and moving the indirect callee to the new, direct callsite ID Issue llvm#89287
1 parent 016e1eb commit 73c3b73

File tree

8 files changed

+344
-33
lines changed

8 files changed

+344
-33
lines changed

llvm/include/llvm/Analysis/CtxProfAnalysis.h

+14-3
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,12 @@ class PGOContextualProfile {
7373
return FuncInfo.find(getDefinedFunctionGUID(F))->second.NextCallsiteIndex++;
7474
}
7575

76+
using ConstVisitor = function_ref<void(const PGOCtxProfContext &)>;
77+
using Visitor = function_ref<void(PGOCtxProfContext &)>;
78+
79+
void update(Visitor, const Function *F = nullptr);
80+
void visit(ConstVisitor, const Function *F = nullptr) const;
81+
7682
const CtxProfFlatProfile flatten() const;
7783

7884
bool invalidate(Module &, const PreservedAnalyses &PA,
@@ -105,13 +111,18 @@ class CtxProfAnalysis : public AnalysisInfoMixin<CtxProfAnalysis> {
105111

106112
class CtxProfAnalysisPrinterPass
107113
: public PassInfoMixin<CtxProfAnalysisPrinterPass> {
108-
raw_ostream &OS;
109-
110114
public:
111-
explicit CtxProfAnalysisPrinterPass(raw_ostream &OS) : OS(OS) {}
115+
enum class PrintMode { Everything, JSON };
116+
explicit CtxProfAnalysisPrinterPass(raw_ostream &OS,
117+
PrintMode Mode = PrintMode::Everything)
118+
: OS(OS), Mode(Mode) {}
112119

113120
PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
114121
static bool isRequired() { return true; }
122+
123+
private:
124+
raw_ostream &OS;
125+
const PrintMode Mode;
115126
};
116127

117128
/// Assign a GUID to functions as metadata. GUID calculation takes linkage into

llvm/include/llvm/IR/IntrinsicInst.h

+2
Original file line numberDiff line numberDiff line change
@@ -1535,6 +1535,7 @@ class InstrProfCntrInstBase : public InstrProfInstBase {
15351535
ConstantInt *getNumCounters() const;
15361536
// The index of the counter that this instruction acts on.
15371537
ConstantInt *getIndex() const;
1538+
void setIndex(uint32_t Idx);
15381539
};
15391540

15401541
/// This represents the llvm.instrprof.cover intrinsic.
@@ -1585,6 +1586,7 @@ class InstrProfCallsite : public InstrProfCntrInstBase {
15851586
return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
15861587
}
15871588
Value *getCallee() const;
1589+
void setCallee(Value *Callee);
15881590
};
15891591

15901592
/// This represents the llvm.instrprof.timestamp intrinsic.

llvm/include/llvm/ProfileData/PGOCtxProfReader.h

+22
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,25 @@ class PGOCtxProfContext final {
5757

5858
GlobalValue::GUID guid() const { return GUID; }
5959
const SmallVectorImpl<uint64_t> &counters() const { return Counters; }
60+
SmallVectorImpl<uint64_t> &counters() { return Counters; }
61+
62+
uint64_t getEntrycount() const {
63+
assert(!Counters.empty() &&
64+
"Functions are expected to have at their entry BB instrumented, so "
65+
"there should always be at least 1 counter.");
66+
return Counters[0];
67+
}
68+
6069
const CallsiteMapTy &callsites() const { return Callsites; }
6170
CallsiteMapTy &callsites() { return Callsites; }
6271

72+
void ingestContext(uint32_t CSId, PGOCtxProfContext &&Other) {
73+
auto [Iter, _] = callsites().try_emplace(CSId, CallTargetMapTy());
74+
Iter->second.emplace(Other.guid(), std::move(Other));
75+
}
76+
77+
void resizeCounters(uint32_t Size) { Counters.resize(Size); }
78+
6379
bool hasCallsite(uint32_t I) const {
6480
return Callsites.find(I) != Callsites.end();
6581
}
@@ -68,6 +84,12 @@ class PGOCtxProfContext final {
6884
assert(hasCallsite(I) && "Callsite not found");
6985
return Callsites.find(I)->second;
7086
}
87+
88+
CallTargetMapTy &callsite(uint32_t I) {
89+
assert(hasCallsite(I) && "Callsite not found");
90+
return Callsites.find(I)->second;
91+
}
92+
7193
void getContainedGuids(DenseSet<GlobalValue::GUID> &Guids) const;
7294
};
7395

llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h

+4
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#ifndef LLVM_TRANSFORMS_UTILS_CALLPROMOTIONUTILS_H
1515
#define LLVM_TRANSFORMS_UTILS_CALLPROMOTIONUTILS_H
1616

17+
#include "llvm/Analysis/CtxProfAnalysis.h"
1718
namespace llvm {
1819
template <typename T> class ArrayRef;
1920
class Constant;
@@ -56,6 +57,9 @@ CallBase &promoteCall(CallBase &CB, Function *Callee,
5657
CallBase &promoteCallWithIfThenElse(CallBase &CB, Function *Callee,
5758
MDNode *BranchWeights = nullptr);
5859

60+
CallBase *promoteCallWithIfThenElse(CallBase &CB, Function &Callee,
61+
PGOContextualProfile &CtxProf);
62+
5963
/// This is similar to `promoteCallWithIfThenElse` except that the condition to
6064
/// promote a virtual call is that \p VPtr is the same as any of \p
6165
/// AddressPoints.

llvm/lib/Analysis/CtxProfAnalysis.cpp

+50-29
Original file line numberDiff line numberDiff line change
@@ -173,16 +173,22 @@ PreservedAnalyses CtxProfAnalysisPrinterPass::run(Module &M,
173173
return PreservedAnalyses::all();
174174
}
175175

176-
OS << "Function Info:\n";
177-
for (const auto &[Guid, FuncInfo] : C.FuncInfo)
178-
OS << Guid << " : " << FuncInfo.Name
179-
<< ". MaxCounterID: " << FuncInfo.NextCounterIndex
180-
<< ". MaxCallsiteID: " << FuncInfo.NextCallsiteIndex << "\n";
176+
if (Mode == PrintMode::Everything) {
177+
OS << "Function Info:\n";
178+
for (const auto &[Guid, FuncInfo] : C.FuncInfo)
179+
OS << Guid << " : " << FuncInfo.Name
180+
<< ". MaxCounterID: " << FuncInfo.NextCounterIndex
181+
<< ". MaxCallsiteID: " << FuncInfo.NextCallsiteIndex << "\n";
182+
}
181183

182184
const auto JSONed = ::llvm::json::toJSON(C.profiles());
183185

184-
OS << "\nCurrent Profile:\n";
186+
if (Mode == PrintMode::Everything)
187+
OS << "\nCurrent Profile:\n";
185188
OS << formatv("{0:2}", JSONed);
189+
if (Mode == PrintMode::JSON)
190+
return PreservedAnalyses::all();
191+
186192
OS << "\n";
187193
OS << "\nFlat Profile:\n";
188194
auto Flat = C.flatten();
@@ -209,34 +215,49 @@ InstrProfIncrementInst *CtxProfAnalysis::getBBInstrumentation(BasicBlock &BB) {
209215
return nullptr;
210216
}
211217

212-
static void
213-
preorderVisit(const PGOCtxProfContext::CallTargetMapTy &Profiles,
214-
function_ref<void(const PGOCtxProfContext &)> Visitor) {
215-
std::function<void(const PGOCtxProfContext &)> Traverser =
216-
[&](const auto &Ctx) {
217-
Visitor(Ctx);
218-
for (const auto &[_, SubCtxSet] : Ctx.callsites())
219-
for (const auto &[__, Subctx] : SubCtxSet)
220-
Traverser(Subctx);
221-
};
222-
for (const auto &[_, P] : Profiles)
218+
template <class ProfilesTy, class ProfTy>
219+
static void preorderVisit(ProfilesTy &Profiles,
220+
function_ref<void(ProfTy &)> Visitor,
221+
GlobalValue::GUID Match = 0) {
222+
std::function<void(ProfTy &)> Traverser = [&](auto &Ctx) {
223+
if (!Match || Ctx.guid() == Match)
224+
Visitor(Ctx);
225+
for (auto &[_, SubCtxSet] : Ctx.callsites())
226+
for (auto &[__, Subctx] : SubCtxSet)
227+
Traverser(Subctx);
228+
};
229+
for (auto &[_, P] : Profiles)
223230
Traverser(P);
224231
}
225232

233+
void PGOContextualProfile::update(Visitor V, const Function *F) {
234+
GlobalValue::GUID G = F ? getDefinedFunctionGUID(*F) : 0U;
235+
preorderVisit<PGOCtxProfContext::CallTargetMapTy, PGOCtxProfContext>(
236+
*Profiles, V, G);
237+
}
238+
239+
void PGOContextualProfile::visit(ConstVisitor V, const Function *F) const {
240+
GlobalValue::GUID G = F ? getDefinedFunctionGUID(*F) : 0U;
241+
preorderVisit<const PGOCtxProfContext::CallTargetMapTy,
242+
const PGOCtxProfContext>(*Profiles, V, G);
243+
}
244+
226245
const CtxProfFlatProfile PGOContextualProfile::flatten() const {
227246
assert(Profiles.has_value());
228247
CtxProfFlatProfile Flat;
229-
preorderVisit(*Profiles, [&](const PGOCtxProfContext &Ctx) {
230-
auto [It, Ins] = Flat.insert({Ctx.guid(), {}});
231-
if (Ins) {
232-
llvm::append_range(It->second, Ctx.counters());
233-
return;
234-
}
235-
assert(It->second.size() == Ctx.counters().size() &&
236-
"All contexts corresponding to a function should have the exact "
237-
"same number of counters.");
238-
for (size_t I = 0, E = It->second.size(); I < E; ++I)
239-
It->second[I] += Ctx.counters()[I];
240-
});
248+
preorderVisit<const PGOCtxProfContext::CallTargetMapTy,
249+
const PGOCtxProfContext>(
250+
*Profiles, [&](const PGOCtxProfContext &Ctx) {
251+
auto [It, Ins] = Flat.insert({Ctx.guid(), {}});
252+
if (Ins) {
253+
llvm::append_range(It->second, Ctx.counters());
254+
return;
255+
}
256+
assert(It->second.size() == Ctx.counters().size() &&
257+
"All contexts corresponding to a function should have the exact "
258+
"same number of counters.");
259+
for (size_t I = 0, E = It->second.size(); I < E; ++I)
260+
It->second[I] += Ctx.counters()[I];
261+
});
241262
return Flat;
242263
}

llvm/lib/IR/IntrinsicInst.cpp

+10
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,11 @@ ConstantInt *InstrProfCntrInstBase::getIndex() const {
285285
return cast<ConstantInt>(const_cast<Value *>(getArgOperand(3)));
286286
}
287287

288+
void InstrProfCntrInstBase::setIndex(uint32_t Idx) {
289+
assert(isa<InstrProfCntrInstBase>(this));
290+
setArgOperand(3, ConstantInt::get(Type::getInt32Ty(getContext()), Idx));
291+
}
292+
288293
Value *InstrProfIncrementInst::getStep() const {
289294
if (InstrProfIncrementInstStep::classof(this)) {
290295
return const_cast<Value *>(getArgOperand(4));
@@ -300,6 +305,11 @@ Value *InstrProfCallsite::getCallee() const {
300305
return nullptr;
301306
}
302307

308+
void InstrProfCallsite::setCallee(Value *Callee) {
309+
assert(isa<InstrProfCallsite>(this));
310+
setArgOperand(4, Callee);
311+
}
312+
303313
std::optional<RoundingMode> ConstrainedFPIntrinsic::getRoundingMode() const {
304314
unsigned NumOperands = arg_size();
305315
Metadata *MD = nullptr;

llvm/lib/Transforms/Utils/CallPromotionUtils.cpp

+85-1
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,16 @@
1212
//===----------------------------------------------------------------------===//
1313

1414
#include "llvm/Transforms/Utils/CallPromotionUtils.h"
15-
#include "llvm/ADT/STLExtras.h"
15+
#include "llvm/Analysis/CtxProfAnalysis.h"
1616
#include "llvm/Analysis/Loads.h"
1717
#include "llvm/Analysis/TypeMetadataUtils.h"
1818
#include "llvm/IR/AttributeMask.h"
1919
#include "llvm/IR/Constant.h"
2020
#include "llvm/IR/IRBuilder.h"
2121
#include "llvm/IR/Instructions.h"
22+
#include "llvm/IR/IntrinsicInst.h"
2223
#include "llvm/IR/Module.h"
24+
#include "llvm/ProfileData/PGOCtxProfReader.h"
2325
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
2426

2527
using namespace llvm;
@@ -572,6 +574,88 @@ CallBase &llvm::promoteCallWithIfThenElse(CallBase &CB, Function *Callee,
572574
return promoteCall(NewInst, Callee);
573575
}
574576

577+
CallBase *llvm::promoteCallWithIfThenElse(CallBase &CB, Function &Callee,
578+
PGOContextualProfile &CtxProf) {
579+
assert(CB.isIndirectCall());
580+
if (!CtxProf.isFunctionKnown(Callee))
581+
return nullptr;
582+
auto &Caller = *CB.getFunction();
583+
auto *CSInstr = CtxProfAnalysis::getCallsiteInstrumentation(CB);
584+
if (!CSInstr)
585+
return nullptr;
586+
const uint64_t CSIndex = CSInstr->getIndex()->getZExtValue();
587+
588+
CallBase &DirectCall = promoteCall(
589+
versionCallSite(CB, &Callee, /*BranchWeights=*/nullptr), &Callee);
590+
CSInstr->moveBefore(&CB);
591+
const auto NewCSID = CtxProf.allocateNextCallsiteIndex(Caller);
592+
auto *NewCSInstr = cast<InstrProfCallsite>(CSInstr->clone());
593+
NewCSInstr->setIndex(NewCSID);
594+
NewCSInstr->setCallee(&Callee);
595+
NewCSInstr->insertBefore(&DirectCall);
596+
auto &DirectBB = *DirectCall.getParent();
597+
auto &IndirectBB = *CB.getParent();
598+
599+
assert((CtxProfAnalysis::getBBInstrumentation(IndirectBB) == nullptr) &&
600+
"The ICP direct BB is new, it shouldn't have instrumentation");
601+
assert((CtxProfAnalysis::getBBInstrumentation(DirectBB) == nullptr) &&
602+
"The ICP indirect BB is new, it shouldn't have instrumentation");
603+
604+
// Allocate counters for the new basic blocks.
605+
const uint32_t DirectID = CtxProf.allocateNextCounterIndex(Caller);
606+
const uint32_t IndirectID = CtxProf.allocateNextCounterIndex(Caller);
607+
auto *EntryBBIns =
608+
CtxProfAnalysis::getBBInstrumentation(Caller.getEntryBlock());
609+
auto *DirectBBIns = cast<InstrProfCntrInstBase>(EntryBBIns->clone());
610+
DirectBBIns->setIndex(DirectID);
611+
DirectBBIns->insertInto(&DirectBB, DirectBB.getFirstInsertionPt());
612+
613+
auto *IndirectBBIns = cast<InstrProfCntrInstBase>(EntryBBIns->clone());
614+
IndirectBBIns->setIndex(IndirectID);
615+
IndirectBBIns->insertInto(&IndirectBB, IndirectBB.getFirstInsertionPt());
616+
617+
const GlobalValue::GUID CalleeGUID = AssignGUIDPass::getGUID(Callee);
618+
const uint32_t NewCountersSize = IndirectID + 1;
619+
620+
auto ProfileUpdater = [&](PGOCtxProfContext &Ctx) {
621+
assert(Ctx.guid() == AssignGUIDPass::getGUID(Caller));
622+
assert(NewCountersSize - 2 == Ctx.counters().size());
623+
// All the ctx-es belonging to a function must have the same size counters.
624+
Ctx.resizeCounters(NewCountersSize);
625+
626+
// Maybe in this context, the indirect callsite wasn't observed at all
627+
if (!Ctx.hasCallsite(CSIndex))
628+
return;
629+
auto &CSData = Ctx.callsite(CSIndex);
630+
auto It = CSData.find(CalleeGUID);
631+
632+
// Maybe we did notice the indirect callsite, but to other targets.
633+
if (It == CSData.end())
634+
return;
635+
636+
assert(CalleeGUID == It->second.guid());
637+
638+
uint32_t DirectCount = It->second.getEntrycount();
639+
uint32_t TotalCount = 0;
640+
for (const auto &[_, V] : CSData)
641+
TotalCount += V.getEntrycount();
642+
assert(TotalCount >= DirectCount);
643+
uint32_t IndirectCount = TotalCount - DirectCount;
644+
// The ICP's effect is as-if the direct BB would have been taken DirectCount
645+
// times, and the indirect BB, IndirectCount times
646+
Ctx.counters()[DirectID] = DirectCount;
647+
Ctx.counters()[IndirectID] = IndirectCount;
648+
649+
// This particular indirect target needs to be moved to this caller under
650+
// the newly-allocated callsite index.
651+
assert(Ctx.callsites().count(NewCSID) == 0);
652+
Ctx.ingestContext(NewCSID, std::move(It->second));
653+
CSData.erase(CalleeGUID);
654+
};
655+
CtxProf.update(ProfileUpdater, &Caller);
656+
return &DirectCall;
657+
}
658+
575659
CallBase &llvm::promoteCallWithVTableCmp(CallBase &CB, Instruction *VPtr,
576660
Function *Callee,
577661
ArrayRef<Constant *> AddressPoints,

0 commit comments

Comments
 (0)