Skip to content

Commit 13881f5

Browse files
committed
add caches to multiple type folders
1 parent 15ac698 commit 13881f5

File tree

8 files changed

+222
-22
lines changed

8 files changed

+222
-22
lines changed

compiler/rustc_infer/src/infer/relate/combine.rs

+6-1
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,15 @@ use crate::traits::{Obligation, PredicateObligation};
3636
#[derive(Clone)]
3737
pub struct CombineFields<'infcx, 'tcx> {
3838
pub infcx: &'infcx InferCtxt<'tcx>,
39+
// Immutable fields
3940
pub trace: TypeTrace<'tcx>,
4041
pub param_env: ty::ParamEnv<'tcx>,
41-
pub goals: Vec<Goal<'tcx, ty::Predicate<'tcx>>>,
4242
pub define_opaque_types: DefineOpaqueTypes,
43+
// Mutable fields
44+
//
45+
// Adding any additional field likely requires
46+
// changes to the cache of `TypeRelating`.
47+
pub goals: Vec<Goal<'tcx, ty::Predicate<'tcx>>>,
4348
}
4449

4550
impl<'infcx, 'tcx> CombineFields<'infcx, 'tcx> {

compiler/rustc_infer/src/infer/relate/type_relating.rs

+40-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use rustc_middle::ty::relate::{
44
};
55
use rustc_middle::ty::{self, Ty, TyCtxt, TyVar};
66
use rustc_span::Span;
7+
use rustc_type_ir::data_structures::DelayedSet;
78
use tracing::{debug, instrument};
89

910
use super::combine::CombineFields;
@@ -13,9 +14,36 @@ use crate::infer::{DefineOpaqueTypes, InferCtxt, SubregionOrigin};
1314

1415
/// Enforce that `a` is equal to or a subtype of `b`.
1516
pub struct TypeRelating<'combine, 'a, 'tcx> {
17+
// Partially mutable.
1618
fields: &'combine mut CombineFields<'a, 'tcx>,
19+
20+
// Immutable fields.
1721
structurally_relate_aliases: StructurallyRelateAliases,
1822
ambient_variance: ty::Variance,
23+
24+
/// The cache has only tracks the `ambient_variance` as its the
25+
/// only field which is mutable and which meaningfully changes
26+
/// the result when relating types.
27+
///
28+
/// The cache does not track whether the state of the
29+
/// `InferCtxt` has been changed or whether we've added any
30+
/// obligations to `self.fields.goals`. Whether a goal is added
31+
/// once or multiple times is not really meaningful.
32+
///
33+
/// Changes in the inference state may delay some type inference to
34+
/// the next fulfillment loop. Given that this loop is already
35+
/// necessary, this is also not a meaningful change. Consider
36+
/// the following three relations:
37+
/// ```text
38+
/// Vec<?0> sub Vec<?1>
39+
/// ?0 eq u32
40+
/// Vec<?0> sub Vec<?1>
41+
/// ```
42+
/// Without a cache, the second `Vec<?0> sub Vec<?1>` would eagerly
43+
/// constrain `?1` to `u32`. When using the cache entry from the
44+
/// first time we've related these types, this only happens when
45+
/// later proving the `Subtype(?0, ?1)` goal from the first relation.
46+
cache: DelayedSet<(ty::Variance, Ty<'tcx>, Ty<'tcx>)>,
1947
}
2048

2149
impl<'combine, 'infcx, 'tcx> TypeRelating<'combine, 'infcx, 'tcx> {
@@ -24,7 +52,12 @@ impl<'combine, 'infcx, 'tcx> TypeRelating<'combine, 'infcx, 'tcx> {
2452
structurally_relate_aliases: StructurallyRelateAliases,
2553
ambient_variance: ty::Variance,
2654
) -> TypeRelating<'combine, 'infcx, 'tcx> {
27-
TypeRelating { fields: f, structurally_relate_aliases, ambient_variance }
55+
TypeRelating {
56+
fields: f,
57+
structurally_relate_aliases,
58+
ambient_variance,
59+
cache: Default::default(),
60+
}
2861
}
2962
}
3063

@@ -78,6 +111,10 @@ impl<'tcx> TypeRelation<TyCtxt<'tcx>> for TypeRelating<'_, '_, 'tcx> {
78111
let a = infcx.shallow_resolve(a);
79112
let b = infcx.shallow_resolve(b);
80113

