65
65
#include "llvm/Support/GraphWriter.h"
66
66
#include "llvm/Support/MathExtras.h"
67
67
#include "llvm/Support/raw_ostream.h"
68
+ #include "llvm/Transforms/Utils/BasicBlockUtils.h"
69
+ #include "llvm/Transforms/Utils/CallPromotionUtils.h"
68
70
#include "llvm/Transforms/Utils/Local.h"
69
71
#include "llvm/Transforms/Utils/ValueMapper.h"
70
72
#include <cassert>
71
73
#include <numeric>
72
74
#include <optional>
75
+ #include <string>
73
76
74
77
using namespace llvm;
75
78
@@ -188,6 +191,7 @@ PIPE_OPERATOR(AAPointerInfo)
188
191
PIPE_OPERATOR(AAAssumptionInfo)
189
192
PIPE_OPERATOR(AAUnderlyingObjects)
190
193
PIPE_OPERATOR(AAAddressSpace)
194
+ PIPE_OPERATOR(AAIndirectCallInfo)
191
195
192
196
#undef PIPE_OPERATOR
193
197
@@ -10560,15 +10564,12 @@ struct AACallEdgesCallSite : public AACallEdgesImpl {
10560
10564
return Change;
10561
10565
}
10562
10566
10563
- // Process callee metadata if available.
10564
- if (auto *MD = getCtxI()->getMetadata(LLVMContext::MD_callees)) {
10565
- for (const auto &Op : MD->operands()) {
10566
- Function *Callee = mdconst::dyn_extract_or_null<Function>(Op);
10567
- if (Callee)
10568
- addCalledFunction(Callee, Change);
10569
- }
10570
- return Change;
10571
- }
10567
+ if (CB->isIndirectCall())
10568
+ if (auto *IndirectCallAA = A.getAAFor<AAIndirectCallInfo>(
10569
+ *this, getIRPosition(), DepClassTy::OPTIONAL))
10570
+ if (IndirectCallAA->foreachCallee(
10571
+ [&](Function *Fn) { return VisitValue(*Fn, CB); }))
10572
+ return Change;
10572
10573
10573
10574
// The most simple case.
10574
10575
ProcessCalledOperand(CB->getCalledOperand(), CB);
@@ -12051,6 +12052,224 @@ struct AAUnderlyingObjectsFunction final : AAUnderlyingObjectsImpl {
12051
12052
};
12052
12053
} // namespace
12053
12054
12055
+ /// ------------------------ Indirect Call Info -------------------------------
12056
+ namespace {
12057
+ struct AAIndirectCallInfoCallSite : public AAIndirectCallInfo {
12058
+ AAIndirectCallInfoCallSite(const IRPosition &IRP, Attributor &A)
12059
+ : AAIndirectCallInfo(IRP, A) {}
12060
+
12061
+ /// See AbstractAttribute::initialize(...).
12062
+ void initialize(Attributor &A) override {
12063
+ auto *MD = getCtxI()->getMetadata(LLVMContext::MD_callees);
12064
+ if (!MD)
12065
+ return;
12066
+ for (const auto &Op : MD->operands())
12067
+ if (Function *Callee = mdconst::dyn_extract_or_null<Function>(Op))
12068
+ PotentialCallees.insert(Callee);
12069
+ }
12070
+
12071
+ ChangeStatus updateImpl(Attributor &A) override {
12072
+ CallBase *CB = cast<CallBase>(getCtxI());
12073
+ Value *FP = CB->getCalledOperand();
12074
+
12075
+ SmallSetVector<Function *, 4> AssumedCalleesNow;
12076
+ bool AllCalleesKnownNow = AllCalleesKnown;
12077
+
12078
+ // Use simplification to find potential callees, if !callees was present,
12079
+ // fallback to that set if necessary.
12080
+ bool UsedAssumedInformation;
12081
+ SmallVector<AA::ValueAndContext> Values;
12082
+ if (!A.getAssumedSimplifiedValues(IRPosition::value(*FP), this, Values,
12083
+ AA::ValueScope::AnyScope,
12084
+ UsedAssumedInformation)) {
12085
+ if (PotentialCallees.empty())
12086
+ return indicatePessimisticFixpoint();
12087
+ AssumedCalleesNow.set_union(PotentialCallees);
12088
+ }
12089
+
12090
+ // Check simplification result, prune known UB callees, also restrict it to
12091
+ // the !callees set, if present.
12092
+ for (auto &VAC : Values) {
12093
+ if (isa<UndefValue>(VAC.getValue()))
12094
+ continue;
12095
+ if (isa<ConstantPointerNull>(VAC.getValue()) &&
12096
+ VAC.getValue()->getType()->getPointerAddressSpace() == 0)
12097
+ continue;
12098
+ // TODO: Check for known UB, e.g., poison + noundef.
12099
+ if (auto *VACFn = dyn_cast<Function>(VAC.getValue())) {
12100
+ if (PotentialCallees.empty() || PotentialCallees.count(VACFn))
12101
+ AssumedCalleesNow.insert(VACFn);
12102
+ continue;
12103
+ }
12104
+ if (!PotentialCallees.empty()) {
12105
+ AssumedCalleesNow.set_union(PotentialCallees);
12106
+ break;
12107
+ }
12108
+ AllCalleesKnownNow = false;
12109
+ }
12110
+
12111
+ // If we can't specialize at all, give up now.
12112
+ if (!AllCalleesKnownNow && AssumedCalleesNow.empty())
12113
+ return indicatePessimisticFixpoint();
12114
+
12115
+ if (AssumedCalleesNow == AssumedCalles &&
12116
+ AllCalleesKnown == AllCalleesKnownNow)
12117
+ return ChangeStatus::UNCHANGED;
12118
+
12119
+ std::swap(AssumedCalles, AssumedCalleesNow);
12120
+ AllCalleesKnown = AllCalleesKnownNow;
12121
+ return ChangeStatus::CHANGED;
12122
+ }
12123
+
12124
+ /// See AbstractAttribute::manifest(...).
12125
+ ChangeStatus manifest(Attributor &A) override {
12126
+
12127
+ ChangeStatus Changed = ChangeStatus::UNCHANGED;
12128
+ CallBase *CB = cast<CallBase>(getCtxI());
12129
+ Value *FP = CB->getCalledOperand();
12130
+
12131
+ bool CBIsVoid = CB->getType()->isVoidTy();
12132
+ Instruction *IP = CB;
12133
+ FunctionType *CSFT = CB->getFunctionType();
12134
+ SmallVector<Value *> CSArgs(CB->arg_begin(), CB->arg_end());
12135
+
12136
+ // If we know all callees and there are none, the call site is (effectively)
12137
+ // dead (or UB).
12138
+ if (AssumedCalles.empty()) {
12139
+ assert(AllCalleesKnown &&
12140
+ "Expected all callees to be known if there are none.");
12141
+ A.changeToUnreachableAfterManifest(CB);
12142
+ return ChangeStatus::CHANGED;
12143
+ }
12144
+
12145
+ // Special handling for the single callee case.
12146
+ if (AllCalleesKnown && AssumedCalles.size() == 1) {
12147
+ auto *NewCallee = AssumedCalles.front();
12148
+ if (isLegalToPromote(*CB, NewCallee)) {
12149
+ promoteCall(*CB, NewCallee, nullptr);
12150
+ return ChangeStatus::CHANGED;
12151
+ }
12152
+ Instruction *NewCall = CallInst::Create(FunctionCallee(CSFT, NewCallee),
12153
+ CSArgs, CB->getName(), CB);
12154
+ if (!CBIsVoid)
12155
+ A.changeAfterManifest(IRPosition::callsite_returned(*CB), *NewCall);
12156
+ A.deleteAfterManifest(*CB);
12157
+ return ChangeStatus::CHANGED;
12158
+ }
12159
+
12160
+ // For each potential value we create a conditional
12161
+ //
12162
+ // ```
12163
+ // if (ptr == value) value(args);
12164
+ // else ...
12165
+ // ```
12166
+ //
12167
+ ICmpInst *LastCmp = nullptr;
12168
+ SmallVector<std::pair<CallInst *, Instruction *>> NewCalls;
12169
+ for (Function *NewCallee : AssumedCalles) {
12170
+ LastCmp = new ICmpInst(IP, llvm::CmpInst::ICMP_EQ, FP, NewCallee);
12171
+ Instruction *ThenTI =
12172
+ SplitBlockAndInsertIfThen(LastCmp, IP, /* Unreachable */ false);
12173
+ BasicBlock *CBBB = CB->getParent();
12174
+ auto *SplitTI = cast<BranchInst>(LastCmp->getNextNode());
12175
+ BasicBlock *ElseBB;
12176
+ if (IP == CB) {
12177
+ ElseBB = BasicBlock::Create(ThenTI->getContext(), "",
12178
+ ThenTI->getFunction(), CBBB);
12179
+ IP = BranchInst::Create(CBBB, ElseBB);
12180
+ SplitTI->replaceUsesOfWith(CBBB, ElseBB);
12181
+ } else {
12182
+ ElseBB = IP->getParent();
12183
+ ThenTI->replaceUsesOfWith(ElseBB, CBBB);
12184
+ }
12185
+ CastInst *RetBC = nullptr;
12186
+ CallInst *NewCall = nullptr;
12187
+ if (isLegalToPromote(*CB, NewCallee)) {
12188
+ auto *CBClone = cast<CallBase>(CB->clone());
12189
+ CBClone->insertBefore(ThenTI);
12190
+ NewCall = &cast<CallInst>(promoteCall(*CBClone, NewCallee, &RetBC));
12191
+ } else {
12192
+ NewCall = CallInst::Create(FunctionCallee(CSFT, NewCallee), CSArgs,
12193
+ CB->getName(), ThenTI);
12194
+ }
12195
+ NewCalls.push_back({NewCall, RetBC});
12196
+ }
12197
+
12198
+ // Check if we need the fallback indirect call still.
12199
+ if (AllCalleesKnown) {
12200
+ LastCmp->replaceAllUsesWith(ConstantInt::getTrue(LastCmp->getContext()));
12201
+ LastCmp->eraseFromParent();
12202
+ new UnreachableInst(IP->getContext(), IP);
12203
+ IP->eraseFromParent();
12204
+ } else {
12205
+ auto *CBClone = cast<CallInst>(CB->clone());
12206
+ CBClone->setName(CB->getName());
12207
+ CBClone->insertBefore(IP);
12208
+ NewCalls.push_back({CBClone, nullptr});
12209
+ }
12210
+
12211
+ // Check if we need a PHI to merge the results.
12212
+ if (!CBIsVoid) {
12213
+ auto *PHI = PHINode::Create(CB->getType(), NewCalls.size(),
12214
+ CB->getName() + ".phi",
12215
+ &*CB->getParent()->getFirstInsertionPt());
12216
+ for (auto &It : NewCalls) {
12217
+ CallBase *NewCall = It.first;
12218
+ Instruction *CallRet = It.second ? It.second : It.first;
12219
+ if (CallRet->getType() == CB->getType())
12220
+ PHI->addIncoming(CallRet, CallRet->getParent());
12221
+ else if (NewCall->getType()->isVoidTy())
12222
+ PHI->addIncoming(PoisonValue::get(CB->getType()),
12223
+ NewCall->getParent());
12224
+ else
12225
+ llvm_unreachable("Call return should match or be void!");
12226
+ }
12227
+ A.changeAfterManifest(IRPosition::callsite_returned(*CB), *PHI);
12228
+ }
12229
+
12230
+ A.deleteAfterManifest(*CB);
12231
+ Changed = ChangeStatus::CHANGED;
12232
+
12233
+ return Changed;
12234
+ }
12235
+
12236
+ /// See AbstractAttribute::getAsStr().
12237
+ const std::string getAsStr(Attributor *A) const override {
12238
+ return std::string(AllCalleesKnown ? "eliminate" : "specialize") +
12239
+ " indirect call site with " + std::to_string(AssumedCalles.size()) +
12240
+ " functions";
12241
+ }
12242
+
12243
+ void trackStatistics() const override {
12244
+ if (AllCalleesKnown) {
12245
+ STATS_DECLTRACK(
12246
+ Eliminated, CallSites,
12247
+ "Number of indirect call sites eliminated via specialization")
12248
+ } else {
12249
+ STATS_DECLTRACK(Specialized, CallSites,
12250
+ "Number of indirect call sites specialized")
12251
+ }
12252
+ }
12253
+
12254
+ bool foreachCallee(function_ref<bool(Function *)> CB) const override {
12255
+ return isValidState() && AllCalleesKnown && all_of(AssumedCalles, CB);
12256
+ }
12257
+
12258
+ private:
12259
+ /// If the !callee metadata was present, this set will contain all potential
12260
+ /// callees (superset).
12261
+ SmallSetVector<Function *, 4> PotentialCallees;
12262
+
12263
+ /// This set contains all currently assumed calllees, which might grow over
12264
+ /// time.
12265
+ SmallSetVector<Function *, 4> AssumedCalles;
12266
+
12267
+ /// Flag to indicate if all possible callees are in the AssumedCalles set or
12268
+ /// if there could be others.
12269
+ bool AllCalleesKnown = true;
12270
+ };
12271
+ } // namespace
12272
+
12054
12273
/// ------------------------ Address Space ------------------------------------
12055
12274
namespace {
12056
12275
struct AAAddressSpaceImpl : public AAAddressSpace {
@@ -12259,6 +12478,7 @@ const char AAPointerInfo::ID = 0;
12259
12478
const char AAAssumptionInfo::ID = 0;
12260
12479
const char AAUnderlyingObjects::ID = 0;
12261
12480
const char AAAddressSpace::ID = 0;
12481
+ const char AAIndirectCallInfo::ID = 0;
12262
12482
12263
12483
// Macro magic to create the static generator function for attributes that
12264
12484
// follow the naming scheme.
@@ -12305,6 +12525,18 @@ const char AAAddressSpace::ID = 0;
12305
12525
return *AA; \
12306
12526
}
12307
12527
12528
+ #define CREATE_ABSTRACT_ATTRIBUTE_FOR_ONE_POSITION(POS, SUFFIX, CLASS) \
12529
+ CLASS &CLASS::createForPosition(const IRPosition &IRP, Attributor &A) { \
12530
+ CLASS *AA = nullptr; \
12531
+ switch (IRP.getPositionKind()) { \
12532
+ SWITCH_PK_CREATE(CLASS, IRP, POS, SUFFIX) \
12533
+ default: \
12534
+ llvm_unreachable("Cannot create " #CLASS " for position otherthan " #POS \
12535
+ " position!"); \
12536
+ } \
12537
+ return *AA; \
12538
+ }
12539
+
12308
12540
#define CREATE_ALL_ABSTRACT_ATTRIBUTE_FOR_POSITION(CLASS) \
12309
12541
CLASS &CLASS::createForPosition(const IRPosition &IRP, Attributor &A) { \
12310
12542
CLASS *AA = nullptr; \
@@ -12383,6 +12615,9 @@ CREATE_ALL_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAIsDead)
12383
12615
CREATE_ALL_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANoFree)
12384
12616
CREATE_ALL_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAUnderlyingObjects)
12385
12617
12618
+ CREATE_ABSTRACT_ATTRIBUTE_FOR_ONE_POSITION(IRP_CALL_SITE, CallSite,
12619
+ AAIndirectCallInfo)
12620
+
12386
12621
CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAHeapToStack)
12387
12622
CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAUndefinedBehavior)
12388
12623
CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANonConvergent)
@@ -12396,5 +12631,6 @@ CREATE_NON_RET_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAMemoryBehavior)
12396
12631
#undef CREATE_NON_RET_ABSTRACT_ATTRIBUTE_FOR_POSITION
12397
12632
#undef CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION
12398
12633
#undef CREATE_ALL_ABSTRACT_ATTRIBUTE_FOR_POSITION
12634
+ #undef CREATE_ABSTRACT_ATTRIBUTE_FOR_ONE_POSITION
12399
12635
#undef SWITCH_PK_CREATE
12400
12636
#undef SWITCH_PK_INV
0 commit comments