Skip to content

Commit 8be3ce9

Browse files
committed
Auto merge of rust-lang#102334 - compiler-errors:rpitit-substs-issue, r=cjgillot
Fix subst issues with return-position `impl Trait` in trait 1. Fix an issue where we were rebase impl substs onto trait method substs, instead of trait substs 2. Fix an issue where early-bound regions aren't being mapped correctly for RPITIT hidden types Fixes rust-lang#102301 Fixes rust-lang#102310 Fixes rust-lang#102334 Fixes rust-lang#102918
2 parents 11432fe + 4259f33 commit 8be3ce9

File tree

5 files changed

+117
-23
lines changed

5 files changed

+117
-23
lines changed

Diff for: compiler/rustc_hir_analysis/src/check/compare_method.rs

+64-19
Original file line numberDiff line numberDiff line change
@@ -465,30 +465,30 @@ pub fn collect_trait_impl_trait_tys<'tcx>(
465465
let ocx = ObligationCtxt::new(infcx);
466466

467467
let norm_cause = ObligationCause::misc(return_span, impl_m_hir_id);
468-
let impl_return_ty = ocx.normalize(
468+
let impl_sig = ocx.normalize(
469469
norm_cause.clone(),
470470
param_env,
471-
infcx
472-
.replace_bound_vars_with_fresh_vars(
473-
return_span,
474-
infer::HigherRankedType,
475-
tcx.fn_sig(impl_m.def_id),
476-
)
477-
.output(),
471+
infcx.replace_bound_vars_with_fresh_vars(
472+
return_span,
473+
infer::HigherRankedType,
474+
tcx.fn_sig(impl_m.def_id),
475+
),
478476
);
477+
let impl_return_ty = impl_sig.output();
479478

480479
let mut collector = ImplTraitInTraitCollector::new(&ocx, return_span, param_env, impl_m_hir_id);
481-
let unnormalized_trait_return_ty = tcx
480+
let unnormalized_trait_sig = tcx
482481
.liberate_late_bound_regions(
483482
impl_m.def_id,
484483
tcx.bound_fn_sig(trait_m.def_id).subst(tcx, trait_to_placeholder_substs),
485484
)
486-
.output()
487485
.fold_with(&mut collector);
488-
let trait_return_ty =
489-
ocx.normalize(norm_cause.clone(), param_env, unnormalized_trait_return_ty);
486+
let trait_sig = ocx.normalize(norm_cause.clone(), param_env, unnormalized_trait_sig);
487+
let trait_return_ty = trait_sig.output();
490488

491-
let wf_tys = FxHashSet::from_iter([unnormalized_trait_return_ty, trait_return_ty]);
489+
let wf_tys = FxHashSet::from_iter(
490+
unnormalized_trait_sig.inputs_and_output.iter().chain(trait_sig.inputs_and_output.iter()),
491+
);
492492

493493
match infcx.at(&cause, param_env).eq(trait_return_ty, impl_return_ty) {
494494
Ok(infer::InferOk { value: (), obligations }) => {
@@ -521,6 +521,26 @@ pub fn collect_trait_impl_trait_tys<'tcx>(
521521
}
522522
}
523523

524+
// Unify the whole function signature. We need to do this to fully infer
525+
// the lifetimes of the return type, but do this after unifying just the
526+
// return types, since we want to avoid duplicating errors from
527+
// `compare_predicate_entailment`.
528+
match infcx
529+
.at(&cause, param_env)
530+
.eq(tcx.mk_fn_ptr(ty::Binder::dummy(trait_sig)), tcx.mk_fn_ptr(ty::Binder::dummy(impl_sig)))
531+
{
532+
Ok(infer::InferOk { value: (), obligations }) => {
533+
ocx.register_obligations(obligations);
534+
}
535+
Err(terr) => {
536+
let guar = tcx.sess.delay_span_bug(
537+
return_span,
538+
format!("could not unify `{trait_sig}` and `{impl_sig}`: {terr:?}"),
539+
);
540+
return Err(guar);
541+
}
542+
}
543+
524544
// Check that all obligations are satisfied by the implementation's
525545
// RPITs.
526546
let errors = ocx.select_all_or_error();
@@ -551,15 +571,40 @@ pub fn collect_trait_impl_trait_tys<'tcx>(
551571
let id_substs = InternalSubsts::identity_for_item(tcx, def_id);
552572
debug!(?id_substs, ?substs);
553573
let map: FxHashMap<ty::GenericArg<'tcx>, ty::GenericArg<'tcx>> =
554-
substs.iter().enumerate().map(|(index, arg)| (arg, id_substs[index])).collect();
574+
std::iter::zip(substs, id_substs).collect();
555575
debug!(?map);
556576

577+
// NOTE(compiler-errors): RPITITs, like all other RPITs, have early-bound
578+
// region substs that are synthesized during AST lowering. These are substs
579+
// that are appended to the parent substs (trait and trait method). However,
580+
// we're trying to infer the unsubstituted type value of the RPITIT inside
581+
// the *impl*, so we can later use the impl's method substs to normalize
582+
// an RPITIT to a concrete type (`confirm_impl_trait_in_trait_candidate`).
583+
//
584+
// Due to the design of RPITITs, during AST lowering, we have no idea that
585+
// an impl method corresponds to a trait method with RPITITs in it. Therefore,
586+
// we don't have a list of early-bound region substs for the RPITIT in the impl.
587+
// Since early region parameters are index-based, we can't just rebase these
588+
// (trait method) early-bound region substs onto the impl, and there's no
589+
// guarantee that the indices from the trait substs and impl substs line up.
590+
// So to fix this, we subtract the number of trait substs and add the number of
591+
// impl substs to *renumber* these early-bound regions to their corresponding
592+
// indices in the impl's substitutions list.
593+
//
594+
// Also, we only need to account for a difference in trait and impl substs,
595+
// since we previously enforce that the trait method and impl method have the
596+
// same generics.
597+
let num_trait_substs = trait_to_impl_substs.len();
598+
let num_impl_substs = tcx.generics_of(impl_m.container_id(tcx)).params.len();
557599
let ty = tcx.fold_regions(ty, |region, _| {
558-
if let ty::ReFree(_) = region.kind() {
559-
map[&region.into()].expect_region()
560-
} else {
561-
region
562-
}
600+
let ty::ReFree(_) = region.kind() else { return region; };
601+
let ty::ReEarlyBound(e) = map[&region.into()].expect_region().kind()
602+
else { bug!("expected ReFree to map to ReEarlyBound"); };
603+
tcx.mk_region(ty::ReEarlyBound(ty::EarlyBoundRegion {
604+
def_id: e.def_id,
605+
name: e.name,
606+
index: (e.index as usize - num_trait_substs + num_impl_substs) as u32,
607+
}))
563608
});
564609
debug!(%ty);
565610
collected_tys.insert(def_id, ty);

Diff for: compiler/rustc_middle/src/ty/subst.rs

+16-3
Original file line numberDiff line numberDiff line change
@@ -606,9 +606,21 @@ impl<'a, 'tcx> TypeFolder<'tcx> for SubstFolder<'a, 'tcx> {
606606
fn fold_region(&mut self, r: ty::Region<'tcx>) -> ty::Region<'tcx> {
607607
#[cold]
608608
#[inline(never)]
609-
fn region_param_out_of_range(data: ty::EarlyBoundRegion) -> ! {
609+
fn region_param_out_of_range(data: ty::EarlyBoundRegion, substs: &[GenericArg<'_>]) -> ! {
610610
bug!(
611-
"Region parameter out of range when substituting in region {} (index={})",
611+
"Region parameter out of range when substituting in region {} (index={}, substs = {:?})",
612+
data.name,
613+
data.index,
614+
substs,
615+
)
616+
}
617+
618+
#[cold]
619+
#[inline(never)]
620+
fn region_param_invalid(data: ty::EarlyBoundRegion, other: GenericArgKind<'_>) -> ! {
621+
bug!(
622+
"Unexpected parameter {:?} when substituting in region {} (index={})",
623+
other,
612624
data.name,
613625
data.index
614626
)
@@ -624,7 +636,8 @@ impl<'a, 'tcx> TypeFolder<'tcx> for SubstFolder<'a, 'tcx> {
624636
let rk = self.substs.get(data.index as usize).map(|k| k.unpack());
625637
match rk {
626638
Some(GenericArgKind::Lifetime(lt)) => self.shift_region_through_binders(lt),
627-
_ => region_param_out_of_range(data),
639+
Some(other) => region_param_invalid(data, other),
640+
None => region_param_out_of_range(data, self.substs),
628641
}
629642
}
630643
_ => r,

Diff for: compiler/rustc_trait_selection/src/traits/project.rs

+4-1
Original file line numberDiff line numberDiff line change
@@ -2254,7 +2254,10 @@ fn confirm_impl_trait_in_trait_candidate<'tcx>(
22542254
}
22552255

22562256
let impl_fn_def_id = leaf_def.item.def_id;
2257-
let impl_fn_substs = obligation.predicate.substs.rebase_onto(tcx, trait_fn_def_id, data.substs);
2257+
// Rebase from {trait}::{fn}::{opaque} to {impl}::{fn}::{opaque},
2258+
// since `data.substs` are the impl substs.
2259+
let impl_fn_substs =
2260+
obligation.predicate.substs.rebase_onto(tcx, tcx.parent(trait_fn_def_id), data.substs);
22582261

22592262
let cause = ObligationCause::new(
22602263
obligation.cause.span,

Diff for: src/test/ui/async-await/in-trait/issue-102310.rs

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// check-pass
2+
// edition:2021
3+
4+
#![feature(async_fn_in_trait)]
5+
#![allow(incomplete_features)]
6+
7+
pub trait SpiDevice {
8+
async fn transaction<F, R>(&mut self);
9+
}
10+
11+
impl SpiDevice for () {
12+
async fn transaction<F, R>(&mut self) {}
13+
}
14+
15+
fn main() {}

Diff for: src/test/ui/impl-trait/in-trait/issue-102301.rs

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// check-pass
2+
3+
#![feature(return_position_impl_trait_in_trait)]
4+
#![allow(incomplete_features)]
5+
6+
trait Foo<T> {
7+
fn foo<F2: Foo<T>>(self) -> impl Foo<T>;
8+
}
9+
10+
struct Bar;
11+
12+
impl Foo<u8> for Bar {
13+
fn foo<F2: Foo<u8>>(self) -> impl Foo<u8> {
14+
self
15+
}
16+
}
17+
18+
fn main() {}

0 commit comments

Comments
 (0)