Skip to content

Commit 9c08e76

Browse files
committed
[Attributor] Introduce AAIndirectCallInfo
AAIndirectCallInfo will collect information and specialize indirect call sites. It is similar to our IndirectCallPromotion but runs as part of the Attributor (so with assumed callee information). It also expands more calls and let's the rest of the pipeline figure out what is UB, for now. We use existing call promotion logic to improve the result, otherwise we rely on the (implicit) function pointer cast. This effectively "fixes" llvm#60327 as it will undo the type punning early enough for the inliner to work with the (now specialized, thus direct) call. Fixes: llvm#60327
1 parent 18b211c commit 9c08e76

File tree

7 files changed

+434
-70
lines changed

7 files changed

+434
-70
lines changed

llvm/include/llvm/Transforms/IPO/Attributor.h

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6109,6 +6109,48 @@ struct AAAddressSpace : public StateWrapper<BooleanState, AbstractAttribute> {
61096109
static const char ID;
61106110
};
61116111

6112+
/// An abstract interface for indirect call information interference.
6113+
struct AAIndirectCallInfo
6114+
: public StateWrapper<BooleanState, AbstractAttribute> {
6115+
AAIndirectCallInfo(const IRPosition &IRP, Attributor &A)
6116+
: StateWrapper<BooleanState, AbstractAttribute>(IRP) {}
6117+
6118+
/// The point is to derive callees, after all.
6119+
static bool requiresCalleeForCallBase() { return false; }
6120+
6121+
/// See AbstractAttribute::isValidIRPositionForInit
6122+
static bool isValidIRPositionForInit(Attributor &A, const IRPosition &IRP) {
6123+
if (IRP.getPositionKind() != IRPosition::IRP_CALL_SITE)
6124+
return false;
6125+
auto *CB = cast<CallBase>(IRP.getCtxI());
6126+
return CB->getOpcode() == Instruction::Call && CB->isIndirectCall() &&
6127+
!CB->isMustTailCall();
6128+
}
6129+
6130+
/// Create an abstract attribute view for the position \p IRP.
6131+
static AAIndirectCallInfo &createForPosition(const IRPosition &IRP,
6132+
Attributor &A);
6133+
6134+
/// Call \CB on each potential callee value and return true if all were known
6135+
/// and \p CB returned true on all of them. Otherwise, return false.
6136+
virtual bool foreachCallee(function_ref<bool(Function *)> CB) const = 0;
6137+
6138+
/// See AbstractAttribute::getName()
6139+
const std::string getName() const override { return "AAIndirectCallInfo"; }
6140+
6141+
/// See AbstractAttribute::getIdAddr()
6142+
const char *getIdAddr() const override { return &ID; }
6143+
6144+
/// This function should return true if the type of the \p AA is
6145+
/// AAIndirectCallInfo
6146+
static bool classof(const AbstractAttribute *AA) {
6147+
return (AA->getIdAddr() == &ID);
6148+
}
6149+
6150+
/// Unique ID (due to the unique address)
6151+
static const char ID;
6152+
};
6153+
61126154
raw_ostream &operator<<(raw_ostream &, const AAPointerInfo::Access &);
61136155

61146156
/// Run options, used by the pass manager.

