Skip to content

Commit b7d870e

Browse files
committed
[AssumptionCache] Avoid dangling llvm.assume calls in the cache
PR49043 exposed a problem when it comes to RAUW llvm.assumes. While D96106 would fix it for GVNSink, it seems a more general concern. To avoid future problems this patch moves away from the vector of weak reference model used in the assumption cache. Instead, we track the llvm.assume calls with a callback handle which will remove itself from the cache if the call is deleted. Fixes PR49043. Reviewed By: nikic Differential Revision: https://reviews.llvm.org/D96168
1 parent 378f4e5 commit b7d870e

File tree

8 files changed

+61
-67
lines changed

8 files changed

+61
-67
lines changed

llvm/include/llvm/Analysis/AssumptionCache.h

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
#include "llvm/ADT/ArrayRef.h"
1919
#include "llvm/ADT/DenseMap.h"
2020
#include "llvm/ADT/DenseMapInfo.h"
21+
#include "llvm/ADT/SmallSet.h"
2122
#include "llvm/ADT/SmallVector.h"
23+
#include "llvm/IR/IntrinsicInst.h"
2224
#include "llvm/IR/PassManager.h"
2325
#include "llvm/IR/ValueHandle.h"
2426
#include "llvm/Pass.h"
@@ -44,6 +46,22 @@ class AssumptionCache {
4446
/// llvm.assume.
4547
enum : unsigned { ExprResultIdx = std::numeric_limits<unsigned>::max() };
4648

49+
/// Callback handle to ensure we do not have dangling pointers to llvm.assume
50+
/// calls in our cache.
51+
class AssumeHandle final : public CallbackVH {
52+
AssumptionCache *AC;
53+
54+
/// Make sure llvm.assume calls that are deleted are removed from the cache.
55+
void deleted() override;
56+
57+
public:
58+
AssumeHandle(Value *V, AssumptionCache *AC = nullptr)
59+
: CallbackVH(V), AC(AC) {}
60+
61+
operator Value *() const { return getValPtr(); }
62+
CallInst *getAssumeCI() const { return cast<CallInst>(getValPtr()); }
63+
};
64+
4765
struct ResultElem {
4866
WeakVH Assume;
4967

@@ -59,9 +77,9 @@ class AssumptionCache {
5977
/// We track this to lazily populate our assumptions.
6078
Function &F;
6179

62-
/// Vector of weak value handles to calls of the \@llvm.assume
63-
/// intrinsic.
64-
SmallVector<ResultElem, 4> AssumeHandles;
80+
/// Set of value handles for calls of the \@llvm.assume intrinsic.
81+
using AssumeHandleSet = DenseSet<AssumeHandle, DenseMapInfo<Value *>>;
82+
AssumeHandleSet AssumeHandles;
6583

6684
class AffectedValueCallbackVH final : public CallbackVH {
6785
AssumptionCache *AC;
@@ -137,13 +155,7 @@ class AssumptionCache {
137155

138156
/// Access the list of assumption handles currently tracked for this
139157
/// function.
140-
///
141-
/// Note that these produce weak handles that may be null. The caller must
142-
/// handle that case.
143-
/// FIXME: We should replace this with pointee_iterator<filter_iterator<...>>
144-
/// when we can write that to filter out the null values. Then caller code
145-
/// will become simpler.
146-
MutableArrayRef<ResultElem> assumptions() {
158+
AssumeHandleSet &assumptions() {
147159
if (!Scanned)
148160
scanFunction();
149161
return AssumeHandles;

llvm/lib/Analysis/AssumptionCache.cpp

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,12 @@ void AssumptionCache::unregisterAssumption(CallInst *CI) {
163163
AffectedValues.erase(AVI);
164164
}
165165

166-
erase_value(AssumeHandles, CI);
166+
AssumeHandles.erase({CI, this});
167+
}
168+
169+
void AssumptionCache::AssumeHandle::deleted() {
170+
AC->AssumeHandles.erase(*this);
171+
// 'this' now dangles!
167172
}
168173

169174
void AssumptionCache::AffectedValueCallbackVH::deleted() {
@@ -204,14 +209,14 @@ void AssumptionCache::scanFunction() {
204209
for (BasicBlock &B : F)
205210
for (Instruction &II : B)
206211
if (match(&II, m_Intrinsic<Intrinsic::assume>()))
207-
AssumeHandles.push_back({&II, ExprResultIdx});
212+
AssumeHandles.insert({&II, this});
208213

209214
// Mark the scan as complete.
210215
Scanned = true;
211216

212217
// Update affected values.
213-
for (auto &A : AssumeHandles)
214-
updateAffectedValues(cast<CallInst>(A));
218+
for (auto &AssumeVH : AssumeHandles)
219+
updateAffectedValues(AssumeVH.getAssumeCI());
215220
}
216221

217222
void AssumptionCache::registerAssumption(CallInst *CI) {
@@ -223,28 +228,19 @@ void AssumptionCache::registerAssumption(CallInst *CI) {
223228
if (!Scanned)
224229
return;
225230

226-
AssumeHandles.push_back({CI, ExprResultIdx});
231+
AssumeHandles.insert({CI, this});
227232

228233
#ifndef NDEBUG
229234
assert(CI->getParent() &&
230235
"Cannot register @llvm.assume call not in a basic block");
231236
assert(&F == CI->getParent()->getParent() &&
232237
"Cannot register @llvm.assume call not in this function");
233238

234-
// We expect the number of assumptions to be small, so in an asserts build
235-
// check that we don't accumulate duplicates and that all assumptions point
236-
// to the same function.
237-
SmallPtrSet<Value *, 16> AssumptionSet;
238-
for (auto &VH : AssumeHandles) {
239-
if (!VH)
240-
continue;
241-
242-
assert(&F == cast<Instruction>(VH)->getParent()->getParent() &&
239+
for (auto &AssumeVH : AssumeHandles) {
240+
assert(&F == AssumeVH.getAssumeCI()->getCaller() &&
243241
"Cached assumption not inside this function!");
244-
assert(match(cast<CallInst>(VH), m_Intrinsic<Intrinsic::assume>()) &&
242+
assert(match(AssumeVH.getAssumeCI(), m_Intrinsic<Intrinsic::assume>()) &&
245243
"Cached something other than a call to @llvm.assume!");
246-
assert(AssumptionSet.insert(VH).second &&
247-
"Cache contains multiple copies of a call!");
248244
}
249245
#endif
250246

@@ -258,9 +254,8 @@ PreservedAnalyses AssumptionPrinterPass::run(Function &F,
258254
AssumptionCache &AC = AM.getResult<AssumptionAnalysis>(F);
259255

260256
OS << "Cached assumptions for function: " << F.getName() << "\n";
261-
for (auto &VH : AC.assumptions())
262-
if (VH)
263-
OS << " " << *cast<CallInst>(VH)->getArgOperand(0) << "\n";
257+
for (auto &AssumeVH : AC.assumptions())
258+
OS << " " << *AssumeVH.getAssumeCI()->getArgOperand(0) << "\n";
264259

265260
return PreservedAnalyses::all();
266261
}
@@ -306,9 +301,8 @@ void AssumptionCacheTracker::verifyAnalysis() const {
306301

307302
SmallPtrSet<const CallInst *, 4> AssumptionSet;
308303
for (const auto &I : AssumptionCaches) {
309-
for (auto &VH : I.second->assumptions())
310-
if (VH)
311-
AssumptionSet.insert(cast<CallInst>(VH));
304+
for (auto &AssumeVH : I.second->assumptions())
305+
AssumptionSet.insert(AssumeVH.getAssumeCI());
312306

313307
for (const BasicBlock &B : cast<Function>(*I.first))
314308
for (const Instruction &II : B)

llvm/lib/Analysis/CodeMetrics.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,7 @@ void CodeMetrics::collectEphemeralValues(
7373
SmallVector<const Value *, 16> Worklist;
7474

7575
for (auto &AssumeVH : AC->assumptions()) {
76-
if (!AssumeVH)
77-
continue;
78-
Instruction *I = cast<Instruction>(AssumeVH);
76+
Instruction *I = AssumeVH.getAssumeCI();
7977

8078
// Filter out call sites outside of the loop so we don't do a function's
8179
// worth of work for each of its loops (and, in the common case, ephemeral
@@ -97,9 +95,7 @@ void CodeMetrics::collectEphemeralValues(
9795
SmallVector<const Value *, 16> Worklist;
9896

9997
for (auto &AssumeVH : AC->assumptions()) {
100-
if (!AssumeVH)
101-
continue;
102-
Instruction *I = cast<Instruction>(AssumeVH);
98+
Instruction *I = AssumeVH.getAssumeCI();
10399
assert(I->getParent()->getParent() == F &&
104100
"Found assumption for the wrong function!");
105101

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1704,9 +1704,9 @@ ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) {
17041704
getZeroExtendExpr(Step, Ty, Depth + 1), L,
17051705
AR->getNoWrapFlags());
17061706
}
1707-
1707+
17081708
// For a negative step, we can extend the operands iff doing so only
1709-
// traverses values in the range zext([0,UINT_MAX]).
1709+
// traverses values in the range zext([0,UINT_MAX]).
17101710
if (isKnownNegative(Step)) {
17111711
const SCEV *N = getConstant(APInt::getMaxValue(BitWidth) -
17121712
getSignedRangeMin(Step));
@@ -9927,9 +9927,7 @@ ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L,
99279927

99289928
// Check conditions due to any @llvm.assume intrinsics.
99299929
for (auto &AssumeVH : AC.assumptions()) {
9930-
if (!AssumeVH)
9931-
continue;
9932-
auto *CI = cast<CallInst>(AssumeVH);
9930+
auto *CI = AssumeVH.getAssumeCI();
99339931
if (!DT.dominates(CI, Latch->getTerminator()))
99349932
continue;
99359933

@@ -10076,9 +10074,7 @@ bool ScalarEvolution::isBasicBlockEntryGuardedByCond(const BasicBlock *BB,
1007610074

1007710075
// Check conditions due to any @llvm.assume intrinsics.
1007810076
for (auto &AssumeVH : AC.assumptions()) {
10079-
if (!AssumeVH)
10080-
continue;
10081-
auto *CI = cast<CallInst>(AssumeVH);
10077+
auto *CI = AssumeVH.getAssumeCI();
1008210078
if (!DT.dominates(CI, BB))
1008310079
continue;
1008410080

@@ -13358,9 +13354,7 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
1335813354

1335913355
// Also collect information from assumptions dominating the loop.
1336013356
for (auto &AssumeVH : AC.assumptions()) {
13361-
if (!AssumeVH)
13362-
continue;
13363-
auto *AssumeI = cast<CallInst>(AssumeVH);
13357+
auto *AssumeI = AssumeVH.getAssumeCI();
1336413358
auto *Cmp = dyn_cast<ICmpInst>(AssumeI->getOperand(0));
1336513359
if (!Cmp || !DT.dominates(AssumeI, L->getHeader()))
1336613360
continue;

llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -331,12 +331,11 @@ bool AlignmentFromAssumptionsPass::runImpl(Function &F, AssumptionCache &AC,
331331
DT = DT_;
332332

333333
bool Changed = false;
334-
for (auto &AssumeVH : AC.assumptions())
335-
if (AssumeVH) {
336-
CallInst *Call = cast<CallInst>(AssumeVH);
337-
for (unsigned Idx = 0; Idx < Call->getNumOperandBundles(); Idx++)
338-
Changed |= processAssumption(Call, Idx);
339-
}
334+
for (auto &AssumeVH : AC.assumptions()) {
335+
CallInst *Call = AssumeVH.getAssumeCI();
336+
for (unsigned Idx = 0; Idx < Call->getNumOperandBundles(); Idx++)
337+
Changed |= processAssumption(Call, Idx);
338+
}
340339

341340
return Changed;
342341
}

llvm/lib/Transforms/Utils/CodeExtractor.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1781,10 +1781,8 @@ CodeExtractor::extractCodeRegion(const CodeExtractorAnalysisCache &CEAC) {
17811781
bool CodeExtractor::verifyAssumptionCache(const Function &OldFunc,
17821782
const Function &NewFunc,
17831783
AssumptionCache *AC) {
1784-
for (auto AssumeVH : AC->assumptions()) {
1785-
auto *I = dyn_cast_or_null<CallInst>(AssumeVH);
1786-
if (!I)
1787-
continue;
1784+
for (auto &AssumeVH : AC->assumptions()) {
1785+
auto *I = AssumeVH.getAssumeCI();
17881786

17891787
// There shouldn't be any llvm.assume intrinsics in the new function.
17901788
if (I->getFunction() != &OldFunc)

llvm/lib/Transforms/Utils/PredicateInfo.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -532,10 +532,11 @@ void PredicateInfoBuilder::buildPredicateInfo() {
532532
processSwitch(SI, BranchBB, OpsToRename);
533533
}
534534
}
535-
for (auto &Assume : AC.assumptions()) {
536-
if (auto *II = dyn_cast_or_null<IntrinsicInst>(Assume))
537-
if (DT.isReachableFromEntry(II->getParent()))
538-
processAssume(II, II->getParent(), OpsToRename);
535+
for (auto &AssumeVH : AC.assumptions()) {
536+
CallInst *AssumeCI = AssumeVH.getAssumeCI();
537+
if (DT.isReachableFromEntry(AssumeCI->getParent()))
538+
processAssume(cast<IntrinsicInst>(AssumeCI), AssumeCI->getParent(),
539+
OpsToRename);
539540
}
540541
// Now rename all our operations.
541542
renameUses(OpsToRename);

llvm/test/Analysis/AssumptionCache/basic.ll

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@ declare void @llvm.assume(i1)
66

77
define void @test1(i32 %a) {
88
; CHECK-LABEL: Cached assumptions for function: test1
9-
; CHECK-NEXT: icmp ne i32 %{{.*}}, 0
10-
; CHECK-NEXT: icmp slt i32 %{{.*}}, 0
11-
; CHECK-NEXT: icmp sgt i32 %{{.*}}, 0
9+
; CHECK-DAG: icmp ne i32 %{{.*}}, 0
10+
; CHECK-DAG: icmp slt i32 %{{.*}}, 0
11+
; CHECK-DAG: icmp sgt i32 %{{.*}}, 0
1212

1313
entry:
1414
%cond1 = icmp ne i32 %a, 0

0 commit comments

Comments
 (0)