Skip to content

Commit feefbe7

Browse files
committed
Auto merge of #13475 - lowr:fix/lookup-impl-method-trait-ref, r=flodiebold
fix: Test all generic args for trait when finding matching impl Addresses rust-lang/rust-analyzer#13463 (comment) When finding matching impl for a trait method, we've been testing the unifiability of self type. However, there can be multiple impl of a trait for the same type with different generic arguments for the trait. This patch takes it into account and tests the unifiability of all type arguments for the trait (the first being the self type) thus enables rust-analyzer to find the correct impl even in such cases.
2 parents 53b6d69 + 67f1d8f commit feefbe7

File tree

4 files changed

+170
-73
lines changed

4 files changed

+170
-73
lines changed

crates/hir-ty/src/infer/unify.rs

+8-4
Original file line numberDiff line numberDiff line change
@@ -340,8 +340,8 @@ impl<'a> InferenceTable<'a> {
340340
self.resolve_with_fallback(t, &|_, _, d, _| d)
341341
}
342342

343-
/// Unify two types and register new trait goals that arise from that.
344-
pub(crate) fn unify(&mut self, ty1: &Ty, ty2: &Ty) -> bool {
343+
/// Unify two relatable values (e.g. `Ty`) and register new trait goals that arise from that.
344+
pub(crate) fn unify<T: ?Sized + Zip<Interner>>(&mut self, ty1: &T, ty2: &T) -> bool {
345345
let result = match self.try_unify(ty1, ty2) {
346346
Ok(r) => r,
347347
Err(_) => return false,
@@ -350,9 +350,13 @@ impl<'a> InferenceTable<'a> {
350350
true
351351
}
352352

353-
/// Unify two types and return new trait goals arising from it, so the
353+
/// Unify two relatable values (e.g. `Ty`) and return new trait goals arising from it, so the
354354
/// caller needs to deal with them.
355-
pub(crate) fn try_unify<T: Zip<Interner>>(&mut self, t1: &T, t2: &T) -> InferResult<()> {
355+
pub(crate) fn try_unify<T: ?Sized + Zip<Interner>>(
356+
&mut self,
357+
t1: &T,
358+
t2: &T,
359+
) -> InferResult<()> {
356360
match self.var_unification_table.relate(
357361
Interner,
358362
&self.db,

crates/hir-ty/src/method_resolution.rs

+54-30
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@ use crate::{
2222
from_foreign_def_id,
2323
infer::{unify::InferenceTable, Adjust, Adjustment, AutoBorrow, OverloadedDeref, PointerCast},
2424
primitive::{FloatTy, IntTy, UintTy},
25-
static_lifetime,
25+
static_lifetime, to_chalk_trait_id,
2626
utils::all_super_traits,
2727
AdtId, Canonical, CanonicalVarKinds, DebruijnIndex, ForeignDefId, InEnvironment, Interner,
28-
Scalar, TraitEnvironment, TraitRefExt, Ty, TyBuilder, TyExt, TyKind,
28+
Scalar, Substitution, TraitEnvironment, TraitRef, TraitRefExt, Ty, TyBuilder, TyExt, TyKind,
2929
};
3030

3131
/// This is used as a key for indexing impls.
@@ -624,52 +624,76 @@ pub(crate) fn iterate_method_candidates<T>(
624624
slot
625625
}
626626

627+
/// Looks up the impl method that actually runs for the trait method `func`.
628+
///
629+
/// Returns `func` if it's not a method defined in a trait or the lookup failed.
627630
pub fn lookup_impl_method(
628-
self_ty: &Ty,
629631
db: &dyn HirDatabase,
630632
env: Arc<TraitEnvironment>,
631-
trait_: TraitId,
633+
func: FunctionId,
634+
fn_subst: Substitution,
635+
) -> FunctionId {
636+
let trait_id = match func.lookup(db.upcast()).container {
637+
ItemContainerId::TraitId(id) => id,
638+
_ => return func,
639+
};
640+
let trait_params = db.generic_params(trait_id.into()).type_or_consts.len();
641+
let fn_params = fn_subst.len(Interner) - trait_params;
642+
let trait_ref = TraitRef {
643+
trait_id: to_chalk_trait_id(trait_id),
644+
substitution: Substitution::from_iter(Interner, fn_subst.iter(Interner).skip(fn_params)),
645+
};
646+
647+
let name = &db.function_data(func).name;
648+
lookup_impl_method_for_trait_ref(trait_ref, db, env, name).unwrap_or(func)
649+
}
650+
651+
fn lookup_impl_method_for_trait_ref(
652+
trait_ref: TraitRef,
653+
db: &dyn HirDatabase,
654+
env: Arc<TraitEnvironment>,
632655
name: &Name,
633656
) -> Option<FunctionId> {
634-
let self_ty_fp = TyFingerprint::for_trait_impl(self_ty)?;
635-
let trait_impls = db.trait_impls_in_deps(env.krate);
636-
let impls = trait_impls.for_trait_and_self_ty(trait_, self_ty_fp);
637-
let mut table = InferenceTable::new(db, env.clone());
638-
find_matching_impl(impls, &mut table, &self_ty).and_then(|data| {
639-
data.items.iter().find_map(|it| match it {
640-
AssocItemId::FunctionId(f) => (db.function_data(*f).name == *name).then(|| *f),
641-
_ => None,
642-
})
657+
let self_ty = trait_ref.self_type_parameter(Interner);
658+
let self_ty_fp = TyFingerprint::for_trait_impl(&self_ty)?;
659+
let impls = db.trait_impls_in_deps(env.krate);
660+
let impls = impls.for_trait_and_self_ty(trait_ref.hir_trait_id(), self_ty_fp);
661+
662+
let table = InferenceTable::new(db, env);
663+
664+
let impl_data = find_matching_impl(impls, table, trait_ref)?;
665+
impl_data.items.iter().find_map(|it| match it {
666+
AssocItemId::FunctionId(f) => (db.function_data(*f).name == *name).then(|| *f),
667+
_ => None,
643668
})
644669
}
645670

646671
fn find_matching_impl(
647672
mut impls: impl Iterator<Item = ImplId>,
648-
table: &mut InferenceTable<'_>,
649-
self_ty: &Ty,
673+
mut table: InferenceTable<'_>,
674+
actual_trait_ref: TraitRef,
650675
) -> Option<Arc<ImplData>> {
651676
let db = table.db;
652677
loop {
653678
let impl_ = impls.next()?;
654679
let r = table.run_in_snapshot(|table| {
655680
let impl_data = db.impl_data(impl_);
656-
let substs =
681+
let impl_substs =
657682
TyBuilder::subst_for_def(db, impl_, None).fill_with_inference_vars(table).build();
658-
let impl_ty = db.impl_self_ty(impl_).substitute(Interner, &substs);
659-
660-
table
661-
.unify(self_ty, &impl_ty)
662-
.then(|| {
663-
let wh_goals =
664-
crate::chalk_db::convert_where_clauses(db, impl_.into(), &substs)
665-
.into_iter()
666-
.map(|b| b.cast(Interner));
683+
let trait_ref = db
684+
.impl_trait(impl_)
685+
.expect("non-trait method in find_matching_impl")
686+
.substitute(Interner, &impl_substs);
667687

668-
let goal = crate::Goal::all(Interner, wh_goals);
688+
if !table.unify(&trait_ref, &actual_trait_ref) {
689+
return None;
690+
}
669691

670-
table.try_obligation(goal).map(|_| impl_data)
671-
})
672-
.flatten()
692+
let wcs = crate::chalk_db::convert_where_clauses(db, impl_.into(), &impl_substs)
693+
.into_iter()
694+
.map(|b| b.cast(Interner));
695+
let goal = crate::Goal::all(Interner, wcs);
696+
table.try_obligation(goal).map(|_| impl_data)
673697
});
674698
if r.is_some() {
675699
break r;
@@ -1214,7 +1238,7 @@ fn is_valid_fn_candidate(
12141238
let expected_receiver =
12151239
sig.map(|s| s.params()[0].clone()).substitute(Interner, &fn_subst);
12161240

1217-
check_that!(table.unify(&receiver_ty, &expected_receiver));
1241+
check_that!(table.unify(receiver_ty, &expected_receiver));
12181242
}
12191243

12201244
if let ItemContainerId::ImplId(impl_id) = container {

crates/hir/src/source_analyzer.rs

+26-39
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ impl SourceAnalyzer {
270270
let expr_id = self.expr_id(db, &call.clone().into())?;
271271
let (f_in_trait, substs) = self.infer.as_ref()?.method_resolution(expr_id)?;
272272

273-
Some(self.resolve_impl_method_or_trait_def(db, f_in_trait, &substs))
273+
Some(self.resolve_impl_method_or_trait_def(db, f_in_trait, substs))
274274
}
275275

276276
pub(crate) fn resolve_await_to_poll(
@@ -311,7 +311,7 @@ impl SourceAnalyzer {
311311
// HACK: subst for `poll()` coincides with that for `Future` because `poll()` itself
312312
// doesn't have any generic parameters, so we skip building another subst for `poll()`.
313313
let substs = hir_ty::TyBuilder::subst_for_def(db, future_trait, None).push(ty).build();
314-
Some(self.resolve_impl_method_or_trait_def(db, poll_fn, &substs))
314+
Some(self.resolve_impl_method_or_trait_def(db, poll_fn, substs))
315315
}
316316

317317
pub(crate) fn resolve_prefix_expr(
@@ -331,7 +331,7 @@ impl SourceAnalyzer {
331331
// don't have any generic parameters, so we skip building another subst for the methods.
332332
let substs = hir_ty::TyBuilder::subst_for_def(db, op_trait, None).push(ty.clone()).build();
333333

334-
Some(self.resolve_impl_method_or_trait_def(db, op_fn, &substs))
334+
Some(self.resolve_impl_method_or_trait_def(db, op_fn, substs))
335335
}
336336

337337
pub(crate) fn resolve_index_expr(
@@ -351,7 +351,7 @@ impl SourceAnalyzer {
351351
.push(base_ty.clone())
352352
.push(index_ty.clone())
353353
.build();
354-
Some(self.resolve_impl_method_or_trait_def(db, op_fn, &substs))
354+
Some(self.resolve_impl_method_or_trait_def(db, op_fn, substs))
355355
}
356356

357357
pub(crate) fn resolve_bin_expr(
@@ -372,7 +372,7 @@ impl SourceAnalyzer {
372372
.push(rhs.clone())
373373
.build();
374374

375-
Some(self.resolve_impl_method_or_trait_def(db, op_fn, &substs))
375+
Some(self.resolve_impl_method_or_trait_def(db, op_fn, substs))
376376
}
377377

378378
pub(crate) fn resolve_try_expr(
@@ -392,7 +392,7 @@ impl SourceAnalyzer {
392392
// doesn't have any generic parameters, so we skip building another subst for `branch()`.
393393
let substs = hir_ty::TyBuilder::subst_for_def(db, op_trait, None).push(ty.clone()).build();
394394

395-
Some(self.resolve_impl_method_or_trait_def(db, op_fn, &substs))
395+
Some(self.resolve_impl_method_or_trait_def(db, op_fn, substs))
396396
}
397397

398398
pub(crate) fn resolve_field(
@@ -487,19 +487,22 @@ impl SourceAnalyzer {
487487

488488
let mut prefer_value_ns = false;
489489
let resolved = (|| {
490+
let infer = self.infer.as_deref()?;
490491
if let Some(path_expr) = parent().and_then(ast::PathExpr::cast) {
491492
let expr_id = self.expr_id(db, &path_expr.into())?;
492-
let infer = self.infer.as_ref()?;
493493
if let Some(assoc) = infer.assoc_resolutions_for_expr(expr_id) {
494494
let assoc = match assoc {
495495
AssocItemId::FunctionId(f_in_trait) => {
496496
match infer.type_of_expr.get(expr_id) {
497497
None => assoc,
498498
Some(func_ty) => {
499499
if let TyKind::FnDef(_fn_def, subs) = func_ty.kind(Interner) {
500-
self.resolve_impl_method(db, f_in_trait, subs)
501-
.map(AssocItemId::FunctionId)
502-
.unwrap_or(assoc)
500+
self.resolve_impl_method_or_trait_def(
501+
db,
502+
f_in_trait,
503+
subs.clone(),
504+
)
505+
.into()
503506
} else {
504507
assoc
505508
}
@@ -520,18 +523,18 @@ impl SourceAnalyzer {
520523
prefer_value_ns = true;
521524
} else if let Some(path_pat) = parent().and_then(ast::PathPat::cast) {
522525
let pat_id = self.pat_id(&path_pat.into())?;
523-
if let Some(assoc) = self.infer.as_ref()?.assoc_resolutions_for_pat(pat_id) {
526+
if let Some(assoc) = infer.assoc_resolutions_for_pat(pat_id) {
524527
return Some(PathResolution::Def(AssocItem::from(assoc).into()));
525528
}
526529
if let Some(VariantId::EnumVariantId(variant)) =
527-
self.infer.as_ref()?.variant_resolution_for_pat(pat_id)
530+
infer.variant_resolution_for_pat(pat_id)
528531
{
529532
return Some(PathResolution::Def(ModuleDef::Variant(variant.into())));
530533
}
531534
} else if let Some(rec_lit) = parent().and_then(ast::RecordExpr::cast) {
532535
let expr_id = self.expr_id(db, &rec_lit.into())?;
533536
if let Some(VariantId::EnumVariantId(variant)) =
534-
self.infer.as_ref()?.variant_resolution_for_expr(expr_id)
537+
infer.variant_resolution_for_expr(expr_id)
535538
{
536539
return Some(PathResolution::Def(ModuleDef::Variant(variant.into())));
537540
}
@@ -541,8 +544,7 @@ impl SourceAnalyzer {
541544
|| parent().and_then(ast::TupleStructPat::cast).map(ast::Pat::from);
542545
if let Some(pat) = record_pat.or_else(tuple_struct_pat) {
543546
let pat_id = self.pat_id(&pat)?;
544-
let variant_res_for_pat =
545-
self.infer.as_ref()?.variant_resolution_for_pat(pat_id);
547+
let variant_res_for_pat = infer.variant_resolution_for_pat(pat_id);
546548
if let Some(VariantId::EnumVariantId(variant)) = variant_res_for_pat {
547549
return Some(PathResolution::Def(ModuleDef::Variant(variant.into())));
548550
}
@@ -780,37 +782,22 @@ impl SourceAnalyzer {
780782
false
781783
}
782784

783-
fn resolve_impl_method(
785+
fn resolve_impl_method_or_trait_def(
784786
&self,
785787
db: &dyn HirDatabase,
786788
func: FunctionId,
787-
substs: &Substitution,
788-
) -> Option<FunctionId> {
789-
let impled_trait = match func.lookup(db.upcast()).container {
790-
ItemContainerId::TraitId(trait_id) => trait_id,
791-
_ => return None,
792-
};
793-
if substs.is_empty(Interner) {
794-
return None;
795-
}
796-
let self_ty = substs.at(Interner, 0).ty(Interner)?;
789+
substs: Substitution,
790+
) -> FunctionId {
797791
let krate = self.resolver.krate();
798-
let trait_env = self.resolver.body_owner()?.as_generic_def_id().map_or_else(
792+
let owner = match self.resolver.body_owner() {
793+
Some(it) => it,
794+
None => return func,
795+
};
796+
let env = owner.as_generic_def_id().map_or_else(
799797
|| Arc::new(hir_ty::TraitEnvironment::empty(krate)),
800798
|d| db.trait_environment(d),
801799
);
802-
803-
let fun_data = db.function_data(func);
804-
method_resolution::lookup_impl_method(self_ty, db, trait_env, impled_trait, &fun_data.name)
805-
}
806-
807-
fn resolve_impl_method_or_trait_def(
808-
&self,
809-
db: &dyn HirDatabase,
810-
func: FunctionId,
811-
substs: &Substitution,
812-
) -> FunctionId {
813-
self.resolve_impl_method(db, func, substs).unwrap_or(func)
800+
method_resolution::lookup_impl_method(db, env, func, substs)
814801
}
815802

816803
fn lang_trait_fn(

crates/ide/src/goto_definition.rs

+82
Original file line numberDiff line numberDiff line change
@@ -1834,4 +1834,86 @@ fn f() {
18341834
"#,
18351835
);
18361836
}
1837+
1838+
#[test]
1839+
fn goto_bin_op_multiple_impl() {
1840+
check(
1841+
r#"
1842+
//- minicore: add
1843+
struct S;
1844+
impl core::ops::Add for S {
1845+
fn add(
1846+
//^^^
1847+
) {}
1848+
}
1849+
impl core::ops::Add<usize> for S {
1850+
fn add(
1851+
) {}
1852+
}
1853+
1854+
fn f() {
1855+
S +$0 S
1856+
}
1857+
"#,
1858+
);
1859+
1860+
check(
1861+
r#"
1862+
//- minicore: add
1863+
struct S;
1864+
impl core::ops::Add for S {
1865+
fn add(
1866+
) {}
1867+
}
1868+
impl core::ops::Add<usize> for S {
1869+
fn add(
1870+
//^^^
1871+
) {}
1872+
}
1873+
1874+
fn f() {
1875+
S +$0 0usize
1876+
}
1877+
"#,
1878+
);
1879+
}
1880+
1881+
#[test]
1882+
fn path_call_multiple_trait_impl() {
1883+
check(
1884+
r#"
1885+
trait Trait<T> {
1886+
fn f(_: T);
1887+
}
1888+
impl Trait<i32> for usize {
1889+
fn f(_: i32) {}
1890+
//^
1891+
}
1892+
impl Trait<i64> for usize {
1893+
fn f(_: i64) {}
1894+
}
1895+
fn main() {
1896+
usize::f$0(0i32);
1897+
}
1898+
"#,
1899+
);
1900+
1901+
check(
1902+
r#"
1903+
trait Trait<T> {
1904+
fn f(_: T);
1905+
}
1906+
impl Trait<i32> for usize {
1907+
fn f(_: i32) {}
1908+
}
1909+
impl Trait<i64> for usize {
1910+
fn f(_: i64) {}
1911+
//^
1912+
}
1913+
fn main() {
1914+
usize::f$0(0i64);
1915+
}
1916+
"#,
1917+
)
1918+
}
18371919
}

0 commit comments

Comments
 (0)