llvm/lib/Transforms/IPO/Attributor.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3460,8 +3460,10 @@ void Attributor::identifyDefaultAbstractAttributes(Function &F) {
34603460
Function *Callee = dyn_cast_if_present<Function>(CB.getCalledOperand());
34613461
// TODO: Even if the callee is not known now we might be able to simplify
34623462
// the call/callee.
3463-
if (!Callee)
3463+
if (!Callee) {
3464+
getOrCreateAAFor<AAIndirectCallInfo>(CBFnPos);
34643465
return true;
3466+
}
34653467

34663468
// Every call site can track active assumptions.
34673469
getOrCreateAAFor<AAAssumptionInfo>(CBFnPos);

llvm/lib/Transforms/IPO/AttributorAttributes.cpp

Lines changed: 245 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,14 @@
6565
#include "llvm/Support/GraphWriter.h"
6666
#include "llvm/Support/MathExtras.h"
6767
#include "llvm/Support/raw_ostream.h"
68+
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
69+
#include "llvm/Transforms/Utils/CallPromotionUtils.h"
6870
#include "llvm/Transforms/Utils/Local.h"
6971
#include "llvm/Transforms/Utils/ValueMapper.h"
7072
#include <cassert>
7173
#include <numeric>
7274
#include <optional>
75+
#include <string>
7376

7477
using namespace llvm;
7578

@@ -188,6 +191,7 @@ PIPE_OPERATOR(AAPointerInfo)
188191
PIPE_OPERATOR(AAAssumptionInfo)
189192
PIPE_OPERATOR(AAUnderlyingObjects)
190193
PIPE_OPERATOR(AAAddressSpace)
194+
PIPE_OPERATOR(AAIndirectCallInfo)
191195

192196
#undef PIPE_OPERATOR
193197

@@ -10560,15 +10564,12 @@ struct AACallEdgesCallSite : public AACallEdgesImpl {
1056010564
return Change;
1056110565
}
1056210566

10563-
// Process callee metadata if available.
10564-
if (auto *MD = getCtxI()->getMetadata(LLVMContext::MD_callees)) {
10565-
for (const auto &Op : MD->operands()) {
10566-
Function *Callee = mdconst::dyn_extract_or_null<Function>(Op);
10567-
if (Callee)
10568-
addCalledFunction(Callee, Change);
10569-
}
10570-
return Change;
10571-
}
10567+
if (CB->isIndirectCall())
10568+
if (auto *IndirectCallAA = A.getAAFor<AAIndirectCallInfo>(
10569+
*this, getIRPosition(), DepClassTy::OPTIONAL))
10570+
if (IndirectCallAA->foreachCallee(
10571+
[&](Function *Fn) { return VisitValue(*Fn, CB); }))
10572+
return Change;
1057210573

1057310574
// The most simple case.
1057410575
ProcessCalledOperand(CB->getCalledOperand(), CB);
@@ -12051,6 +12052,224 @@ struct AAUnderlyingObjectsFunction final : AAUnderlyingObjectsImpl {
1205112052
};
1205212053
} // namespace
1205312054

