Skip to content

Commit 118453c

Browse files
committed
readd the provisional cache
1 parent eb4d7c7 commit 118453c

File tree

6 files changed

+166
-62
lines changed

6 files changed

+166
-62
lines changed

compiler/rustc_middle/src/traits/solve/inspect.rs

+1
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ pub struct CanonicalGoalEvaluation<'tcx> {
7373
pub enum CanonicalGoalEvaluationKind<'tcx> {
7474
Overflow,
7575
CycleInStack,
76+
ProvisionalCacheHit,
7677
Evaluation { revisions: &'tcx [GoalEvaluationStep<'tcx>] },
7778
}
7879
impl Debug for GoalEvaluation<'_> {

compiler/rustc_middle/src/traits/solve/inspect/format.rs

+3
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ impl<'a, 'b> ProofTreeFormatter<'a, 'b> {
7777
CanonicalGoalEvaluationKind::CycleInStack => {
7878
writeln!(self.f, "CYCLE IN STACK: {:?}", eval.result)
7979
}
80+
CanonicalGoalEvaluationKind::ProvisionalCacheHit => {
81+
writeln!(self.f, "PROVISIONAL CACHE HIT: {:?}", eval.result)
82+
}
8083
CanonicalGoalEvaluationKind::Evaluation { revisions } => {
8184
for (n, step) in revisions.iter().enumerate() {
8285
writeln!(self.f, "REVISION {n}")?;

compiler/rustc_trait_selection/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#![feature(control_flow_enum)]
2020
#![feature(extract_if)]
2121
#![feature(let_chains)]
22+
#![feature(option_take_if)]
2223
#![feature(if_let_guard)]
2324
#![feature(never_type)]
2425
#![feature(type_alias_impl_trait)]

compiler/rustc_trait_selection/src/solve/inspect/analyse.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,8 @@ impl<'a, 'tcx> InspectGoal<'a, 'tcx> {
171171
let mut candidates = vec![];
172172
let last_eval_step = match self.evaluation.evaluation.kind {
173173
inspect::CanonicalGoalEvaluationKind::Overflow
174-
| inspect::CanonicalGoalEvaluationKind::CycleInStack => {
174+
| inspect::CanonicalGoalEvaluationKind::CycleInStack
175+
| inspect::CanonicalGoalEvaluationKind::ProvisionalCacheHit => {
175176
warn!("unexpected root evaluation: {:?}", self.evaluation);
176177
return vec![];
177178
}

compiler/rustc_trait_selection/src/solve/inspect/build.rs

+5
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ pub(in crate::solve) enum WipGoalEvaluationKind<'tcx> {
118118
pub(in crate::solve) enum WipCanonicalGoalEvaluationKind<'tcx> {
119119
Overflow,
120120
CycleInStack,
121+
ProvisionalCacheHit,
121122
Interned { revisions: &'tcx [inspect::GoalEvaluationStep<'tcx>] },
122123
}
123124

@@ -126,6 +127,7 @@ impl std::fmt::Debug for WipCanonicalGoalEvaluationKind<'_> {
126127
match self {
127128
Self::Overflow => write!(f, "Overflow"),
128129
Self::CycleInStack => write!(f, "CycleInStack"),
130+
Self::ProvisionalCacheHit => write!(f, "ProvisionalCacheHit"),
129131
Self::Interned { revisions: _ } => f.debug_struct("Interned").finish_non_exhaustive(),
130132
}
131133
}
@@ -151,6 +153,9 @@ impl<'tcx> WipCanonicalGoalEvaluation<'tcx> {
151153
WipCanonicalGoalEvaluationKind::CycleInStack => {
152154
inspect::CanonicalGoalEvaluationKind::CycleInStack
153155
}
156+
WipCanonicalGoalEvaluationKind::ProvisionalCacheHit => {
157+
inspect::CanonicalGoalEvaluationKind::ProvisionalCacheHit
158+
}
154159
WipCanonicalGoalEvaluationKind::Interned { revisions } => {
155160
inspect::CanonicalGoalEvaluationKind::Evaluation { revisions }
156161
}

compiler/rustc_trait_selection/src/solve/search_graph.rs

+154-61
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ use rustc_middle::traits::solve::{CanonicalInput, Certainty, EvaluationCache, Qu
1111
use rustc_middle::ty;
1212
use rustc_middle::ty::TyCtxt;
1313
use rustc_session::Limit;
14-
use std::collections::hash_map::Entry;
1514
use std::mem;
1615

1716
rustc_index::newtype_index! {
@@ -30,7 +29,7 @@ struct StackEntry<'tcx> {
3029
///
3130
/// If so, it must not be moved to the global cache. See
3231
/// [SearchGraph::cycle_participants] for more details.
33-
non_root_cycle_participant: bool,
32+
non_root_cycle_participant: Option<StackDepth>,
3433

3534
encountered_overflow: bool,
3635
has_been_used: bool,
@@ -39,14 +38,42 @@ struct StackEntry<'tcx> {
3938
provisional_result: Option<QueryResult<'tcx>>,
4039
}
4140

41+
struct DetachedEntry<'tcx> {
42+
/// The head of the smallest non-trivial cycle involving this entry.
43+
///
44+
/// Given the following rules, when proving `A` the head for
45+
/// the provisional entry of `C` would be `B`.
46+
///
47+
/// A :- B
48+
/// B :- C
49+
/// C :- A + B + C
50+
head: StackDepth,
51+
result: QueryResult<'tcx>,
52+
}
53+
54+
#[derive(Default)]
55+
struct ProvisionalCacheEntry<'tcx> {
56+
stack_depth: Option<StackDepth>,
57+
with_inductive_stack: Option<DetachedEntry<'tcx>>,
58+
with_coinductive_stack: Option<DetachedEntry<'tcx>>,
59+
}
60+
61+
impl<'tcx> ProvisionalCacheEntry<'tcx> {
62+
fn is_empty(&self) -> bool {
63+
self.stack_depth.is_none()
64+
&& self.with_inductive_stack.is_none()
65+
&& self.with_coinductive_stack.is_none()
66+
}
67+
}
68+
4269
pub(super) struct SearchGraph<'tcx> {
4370
mode: SolverMode,
4471
local_overflow_limit: usize,
4572
/// The stack of goals currently being computed.
4673
///
4774
/// An element is *deeper* in the stack if its index is *lower*.
4875
stack: IndexVec<StackDepth, StackEntry<'tcx>>,
49-
stack_entries: FxHashMap<CanonicalInput<'tcx>, StackDepth>,
76+
provisional_cache: FxHashMap<CanonicalInput<'tcx>, ProvisionalCacheEntry<'tcx>>,
5077
/// We put only the root goal of a coinductive cycle into the global cache.
5178
///
5279
/// If we were to use that result when later trying to prove another cycle
@@ -63,7 +90,7 @@ impl<'tcx> SearchGraph<'tcx> {
6390
mode,
6491
local_overflow_limit: tcx.recursion_limit().0.checked_ilog2().unwrap_or(0) as usize,
6592
stack: Default::default(),
66-
stack_entries: Default::default(),
93+
provisional_cache: Default::default(),
6794
cycle_participants: Default::default(),
6895
}
6996
}
@@ -93,7 +120,6 @@ impl<'tcx> SearchGraph<'tcx> {
93120
/// would cause us to not track overflow and recursion depth correctly.
94121
fn pop_stack(&mut self) -> StackEntry<'tcx> {
95122
let elem = self.stack.pop().unwrap();
96-
assert!(self.stack_entries.remove(&elem.input).is_some());
97123
if let Some(last) = self.stack.raw.last_mut() {
98124
last.reached_depth = last.reached_depth.max(elem.reached_depth);
99125
last.encountered_overflow |= elem.encountered_overflow;
@@ -114,7 +140,7 @@ impl<'tcx> SearchGraph<'tcx> {
114140

115141
pub(super) fn is_empty(&self) -> bool {
116142
if self.stack.is_empty() {
117-
debug_assert!(self.stack_entries.is_empty());
143+
debug_assert!(self.provisional_cache.is_empty());
118144
debug_assert!(self.cycle_participants.is_empty());
119145
true
120146
} else {
@@ -156,6 +182,40 @@ impl<'tcx> SearchGraph<'tcx> {
156182
}
157183
}
158184

185+
fn stack_coinductive_from(
186+
tcx: TyCtxt<'tcx>,
187+
stack: &IndexVec<StackDepth, StackEntry<'tcx>>,
188+
head: StackDepth,
189+
) -> bool {
190+
stack.raw[head.index()..]
191+
.iter()
192+
.all(|entry| entry.input.value.goal.predicate.is_coinductive(tcx))
193+
}
194+
195+
fn tag_cycle_participants(
196+
stack: &mut IndexVec<StackDepth, StackEntry<'tcx>>,
197+
cycle_participants: &mut FxHashSet<CanonicalInput<'tcx>>,
198+
head: StackDepth,
199+
) {
200+
stack[head].has_been_used = true;
201+
for entry in &mut stack.raw[head.index() + 1..] {
202+
entry.non_root_cycle_participant = entry.non_root_cycle_participant.max(Some(head));
203+
cycle_participants.insert(entry.input);
204+
}
205+
}
206+
207+
fn clear_dependent_provisional_results(
208+
provisional_cache: &mut FxHashMap<CanonicalInput<'tcx>, ProvisionalCacheEntry<'tcx>>,
209+
head: StackDepth,
210+
) {
211+
#[allow(rustc::potential_query_instability)]
212+
provisional_cache.retain(|_, entry| {
213+
entry.with_coinductive_stack.take_if(|p| p.head == head);
214+
entry.with_inductive_stack.take_if(|p| p.head == head);
215+
!entry.is_empty()
216+
});
217+
}
218+
159219
/// Probably the most involved method of the whole solver.
160220
///
161221
/// Given some goal which is proven via the `prove_goal` closure, this
@@ -210,23 +270,36 @@ impl<'tcx> SearchGraph<'tcx> {
210270
return result;
211271
}
212272

213-
// Check whether we're in a cycle.
214-
match self.stack_entries.entry(input) {
215-
// No entry, we push this goal on the stack and try to prove it.
216-
Entry::Vacant(v) => {
217-
let depth = self.stack.next_index();
218-
let entry = StackEntry {
219-
input,
220-
available_depth,
221-
reached_depth: depth,
222-
non_root_cycle_participant: false,
223-
encountered_overflow: false,
224-
has_been_used: false,
225-
provisional_result: None,
226-
};
227-
assert_eq!(self.stack.push(entry), depth);
228-
v.insert(depth);
229-
}
273+
// Check whether the goal is in the provisional cache.
274+
let cache_entry = self.provisional_cache.entry(input).or_default();
275+
if let Some(with_coinductive_stack) = &mut cache_entry.with_coinductive_stack
276+
&& Self::stack_coinductive_from(tcx, &self.stack, with_coinductive_stack.head)
277+
{
278+
// We have a nested goal which is already in the provisional cache, use
279+
// its result.
280+
inspect
281+
.goal_evaluation_kind(inspect::WipCanonicalGoalEvaluationKind::ProvisionalCacheHit);
282+
Self::tag_cycle_participants(
283+
&mut self.stack,
284+
&mut self.cycle_participants,
285+
with_coinductive_stack.head,
286+
);
287+
return with_coinductive_stack.result;
288+
} else if let Some(with_inductive_stack) = &mut cache_entry.with_inductive_stack
289+
&& !Self::stack_coinductive_from(tcx, &self.stack, with_inductive_stack.head)
290+
{
291+
// We have a nested goal which is already in the provisional cache, use
292+
// its result.
293+
inspect
294+
.goal_evaluation_kind(inspect::WipCanonicalGoalEvaluationKind::ProvisionalCacheHit);
295+
Self::tag_cycle_participants(
296+
&mut self.stack,
297+
&mut self.cycle_participants,
298+
with_inductive_stack.head,
299+
);
300+
return with_inductive_stack.result;
301+
} else if let Some(stack_depth) = cache_entry.stack_depth {
302+
debug!("encountered cycle with depth {stack_depth:?}");
230303
// We have a nested goal which relies on a goal `root` deeper in the stack.
231304
//
232305
// We first store that we may have to reprove `root` in case the provisional
@@ -236,40 +309,37 @@ impl<'tcx> SearchGraph<'tcx> {
236309
//
237310
// Finally we can return either the provisional response for that goal if we have a
238311
// coinductive cycle or an ambiguous result if the cycle is inductive.
239-
Entry::Occupied(entry) => {
240-
inspect.goal_evaluation_kind(inspect::WipCanonicalGoalEvaluationKind::CycleInStack);
241-
242-
let stack_depth = *entry.get();
243-
debug!("encountered cycle with depth {stack_depth:?}");
244-
// We start by tagging all non-root cycle participants.
245-
let participants = self.stack.raw.iter_mut().skip(stack_depth.as_usize() + 1);
246-
for entry in participants {
247-
entry.non_root_cycle_participant = true;
248-
self.cycle_participants.insert(entry.input);
249-
}
250-
251-
// If we're in a cycle, we have to retry proving the cycle head
252-
// until we reach a fixpoint. It is not enough to simply retry the
253-
// `root` goal of this cycle.
254-
//
255-
// See tests/ui/traits/next-solver/cycles/fixpoint-rerun-all-cycle-heads.rs
256-
// for an example.
257-
self.stack[stack_depth].has_been_used = true;
258-
return if let Some(result) = self.stack[stack_depth].provisional_result {
259-
result
312+
inspect.goal_evaluation_kind(inspect::WipCanonicalGoalEvaluationKind::CycleInStack);
313+
Self::tag_cycle_participants(
314+
&mut self.stack,
315+
&mut self.cycle_participants,
316+
stack_depth,
317+
);
318+
return if let Some(result) = self.stack[stack_depth].provisional_result {
319+
result
320+
} else {
321+
// If we don't have a provisional result yet we're in the first iteration,
322+
// so we start with no constraints.
323+
if Self::stack_coinductive_from(tcx, &self.stack, stack_depth) {
324+
Self::response_no_constraints(tcx, input, Certainty::Yes)
260325
} else {
261-
// If we don't have a provisional result yet we're in the first iteration,
262-
// so we start with no constraints.
263-
let is_inductive = self.stack.raw[stack_depth.index()..]
264-
.iter()
265-
.any(|entry| !entry.input.value.goal.predicate.is_coinductive(tcx));
266-
if is_inductive {
267-
Self::response_no_constraints(tcx, input, Certainty::OVERFLOW)
268-
} else {
269-
Self::response_no_constraints(tcx, input, Certainty::Yes)
270-
}
271-
};
272-
}
326+
Self::response_no_constraints(tcx, input, Certainty::OVERFLOW)
327+
}
328+
};
329+
} else {
330+
// No entry, we push this goal on the stack and try to prove it.
331+
let depth = self.stack.next_index();
332+
let entry = StackEntry {
333+
input,
334+
available_depth,
335+
reached_depth: depth,
336+
non_root_cycle_participant: None,
337+
encountered_overflow: false,
338+
has_been_used: false,
339+
provisional_result: None,
340+
};
341+
assert_eq!(self.stack.push(entry), depth);
342+
cache_entry.stack_depth = Some(depth);
273343
}
274344

275345
// This is for global caching, so we properly track query dependencies.
@@ -285,11 +355,22 @@ impl<'tcx> SearchGraph<'tcx> {
285355
for _ in 0..self.local_overflow_limit() {
286356
let result = prove_goal(self, inspect);
287357

288-
// Check whether the current goal is the root of a cycle and whether
289-
// we have to rerun because its provisional result differed from the
290-
// final result.
358+
// Check whether the current goal is the root of a cycle.
359+
// If so, we have to retry proving the cycle head
360+
// until its result reaches a fixpoint. We need to do so for
361+
// all cycle heads, not only for the root.
362+
//
363+
// See tests/ui/traits/next-solver/cycles/fixpoint-rerun-all-cycle-heads.rs
364+
// for an example.
291365
let stack_entry = self.pop_stack();
292366
debug_assert_eq!(stack_entry.input, input);
367+
if stack_entry.has_been_used {
368+
Self::clear_dependent_provisional_results(
369+
&mut self.provisional_cache,
370+
self.stack.next_index(),
371+
);
372+
}
373+
293374
if stack_entry.has_been_used
294375
&& stack_entry.provisional_result.map_or(true, |r| r != result)
295376
{
@@ -299,14 +380,15 @@ impl<'tcx> SearchGraph<'tcx> {
299380
provisional_result: Some(result),
300381
..stack_entry
301382
});
302-
assert_eq!(self.stack_entries.insert(input, depth), None);
383+
debug_assert_eq!(self.provisional_cache[&input].stack_depth, Some(depth));
303384
} else {
304385
return (stack_entry, result);
305386
}
306387
}
307388

308389
debug!("canonical cycle overflow");
309390
let current_entry = self.pop_stack();
391+
debug_assert!(!current_entry.has_been_used);
310392
let result = Self::response_no_constraints(tcx, input, Certainty::OVERFLOW);
311393
(current_entry, result)
312394
});
@@ -319,7 +401,17 @@ impl<'tcx> SearchGraph<'tcx> {
319401
//
320402
// It is not possible for any nested goal to depend on something deeper on the
321403
// stack, as this would have also updated the depth of the current goal.
322-
if !final_entry.non_root_cycle_participant {
404+
if let Some(head) = final_entry.non_root_cycle_participant {
405+
let coinductive_stack = Self::stack_coinductive_from(tcx, &self.stack, head);
406+
407+
let entry = self.provisional_cache.get_mut(&input).unwrap();
408+
entry.stack_depth = None;
409+
if coinductive_stack {
410+
entry.with_coinductive_stack = Some(DetachedEntry { head, result });
411+
} else {
412+
entry.with_inductive_stack = Some(DetachedEntry { head, result });
413+
}
414+
} else {
323415
// When encountering a cycle, both inductive and coinductive, we only
324416
// move the root into the global cache. We also store all other cycle
325417
// participants involved.
@@ -328,6 +420,7 @@ impl<'tcx> SearchGraph<'tcx> {
328420
// participant is on the stack. This is necessary to prevent unstable
329421
// results. See the comment of `SearchGraph::cycle_participants` for
330422
// more details.
423+
self.provisional_cache.remove(&input);
331424
let reached_depth = final_entry.reached_depth.as_usize() - self.stack.len();
332425
let cycle_participants = mem::take(&mut self.cycle_participants);
333426
self.global_cache(tcx).insert(

0 commit comments

Comments
 (0)