Skip to content

Commit dba4147

Browse files
Make SearchGraph fully generic
1 parent af3d100 commit dba4147

File tree

8 files changed

+149
-95
lines changed

8 files changed

+149
-95
lines changed

compiler/rustc_middle/src/traits/solve.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use crate::ty::{
1010

1111
mod cache;
1212

13-
pub use cache::{CacheData, EvaluationCache};
13+
pub use cache::EvaluationCache;
1414

1515
pub type Goal<'tcx, P> = ir::solve::Goal<TyCtxt<'tcx>, P>;
1616
pub type QueryInput<'tcx, P> = ir::solve::QueryInput<TyCtxt<'tcx>, P>;

compiler/rustc_middle/src/traits/solve/cache.rs

+11-17
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ use rustc_data_structures::sync::Lock;
55
use rustc_query_system::cache::WithDepNode;
66
use rustc_query_system::dep_graph::DepNodeIndex;
77
use rustc_session::Limit;
8+
use rustc_type_ir::solve::CacheData;
9+
810
/// The trait solver cache used by `-Znext-solver`.
911
///
1012
/// FIXME(@lcnr): link to some official documentation of how
@@ -14,17 +16,9 @@ pub struct EvaluationCache<'tcx> {
1416
map: Lock<FxHashMap<CanonicalInput<'tcx>, CacheEntry<'tcx>>>,
1517
}
1618