12055+
/// ------------------------ Indirect Call Info -------------------------------
12056+
namespace {
12057+
struct AAIndirectCallInfoCallSite : public AAIndirectCallInfo {
12058+
AAIndirectCallInfoCallSite(const IRPosition &IRP, Attributor &A)
12059+
: AAIndirectCallInfo(IRP, A) {}
12060+
12061+
/// See AbstractAttribute::initialize(...).
12062+
void initialize(Attributor &A) override {
12063+
auto *MD = getCtxI()->getMetadata(LLVMContext::MD_callees);
12064+
if (!MD)
12065+
return;
12066+
for (const auto &Op : MD->operands())
12067+
if (Function *Callee = mdconst::dyn_extract_or_null<Function>(Op))
12068+
PotentialCallees.insert(Callee);
12069+
}
12070+
12071+
ChangeStatus updateImpl(Attributor &A) override {
12072+
CallBase *CB = cast<CallBase>(getCtxI());
12073+
Value *FP = CB->getCalledOperand();
12074+
12075+
SmallSetVector<Function *, 4> AssumedCalleesNow;
12076+
bool AllCalleesKnownNow = AllCalleesKnown;
12077+
12078+
// Use simplification to find potential callees, if !callees was present,
12079+
// fallback to that set if necessary.
12080+
bool UsedAssumedInformation;
12081+
SmallVector<AA::ValueAndContext> Values;
12082+
if (!A.getAssumedSimplifiedValues(IRPosition::value(*FP), this, Values,
12083+
AA::ValueScope::AnyScope,
12084+
UsedAssumedInformation)) {
12085+
if (PotentialCallees.empty())
12086+
return indicatePessimisticFixpoint();
12087+
AssumedCalleesNow.set_union(PotentialCallees);
12088+
}
12089+
12090+
// Check simplification result, prune known UB callees, also restrict it to
12091+
// the !callees set, if present.
12092+
for (auto &VAC : Values) {
12093+
if (isa<UndefValue>(VAC.getValue()))
12094+
continue;
12095+
if (isa<ConstantPointerNull>(VAC.getValue()) &&
12096+
VAC.getValue()->getType()->getPointerAddressSpace() == 0)
12097+
continue;
12098+
// TODO: Check for known UB, e.g., poison + noundef.
12099+
if (auto *VACFn = dyn_cast<Function>(VAC.getValue())) {
12100+
if (PotentialCallees.empty() || PotentialCallees.count(VACFn))
12101+
AssumedCalleesNow.insert(VACFn);
12102+
continue;
12103+
}
12104+
if (!PotentialCallees.empty()) {
12105+
AssumedCalleesNow.set_union(PotentialCallees);
12106+
break;
12107+
}
12108+
AllCalleesKnownNow = false;
12109+
}
12110+
12111+
// If we can't specialize at all, give up now.
12112+
if (!AllCalleesKnownNow && AssumedCalleesNow.empty())
12113+
return indicatePessimisticFixpoint();
12114+
12115+
if (AssumedCalleesNow == AssumedCalles &&
12116+
AllCalleesKnown == AllCalleesKnownNow)
12117+
return ChangeStatus::UNCHANGED;
12118+
12119+
std::swap(AssumedCalles, AssumedCalleesNow);
12120+
AllCalleesKnown = AllCalleesKnownNow;
12121+
return ChangeStatus::CHANGED;
12122+
}
12123+
12124+
/// See AbstractAttribute::manifest(...).
12125+
ChangeStatus manifest(Attributor &A) override {
12126+
12127+
ChangeStatus Changed = ChangeStatus::UNCHANGED;
12128+
CallBase *CB = cast<CallBase>(getCtxI());
12129+
Value *FP = CB->getCalledOperand();
12130+
12131+
bool CBIsVoid = CB->getType()->isVoidTy();
12132+
Instruction *IP = CB;
12133+
FunctionType *CSFT = CB->getFunctionType();
12134+
SmallVector<Value *> CSArgs(CB->arg_begin(), CB->arg_end());
12135+
12136+
// If we know all callees and there are none, the call site is (effectively)
12137+
// dead (or UB).
12138+
if (AssumedCalles.empty()) {
12139+
assert(AllCalleesKnown &&
12140+
"Expected all callees to be known if there are none.");
12141+
A.changeToUnreachableAfterManifest(CB);
12142+
return ChangeStatus::CHANGED;
12143+
}
12144+
12145+
// Special handling for the single callee case.
12146+
if (AllCalleesKnown && AssumedCalles.size() == 1) {
12147+
auto *NewCallee = AssumedCalles.front();
12148+
if (isLegalToPromote(*CB, NewCallee)) {
12149+
promoteCall(*CB, NewCallee, nullptr);
12150+
return ChangeStatus::CHANGED;
12151+
}
12152+
Instruction *NewCall = CallInst::Create(FunctionCallee(CSFT, NewCallee),
12153+
CSArgs, CB->getName(), CB);
12154+
if (!CBIsVoid)
12155+
A.changeAfterManifest(IRPosition::callsite_returned(*CB), *NewCall);
12156+
A.deleteAfterManifest(*CB);
12157+
return ChangeStatus::CHANGED;
12158+
}
12159+
12160+
// For each potential value we create a conditional
12161+
//
12162+
// ```
12163+
// if (ptr == value) value(args);
12164+
// else ...
12165+
// ```
12166+
//
12167+
ICmpInst *LastCmp = nullptr;
12168+
SmallVector<std::pair<CallInst *, Instruction *>> NewCalls;
12169+
for (Function *NewCallee : AssumedCalles) {
12170+
LastCmp = new ICmpInst(IP, llvm::CmpInst::ICMP_EQ, FP, NewCallee);
12171+
Instruction *ThenTI =
12172+
SplitBlockAndInsertIfThen(LastCmp, IP, /* Unreachable */ false);
12173+
BasicBlock *CBBB = CB->getParent();
12174+
auto *SplitTI = cast<BranchInst>(LastCmp->getNextNode());
12175+
BasicBlock *ElseBB;
12176+
if (IP == CB) {
12177+
ElseBB = BasicBlock::Create(ThenTI->getContext(), "",
12178+
ThenTI->getFunction(), CBBB);
12179+
IP = BranchInst::Create(CBBB, ElseBB);
12180+
SplitTI->replaceUsesOfWith(CBBB, ElseBB);
12181+
} else {
12182+
ElseBB = IP->getParent();
12183+
ThenTI->replaceUsesOfWith(ElseBB, CBBB);
12184+
}
12185+
CastInst *RetBC = nullptr;
12186+
CallInst *NewCall = nullptr;
12187+
if (isLegalToPromote(*CB, NewCallee)) {
12188+
auto *CBClone = cast<CallBase>(CB->clone());
12189+
CBClone->insertBefore(ThenTI);
12190+
NewCall = &cast<CallInst>(promoteCall(*CBClone, NewCallee, &RetBC));
12191+
} else {
12192+
NewCall = CallInst::Create(FunctionCallee(CSFT, NewCallee), CSArgs,
12193+
CB->getName(), ThenTI);
12194+
}
12195+
NewCalls.push_back({NewCall, RetBC});
12196+
}
12197+
12198+
// Check if we need the fallback indirect call still.
12199+
if (AllCalleesKnown) {
12200+
LastCmp->replaceAllUsesWith(ConstantInt::getTrue(LastCmp->getContext()));
12201+
LastCmp->eraseFromParent();
12202+
new UnreachableInst(IP->getContext(), IP);
12203+
IP->eraseFromParent();
12204+
} else {
12205+
auto *CBClone = cast<CallInst>(CB->clone());
12206+
CBClone->setName(CB->getName());
12207+
CBClone->insertBefore(IP);
12208+
NewCalls.push_back({CBClone, nullptr});
12209+
}
12210+
12211+
// Check if we need a PHI to merge the results.
12212+
if (!CBIsVoid) {
12213+
auto *PHI = PHINode::Create(CB->getType(), NewCalls.size(),
12214+
CB->getName() + ".phi",
12215+
&*CB->getParent()->getFirstInsertionPt());
12216+
for (auto &It : NewCalls) {
12217+
CallBase *NewCall = It.first;
12218+
Instruction *CallRet = It.second ? It.second : It.first;
12219+
if (CallRet->getType() == CB->getType())
12220+
PHI->addIncoming(CallRet, CallRet->getParent());
12221+
else if (NewCall->getType()->isVoidTy())
12222+
PHI->addIncoming(PoisonValue::get(CB->getType()),
12223+
NewCall->getParent());
12224+
else
12225+
llvm_unreachable("Call return should match or be void!");
12226+
}
12227+
A.changeAfterManifest(IRPosition::callsite_returned(*CB), *PHI);
12228+
}
12229+
12230+
A.deleteAfterManifest(*CB);
12231+
Changed = ChangeStatus::CHANGED;
12232+
12233+
return Changed;
12234+
}
12235+
12236+
/// See AbstractAttribute::getAsStr().
12237+
const std::string getAsStr(Attributor *A) const override {
12238+
return std::string(AllCalleesKnown ? "eliminate" : "specialize") +
12239+
" indirect call site with " + std::to_string(AssumedCalles.size()) +
12240+
" functions";
12241+
}
12242+
12243+
void trackStatistics() const override {
12244+
if (AllCalleesKnown) {
12245+
STATS_DECLTRACK(
12246+
Eliminated, CallSites,
12247+
"Number of indirect call sites eliminated via specialization")
12248+
} else {
12249+
STATS_DECLTRACK(Specialized, CallSites,
12250+
"Number of indirect call sites specialized")
12251+
}
12252+
}
12253+
12254+
bool foreachCallee(function_ref<bool(Function *)> CB) const override {
12255+
return isValidState() && AllCalleesKnown && all_of(AssumedCalles, CB);
12256+
}
12257+
12258+
private:
12259+
/// If the !callee metadata was present, this set will contain all potential
12260+
/// callees (superset).
12261+
SmallSetVector<Function *, 4> PotentialCallees;
12262+
12263+
/// This set contains all currently assumed calllees, which might grow over
12264+
/// time.
12265+
SmallSetVector<Function *, 4> AssumedCalles;
12266+
12267+
/// Flag to indicate if all possible callees are in the AssumedCalles set or
12268+
/// if there could be others.
12269+
bool AllCalleesKnown = true;
12270+
};
12271+
} // namespace
12272+
1205412273
/// ------------------------ Address Space ------------------------------------
1205512274
namespace {
1205612275
struct AAAddressSpaceImpl : public AAAddressSpace {
@@ -12259,6 +12478,7 @@ const char AAPointerInfo::ID = 0;
1225912478
const char AAAssumptionInfo::ID = 0;
1226012479
const char AAUnderlyingObjects::ID = 0;
1226112480
const char AAAddressSpace::ID = 0;
12481+
const char AAIndirectCallInfo::ID = 0;
1226212482

1226312483
// Macro magic to create the static generator function for attributes that
1226412484
// follow the naming scheme.
@@ -12305,6 +12525,18 @@ const char AAAddressSpace::ID = 0;
1230512525
return *AA; \
1230612526
}
1230712527

12528+
#define CREATE_ABSTRACT_ATTRIBUTE_FOR_ONE_POSITION(POS, SUFFIX, CLASS) \
12529+
CLASS &CLASS::createForPosition(const IRPosition &IRP, Attributor &A) { \
12530+
CLASS *AA = nullptr; \
12531+
switch (IRP.getPositionKind()) { \
12532+
SWITCH_PK_CREATE(CLASS, IRP, POS, SUFFIX) \
12533+
default: \
12534+
llvm_unreachable("Cannot create " #CLASS " for position otherthan " #POS \
12535+
" position!"); \
12536+
} \
12537+
return *AA; \
12538+
}
12539+
1230812540
#define CREATE_ALL_ABSTRACT_ATTRIBUTE_FOR_POSITION(CLASS) \
1230912541
CLASS &CLASS::createForPosition(const IRPosition &IRP, Attributor &A) { \
1231012542
CLASS *AA = nullptr; \
@@ -12383,6 +12615,9 @@ CREATE_ALL_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAIsDead)
1238312615
CREATE_ALL_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANoFree)
1238412616
CREATE_ALL_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAUnderlyingObjects)
1238512617