114+
if self.cache.contains(&(self.ambient_variance, a, b)) {
115+
return Ok(a);
116+
}
117+
81118
match (a.kind(), b.kind()) {
82119
(&ty::Infer(TyVar(a_id)), &ty::Infer(TyVar(b_id))) => {
83120
match self.ambient_variance {
@@ -160,6 +197,8 @@ impl<'tcx> TypeRelation<TyCtxt<'tcx>> for TypeRelating<'_, '_, 'tcx> {
160197
}
161198
}
162199

200+
assert!(self.cache.insert((self.ambient_variance, a, b)));
201+
163202
Ok(a)
164203
}
165204

compiler/rustc_infer/src/infer/resolve.rs

+11-3
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use rustc_middle::bug;
22
use rustc_middle::ty::fold::{FallibleTypeFolder, TypeFolder, TypeSuperFoldable};
33
use rustc_middle::ty::visit::TypeVisitableExt;
44
use rustc_middle::ty::{self, Const, InferConst, Ty, TyCtxt, TypeFoldable};
5+
use rustc_type_ir::data_structures::DelayedMap;
56

67
use super::{FixupError, FixupResult, InferCtxt};
78

@@ -15,12 +16,15 @@ use super::{FixupError, FixupResult, InferCtxt};
1516
/// points for correctness.
1617
pub struct OpportunisticVarResolver<'a, 'tcx> {
1718
infcx: &'a InferCtxt<'tcx>,
19+
/// We're able to use a cache here as the folder does
20+
/// not have any mutable state.
21+
cache: DelayedMap<Ty<'tcx>, Ty<'tcx>>,
1822
}
1923

2024
impl<'a, 'tcx> OpportunisticVarResolver<'a, 'tcx> {
2125
#[inline]
2226
pub fn new(infcx: &'a InferCtxt<'tcx>) -> Self {
23-
OpportunisticVarResolver { infcx }
27+
OpportunisticVarResolver { infcx, cache: Default::default() }
2428
}
2529
}
2630

@@ -33,9 +37,13 @@ impl<'a, 'tcx> TypeFolder<TyCtxt<'tcx>> for OpportunisticVarResolver<'a, 'tcx> {
3337
fn fold_ty(&mut self, t: Ty<'tcx>) -> Ty<'tcx> {
3438
if !t.has_non_region_infer() {
3539
t // micro-optimize -- if there is nothing in this type that this fold affects...
40+
} else if let Some(&ty) = self.cache.get(&t) {
41+
return ty;
3642
} else {
37-
let t = self.infcx.shallow_resolve(t);
38-
t.super_fold_with(self)
43+
let shallow = self.infcx.shallow_resolve(t);
44+
let res = shallow.super_fold_with(self);
45+
assert!(self.cache.insert(t, res));
46+
res
3947
}
4048
}
4149

compiler/rustc_middle/src/ty/fold.rs

+23-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use rustc_data_structures::fx::FxIndexMap;
22
use rustc_hir::def_id::DefId;
3+
use rustc_type_ir::data_structures::DelayedMap;
34
pub use rustc_type_ir::fold::{
45
FallibleTypeFolder, TypeFoldable, TypeFolder, TypeSuperFoldable, shift_region, shift_vars,
56
};
@@ -131,12 +132,20 @@ impl<'a, 'tcx> TypeFolder<TyCtxt<'tcx>> for RegionFolder<'a, 'tcx> {
131132
///////////////////////////////////////////////////////////////////////////
132133
// Bound vars replacer
133134

135+
/// A delegate used when instantiating bound vars.
136+
///
137+
/// Any implementation must make sure that each bound variable always
138+
/// gets mapped to the same result. `BoundVarReplacer` caches by using
139+
/// a `DelayedMap` which does not cache the first few types it encounters.
134140
pub trait BoundVarReplacerDelegate<'tcx> {
135141
fn replace_region(&mut self, br: ty::BoundRegion) -> ty::Region<'tcx>;
136142
fn replace_ty(&mut self, bt: ty::BoundTy) -> Ty<'tcx>;
137143
fn replace_const(&mut self, bv: ty::BoundVar) -> ty::Const<'tcx>;
138144
}
139145

