Skip to content

Commit 2dacf7a

Browse files
Collect relevant item bounds from trait clauses for nested rigid projections, GATs
1 parent b511753 commit 2dacf7a

File tree

4 files changed

+286
-10
lines changed

4 files changed

+286
-10
lines changed

compiler/rustc_hir_analysis/src/collect/item_bounds.rs

+216-10
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
use rustc_data_structures::fx::FxIndexSet;
1+
use rustc_data_structures::fx::{FxIndexMap, FxIndexSet};
22
use rustc_hir as hir;
33
use rustc_infer::traits::util;
4+
use rustc_middle::ty::fold::shift_vars;
45
use rustc_middle::ty::{
5-
self, GenericArgs, Ty, TyCtxt, TypeFoldable, TypeFolder, TypeSuperFoldable,
6+
self, GenericArgs, Ty, TyCtxt, TypeFoldable, TypeFolder, TypeSuperFoldable, TypeVisitableExt,
67
};
78
use rustc_middle::{bug, span_bug};
89
use rustc_span::Span;
@@ -42,14 +43,110 @@ fn associated_type_bounds<'tcx>(
4243
let trait_def_id = tcx.local_parent(assoc_item_def_id);
4344
let trait_predicates = tcx.trait_explicit_predicates_and_bounds(trait_def_id);
4445

45-
let bounds_from_parent = trait_predicates.predicates.iter().copied().filter(|(pred, _)| {
46-
match pred.kind().skip_binder() {
47-
ty::ClauseKind::Trait(tr) => tr.self_ty() == item_ty,
48-
ty::ClauseKind::Projection(proj) => proj.projection_term.self_ty() == item_ty,
49-
ty::ClauseKind::TypeOutlives(outlives) => outlives.0 == item_ty,
50-
_ => false,
51-
}
52-
});
46+
let item_trait_ref = ty::TraitRef::identity(tcx, tcx.parent(assoc_item_def_id.to_def_id()));
47+
let bounds_from_parent =
48+
trait_predicates.predicates.iter().copied().filter_map(|(pred, span)| {
49+
let mut clause_ty = match pred.kind().skip_binder() {
50+
ty::ClauseKind::Trait(tr) => tr.self_ty(),
51+
ty::ClauseKind::Projection(proj) => proj.projection_term.self_ty(),
52+
ty::ClauseKind::TypeOutlives(outlives) => outlives.0,
53+
_ => return None,
54+
};
55+
56+
// The code below is quite involved, so let me explain.
57+
//
58+
// We loop here, because we also want to collect vars for nested associated items as
59+
// well. For example, given a clause like `Self::A::B`, we want to add that to the
60+
// item bounds for `A`, so that we may use that bound in the case that `Self::A::B` is
61+
// rigid.
62+
//
63+
// Secondly, regarding bound vars, when we see a where clause that mentions a GAT
64+
// like `for<'a, ...> Self::Assoc<'a, ...>: Bound<'b, ...>`, we want to turn that into
65+
// an item bound on the GAT, where all of the GAT args are substituted with the GAT's
66+
// param regions, and then keep all of the other late-bound vars in the bound around.
67+
// We need to "compress" the binder so that it doesn't mention any of those vars that
68+
// were mapped to params.
69+
let gat_vars = loop {
70+
if let ty::Alias(ty::Projection, alias_ty) = *clause_ty.kind() {
71+
if alias_ty.trait_ref(tcx) == item_trait_ref
72+
&& alias_ty.def_id == assoc_item_def_id.to_def_id()
73+
{
74+
break &alias_ty.args[item_trait_ref.args.len()..];
75+
} else {
76+
// Only collect *self* type bounds if the filter is for self.
77+
match filter {
78+
PredicateFilter::SelfOnly | PredicateFilter::SelfThatDefines(_) => {
79+
return None;
80+
}
81+
PredicateFilter::All | PredicateFilter::SelfAndAssociatedTypeBounds => {
82+
}
83+
}
84+
85+
clause_ty = alias_ty.self_ty();
86+
continue;
87+
}
88+
}
89+
90+
return None;
91+
};
92+
// Special-case: No GAT vars, no mapping needed.
93+
if gat_vars.is_empty() {
94+
return Some((pred, span));
95+
}
96+
97+
// First, check that all of the GAT args are substituted with a unique late-bound arg.
98+
// If we find a duplicate, then it can't be mapped to the definition's params.
99+
let mut mapping = FxIndexMap::default();
100+
let generics = tcx.generics_of(assoc_item_def_id);
101+
for (param, var) in std::iter::zip(&generics.own_params, gat_vars) {
102+
let existing = match var.unpack() {
103+
ty::GenericArgKind::Lifetime(re) => {
104+
if let ty::RegionKind::ReBound(ty::INNERMOST, bv) = re.kind() {
105+
mapping.insert(bv.var, tcx.mk_param_from_def(param))
106+
} else {
107+
return None;
108+
}
109+
}
110+
ty::GenericArgKind::Type(ty) => {
111+
if let ty::Bound(ty::INNERMOST, bv) = *ty.kind() {
112+
mapping.insert(bv.var, tcx.mk_param_from_def(param))
113+
} else {
114+
return None;
115+
}
116+
}
117+
ty::GenericArgKind::Const(ct) => {
118+
if let ty::ConstKind::Bound(ty::INNERMOST, bv) = ct.kind() {
119+
mapping.insert(bv, tcx.mk_param_from_def(param))
120+
} else {
121+
return None;
122+
}
123+
}
124+
};
125+
126+
if existing.is_some() {
127+
return None;
128+
}
129+
}
130+
131+
// Finally, map all of the args in the GAT to the params we expect, and compress
132+
// the remaining late-bound vars so that they count up from var 0.
133+
let mut folder = MapAndCompressBoundVars {
134+
tcx,
135+
binder: ty::INNERMOST,
136+
still_bound_vars: vec![],
137+
mapping,
138+
};
139+
let pred = pred.kind().skip_binder().fold_with(&mut folder);
140+
141+
Some((
142+
ty::Binder::bind_with_vars(
143+
pred,
144+
tcx.mk_bound_variable_kinds(&folder.still_bound_vars),
145+
)
146+
.upcast(tcx),
147+
span,
148+
))
149+
});
53150

