Skip to content

Commit d3ebb14

Browse files
authored
Merge pull request rust-lang#18774 from Veykril/push-ysppqxpuknnw
Implement parameter variance inference
2 parents 0337e79 + a102ea1 commit d3ebb14

File tree

23 files changed

+1335
-129
lines changed

23 files changed

+1335
-129
lines changed

src/tools/rust-analyzer/crates/hir-ty/src/chalk_db.rs

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -950,22 +950,33 @@ pub(crate) fn fn_def_datum_query(db: &dyn HirDatabase, fn_def_id: FnDefId) -> Ar
950950

951951
pub(crate) fn fn_def_variance_query(db: &dyn HirDatabase, fn_def_id: FnDefId) -> Variances {
952952
let callable_def: CallableDefId = from_chalk(db, fn_def_id);
953-
let generic_params =
954-
generics(db.upcast(), GenericDefId::from_callable(db.upcast(), callable_def));
955953
Variances::from_iter(
956954
Interner,
957-
std::iter::repeat(chalk_ir::Variance::Invariant).take(generic_params.len()),
955+
db.variances_of(GenericDefId::from_callable(db.upcast(), callable_def))
956+
.as_deref()
957+
.unwrap_or_default()
958+
.iter()
959+
.map(|v| match v {
960+
crate::variance::Variance::Covariant => chalk_ir::Variance::Covariant,
961+
crate::variance::Variance::Invariant => chalk_ir::Variance::Invariant,
962+
crate::variance::Variance::Contravariant => chalk_ir::Variance::Contravariant,
963+
crate::variance::Variance::Bivariant => chalk_ir::Variance::Invariant,
964+
}),
958965
)
959966
}
960967

961968
pub(crate) fn adt_variance_query(
962969
db: &dyn HirDatabase,
963970
chalk_ir::AdtId(adt_id): AdtId,
964971
) -> Variances {
965-
let generic_params = generics(db.upcast(), adt_id.into());
966972
Variances::from_iter(
967973
Interner,
968-
std::iter::repeat(chalk_ir::Variance::Invariant).take(generic_params.len()),
974+
db.variances_of(adt_id.into()).as_deref().unwrap_or_default().iter().map(|v| match v {
975+
crate::variance::Variance::Covariant => chalk_ir::Variance::Covariant,
976+
crate::variance::Variance::Invariant => chalk_ir::Variance::Invariant,
977+
crate::variance::Variance::Contravariant => chalk_ir::Variance::Contravariant,
978+
crate::variance::Variance::Bivariant => chalk_ir::Variance::Invariant,
979+
}),
969980
)
970981
}
971982

src/tools/rust-analyzer/crates/hir-ty/src/chalk_ext.rs

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -443,13 +443,25 @@ impl ProjectionTyExt for ProjectionTy {
443443
}
444444

445445
pub trait DynTyExt {
446-
fn principal(&self) -> Option<&TraitRef>;
446+
fn principal(&self) -> Option<Binders<Binders<&TraitRef>>>;
447+
fn principal_id(&self) -> Option<chalk_ir::TraitId<Interner>>;
447448
}
448449