146+
/// A simple delegate taking 3 mutable functions. The used functions must
147+
/// always return the same result for each bound variable, no matter how
148+
/// frequently they are called.
140149
pub struct FnMutDelegate<'a, 'tcx> {
141150
pub regions: &'a mut (dyn FnMut(ty::BoundRegion) -> ty::Region<'tcx> + 'a),
142151
pub types: &'a mut (dyn FnMut(ty::BoundTy) -> Ty<'tcx> + 'a),
@@ -164,11 +173,15 @@ struct BoundVarReplacer<'tcx, D> {
164173
current_index: ty::DebruijnIndex,
165174

166175
delegate: D,
176+
177+
/// This cache only tracks the `DebruijnIndex` and assumes that it does not matter
178+
/// for the delegate how often its methods get used.
179+
cache: DelayedMap<(ty::DebruijnIndex, Ty<'tcx>), Ty<'tcx>>,
167180
}
168181

169182
impl<'tcx, D: BoundVarReplacerDelegate<'tcx>> BoundVarReplacer<'tcx, D> {
170183
fn new(tcx: TyCtxt<'tcx>, delegate: D) -> Self {
171-
BoundVarReplacer { tcx, current_index: ty::INNERMOST, delegate }
184+
BoundVarReplacer { tcx, current_index: ty::INNERMOST, delegate, cache: Default::default() }
172185
}
173186
}
174187

@@ -197,7 +210,15 @@ where
197210
debug_assert!(!ty.has_vars_bound_above(ty::INNERMOST));
198211
ty::fold::shift_vars(self.tcx, ty, self.current_index.as_u32())
199212
}
200-
_ if t.has_vars_bound_at_or_above(self.current_index) => t.super_fold_with(self),
213+
_ if t.has_vars_bound_at_or_above(self.current_index) => {
214+
if let Some(&ty) = self.cache.get(&(self.current_index, t)) {
215+
return ty;
216+
}
217+
218+
let res = t.super_fold_with(self);
219+
assert!(self.cache.insert((self.current_index, t), res));
220+
res
221+
}
201222
_ => t,
202223
}
203224
}

compiler/rustc_next_trait_solver/src/resolve.rs

+9-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use rustc_type_ir::data_structures::DelayedMap;
12
use rustc_type_ir::fold::{TypeFoldable, TypeFolder, TypeSuperFoldable};
23
use rustc_type_ir::inherent::*;
34
use rustc_type_ir::visit::TypeVisitableExt;
@@ -15,11 +16,12 @@ where
1516
I: Interner,
1617
{
1718
delegate: &'a D,
19+
cache: DelayedMap<I::Ty, I::Ty>,
1820
}
1921

2022
impl<'a, D: SolverDelegate> EagerResolver<'a, D> {
2123
pub fn new(delegate: &'a D) -> Self {
22-
EagerResolver { delegate }
24+
EagerResolver { delegate, cache: Default::default() }
2325
}
2426
}
2527

@@ -42,7 +44,12 @@ impl<D: SolverDelegate<Interner = I>, I: Interner> TypeFolder<I> for EagerResolv
4244
ty::Infer(ty::FloatVar(vid)) => self.delegate.opportunistic_resolve_float_var(vid),
4345
_ => {
4446
if t.has_infer() {
45-
t.super_fold_with(self)
47+
if let Some(&ty) = self.cache.get(&t) {
48+
return ty;
49+
}
50+
let res = t.super_fold_with(self);
51+
assert!(self.cache.insert(t, res));
52+
res
4653
} else {
4754
t
4855
}

compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs

+38-13
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::ops::ControlFlow;
33
use derive_where::derive_where;
44
#[cfg(feature = "nightly")]
55
use rustc_macros::{HashStable_NoContext, TyDecodable, TyEncodable};
6-
use rustc_type_ir::data_structures::ensure_sufficient_stack;
6+
use rustc_type_ir::data_structures::{HashMap, HashSet, ensure_sufficient_stack};
77
use rustc_type_ir::fold::{TypeFoldable, TypeFolder, TypeSuperFoldable};
88
use rustc_type_ir::inherent::*;
99
use rustc_type_ir::relate::Relate;
@@ -579,18 +579,16 @@ where
579579

580580
#[instrument(level = "trace", skip(self))]
581581
pub(super) fn add_normalizes_to_goal(&mut self, mut goal: Goal<I, ty::NormalizesTo<I>>) {
582-
goal.predicate = goal
583-
.predicate
584-
.fold_with(&mut ReplaceAliasWithInfer { ecx: self, param_env: goal.param_env });
582+
goal.predicate =
583+
goal.predicate.fold_with(&mut ReplaceAliasWithInfer::new(self, goal.param_env));
585584
self.inspect.add_normalizes_to_goal(self.delegate, self.max_input_universe, goal);
586585
self.nested_goals.normalizes_to_goals.push(goal);
587586
}
588587

589588
#[instrument(level = "debug", skip(self))]
590589
pub(super) fn add_goal(&mut self, source: GoalSource, mut goal: Goal<I, I::Predicate>) {
591-
goal.predicate = goal
592-
.predicate
593-
.fold_with(&mut ReplaceAliasWithInfer { ecx: self, param_env: goal.param_env });
590+
goal.predicate =
591+
goal.predicate.fold_with(&mut ReplaceAliasWithInfer::new(self, goal.param_env));
594592
self.inspect.add_goal(self.delegate, self.max_input_universe, source, goal);
595593
self.nested_goals.goals.push((source, goal));
596594
}
@@ -654,6 +652,7 @@ where
654652
term: I::Term,
655653
universe_of_term: ty::UniverseIndex,
656654
delegate: &'a D,
655+
cache: HashSet<I::Ty>,
657656
}
658657

