Skip to content

Commit c6e1e17

Browse files
committed
Auto merge of rust-lang#13494 - lowr:feat/gats, r=flodiebold
feat: type inference for generic associated types This PR implements type inference for generic associated types. Basically, this PR lowers generic arguments for associated types in valid places and creates `Substitution`s for them. I focused on the inference for correct Rust programs, so there are cases where we *accidentally* manage to infer things that are actually invalid (which would then be reported by flycheck so I deem them non-fatal). See the following tests and FIXME notes on them: `gats_with_dyn`, `gats_with_impl_trait`. The added tests are rather arbitrary. Let me know if there are cases I'm missing or I should add. Closes rust-lang#9673
2 parents 0340b51 + 5fc18ad commit c6e1e17

File tree

13 files changed

+381
-98
lines changed

13 files changed

+381
-98
lines changed

crates/hir-def/src/item_tree/lower.rs

+6-2
Original file line numberDiff line numberDiff line change
@@ -662,8 +662,12 @@ fn desugar_future_path(orig: TypeRef) -> Path {
662662
let mut generic_args: Vec<_> =
663663
std::iter::repeat(None).take(path.segments().len() - 1).collect();
664664
let mut last = GenericArgs::empty();
665-
let binding =
666-
AssociatedTypeBinding { name: name![Output], type_ref: Some(orig), bounds: Vec::new() };
665+
let binding = AssociatedTypeBinding {
666+
name: name![Output],
667+
args: None,
668+
type_ref: Some(orig),
669+
bounds: Vec::new(),
670+
};
667671
last.bindings.push(binding);
668672
generic_args.push(Some(Interned::new(last)));
669673

crates/hir-def/src/path.rs

+3
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ pub struct GenericArgs {
6868
pub struct AssociatedTypeBinding {
6969
/// The name of the associated type.
7070
pub name: Name,
71+
/// The generic arguments to the associated type. e.g. For `Trait<Assoc<'a, T> = &'a T>`, this
72+
/// would be `['a, T]`.
73+
pub args: Option<Interned<GenericArgs>>,
7174
/// The type bound to this associated type (in `Item = T`, this would be the
7275
/// `T`). This can be `None` if there are bounds instead.
7376
pub type_ref: Option<TypeRef>,

crates/hir-def/src/path/lower.rs

+7-1
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,10 @@ pub(super) fn lower_generic_args(
163163
ast::GenericArg::AssocTypeArg(assoc_type_arg) => {
164164
if let Some(name_ref) = assoc_type_arg.name_ref() {
165165
let name = name_ref.as_name();
166+
let args = assoc_type_arg
167+
.generic_arg_list()
168+
.and_then(|args| lower_generic_args(lower_ctx, args))
169+
.map(Interned::new);
166170
let type_ref = assoc_type_arg.ty().map(|it| TypeRef::from_ast(lower_ctx, it));
167171
let bounds = if let Some(l) = assoc_type_arg.type_bound_list() {
168172
l.bounds()
@@ -171,7 +175,7 @@ pub(super) fn lower_generic_args(
171175
} else {
172176
Vec::new()
173177
};
174-
bindings.push(AssociatedTypeBinding { name, type_ref, bounds });
178+
bindings.push(AssociatedTypeBinding { name, args, type_ref, bounds });
175179
}
176180
}
177181
ast::GenericArg::LifetimeArg(lifetime_arg) => {
@@ -214,6 +218,7 @@ fn lower_generic_args_from_fn_path(
214218
let type_ref = TypeRef::from_ast_opt(ctx, ret_type.ty());
215219
bindings.push(AssociatedTypeBinding {
216220
name: name![Output],
221+
args: None,
217222
type_ref: Some(type_ref),
218223
bounds: Vec::new(),
219224
});
@@ -222,6 +227,7 @@ fn lower_generic_args_from_fn_path(
222227
let type_ref = TypeRef::Tuple(Vec::new());
223228
bindings.push(AssociatedTypeBinding {
224229
name: name![Output],
230+
args: None,
225231
type_ref: Some(type_ref),
226232
bounds: Vec::new(),
227233
});

crates/hir-ty/src/chalk_ext.rs

+10-7
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@ use syntax::SmolStr;
1111

1212
use crate::{
1313
db::HirDatabase, from_assoc_type_id, from_chalk_trait_id, from_foreign_def_id,
14-
from_placeholder_idx, to_chalk_trait_id, AdtId, AliasEq, AliasTy, Binders, CallableDefId,
15-
CallableSig, FnPointer, ImplTraitId, Interner, Lifetime, ProjectionTy, QuantifiedWhereClause,
16-
Substitution, TraitRef, Ty, TyBuilder, TyKind, WhereClause,
14+
from_placeholder_idx, to_chalk_trait_id, utils::generics, AdtId, AliasEq, AliasTy, Binders,
15+
CallableDefId, CallableSig, FnPointer, ImplTraitId, Interner, Lifetime, ProjectionTy,
16+
QuantifiedWhereClause, Substitution, TraitRef, Ty, TyBuilder, TyKind, WhereClause,
1717
};
1818

1919
pub trait TyExt {
@@ -338,10 +338,13 @@ pub trait ProjectionTyExt {
338338

339339
impl ProjectionTyExt for ProjectionTy {
340340
fn trait_ref(&self, db: &dyn HirDatabase) -> TraitRef {
341-
TraitRef {
342-
trait_id: to_chalk_trait_id(self.trait_(db)),
343-
substitution: self.substitution.clone(),
344-
}
341+
// FIXME: something like `Split` trait from chalk-solve might be nice.
342+
let generics = generics(db.upcast(), from_assoc_type_id(self.associated_ty_id).into());
343+
let substitution = Substitution::from_iter(
344+
Interner,
345+
self.substitution.iter(Interner).skip(generics.len_self()),
346+
);
347+
TraitRef { trait_id: to_chalk_trait_id(self.trait_(db)), substitution }
345348
}
346349

347350
fn trait_(&self, db: &dyn HirDatabase) -> TraitId {

crates/hir-ty/src/display.rs

+26-10
Original file line numberDiff line numberDiff line change
@@ -289,16 +289,18 @@ impl HirDisplay for ProjectionTy {
289289
return write!(f, "{}", TYPE_HINT_TRUNCATION);
290290
}
291291

292-
let trait_ = f.db.trait_data(self.trait_(f.db));
292+
let trait_ref = self.trait_ref(f.db);
293293
write!(f, "<")?;
294-
self.self_type_parameter(f.db).hir_fmt(f)?;
295-
write!(f, " as {}", trait_.name)?;
296-
if self.substitution.len(Interner) > 1 {
294+
fmt_trait_ref(&trait_ref, f, true)?;
295+
write!(f, ">::{}", f.db.type_alias_data(from_assoc_type_id(self.associated_ty_id)).name)?;
296+
let proj_params_count =
297+
self.substitution.len(Interner) - trait_ref.substitution.len(Interner);
298+
let proj_params = &self.substitution.as_slice(Interner)[..proj_params_count];
299+
if !proj_params.is_empty() {
297300
write!(f, "<")?;
298-
f.write_joined(&self.substitution.as_slice(Interner)[1..], ", ")?;
301+
f.write_joined(proj_params, ", ")?;
299302
write!(f, ">")?;
300303
}
301-
write!(f, ">::{}", f.db.type_alias_data(from_assoc_type_id(self.associated_ty_id)).name)?;
302304
Ok(())
303305
}
304306
}
@@ -641,9 +643,12 @@ impl HirDisplay for Ty {
641643
// Use placeholder associated types when the target is test (https://rust-lang.github.io/chalk/book/clauses/type_equality.html#placeholder-associated-types)
642644
if f.display_target.is_test() {
643645
write!(f, "{}::{}", trait_.name, type_alias_data.name)?;
646+
// Note that the generic args for the associated type come before those for the
647+
// trait (including the self type).
648+
// FIXME: reconsider the generic args order upon formatting?
644649
if parameters.len(Interner) > 0 {
645650
write!(f, "<")?;
646-
f.write_joined(&*parameters.as_slice(Interner), ", ")?;
651+
f.write_joined(parameters.as_slice(Interner), ", ")?;
647652
write!(f, ">")?;
648653
}
649654
} else {
@@ -972,9 +977,20 @@ fn write_bounds_like_dyn_trait(
972977
angle_open = true;
973978
}
974979
if let AliasTy::Projection(proj) = alias {
975-
let type_alias =
976-
f.db.type_alias_data(from_assoc_type_id(proj.associated_ty_id));
977-
write!(f, "{} = ", type_alias.name)?;
980+
let assoc_ty_id = from_assoc_type_id(proj.associated_ty_id);
981+
let type_alias = f.db.type_alias_data(assoc_ty_id);
982+
write!(f, "{}", type_alias.name)?;
983+
984+
let proj_arg_count = generics(f.db.upcast(), assoc_ty_id.into()).len_self();
985+
if proj_arg_count > 0 {
986+
write!(f, "<")?;
987+
f.write_joined(
988+
&proj.substitution.as_slice(Interner)[..proj_arg_count],
989+
", ",
990+
)?;
991+
write!(f, ">")?;
992+
}
993+
write!(f, " = ")?;
978994
}
979995
ty.hir_fmt(f)?;
980996
}

crates/hir-ty/src/infer/path.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ impl<'a> InferenceContext<'a> {
157157
remaining_segments_for_ty,
158158
true,
159159
);
160-
if let TyKind::Error = ty.kind(Interner) {
160+
if ty.is_unknown() {
161161
return None;
162162
}
163163

crates/hir-ty/src/lib.rs

-9
Original file line numberDiff line numberDiff line change
@@ -124,14 +124,6 @@ pub type ConstrainedSubst = chalk_ir::ConstrainedSubst<Interner>;
124124
pub type Guidance = chalk_solve::Guidance<Interner>;
125125
pub type WhereClause = chalk_ir::WhereClause<Interner>;
126126

127-
// FIXME: get rid of this
128-
pub fn subst_prefix(s: &Substitution, n: usize) -> Substitution {
129-
Substitution::from_iter(
130-
Interner,
131-
s.as_slice(Interner)[..std::cmp::min(s.len(Interner), n)].iter().cloned(),
132-
)
133-
}
134-
135127
/// Return an index of a parameter in the generic type parameter list by it's id.
136128
pub fn param_idx(db: &dyn HirDatabase, id: TypeOrConstParamId) -> Option<usize> {
137129
generics(db.upcast(), id.parent).param_idx(id)
@@ -382,7 +374,6 @@ pub(crate) fn fold_tys_and_consts<T: HasInterner<Interner = Interner> + TypeFold
382374
pub fn replace_errors_with_variables<T>(t: &T) -> Canonical<T>
383375
where
384376
T: HasInterner<Interner = Interner> + TypeFoldable<Interner> + Clone,
385-
T: HasInterner<Interner = Interner>,
386377
{
387378
use chalk_ir::{
388379
fold::{FallibleTypeFolder, TypeSuperFoldable},

crates/hir-ty/src/lower.rs

+91-33
Original file line numberDiff line numberDiff line change
@@ -447,12 +447,31 @@ impl<'a> TyLoweringContext<'a> {
447447
.db
448448
.trait_data(trait_ref.hir_trait_id())
449449
.associated_type_by_name(segment.name);
450+
450451
match found {
451452
Some(associated_ty) => {
452-
// FIXME handle type parameters on the segment
453+
// FIXME: `substs_from_path_segment()` pushes `TyKind::Error` for every parent
454+
// generic params. It's inefficient to splice the `Substitution`s, so we may want
455+
// that method to optionally take parent `Substitution` as we already know them at
456+
// this point (`trait_ref.substitution`).
457+
let substitution = self.substs_from_path_segment(
458+
segment,
459+
Some(associated_ty.into()),
460+
false,
461+
None,
462+
);
463+
let len_self =
464+
generics(self.db.upcast(), associated_ty.into()).len_self();
465+
let substitution = Substitution::from_iter(
466+
Interner,
467+
substitution
468+
.iter(Interner)
469+
.take(len_self)
470+
.chain(trait_ref.substitution.iter(Interner)),
471+
);
453472
TyKind::Alias(AliasTy::Projection(ProjectionTy {
454473
associated_ty_id: to_assoc_type_id(associated_ty),
455-
substitution: trait_ref.substitution,
474+
substitution,
456475
}))
457476
.intern(Interner)
458477
}
@@ -590,36 +609,48 @@ impl<'a> TyLoweringContext<'a> {
590609
res,
591610
Some(segment.name.clone()),
592611
move |name, t, associated_ty| {
593-
if name == segment.name {
594-
let substs = match self.type_param_mode {
595-
ParamLoweringMode::Placeholder => {
596-
// if we're lowering to placeholders, we have to put
597-
// them in now
598-
let generics = generics(
599-
self.db.upcast(),
600-
self.resolver
601-
.generic_def()
602-
.expect("there should be generics if there's a generic param"),
603-
);
604-
let s = generics.placeholder_subst(self.db);
605-
s.apply(t.substitution.clone(), Interner)
606-
}
607-
ParamLoweringMode::Variable => t.substitution.clone(),
608-
};
609-
// We need to shift in the bound vars, since
610-
// associated_type_shorthand_candidates does not do that
611-
let substs = substs.shifted_in_from(Interner, self.in_binders);
612-
// FIXME handle type parameters on the segment
613-
Some(
614-
TyKind::Alias(AliasTy::Projection(ProjectionTy {
615-
associated_ty_id: to_assoc_type_id(associated_ty),
616-
substitution: substs,
617-
}))
618-
.intern(Interner),
619-
)
620-
} else {
621-
None
612+
if name != segment.name {
613+
return None;
622614
}
615+
616+
// FIXME: `substs_from_path_segment()` pushes `TyKind::Error` for every parent
617+
// generic params. It's inefficient to splice the `Substitution`s, so we may want
618+
// that method to optionally take parent `Substitution` as we already know them at
619+
// this point (`t.substitution`).
620+
let substs = self.substs_from_path_segment(
621+
segment.clone(),
622+
Some(associated_ty.into()),
623+
false,
624+
None,
625+
);
626+
627+
let len_self = generics(self.db.upcast(), associated_ty.into()).len_self();
628+
629+
let substs = Substitution::from_iter(
630+
Interner,
631+
substs.iter(Interner).take(len_self).chain(t.substitution.iter(Interner)),
632+
);
633+
634+
let substs = match self.type_param_mode {
635+
ParamLoweringMode::Placeholder => {
636+
// if we're lowering to placeholders, we have to put
637+
// them in now
638+
let generics = generics(self.db.upcast(), def);
639+
let s = generics.placeholder_subst(self.db);
640+
s.apply(substs, Interner)
641+
}
642+
ParamLoweringMode::Variable => substs,
643+
};
644+
// We need to shift in the bound vars, since
645+
// associated_type_shorthand_candidates does not do that
646+
let substs = substs.shifted_in_from(Interner, self.in_binders);
647+
Some(
648+
TyKind::Alias(AliasTy::Projection(ProjectionTy {
649+
associated_ty_id: to_assoc_type_id(associated_ty),
650+
substitution: substs,
651+
}))
652+
.intern(Interner),
653+
)
623654
},
624655
);
625656

@@ -777,7 +808,15 @@ impl<'a> TyLoweringContext<'a> {
777808
// handle defaults. In expression or pattern path segments without
778809
// explicitly specified type arguments, missing type arguments are inferred
779810
// (i.e. defaults aren't used).
780-
if !infer_args || had_explicit_args {
811+
// Generic parameters for associated types are not supposed to have defaults, so we just
812+
// ignore them.
813+
let is_assoc_ty = if let GenericDefId::TypeAliasId(id) = def {
814+
let container = id.lookup(self.db.upcast()).container;
815+
matches!(container, ItemContainerId::TraitId(_))
816+
} else {
817+
false
818+
};
819+
if !is_assoc_ty && (!infer_args || had_explicit_args) {
781820
let defaults = self.db.generic_defaults(def);
782821
assert_eq!(total_len, defaults.len());
783822
let parent_from = item_len - substs.len();
@@ -966,9 +1005,28 @@ impl<'a> TyLoweringContext<'a> {
9661005
None => return SmallVec::new(),
9671006
Some(t) => t,
9681007
};
1008+
// FIXME: `substs_from_path_segment()` pushes `TyKind::Error` for every parent
1009+
// generic params. It's inefficient to splice the `Substitution`s, so we may want
1010+
// that method to optionally take parent `Substitution` as we already know them at
1011+
// this point (`super_trait_ref.substitution`).
1012+
let substitution = self.substs_from_path_segment(
1013+
// FIXME: This is hack. We shouldn't really build `PathSegment` directly.
1014+
PathSegment { name: &binding.name, args_and_bindings: binding.args.as_deref() },
1015+
Some(associated_ty.into()),
1016+
false, // this is not relevant
1017+
Some(super_trait_ref.self_type_parameter(Interner)),
1018+
);
1019+
let self_params = generics(self.db.upcast(), associated_ty.into()).len_self();
1020+
let substitution = Substitution::from_iter(
1021+
Interner,
1022+
substitution
1023+
.iter(Interner)
1024+
.take(self_params)
1025+
.chain(super_trait_ref.substitution.iter(Interner)),
1026+
);
9691027
let projection_ty = ProjectionTy {
9701028
associated_ty_id: to_assoc_type_id(associated_ty),
971-
substitution: super_trait_ref.substitution,
1029+
substitution,
9721030
};
9731031
let mut preds: SmallVec<[_; 1]> = SmallVec::with_capacity(
9741032
binding.type_ref.as_ref().map_or(0, |_| 1) + binding.bounds.len(),

crates/hir-ty/src/tests/display_source_code.rs

+31
Original file line numberDiff line numberDiff line change
@@ -196,3 +196,34 @@ fn test(
196196
"#,
197197
);
198198
}
199+
200+
#[test]
201+
fn projection_type_correct_arguments_order() {
202+
check_types_source_code(
203+
r#"
204+
trait Foo<T> {
205+
type Assoc<U>;
206+
}
207+
fn f<T: Foo<i32>>(a: T::Assoc<usize>) {
208+
a;
209+
//^ <T as Foo<i32>>::Assoc<usize>
210+
}
211+
"#,
212+
);
213+
}
214+
215+
#[test]
216+
fn generic_associated_type_binding_in_impl_trait() {
217+
check_types_source_code(
218+
r#"
219+
//- minicore: sized
220+
trait Foo<T> {
221+
type Assoc<U>;
222+
}
223+
fn f(a: impl Foo<i8, Assoc<i16> = i32>) {
224+
a;
225+
//^ impl Foo<i8, Assoc<i16> = i32>
226+
}
227+
"#,
228+
);
229+
}

0 commit comments

Comments
 (0)