17-
#[derive(Debug, PartialEq, Eq)]
18-
pub struct CacheData<'tcx> {
19-
pub result: QueryResult<'tcx>,
20-
pub proof_tree: Option<&'tcx inspect::CanonicalGoalEvaluationStep<TyCtxt<'tcx>>>,
21-
pub additional_depth: usize,
22-
pub encountered_overflow: bool,
23-
}
24-
25-
impl<'tcx> EvaluationCache<'tcx> {
19+
impl<'tcx> rustc_type_ir::inherent::EvaluationCache<TyCtxt<'tcx>> for &'tcx EvaluationCache<'tcx> {
2620
/// Insert a final result into the global cache.
27-
pub fn insert(
21+
fn insert(
2822
&self,
2923
tcx: TyCtxt<'tcx>,
3024
key: CanonicalInput<'tcx>,
@@ -48,7 +42,7 @@ impl<'tcx> EvaluationCache<'tcx> {
4842
if cfg!(debug_assertions) {
4943
drop(map);
5044
let expected = CacheData { result, proof_tree, additional_depth, encountered_overflow };
51-
let actual = self.get(tcx, key, [], Limit(additional_depth));
45+
let actual = self.get(tcx, key, [], additional_depth);
5246
if !actual.as_ref().is_some_and(|actual| expected == *actual) {
5347
bug!("failed to lookup inserted element for {key:?}: {expected:?} != {actual:?}");
5448
}
@@ -59,13 +53,13 @@ impl<'tcx> EvaluationCache<'tcx> {
5953
/// and handling root goals of coinductive cycles.
6054
///
6155
/// If this returns `Some` the cache result can be used.
62-
pub fn get(
56+
fn get(
6357
&self,
6458
tcx: TyCtxt<'tcx>,
6559
key: CanonicalInput<'tcx>,
6660
stack_entries: impl IntoIterator<Item = CanonicalInput<'tcx>>,
67-
available_depth: Limit,
68-
) -> Option<CacheData<'tcx>> {
61+
available_depth: usize,
62+
) -> Option<CacheData<TyCtxt<'tcx>>> {
6963
let map = self.map.borrow();
7064
let entry = map.get(&key)?;
7165

@@ -76,7 +70,7 @@ impl<'tcx> EvaluationCache<'tcx> {
7670
}
7771

7872
if let Some(ref success) = entry.success {
79-
if available_depth.value_within_limit(success.additional_depth) {
73+
if Limit(available_depth).value_within_limit(success.additional_depth) {
8074
let QueryData { result, proof_tree } = success.data.get(tcx);
8175
return Some(CacheData {
8276
result,
@@ -87,12 +81,12 @@ impl<'tcx> EvaluationCache<'tcx> {
8781
}
8882
}
8983

90-
entry.with_overflow.get(&available_depth.0).map(|e| {
84+
entry.with_overflow.get(&available_depth).map(|e| {
9185
let QueryData { result, proof_tree } = e.get(tcx);
9286
CacheData {
9387
result,
9488
proof_tree,
95-
additional_depth: available_depth.0,
89+
additional_depth: available_depth,
9690
encountered_overflow: true,
9791
}
9892
})

compiler/rustc_middle/src/ty/context.rs

+21
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ use rustc_target::abi::{FieldIdx, Layout, LayoutS, TargetDataLayout, VariantIdx}
7171
use rustc_target::spec::abi;
7272
use rustc_type_ir::fold::TypeFoldable;
7373
use rustc_type_ir::lang_items::TraitSolverLangItem;
74+
use rustc_type_ir::solve::SolverMode;
7475
use rustc_type_ir::TyKind::*;
7576
use rustc_type_ir::{CollectAndApply, Interner, TypeFlags, WithCachedTypeInfo};
7677
use tracing::{debug, instrument};
@@ -139,10 +140,30 @@ impl<'tcx> Interner for TyCtxt<'tcx> {
139140
type Clause = Clause<'tcx>;
140141
type Clauses = ty::Clauses<'tcx>;
141142

143+
type DepNodeIndex = DepNodeIndex;
144+
fn with_cached_task<T>(self, task: impl FnOnce() -> T) -> (T, DepNodeIndex) {
145+
self.dep_graph.with_anon_task(self, crate::dep_graph::dep_kinds::TraitSelect, task)
146+
}
147+
148+
type EvaluationCache = &'tcx solve::EvaluationCache<'tcx>;
149+
fn evaluation_cache(self, mode: SolverMode) -> &'tcx solve::EvaluationCache<'tcx> {
150+
match mode {
151+
SolverMode::Normal => &self.new_solver_evaluation_cache,
152+
SolverMode::Coherence => &self.new_solver_coherence_evaluation_cache,
153+
}
154+
}
155+
142156
fn expand_abstract_consts<T: TypeFoldable<TyCtxt<'tcx>>>(self, t: T) -> T {
143157
self.expand_abstract_consts(t)
144158
}
145159

160+
fn mk_external_constraints(
161+
self,
162+
data: ExternalConstraintsData<Self>,
163+
) -> ExternalConstraints<'tcx> {
164+
self.mk_external_constraints(data)
165+
}
166+
146167
fn mk_canonical_var_infos(self, infos: &[ty::CanonicalVarInfo<Self>]) -> Self::CanonicalVars {
147168
self.mk_canonical_var_infos(infos)
148169
}

compiler/rustc_trait_selection/src/solve/mod.rs

+9-21
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,20 @@
1414
//! FIXME(@lcnr): Write that section. If you read this before then ask me
1515
//! about it on zulip.
1616
use rustc_hir::def_id::DefId;
17-
use rustc_infer::infer::canonical::{Canonical, CanonicalVarValues};
17+
use rustc_infer::infer::canonical::Canonical;
1818
use rustc_infer::infer::InferCtxt;
1919
use rustc_infer::traits::query::NoSolution;
2020
use rustc_macros::extension;
2121
use rustc_middle::bug;
22-
use rustc_middle::infer::canonical::CanonicalVarInfos;
2322
use rustc_middle::traits::solve::{
2423
CanonicalResponse, Certainty, ExternalConstraintsData, Goal, GoalSource, QueryResult, Response,
2524
};
2625
use rustc_middle::ty::{
2726
self, AliasRelationDirection, CoercePredicate, RegionOutlivesPredicate, SubtypePredicate, Ty,
2827
TyCtxt, TypeOutlivesPredicate, UniverseIndex,
2928
};
29+
use rustc_type_ir::solve::SolverMode;
30+
use rustc_type_ir::{self as ir, Interner};
3031

3132
mod alias_relate;
3233
mod assembly;
@@ -57,19 +58,6 @@ pub use select::InferCtxtSelectExt;
5758
/// recursion limit again. However, this feels very unlikely.
5859
const FIXPOINT_STEP_LIMIT: usize = 8;
5960

60-
#[derive(Debug, Clone, Copy)]
61-
enum SolverMode {
62-
/// Ordinary trait solving, using everywhere except for coherence.
63-
Normal,
64-
/// Trait solving during coherence. There are a few notable differences
65-
/// between coherence and ordinary trait solving.
66-
///
67-
/// Most importantly, trait solving during coherence must not be incomplete,
68-
/// i.e. return `Err(NoSolution)` for goals for which a solution exists.
69-
/// This means that we must not make any guesses or arbitrary choices.
70-
Coherence,
71-
}
72-
7361
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
7462
enum GoalEvaluationKind {
7563
Root,
@@ -314,17 +302,17 @@ impl<'tcx> EvalCtxt<'_, InferCtxt<'tcx>> {
314302
}
315303
}
316304

317-
fn response_no_constraints_raw<'tcx>(
318-
tcx: TyCtxt<'tcx>,
305+
fn response_no_constraints_raw<I: Interner>(
306+
tcx: I,
319307
max_universe: UniverseIndex,
320-
variables: CanonicalVarInfos<'tcx>,
308+
variables: I::CanonicalVars,
321309
certainty: Certainty,
322-
) -> CanonicalResponse<'tcx> {
323-
Canonical {
310+
) -> ir::solve::CanonicalResponse<I> {
311+
ir::Canonical {
324312
max_universe,
325313
variables,
326314
value: Response {
327-
var_values: CanonicalVarValues::make_identity(tcx, variables),
315+
var_values: ir::CanonicalVarValues::make_identity(tcx, variables),
328316
// FIXME: maybe we should store the "no response" version in tcx, like
329317
// we do for tcx.types and stuff.
330318
external_constraints: tcx.mk_external_constraints(ExternalConstraintsData::default()),

compiler/rustc_trait_selection/src/solve/search_graph.rs

+40-52
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,11 @@ use std::mem;
33
use rustc_data_structures::fx::{FxHashMap, FxHashSet};
44
use rustc_index::Idx;
55
use rustc_index::IndexVec;
6-
use rustc_infer::infer::InferCtxt;
7-
use rustc_middle::dep_graph::dep_kinds;
8-
use rustc_middle::traits::solve::CacheData;
9-
use rustc_middle::traits::solve::EvaluationCache;
10-
use rustc_middle::ty::TyCtxt;
6+
use rustc_next_trait_solver::solve::CacheData;
117
use rustc_next_trait_solver::solve::{CanonicalInput, Certainty, QueryResult};
128
use rustc_session::Limit;
139
use rustc_type_ir::inherent::*;
10+
use rustc_type_ir::InferCtxtLike;
1411
use rustc_type_ir::Interner;
1512

1613
use super::inspect;
@@ -240,34 +237,26 @@ impl<I: Interner> SearchGraph<I> {
240237
!entry.is_empty()
241238
});
242239
}
243-
}
244240

245-
impl<'tcx> SearchGraph<TyCtxt<'tcx>> {
246241
/// The trait solver behavior is different for coherence
247242
/// so we use a separate cache. Alternatively we could use
248243
/// a single cache and share it between coherence and ordinary
249244
/// trait solving.
250-
pub(super) fn global_cache(&self, tcx: TyCtxt<'tcx>) -> &'tcx EvaluationCache<'tcx> {
251-
match self.mode {
252-
SolverMode::Normal => &tcx.new_solver_evaluation_cache,
253-
SolverMode::Coherence => &tcx.new_solver_coherence_evaluation_cache,
254-
}
245+
pub(super) fn global_cache(&self, tcx: I) -> I::EvaluationCache {
246+
tcx.evaluation_cache(self.mode)
255247
}
256248

257249
/// Probably the most involved method of the whole solver.
258250
///
259251
/// Given some goal which is proven via the `prove_goal` closure, this
260252
/// handles caching, overflow, and coinductive cycles.
261-
pub(super) fn with_new_goal(
253+
pub(super) fn with_new_goal<Infcx: InferCtxtLike<Interner = I>>(
262254
&mut self,
263-
tcx: TyCtxt<'tcx>,
264-
input: CanonicalInput<TyCtxt<'tcx>>,
265-
inspect: &mut ProofTreeBuilder<InferCtxt<'tcx>>,
266-
mut prove_goal: impl FnMut(
267-
&mut Self,
268-
&mut ProofTreeBuilder<InferCtxt<'tcx>>,
269-
) -> QueryResult<TyCtxt<'tcx>>,
270-
) -> QueryResult<TyCtxt<'tcx>> {
255+
tcx: I,
256+
input: CanonicalInput<I>,
257+
inspect: &mut ProofTreeBuilder<Infcx>,
258+
mut prove_goal: impl FnMut(&mut Self, &mut ProofTreeBuilder<Infcx>) -> QueryResult<I>,
259+
) -> QueryResult<I> {
271260
self.check_invariants();
272261
// Check for overflow.
273262
let Some(available_depth) = Self::allowed_depth_for_nested(tcx, &self.stack) else {
@@ -361,21 +350,20 @@ impl<'tcx> SearchGraph<TyCtxt<'tcx>> {
361350
// not tracked by the cache key and from outside of this anon task, it
362351
// must not be added to the global cache. Notably, this is the case for
363352
// trait solver cycles participants.
364-
let ((final_entry, result), dep_node) =
365-
tcx.dep_graph.with_anon_task(tcx, dep_kinds::TraitSelect, || {
366-
for _ in 0..FIXPOINT_STEP_LIMIT {
367-
match self.fixpoint_step_in_task(tcx, input, inspect, &mut prove_goal) {
368-
StepResult::Done(final_entry, result) => return (final_entry, result),
369-
StepResult::HasChanged => debug!("fixpoint changed provisional results"),
370-
}
353+
let ((final_entry, result), dep_node) = tcx.with_cached_task(|| {
354+
for _ in 0..FIXPOINT_STEP_LIMIT {
355+
match self.fixpoint_step_in_task(tcx, input, inspect, &mut prove_goal) {
356+
StepResult::Done(final_entry, result) => return (final_entry, result),
357+
StepResult::HasChanged => debug!("fixpoint changed provisional results"),
371358
}
359+
}
372360

373-
debug!("canonical cycle overflow");
374-
let current_entry = self.pop_stack();
375-
debug_assert!(current_entry.has_been_used.is_empty());
376-
let result = Self::response_no_constraints(tcx, input, Certainty::overflow(false));
377-
(current_entry, result)
378-
});
361+
debug!("canonical cycle overflow");
362+
let current_entry = self.pop_stack();
363+
debug_assert!(current_entry.has_been_used.is_empty());
364+
let result = Self::response_no_constraints(tcx, input, Certainty::overflow(false));
365+
(current_entry, result)
366+
});
379367

380368
let proof_tree = inspect.finalize_canonical_goal_evaluation(tcx);
381369

@@ -423,16 +411,17 @@ impl<'tcx> SearchGraph<TyCtxt<'tcx>> {
423411
/// Try to fetch a previously computed result from the global cache,
424412
/// making sure to only do so if it would match the result of reevaluating
425413
/// this goal.
426-
fn lookup_global_cache(
414+
fn lookup_global_cache<Infcx: InferCtxtLike<Interner = I>>(
427415
&mut self,
428-
tcx: TyCtxt<'tcx>,
429-
input: CanonicalInput<TyCtxt<'tcx>>,
416+
tcx: I,
417+
input: CanonicalInput<I>,
430418
available_depth: Limit,
431-
inspect: &mut ProofTreeBuilder<InferCtxt<'tcx>>,
432-
) -> Option<QueryResult<TyCtxt<'tcx>>> {
419+
inspect: &mut ProofTreeBuilder<Infcx>,
420+
) -> Option<QueryResult<I>> {
433421
let CacheData { result, proof_tree, additional_depth, encountered_overflow } = self
434422
.global_cache(tcx)
435-
.get(tcx, input, self.stack.iter().map(|e| e.input), available_depth)?;
423+
// TODO: Awkward `Limit -> usize -> Limit`.
424+
.get(tcx, input, self.stack.iter().map(|e| e.input), available_depth.0)?;
436425

437426
// If we're building a proof tree and the current cache entry does not
438427
// contain a proof tree, we do not use the entry but instead recompute
@@ -465,21 +454,22 @@ enum StepResult<I: Interner> {
465454
HasChanged,
466455
}
467456

468-
impl<'tcx> SearchGraph<TyCtxt<'tcx>> {
457+
impl<I: Interner> SearchGraph<I> {
469458
/// When we encounter a coinductive cycle, we have to fetch the
470459
/// result of that cycle while we are still computing it. Because
471460
/// of this we continuously recompute the cycle until the result
472461
/// of the previous iteration is equal to the final result, at which
473462
/// point we are done.
474-
fn fixpoint_step_in_task<F>(
463+
fn fixpoint_step_in_task<Infcx, F>(
475464
&mut self,
476-
tcx: TyCtxt<'tcx>,
477-
input: CanonicalInput<TyCtxt<'tcx>>,
478-
inspect: &mut ProofTreeBuilder<InferCtxt<'tcx>>,
465+
tcx: I,
466+
input: CanonicalInput<I>,
467+
inspect: &mut ProofTreeBuilder<Infcx>,
479468
prove_goal: &mut F,
480-
) -> StepResult<TyCtxt<'tcx>>
469+
) -> StepResult<I>
481470
where
482-
F: FnMut(&mut Self, &mut ProofTreeBuilder<InferCtxt<'tcx>>) -> QueryResult<TyCtxt<'tcx>>,
471+
Infcx: InferCtxtLike<Interner = I>,
472+
F: FnMut(&mut Self, &mut ProofTreeBuilder<Infcx>) -> QueryResult<I>,
483473
{
484474
let result = prove_goal(self, inspect);
485475
let stack_entry = self.pop_stack();
@@ -533,15 +523,13 @@ impl<'tcx> SearchGraph<TyCtxt<'tcx>> {
533523
}
534524

535525
fn response_no_constraints(
536-
tcx: TyCtxt<'tcx>,
537-
goal: CanonicalInput<TyCtxt<'tcx>>,
526+
tcx: I,
527+
goal: CanonicalInput<I>,
538528
certainty: Certainty,
539-
) -> QueryResult<TyCtxt<'tcx>> {
529+
) -> QueryResult<I> {
540530
Ok(super::response_no_constraints_raw(tcx, goal.max_universe, goal.variables, certainty))
541531
}
542-
}
543532

544-
impl<I: Interner> SearchGraph<I> {
545533
#[allow(rustc::potential_query_instability)]
546534
fn check_invariants(&self) {
547535
if !cfg!(debug_assertions) {

0 commit comments

Comments
 (0)