659658
impl<D: SolverDelegate<Interner = I>, I: Interner> ContainsTermOrNotNameable<'_, D, I> {
@@ -671,6 +670,10 @@ where
671670
{
672671
type Result = ControlFlow<()>;
673672
fn visit_ty(&mut self, t: I::Ty) -> Self::Result {
673+
if self.cache.contains(&t) {
674+
return ControlFlow::Continue(());
675+
}
676+
674677
match t.kind() {
675678
ty::Infer(ty::TyVar(vid)) => {
676679
if let ty::TermKind::Ty(term) = self.term.kind() {
@@ -683,17 +686,18 @@ where
683686
}
684687
}
685688

686-
self.check_nameable(self.delegate.universe_of_ty(vid).unwrap())
689+
self.check_nameable(self.delegate.universe_of_ty(vid).unwrap())?;
687690
}
688-
ty::Placeholder(p) => self.check_nameable(p.universe()),
691+
ty::Placeholder(p) => self.check_nameable(p.universe())?,
689692
_ => {
690693
if t.has_non_region_infer() || t.has_placeholders() {
691-
t.super_visit_with(self)
692-
} else {
693-
ControlFlow::Continue(())
694+
t.super_visit_with(self)?
694695
}
695696
}
696697
}
698+
699+
assert!(self.cache.insert(t));
700+
ControlFlow::Continue(())
697701
}
698702

699703
fn visit_const(&mut self, c: I::Const) -> Self::Result {
@@ -728,6 +732,7 @@ where
728732
delegate: self.delegate,
729733
universe_of_term,
730734
term: goal.predicate.term,
735+
cache: Default::default(),
731736
};
732737
goal.predicate.alias.visit_with(&mut visitor).is_continue()
733738
&& goal.param_env.visit_with(&mut visitor).is_continue()
@@ -1015,6 +1020,17 @@ where
10151020
{
10161021
ecx: &'me mut EvalCtxt<'a, D>,
10171022
param_env: I::ParamEnv,
1023+
cache: HashMap<I::Ty, I::Ty>,
1024+
}
1025+
1026+
impl<'me, 'a, D, I> ReplaceAliasWithInfer<'me, 'a, D, I>
1027+
where
1028+
D: SolverDelegate<Interner = I>,
1029+
I: Interner,
1030+
{
1031+
fn new(ecx: &'me mut EvalCtxt<'a, D>, param_env: I::ParamEnv) -> Self {
1032+
ReplaceAliasWithInfer { ecx, param_env, cache: Default::default() }
1033+
}
10181034
}
10191035

10201036
impl<D, I> TypeFolder<I> for ReplaceAliasWithInfer<'_, '_, D, I>
@@ -1041,7 +1057,16 @@ where
10411057
);
10421058
infer_ty
10431059
}
1044-
_ => ty.super_fold_with(self),
1060+
_ if ty.has_aliases() => {
1061+
if let Some(&entry) = self.cache.get(&ty) {
1062+
return entry;
1063+
}
1064+
1065+
let res = ty.super_fold_with(self);
1066+
assert!(self.cache.insert(ty, res).is_none());
1067+
res
1068+
}
1069+
_ => ty,
10451070
}
10461071
}
10471072

0 commit comments

Comments
 (0)