Skip to content

Commit 91535ad

Browse files
committed
remove sub_relations from infcx, recompute in diagnostics
we don't track them when canonicalizing or when freshening, resulting in instable caching in the old solver, and issues when instantiating query responses in the new one.
1 parent 1bb3a9f commit 91535ad

27 files changed

+180
-280
lines changed

Diff for: compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs

+6-3
Original file line numberDiff line numberDiff line change
@@ -1522,10 +1522,13 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
15221522
if self.next_trait_solver()
15231523
&& let ty::Alias(..) = ty.kind()
15241524
{
1525-
match self
1525+
// We need to use a separate variable here as otherwise the temporary for
1526+
// `self.fulfillment_cx.borrow_mut()` is alive in the `Err` branch, resulting
1527+
// in a reentrant borrow, causing an ICE.
1528+
let result = self
15261529
.at(&self.misc(sp), self.param_env)
1527-
.structurally_normalize(ty, &mut **self.fulfillment_cx.borrow_mut())
1528-
{
1530+
.structurally_normalize(ty, &mut **self.fulfillment_cx.borrow_mut());
1531+
match result {
15291532
Ok(normalized_ty) => normalized_ty,
15301533
Err(errors) => {
15311534
let guar = self.err_ctxt().report_fulfillment_errors(errors);

Diff for: compiler/rustc_hir_typeck/src/fn_ctxt/mod.rs

+7
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use rustc_hir as hir;
1111
use rustc_hir::def_id::{DefId, LocalDefId};
1212
use rustc_hir_analysis::astconv::AstConv;
1313
use rustc_infer::infer;
14+
use rustc_infer::infer::error_reporting::sub_relations::SubRelations;
1415
use rustc_infer::infer::error_reporting::TypeErrCtxt;
1516
use rustc_infer::infer::type_variable::{TypeVariableOrigin, TypeVariableOriginKind};
1617
use rustc_middle::infer::unify_key::{ConstVariableOrigin, ConstVariableOriginKind};
@@ -155,8 +156,14 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
155156
///
156157
/// [`InferCtxt::err_ctxt`]: infer::InferCtxt::err_ctxt
157158
pub fn err_ctxt(&'a self) -> TypeErrCtxt<'a, 'tcx> {
159+
let mut sub_relations = SubRelations::default();
160+
sub_relations.add_constraints(
161+
self,
162+
self.fulfillment_cx.borrow_mut().pending_obligations().iter().map(|o| o.predicate),
163+
);
158164
TypeErrCtxt {
159165
infcx: &self.infcx,
166+
sub_relations: RefCell::new(sub_relations),
160167
typeck_results: Some(self.typeck_results.borrow()),
161168
fallback_has_occurred: self.fallback_has_occurred.get(),
162169
normalize_fn_sig: Box::new(|fn_sig| {

Diff for: compiler/rustc_infer/src/infer/error_reporting/mod.rs

+3
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ mod note_and_explain;
8888
mod suggest;
8989

9090
pub(crate) mod need_type_info;
91+
pub mod sub_relations;
9192
pub use need_type_info::TypeAnnotationNeeded;
9293

9394
pub mod nice_region_error;
@@ -123,6 +124,8 @@ fn escape_literal(s: &str) -> String {
123124
/// methods which should not be used during the happy path.
124125
pub struct TypeErrCtxt<'a, 'tcx> {
125126
pub infcx: &'a InferCtxt<'tcx>,
127+
pub sub_relations: std::cell::RefCell<sub_relations::SubRelations>,
128+
126129
pub typeck_results: Option<std::cell::Ref<'a, ty::TypeckResults<'tcx>>>,
127130
pub fallback_has_occurred: bool,
128131

Diff for: compiler/rustc_infer/src/infer/error_reporting/need_type_info.rs

+20-23
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,7 @@ impl<'tcx> TypeErrCtxt<'_, 'tcx> {
502502
parent_name,
503503
});
504504

505-
let args = if self.infcx.tcx.get_diagnostic_item(sym::iterator_collect_fn)
505+
let args = if self.tcx.get_diagnostic_item(sym::iterator_collect_fn)
506506
== Some(generics_def_id)
507507
{
508508
"Vec<_>".to_string()
@@ -710,7 +710,7 @@ struct InsertableGenericArgs<'tcx> {
710710
/// While doing so, the currently best spot is stored in `infer_source`.
711711
/// For details on how we rank spots, see [Self::source_cost]
712712
struct FindInferSourceVisitor<'a, 'tcx> {
713-
infcx: &'a InferCtxt<'tcx>,
713+
tecx: &'a TypeErrCtxt<'a, 'tcx>,
714714
typeck_results: &'a TypeckResults<'tcx>,
715715

716716
target: GenericArg<'tcx>,
@@ -722,12 +722,12 @@ struct FindInferSourceVisitor<'a, 'tcx> {
722722

723723
impl<'a, 'tcx> FindInferSourceVisitor<'a, 'tcx> {
724724
fn new(
725-
infcx: &'a InferCtxt<'tcx>,
725+
tecx: &'a TypeErrCtxt<'a, 'tcx>,
726726
typeck_results: &'a TypeckResults<'tcx>,
727727
target: GenericArg<'tcx>,
728728
) -> Self {
729729
FindInferSourceVisitor {
730-
infcx,
730+
tecx,
731731
typeck_results,
732732

733733
target,
@@ -778,7 +778,7 @@ impl<'a, 'tcx> FindInferSourceVisitor<'a, 'tcx> {
778778
}
779779

780780
// The sources are listed in order of preference here.
781-
let tcx = self.infcx.tcx;
781+
let tcx = self.tecx.tcx;
782782
let ctx = CostCtxt { tcx };
783783
match source.kind {
784784
InferSourceKind::LetBinding { ty, .. } => ctx.ty_cost(ty),
@@ -829,12 +829,12 @@ impl<'a, 'tcx> FindInferSourceVisitor<'a, 'tcx> {
829829

830830
fn node_args_opt(&self, hir_id: HirId) -> Option<GenericArgsRef<'tcx>> {
831831
let args = self.typeck_results.node_args_opt(hir_id);
832-
self.infcx.resolve_vars_if_possible(args)
832+
self.tecx.resolve_vars_if_possible(args)
833833
}
834834

835835
fn opt_node_type(&self, hir_id: HirId) -> Option<Ty<'tcx>> {
836836
let ty = self.typeck_results.node_type_opt(hir_id);
837-
self.infcx.resolve_vars_if_possible(ty)
837+
self.tecx.resolve_vars_if_possible(ty)
838838
}
839839

840840
// Check whether this generic argument is the inference variable we
@@ -849,20 +849,17 @@ impl<'a, 'tcx> FindInferSourceVisitor<'a, 'tcx> {
849849
use ty::{Infer, TyVar};
850850
match (inner_ty.kind(), target_ty.kind()) {
851851
(&Infer(TyVar(a_vid)), &Infer(TyVar(b_vid))) => {
852-
self.infcx.inner.borrow_mut().type_variables().sub_unified(a_vid, b_vid)
852+
self.tecx.sub_relations.borrow_mut().unified(self.tecx, a_vid, b_vid)
853853
}
854854
_ => false,
855855
}
856856
}
857857
(GenericArgKind::Const(inner_ct), GenericArgKind::Const(target_ct)) => {
858858
use ty::InferConst::*;
859859
match (inner_ct.kind(), target_ct.kind()) {
860-
(ty::ConstKind::Infer(Var(a_vid)), ty::ConstKind::Infer(Var(b_vid))) => self
861-
.infcx
862-
.inner
863-
.borrow_mut()
864-
.const_unification_table()
865-
.unioned(a_vid, b_vid),
860+
(ty::ConstKind::Infer(Var(a_vid)), ty::ConstKind::Infer(Var(b_vid))) => {
861+
self.tecx.inner.borrow_mut().const_unification_table().unioned(a_vid, b_vid)
862+
}
866863
_ => false,
867864
}
868865
}
@@ -917,7 +914,7 @@ impl<'a, 'tcx> FindInferSourceVisitor<'a, 'tcx> {
917914
&self,
918915
expr: &'tcx hir::Expr<'tcx>,
919916
) -> Box<dyn Iterator<Item = InsertableGenericArgs<'tcx>> + 'a> {
920-
let tcx = self.infcx.tcx;
917+
let tcx = self.tecx.tcx;
921918
match expr.kind {
922919
hir::ExprKind::Path(ref path) => {
923920
if let Some(args) = self.node_args_opt(expr.hir_id) {
@@ -980,7 +977,7 @@ impl<'a, 'tcx> FindInferSourceVisitor<'a, 'tcx> {
980977
path: &'tcx hir::Path<'tcx>,
981978
args: GenericArgsRef<'tcx>,
982979
) -> impl Iterator<Item = InsertableGenericArgs<'tcx>> + 'a {
983-
let tcx = self.infcx.tcx;
980+
let tcx = self.tecx.tcx;
984981
let have_turbofish = path.segments.iter().any(|segment| {
985982
segment.args.is_some_and(|args| args.args.iter().any(|arg| arg.is_ty_or_const()))
986983
});
@@ -1034,7 +1031,7 @@ impl<'a, 'tcx> FindInferSourceVisitor<'a, 'tcx> {
10341031
args: GenericArgsRef<'tcx>,
10351032
qpath: &'tcx hir::QPath<'tcx>,
10361033
) -> Box<dyn Iterator<Item = InsertableGenericArgs<'tcx>> + 'a> {
1037-
let tcx = self.infcx.tcx;
1034+
let tcx = self.tecx.tcx;
10381035
match qpath {
10391036
hir::QPath::Resolved(_self_ty, path) => {
10401037
Box::new(self.resolved_path_inferred_arg_iter(path, args))
@@ -1107,7 +1104,7 @@ impl<'a, 'tcx> Visitor<'tcx> for FindInferSourceVisitor<'a, 'tcx> {
11071104
type NestedFilter = nested_filter::OnlyBodies;
11081105

11091106
fn nested_visit_map(&mut self) -> Self::Map {
1110-
self.infcx.tcx.hir()
1107+
self.tecx.tcx.hir()
11111108
}
11121109

11131110
fn visit_local(&mut self, local: &'tcx Local<'tcx>) {
@@ -1163,7 +1160,7 @@ impl<'a, 'tcx> Visitor<'tcx> for FindInferSourceVisitor<'a, 'tcx> {
11631160

11641161
#[instrument(level = "debug", skip(self))]
11651162
fn visit_expr(&mut self, expr: &'tcx Expr<'tcx>) {
1166-
let tcx = self.infcx.tcx;
1163+
let tcx = self.tecx.tcx;
11671164
match expr.kind {
11681165
// When encountering `func(arg)` first look into `arg` and then `func`,
11691166
// as `arg` is "more specific".
@@ -1194,7 +1191,7 @@ impl<'a, 'tcx> Visitor<'tcx> for FindInferSourceVisitor<'a, 'tcx> {
11941191
if generics.parent.is_none() && generics.has_self {
11951192
argument_index += 1;
11961193
}
1197-
let args = self.infcx.resolve_vars_if_possible(args);
1194+
let args = self.tecx.resolve_vars_if_possible(args);
11981195
let generic_args =
11991196
&generics.own_args_no_defaults(tcx, args)[generics.own_counts().lifetimes..];
12001197
let span = match expr.kind {
@@ -1224,7 +1221,7 @@ impl<'a, 'tcx> Visitor<'tcx> for FindInferSourceVisitor<'a, 'tcx> {
12241221
{
12251222
let output = args.as_closure().sig().output().skip_binder();
12261223
if self.generic_arg_contains_target(output.into()) {
1227-
let body = self.infcx.tcx.hir().body(body);
1224+
let body = self.tecx.tcx.hir().body(body);
12281225
let should_wrap_expr = if matches!(body.value.kind, ExprKind::Block(..)) {
12291226
None
12301227
} else {
@@ -1252,12 +1249,12 @@ impl<'a, 'tcx> Visitor<'tcx> for FindInferSourceVisitor<'a, 'tcx> {
12521249
&& let Some(args) = self.node_args_opt(expr.hir_id)
12531250
&& args.iter().any(|arg| self.generic_arg_contains_target(arg))
12541251
&& let Some(def_id) = self.typeck_results.type_dependent_def_id(expr.hir_id)
1255-
&& self.infcx.tcx.trait_of_item(def_id).is_some()
1252+
&& self.tecx.tcx.trait_of_item(def_id).is_some()
12561253
&& !has_impl_trait(def_id)
12571254
{
12581255
let successor =
12591256
method_args.get(0).map_or_else(|| (")", span.hi()), |arg| (", ", arg.span.lo()));
1260-
let args = self.infcx.resolve_vars_if_possible(args);
1257+
let args = self.tecx.resolve_vars_if_possible(args);
12611258
self.update_infer_source(InferSource {
12621259
span: path.ident.span,
12631260
kind: InferSourceKind::FullyQualifiedMethodCall {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
use rustc_data_structures::fx::FxHashMap;
2+
use rustc_data_structures::undo_log::NoUndo;
3+
use rustc_data_structures::unify as ut;
4+
use rustc_middle::ty;
5+
6+
use crate::infer::InferCtxt;
7+
8+
#[derive(Debug, Copy, Clone, PartialEq)]
9+
struct SubId(u32);
10+
impl ut::UnifyKey for SubId {
11+
type Value = ();
12+
#[inline]
13+
fn index(&self) -> u32 {
14+
self.0
15+
}
16+
#[inline]
17+
fn from_index(i: u32) -> SubId {
18+
SubId(i)
19+
}
20+
fn tag() -> &'static str {
21+
"SubId"
22+
}
23+
}
24+
25+
/// When reporting ambiguity errors, we sometimes want to
26+
/// treat all inference vars which are subtypes of each
27+
/// others as if they are equal. For this case we compute
28+
/// the transitive closure of our subtype obligations here.
29+
///
30+
/// E.g. when encountering ambiguity errors, we want to suggest
31+
/// specifying some method argument or to add a type annotation
32+
/// to a local variable. Because subtyping cannot change the
33+
/// shape of a type, it's fine if the cause of the ambiguity error
34+
/// is only related to the suggested variable via subtyping.
35+
///
36+
/// Even for something like `let x = returns_arg(); x.method();` the
37+
/// type of `x` is only a supertype of the argument of `returns_arg`. We
38+
/// still want to suggest specifying the type of the argument.
39+
#[derive(Default)]
40+
pub struct SubRelations {
41+
map: FxHashMap<ty::TyVid, SubId>,
42+
table: ut::UnificationTableStorage<SubId>,
43+
}
44+
45+
impl SubRelations {
46+
fn get_id<'tcx>(&mut self, infcx: &InferCtxt<'tcx>, vid: ty::TyVid) -> SubId {
47+
let root_vid = infcx.root_var(vid);
48+
*self.map.entry(root_vid).or_insert_with(|| self.table.with_log(&mut NoUndo).new_key(()))
49+
}
50+
51+
pub fn add_constraints<'tcx>(
52+
&mut self,
53+
infcx: &InferCtxt<'tcx>,
54+
obls: impl IntoIterator<Item = ty::Predicate<'tcx>>,
55+
) {
56+
for p in obls {
57+
let (a, b) = match p.kind().skip_binder() {
58+
ty::PredicateKind::Subtype(ty::SubtypePredicate { a_is_expected: _, a, b }) => {
59+
(a, b)
60+
}
61+
ty::PredicateKind::Coerce(ty::CoercePredicate { a, b }) => (a, b),
62+
_ => continue,
63+
};
64+
65+
match (a.kind(), b.kind()) {
66+
(&ty::Infer(ty::TyVar(a_vid)), &ty::Infer(ty::TyVar(b_vid))) => {
67+
let a = self.get_id(infcx, a_vid);
68+
let b = self.get_id(infcx, b_vid);
69+
self.table.with_log(&mut NoUndo).unify_var_var(a, b).unwrap();
70+
}
71+
_ => continue,
72+
}
73+
}
74+
}
75+
76+
pub fn unified<'tcx>(&mut self, infcx: &InferCtxt<'tcx>, a: ty::TyVid, b: ty::TyVid) -> bool {
77+
let a = self.get_id(infcx, a);
78+
let b = self.get_id(infcx, b);
79+
self.table.with_log(&mut NoUndo).unioned(a, b)
80+
}
81+
}

Diff for: compiler/rustc_infer/src/infer/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -762,6 +762,7 @@ impl<'tcx> InferCtxt<'tcx> {
762762
pub fn err_ctxt(&self) -> TypeErrCtxt<'_, 'tcx> {
763763
TypeErrCtxt {
764764
infcx: self,
765+
sub_relations: Default::default(),
765766
typeck_results: None,
766767
fallback_has_occurred: false,
767768
normalize_fn_sig: Box::new(|fn_sig| fn_sig),
@@ -1029,7 +1030,6 @@ impl<'tcx> InferCtxt<'tcx> {
10291030
let r_b = self.shallow_resolve(predicate.skip_binder().b);
10301031
match (r_a.kind(), r_b.kind()) {
10311032
(&ty::Infer(ty::TyVar(a_vid)), &ty::Infer(ty::TyVar(b_vid))) => {
1032-
self.inner.borrow_mut().type_variables().sub(a_vid, b_vid);
10331033
return Err((a_vid, b_vid));
10341034
}
10351035
_ => {}

Diff for: compiler/rustc_infer/src/infer/relate/generalize.rs

+4-12
Original file line numberDiff line numberDiff line change
@@ -217,10 +217,9 @@ impl<'tcx> InferCtxt<'tcx> {
217217
) -> RelateResult<'tcx, Generalization<T>> {
218218
assert!(!source_term.has_escaping_bound_vars());
219219
let (for_universe, root_vid) = match target_vid.into() {
220-
ty::TermVid::Ty(ty_vid) => (
221-
self.probe_ty_var(ty_vid).unwrap_err(),
222-
ty::TermVid::Ty(self.inner.borrow_mut().type_variables().sub_root_var(ty_vid)),
223-
),
220+
ty::TermVid::Ty(ty_vid) => {
221+
(self.probe_ty_var(ty_vid).unwrap_err(), ty::TermVid::Ty(self.root_var(ty_vid)))
222+
}
224223
ty::TermVid::Const(ct_vid) => (
225224
self.probe_const_var(ct_vid).unwrap_err(),
226225
ty::TermVid::Const(
@@ -424,9 +423,7 @@ impl<'tcx> TypeRelation<'tcx> for Generalizer<'_, 'tcx> {
424423
ty::Infer(ty::TyVar(vid)) => {
425424
let mut inner = self.infcx.inner.borrow_mut();
426425
let vid = inner.type_variables().root_var(vid);
427-
let sub_vid = inner.type_variables().sub_root_var(vid);
428-
429-
if ty::TermVid::Ty(sub_vid) == self.root_vid {
426+
if ty::TermVid::Ty(vid) == self.root_vid {
430427
// If sub-roots are equal, then `root_vid` and
431428
// `vid` are related via subtyping.
432429
Err(self.cyclic_term_error())
@@ -461,11 +458,6 @@ impl<'tcx> TypeRelation<'tcx> for Generalizer<'_, 'tcx> {
461458
let new_var_id =
462459
inner.type_variables().new_var(self.for_universe, origin);
463460
let u = Ty::new_var(self.tcx(), new_var_id);
464-
465-
// Record that we replaced `vid` with `new_var_id` as part of a generalization
466-
// operation. This is needed to detect cyclic types. To see why, see the
467-
// docs in the `type_variables` module.
468-
inner.type_variables().sub(vid, new_var_id);
469461
debug!("replacing original vid={:?} with new={:?}", vid, u);
470462
Ok(u)
471463
}

0 commit comments

Comments
 (0)