12618+
CREATE_ABSTRACT_ATTRIBUTE_FOR_ONE_POSITION(IRP_CALL_SITE, CallSite,
12619+
AAIndirectCallInfo)
12620+
1238612621
CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAHeapToStack)
1238712622
CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAUndefinedBehavior)
1238812623
CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANonConvergent)
@@ -12396,5 +12631,6 @@ CREATE_NON_RET_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAMemoryBehavior)
1239612631
#undef CREATE_NON_RET_ABSTRACT_ATTRIBUTE_FOR_POSITION
1239712632
#undef CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION
1239812633
#undef CREATE_ALL_ABSTRACT_ATTRIBUTE_FOR_POSITION
12634+
#undef CREATE_ABSTRACT_ATTRIBUTE_FOR_ONE_POSITION
1239912635
#undef SWITCH_PK_CREATE
1240012636
#undef SWITCH_PK_INV

llvm/lib/Transforms/IPO/OpenMPOpt.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5419,6 +5419,11 @@ void OpenMPOpt::registerAAsForFunction(Attributor &A, const Function &F) {
54195419
UsedAssumedInformation, AA::Interprocedural);
54205420
continue;
54215421
}
5422+
if (auto *CI = dyn_cast<CallBase>(&I)) {
5423+
if (CI->isIndirectCall())
5424+
A.getOrCreateAAFor<AAIndirectCallInfo>(
5425+
IRPosition::callsite_function(*CI));
5426+
}
54225427
if (auto *SI = dyn_cast<StoreInst>(&I)) {
54235428
A.getOrCreateAAFor<AAIsDead>(IRPosition::value(*SI));
54245429
continue;

0 commit comments

Comments
 (0)