54151
let all_bounds = tcx.arena.alloc_from_iter(bounds.clauses(tcx).chain(bounds_from_parent));
55152
debug!(
@@ -63,6 +160,115 @@ fn associated_type_bounds<'tcx>(
63160
all_bounds
64161
}
65162

163+
struct MapAndCompressBoundVars<'tcx> {
164+
tcx: TyCtxt<'tcx>,
165+
/// How deep are we? Makes sure we don't touch the vars of nested binders.
166+
binder: ty::DebruijnIndex,
167+
/// List of bound vars that remain unsubstituted because they were not
168+
/// mentioned in the GAT's args.
169+
still_bound_vars: Vec<ty::BoundVariableKind>,
170+
/// Subtle invariant: If the `GenericArg` is bound, then it should be
171+
/// stored with the debruijn index of `INNERMOST` so it can be shifted
172+
/// correctly during substitution.
173+
mapping: FxIndexMap<ty::BoundVar, ty::GenericArg<'tcx>>,
174+
}
175+
176+
impl<'tcx> TypeFolder<TyCtxt<'tcx>> for MapAndCompressBoundVars<'tcx> {
177+
fn cx(&self) -> TyCtxt<'tcx> {
178+
self.tcx
179+
}
180+
181+
fn fold_binder<T>(&mut self, t: ty::Binder<'tcx, T>) -> ty::Binder<'tcx, T>
182+
where
183+
ty::Binder<'tcx, T>: TypeSuperFoldable<TyCtxt<'tcx>>,
184+
{
185+
self.binder.shift_in(1);
186+
let out = t.super_fold_with(self);
187+
self.binder.shift_out(1);
188+
out
189+
}
190+
191+
fn fold_ty(&mut self, ty: Ty<'tcx>) -> Ty<'tcx> {
192+
if !ty.has_bound_vars() {
193+
return ty;
194+
}
195+
196+
if let ty::Bound(binder, old_bound) = *ty.kind()
197+
&& self.binder == binder
198+
{
199+
let mapped = if let Some(mapped) = self.mapping.get(&old_bound.var) {
200+
mapped.expect_ty()
201+
} else {
202+
// If we didn't find a mapped generic, then make a new one.
203+
// Allocate a new var idx, and insert a new bound ty.
204+
let var = ty::BoundVar::from_usize(self.still_bound_vars.len());
205+
self.still_bound_vars.push(ty::BoundVariableKind::Ty(old_bound.kind));
206+
let mapped = Ty::new_bound(self.tcx, ty::INNERMOST, ty::BoundTy {
207+
var,
208+
kind: old_bound.kind,
209+
});
210+
self.mapping.insert(old_bound.var, mapped.into());
211+
mapped
212+
};
213+
214+
shift_vars(self.tcx, mapped, self.binder.as_u32())
215+
} else {
216+
ty.super_fold_with(self)
217+
}
218+
}
219+
220+
fn fold_region(&mut self, re: ty::Region<'tcx>) -> ty::Region<'tcx> {
221+
if let ty::ReBound(binder, old_bound) = re.kind()
222+
&& self.binder == binder
223+
{
224+
let mapped = if let Some(mapped) = self.mapping.get(&old_bound.var) {
225+
mapped.expect_region()
226+
} else {
227+
let var = ty::BoundVar::from_usize(self.still_bound_vars.len());
228+
self.still_bound_vars.push(ty::BoundVariableKind::Region(old_bound.kind));
229+
let mapped = ty::Region::new_bound(self.tcx, ty::INNERMOST, ty::BoundRegion {
230+
var,
231+
kind: old_bound.kind,
232+
});
233+
self.mapping.insert(old_bound.var, mapped.into());
234+
mapped
235+
};
236+
237+
shift_vars(self.tcx, mapped, self.binder.as_u32())
238+
} else {
239+
re
240+
}
241+
}
242+
243+
fn fold_const(&mut self, ct: ty::Const<'tcx>) -> ty::Const<'tcx> {
244+
if !ct.has_bound_vars() {
245+
return ct;
246+
}
247+
248+
if let ty::ConstKind::Bound(binder, old_var) = ct.kind()
249+
&& self.binder == binder
250+
{
251+
let mapped = if let Some(mapped) = self.mapping.get(&old_var) {
252+
mapped.expect_const()
253+
} else {
254+
let var = ty::BoundVar::from_usize(self.still_bound_vars.len());
255+
self.still_bound_vars.push(ty::BoundVariableKind::Const);
256+
let mapped = ty::Const::new_bound(self.tcx, ty::INNERMOST, var);
257+
self.mapping.insert(old_var, mapped.into());
258+
mapped
259+
};
260+
261+
shift_vars(self.tcx, mapped, self.binder.as_u32())
262+
} else {
263+
ct.super_fold_with(self)
264+
}
265+
}
266+
267+
fn fold_predicate(&mut self, p: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> {
268+
if !p.has_bound_vars() { p } else { p.super_fold_with(self) }
269+
}
270+
}
271+
66272
/// Opaque types don't inherit bounds from their parent: for return position
67273
/// impl trait it isn't possible to write a suitable predicate on the
68274
/// containing function and for type-alias impl trait we don't have a backwards
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
//@ check-pass
2+
//@ revisions: current next
3+
//@[next] compile-flags: -Znext-solver
4+
5+
trait Trait
6+
where
7+
Self::Assoc: Clone,
8+
{
9+
type Assoc;
10+
}
11+
12+
fn foo<T: Trait>(x: &T::Assoc) -> T::Assoc {
13+
x.clone()
14+
}
15+
16+
trait Trait2
17+
where
18+
Self::Assoc: Iterator,
19+
<Self::Assoc as Iterator>::Item: Clone,
20+
{
21+
type Assoc;
22+
}
23+
24+
fn foo2<T: Trait2>(x: &<T::Assoc as Iterator>::Item) -> <T::Assoc as Iterator>::Item {
25+
x.clone()
26+
}
27+
28+
fn main() {}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
//@ check-pass
2+
3+
// Test that `for<'a> Self::Gat<'a>: Debug` is implied in the definition of `Foo`,
4+
// just as it would be if it weren't a GAT but just a regular associated type.
5+
6+
use std::fmt::Debug;
7+
8+
trait Foo
9+
where
10+
for<'a> Self::Gat<'a>: Debug,
11+
{
12+
type Gat<'a>;
13+
}
14+
15+
fn test<T: Foo>(x: T::Gat<'static>) {
16+
println!("{:?}", x);
17+
}
18+
19+
fn main() {}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
//@ check-pass
2+
//@ revisions: current next
3+
//@[next] compile-flags: -Znext-solver
4+
5+
trait Foo
6+
where
7+
Self::Iterator: Iterator,
8+
<Self::Iterator as Iterator>::Item: Bar,
9+
{
10+
type Iterator;
11+
12+
fn iter() -> Self::Iterator;
13+
}
14+
15+
trait Bar {
16+
fn bar(&self);
17+
}
18+
19+
fn x<T: Foo>() {
20+
T::iter().next().unwrap().bar();
21+
}
22+
23+
fn main() {}

0 commit comments

Comments
 (0)