449450
impl DynTyExt for DynTy {
450-
fn principal(&self) -> Option<&TraitRef> {
451+
fn principal(&self) -> Option<Binders<Binders<&TraitRef>>> {
452+
self.bounds.as_ref().filter_map(|bounds| {
453+
bounds.interned().first().and_then(|b| {
454+
b.as_ref().filter_map(|b| match b {
455+
crate::WhereClause::Implemented(trait_ref) => Some(trait_ref),
456+
_ => None,
457+
})
458+
})
459+
})
460+
}
461+
462+
fn principal_id(&self) -> Option<chalk_ir::TraitId<Interner>> {
451463
self.bounds.skip_binders().interned().first().and_then(|b| match b.skip_binders() {
452-
crate::WhereClause::Implemented(trait_ref) => Some(trait_ref),
464+
crate::WhereClause::Implemented(trait_ref) => Some(trait_ref.trait_id),
453465
_ => None,
454466
})
455467
}

src/tools/rust-analyzer/crates/hir-ty/src/db.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,10 @@ pub trait HirDatabase: DefDatabase + Upcast<dyn DefDatabase> {
271271
#[ra_salsa::invoke(chalk_db::adt_variance_query)]
272272
fn adt_variance(&self, adt_id: chalk_db::AdtId) -> chalk_db::Variances;
273273

274+
#[ra_salsa::invoke(crate::variance::variances_of)]
275+
#[ra_salsa::cycle(crate::variance::variances_of_cycle)]
276+
fn variances_of(&self, def: GenericDefId) -> Option<Arc<[crate::variance::Variance]>>;
277+
274278
#[ra_salsa::invoke(chalk_db::associated_ty_value_query)]
275279
fn associated_ty_value(
276280
&self,

src/tools/rust-analyzer/crates/hir-ty/src/generics.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,14 @@ use triomphe::Arc;
2626

2727
use crate::{db::HirDatabase, lt_to_placeholder_idx, to_placeholder_idx, Interner, Substitution};
2828

29-
pub(crate) fn generics(db: &dyn DefDatabase, def: GenericDefId) -> Generics {
29+
pub fn generics(db: &dyn DefDatabase, def: GenericDefId) -> Generics {
3030
let parent_generics = parent_generic_def(db, def).map(|def| Box::new(generics(db, def)));
3131
let params = db.generic_params(def);
3232
let has_trait_self_param = params.trait_self_param().is_some();
3333
Generics { def, params, parent_generics, has_trait_self_param }
3434
}
3535
#[derive(Clone, Debug)]
36-
pub(crate) struct Generics {
36+
pub struct Generics {
3737
def: GenericDefId,
3838
params: Arc<GenericParams>,
3939
parent_generics: Option<Box<Generics>>,
@@ -153,7 +153,7 @@ impl Generics {
153153
(parent_len, self_param, type_params, const_params, impl_trait_params, lifetime_params)
154154
}
155155

156-
pub(crate) fn type_or_const_param_idx(&self, param: TypeOrConstParamId) -> Option<usize> {
156+
pub fn type_or_const_param_idx(&self, param: TypeOrConstParamId) -> Option<usize> {
157157
self.find_type_or_const_param(param)
158158
}
159159

@@ -174,7 +174,7 @@ impl Generics {
174174
}
175175
}
176176

177-
pub(crate) fn lifetime_idx(&self, lifetime: LifetimeParamId) -> Option<usize> {
177+
pub fn lifetime_idx(&self, lifetime: LifetimeParamId) -> Option<usize> {
178178
self.find_lifetime(lifetime)
179179
}
180180

src/tools/rust-analyzer/crates/hir-ty/src/infer/closure.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@ impl InferenceContext<'_> {
9696
.map(|b| b.into_value_and_skipped_binders().0);
9797
self.deduce_closure_kind_from_predicate_clauses(clauses)
9898
}
99-
TyKind::Dyn(dyn_ty) => dyn_ty.principal().and_then(|trait_ref| {
100-
self.fn_trait_kind_from_trait_id(from_chalk_trait_id(trait_ref.trait_id))
99+
TyKind::Dyn(dyn_ty) => dyn_ty.principal_id().and_then(|trait_id| {
100+
self.fn_trait_kind_from_trait_id(from_chalk_trait_id(trait_id))
101101
}),
102102
TyKind::InferenceVar(ty, chalk_ir::TyVariableKind::General) => {
103103
let clauses = self.clauses_for_self_ty(*ty);

src/tools/rust-analyzer/crates/hir-ty/src/lang_items.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ pub fn is_box(db: &dyn HirDatabase, adt: AdtId) -> bool {
1313

1414
pub fn is_unsafe_cell(db: &dyn HirDatabase, adt: AdtId) -> bool {
1515
let AdtId::StructId(id) = adt else { return false };
16+
1617
db.struct_data(id).flags.contains(StructFlags::IS_UNSAFE_CELL)
1718
}
1819

src/tools/rust-analyzer/crates/hir-ty/src/lib.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ extern crate ra_ap_rustc_pattern_analysis as rustc_pattern_analysis;
2424
mod builder;
2525
mod chalk_db;
2626
mod chalk_ext;
27-
mod generics;
2827
mod infer;
2928
mod inhabitedness;
3029
mod interner;
@@ -39,6 +38,7 @@ pub mod db;
3938
pub mod diagnostics;
4039
pub mod display;
4140
pub mod dyn_compatibility;
41+
pub mod generics;
4242
pub mod lang_items;
4343
pub mod layout;
4444
pub mod method_resolution;
@@ -50,6 +50,7 @@ pub mod traits;
5050
mod test_db;
5151
#[cfg(test)]
5252
mod tests;
53+
mod variance;
5354

5455
use std::hash::Hash;
5556

@@ -88,10 +89,9 @@ pub use infer::{
8889
PointerCast,
8990
};
9091
pub use interner::Interner;
91-
pub use lower::diagnostics::*;
9292
pub use lower::{
93-
associated_type_shorthand_candidates, ImplTraitLoweringMode, ParamLoweringMode, TyDefId,
94-
TyLoweringContext, ValueTyDefId,
93+
associated_type_shorthand_candidates, diagnostics::*, ImplTraitLoweringMode, ParamLoweringMode,
94+
TyDefId, TyLoweringContext, ValueTyDefId,
9595
};
9696
pub use mapping::{
9797
from_assoc_type_id, from_chalk_trait_id, from_foreign_def_id, from_placeholder_idx,
@@ -101,6 +101,7 @@ pub use mapping::{
101101
pub use method_resolution::check_orphan_rules;
102102
pub use traits::TraitEnvironment;
103103
pub use utils::{all_super_traits, direct_super_traits, is_fn_unsafe_to_call};
104+
pub use variance::Variance;
104105

105106
pub use chalk_ir::{
106107
cast::Cast,

src/tools/rust-analyzer/crates/hir-ty/src/method_resolution.rs

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -805,8 +805,8 @@ fn is_inherent_impl_coherent(
805805
| TyKind::Scalar(_) => def_map.is_rustc_coherence_is_core(),
806806

807807
&TyKind::Adt(AdtId(adt), _) => adt.module(db.upcast()).krate() == def_map.krate(),
808-
TyKind::Dyn(it) => it.principal().map_or(false, |trait_ref| {
809-
from_chalk_trait_id(trait_ref.trait_id).module(db.upcast()).krate() == def_map.krate()
808+
TyKind::Dyn(it) => it.principal_id().map_or(false, |trait_id| {
809+
from_chalk_trait_id(trait_id).module(db.upcast()).krate() == def_map.krate()
810810
}),
811811

812812
_ => true,
@@ -834,9 +834,8 @@ fn is_inherent_impl_coherent(
834834
.contains(StructFlags::IS_RUSTC_HAS_INCOHERENT_INHERENT_IMPL),
835835
hir_def::AdtId::EnumId(it) => db.enum_data(it).rustc_has_incoherent_inherent_impls,
836836
},
837-
TyKind::Dyn(it) => it.principal().map_or(false, |trait_ref| {
838-
db.trait_data(from_chalk_trait_id(trait_ref.trait_id))
839-
.rustc_has_incoherent_inherent_impls
837+
TyKind::Dyn(it) => it.principal_id().map_or(false, |trait_id| {
838+
db.trait_data(from_chalk_trait_id(trait_id)).rustc_has_incoherent_inherent_impls
840839
}),
841840

842841
_ => false,
@@ -896,8 +895,8 @@ pub fn check_orphan_rules(db: &dyn HirDatabase, impl_: ImplId) -> bool {
896895
match unwrap_fundamental(ty).kind(Interner) {
897896
&TyKind::Adt(AdtId(id), _) => is_local(id.module(db.upcast()).krate()),
898897
TyKind::Error => true,
899-
TyKind::Dyn(it) => it.principal().map_or(false, |trait_ref| {
900-
is_local(from_chalk_trait_id(trait_ref.trait_id).module(db.upcast()).krate())
898+
TyKind::Dyn(it) => it.principal_id().map_or(false, |trait_id| {
899+
is_local(from_chalk_trait_id(trait_id).module(db.upcast()).krate())
901900
}),
902901
_ => false,
903902
}

src/tools/rust-analyzer/crates/hir-ty/src/tests.rs

Lines changed: 50 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,15 @@ fn check_impl(ra_fixture: &str, allow_none: bool, only_types: bool, display_sour
127127
None => continue,
128128
};
129129
let def_map = module.def_map(&db);
130-
visit_module(&db, &def_map, module.local_id, &mut |it| defs.push(it));
130+
visit_module(&db, &def_map, module.local_id, &mut |it| {
131+
defs.push(match it {
132+
ModuleDefId::FunctionId(it) => it.into(),
133+
ModuleDefId::EnumVariantId(it) => it.into(),
134+
ModuleDefId::ConstId(it) => it.into(),
135+
ModuleDefId::StaticId(it) => it.into(),
136+
_ => return,
137+
})
138+
});
131139
}
132140
defs.sort_by_key(|def| match def {
133141
DefWithBodyId::FunctionId(it) => {
@@ -375,7 +383,15 @@ fn infer_with_mismatches(content: &str, include_mismatches: bool) -> String {
375383
let def_map = module.def_map(&db);
376384

377385
let mut defs: Vec<DefWithBodyId> = Vec::new();
378-
visit_module(&db, &def_map, module.local_id, &mut |it| defs.push(it));
386+
visit_module(&db, &def_map, module.local_id, &mut |it| {
387+
defs.push(match it {
388+
ModuleDefId::FunctionId(it) => it.into(),
389+
ModuleDefId::EnumVariantId(it) => it.into(),
390+
ModuleDefId::ConstId(it) => it.into(),
391+
ModuleDefId::StaticId(it) => it.into(),
392+
_ => return,
393+
})
394+
});
379395
defs.sort_by_key(|def| match def {
380396
DefWithBodyId::FunctionId(it) => {
381397
let loc = it.lookup(&db);
@@ -405,30 +421,30 @@ fn infer_with_mismatches(content: &str, include_mismatches: bool) -> String {
405421
buf
406422
}
407423

408-
fn visit_module(
424+
pub(crate) fn visit_module(
409425
db: &TestDB,
410426
crate_def_map: &DefMap,
411427
module_id: LocalModuleId,
412-
cb: &mut dyn FnMut(DefWithBodyId),
428+
cb: &mut dyn FnMut(ModuleDefId),
413429
) {
414430
visit_scope(db, crate_def_map, &crate_def_map[module_id].scope, cb);
415431
for impl_id in crate_def_map[module_id].scope.impls() {
416432
let impl_data = db.impl_data(impl_id);
417433
for &item in impl_data.items.iter() {
418434
match item {
419435
AssocItemId::FunctionId(it) => {
420-
let def = it.into();
421-
cb(def);
422-
let body = db.body(def);
436+
let body = db.body(it.into());
437+
cb(it.into());
423438
visit_body(db, &body, cb);
424439
}
425440
AssocItemId::ConstId(it) => {
426-
let def = it.into();
427-
cb(def);
428-
let body = db.body(def);
441+
let body = db.body(it.into());
442+
cb(it.into());
429443
visit_body(db, &body, cb);
430444
}
431-
AssocItemId::TypeAliasId(_) => (),
445+
AssocItemId::TypeAliasId(it) => {
446+
cb(it.into());
447+
}
432448
}
433449
}
434450
}
@@ -437,33 +453,27 @@ fn visit_module(
437453
db: &TestDB,
438454
crate_def_map: &DefMap,
439455
scope: &ItemScope,
440-
cb: &mut dyn FnMut(DefWithBodyId),
456+
cb: &mut dyn FnMut(ModuleDefId),
441457
) {
442458
for decl in scope.declarations() {
459+
cb(decl);
443460
match decl {
444461
ModuleDefId::FunctionId(it) => {
445-
let def = it.into();
446-
cb(def);
447-
let body = db.body(def);
462+
let body = db.body(it.into());
448463
visit_body(db, &body, cb);
449464
}
450465
ModuleDefId::ConstId(it) => {
451-
let def = it.into();
452-
cb(def);
453-
let body = db.body(def);
466+
let body = db.body(it.into());
454467
visit_body(db, &body, cb);
455468
}
456469
ModuleDefId::StaticId(it) => {
457-
let def = it.into();
458-
cb(def);
459-
let body = db.body(def);
470+
let body = db.body(it.into());
460471
visit_body(db, &body, cb);
461472
}
462473
ModuleDefId::AdtId(hir_def::AdtId::EnumId(it)) => {
463474
db.enum_data(it).variants.iter().for_each(|&(it, _)| {
464-
let def = it.into();
465-
cb(def);
466-
let body = db.body(def);
475+
let body = db.body(it.into());
476+
cb(it.into());
467477
visit_body(db, &body, cb);
468478
});
469479
}
@@ -473,7 +483,7 @@ fn visit_module(
473483
match item {
474484
AssocItemId::FunctionId(it) => cb(it.into()),
475485
AssocItemId::ConstId(it) => cb(it.into()),
476-
AssocItemId::TypeAliasId(_) => (),
486+
AssocItemId::TypeAliasId(it) => cb(it.into()),
477487
}
478488
}
479489
}
@@ -483,7 +493,7 @@ fn visit_module(
483493
}
484494
}
485495

486-
fn visit_body(db: &TestDB, body: &Body, cb: &mut dyn FnMut(DefWithBodyId)) {
496+
fn visit_body(db: &TestDB, body: &Body, cb: &mut dyn FnMut(ModuleDefId)) {
487497
for (_, def_map) in body.blocks(db) {
488498
for (mod_id, _) in def_map.modules() {
489499
visit_module(db, &def_map, mod_id, cb);
@@ -553,7 +563,13 @@ fn salsa_bug() {
553563
let module = db.module_for_file(pos.file_id);
554564
let crate_def_map = module.def_map(&db);
555565
visit_module(&db, &crate_def_map, module.local_id, &mut |def| {
556-
db.infer(def);
566+
db.infer(match def {
567+
ModuleDefId::FunctionId(it) => it.into(),
568+
ModuleDefId::EnumVariantId(it) => it.into(),
569+
ModuleDefId::ConstId(it) => it.into(),
570+
ModuleDefId::StaticId(it) => it.into(),
571+
_ => return,
572+
});
557573
});
558574

559575
let new_text = "
@@ -586,6 +602,12 @@ fn salsa_bug() {
586602
let module = db.module_for_file(pos.file_id);
587603
let crate_def_map = module.def_map(&db);
588604
visit_module(&db, &crate_def_map, module.local_id, &mut |def| {
589-
db.infer(def);
605+
db.infer(match def {
606+
ModuleDefId::FunctionId(it) => it.into(),
607+
ModuleDefId::EnumVariantId(it) => it.into(),
608+
ModuleDefId::ConstId(it) => it.into(),
609+
ModuleDefId::StaticId(it) => it.into(),
610+
_ => return,
611+
});
590612
});
591613
}

src/tools/rust-analyzer/crates/hir-ty/src/tests/closure_captures.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,13 @@ fn check_closure_captures(ra_fixture: &str, expect: Expect) {
2424

2525
let mut captures_info = Vec::new();
2626
for def in defs {
27+
let def = match def {
28+
hir_def::ModuleDefId::FunctionId(it) => it.into(),
29+
hir_def::ModuleDefId::EnumVariantId(it) => it.into(),
30+
hir_def::ModuleDefId::ConstId(it) => it.into(),
31+
hir_def::ModuleDefId::StaticId(it) => it.into(),
32+
_ => continue,
33+
};
2734
let infer = db.infer(def);
2835
let db = &db;
2936
captures_info.extend(infer.closure_info.iter().flat_map(|(closure_id, (captures, _))| {

0 commit comments

Comments
 (0)