Skip to content

Commit cdf7807

Browse files
Deeply check that method signatures match, and allow for nested RPITITs
1 parent 1f03ede commit cdf7807

File tree

13 files changed

+231
-52
lines changed

13 files changed

+231
-52
lines changed

compiler/rustc_ast_lowering/src/lib.rs

+2-7
Original file line numberDiff line numberDiff line change
@@ -1358,10 +1358,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
13581358
}
13591359
ImplTraitContext::InTrait => {
13601360
self.lower_impl_trait_in_trait(span, def_node_id, |lctx| {
1361-
lctx.lower_param_bounds(
1362-
bounds,
1363-
ImplTraitContext::Disallowed(ImplTraitPosition::Trait),
1364-
)
1361+
lctx.lower_param_bounds(bounds, ImplTraitContext::InTrait)
13651362
})
13661363
}
13671364
ImplTraitContext::Universal => {
@@ -1559,8 +1556,6 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
15591556
) -> hir::TyKind<'hir> {
15601557
let opaque_ty_def_id = self.local_def_id(opaque_ty_node_id);
15611558
self.with_hir_id_owner(opaque_ty_node_id, |lctx| {
1562-
// FIXME(RPITIT): This should be a more descriptive ImplTraitPosition, i.e. nested RPITIT
1563-
// FIXME(RPITIT): We _also_ should support this eventually
15641559
let hir_bounds = lower_bounds(lctx);
15651560
let rpitit_placeholder = hir::ImplTraitPlaceholder { bounds: hir_bounds };
15661561
let rpitit_item = hir::Item {
@@ -2073,7 +2068,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
20732068
let bound = lctx.lower_async_fn_output_type_to_future_bound(
20742069
output,
20752070
output.span(),
2076-
ImplTraitContext::Disallowed(ImplTraitPosition::TraitReturn),
2071+
ImplTraitContext::InTrait,
20772072
);
20782073
arena_vec![lctx; bound]
20792074
});

compiler/rustc_middle/src/arena.rs

+2
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ macro_rules! arena_types {
101101
[decode] impl_source: rustc_middle::traits::ImplSource<'tcx, ()>,
102102

103103
[] dep_kind: rustc_middle::dep_graph::DepKindStruct<'tcx>,
104+
105+
[] trait_impl_trait_tys: rustc_data_structures::fx::FxHashMap<rustc_hir::def_id::DefId, rustc_middle::ty::Ty<'tcx>>,
104106
]);
105107
)
106108
}

compiler/rustc_middle/src/query/mod.rs

+7
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,13 @@ rustc_queries! {
161161
separate_provide_extern
162162
}
163163

164+
query compare_predicates_and_trait_impl_trait_tys(key: DefId)
165+
-> Result<&'tcx FxHashMap<DefId, Ty<'tcx>>, ErrorGuaranteed>
166+
{
167+
desc { "better description please" }
168+
separate_provide_extern
169+
}
170+
164171
query analysis(key: ()) -> Result<(), ErrorGuaranteed> {
165172
eval_always
166173
desc { "running analysis passes on this crate" }

compiler/rustc_middle/src/ty/mod.rs

+8
Original file line numberDiff line numberDiff line change
@@ -2484,6 +2484,14 @@ impl<'tcx> TyCtxt<'tcx> {
24842484
pub fn is_const_default_method(self, def_id: DefId) -> bool {
24852485
matches!(self.trait_of_item(def_id), Some(trait_id) if self.has_attr(trait_id, sym::const_trait))
24862486
}
2487+
2488+
pub fn impl_trait_in_trait_parent(self, mut def_id: DefId) -> DefId {
2489+
while let def_kind = self.def_kind(def_id) && def_kind != DefKind::AssocFn {
2490+
debug_assert_eq!(def_kind, DefKind::ImplTraitPlaceholder);
2491+
def_id = self.parent(def_id);
2492+
}
2493+
def_id
2494+
}
24872495
}
24882496

24892497
/// Yields the parent function's `LocalDefId` if `def_id` is an `impl Trait` definition.

compiler/rustc_middle/src/ty/util.rs

+7
Original file line numberDiff line numberDiff line change
@@ -651,6 +651,13 @@ impl<'tcx> TyCtxt<'tcx> {
651651
ty::EarlyBinder(self.type_of(def_id))
652652
}
653653

654+
pub fn bound_trait_impl_trait_tys(
655+
self,
656+
def_id: DefId,
657+
) -> ty::EarlyBinder<Result<&'tcx FxHashMap<DefId, Ty<'tcx>>, ErrorGuaranteed>> {
658+
ty::EarlyBinder(self.compare_predicates_and_trait_impl_trait_tys(def_id))
659+
}
660+
654661
pub fn bound_fn_sig(self, def_id: DefId) -> ty::EarlyBinder<ty::PolyFnSig<'tcx>> {
655662
ty::EarlyBinder(self.fn_sig(def_id))
656663
}

compiler/rustc_trait_selection/src/traits/project.rs

+6-7
Original file line numberDiff line numberDiff line change
@@ -1318,7 +1318,7 @@ fn assemble_candidate_for_impl_trait_in_trait<'cx, 'tcx>(
13181318
) {
13191319
let tcx = selcx.tcx();
13201320
if tcx.def_kind(obligation.predicate.item_def_id) == DefKind::ImplTraitPlaceholder {
1321-
let trait_fn_def_id = tcx.parent(obligation.predicate.item_def_id);
1321+
let trait_fn_def_id = tcx.impl_trait_in_trait_parent(obligation.predicate.item_def_id);
13221322
let trait_def_id = tcx.parent(trait_fn_def_id);
13231323
let trait_substs =
13241324
obligation.predicate.substs.truncate_to(tcx, tcx.generics_of(trait_def_id));
@@ -2176,11 +2176,6 @@ fn confirm_impl_trait_in_trait_candidate<'tcx>(
21762176
let impl_fn_def_id = leaf_def.item.def_id;
21772177
let impl_fn_substs = obligation.predicate.substs.rebase_onto(tcx, trait_fn_def_id, data.substs);
21782178

2179-
let sig = tcx
2180-
.bound_fn_sig(impl_fn_def_id)
2181-
.map_bound(|fn_sig| tcx.liberate_late_bound_regions(impl_fn_def_id, fn_sig))
2182-
.subst(tcx, impl_fn_substs);
2183-
21842179
let cause = ObligationCause::new(
21852180
obligation.cause.span,
21862181
obligation.cause.body_id,
@@ -2217,7 +2212,11 @@ fn confirm_impl_trait_in_trait_candidate<'tcx>(
22172212
selcx,
22182213
obligation.param_env,
22192214
cause.clone(),
2220-
sig.output(),
2215+
tcx.bound_trait_impl_trait_tys(impl_fn_def_id)
2216+
.map_bound(|tys| {
2217+
tys.map_or_else(|_| tcx.ty_error(), |tys| tys[&obligation.predicate.item_def_id])
2218+
})
2219+
.subst(tcx, impl_fn_substs),
22212220
&mut obligations,
22222221
);
22232222

compiler/rustc_typeck/src/check/compare_method.rs

+110-37
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,22 @@
11
use super::potentially_plural_count;
22
use crate::errors::LifetimesOrBoundsMismatchOnTrait;
3-
use rustc_data_structures::fx::FxHashSet;
3+
use hir::def_id::DefId;
4+
use rustc_data_structures::fx::{FxHashMap, FxHashSet};
45
use rustc_errors::{pluralize, struct_span_err, Applicability, DiagnosticId, ErrorGuaranteed};
56
use rustc_hir as hir;
67
use rustc_hir::def::{DefKind, Res};
78
use rustc_hir::intravisit;
89
use rustc_hir::{GenericParamKind, ImplItemKind, TraitItemKind};
910
use rustc_infer::infer::outlives::env::OutlivesEnvironment;
11+
use rustc_infer::infer::type_variable::{TypeVariableOrigin, TypeVariableOriginKind};
1012
use rustc_infer::infer::{self, TyCtxtInferExt};
1113
use rustc_infer::traits::util;
1214
use rustc_middle::ty::error::{ExpectedFound, TypeError};
1315
use rustc_middle::ty::subst::{InternalSubsts, Subst};
1416
use rustc_middle::ty::util::ExplicitSelf;
15-
use rustc_middle::ty::{self, DefIdTree};
17+
use rustc_middle::ty::{
18+
self, DefIdTree, Ty, TypeFoldable, TypeFolder, TypeSuperFoldable, TypeVisitable,
19+
};
1620
use rustc_middle::ty::{GenericParamDefKind, ToPredicate, TyCtxt};
1721
use rustc_span::Span;
1822
use rustc_trait_selection::traits::error_reporting::InferCtxtExt;
@@ -64,10 +68,7 @@ pub(crate) fn compare_impl_method<'tcx>(
6468
return;
6569
}
6670

67-
if let Err(_) = compare_predicate_entailment(tcx, impl_m, impl_m_span, trait_m, impl_trait_ref)
68-
{
69-
return;
70-
}
71+
tcx.ensure().compare_predicates_and_trait_impl_trait_tys(impl_m.def_id);
7172
}
7273

7374
/// This function is best explained by example. Consider a trait:
@@ -136,13 +137,15 @@ pub(crate) fn compare_impl_method<'tcx>(
136137
///
137138
/// Finally we register each of these predicates as an obligation and check that
138139
/// they hold.
139-
fn compare_predicate_entailment<'tcx>(
140+
pub(super) fn compare_predicates_and_trait_impl_trait_tys<'tcx>(
140141
tcx: TyCtxt<'tcx>,
141-
impl_m: &ty::AssocItem,
142-
impl_m_span: Span,
143-
trait_m: &ty::AssocItem,
144-
impl_trait_ref: ty::TraitRef<'tcx>,
145-
) -> Result<(), ErrorGuaranteed> {
142+
def_id: DefId,
143+
) -> Result<&'tcx FxHashMap<DefId, Ty<'tcx>>, ErrorGuaranteed> {
144+
let impl_m = tcx.opt_associated_item(def_id).unwrap();
145+
let impl_m_span = tcx.def_span(def_id);
146+
let trait_m = tcx.opt_associated_item(impl_m.trait_item_def_id.unwrap()).unwrap();
147+
let impl_trait_ref = tcx.impl_trait_ref(impl_m.impl_container(tcx).unwrap()).unwrap();
148+
146149
let trait_to_impl_substs = impl_trait_ref.substs;
147150

148151
// This node-id should be used for the `body_id` field on each
@@ -161,6 +164,7 @@ fn compare_predicate_entailment<'tcx>(
161164
kind: impl_m.kind,
162165
},
163166
);
167+
let return_span = tcx.hir().fn_decl_by_hir_id(impl_m_hir_id).unwrap().output.span();
164168

165169
// Create mapping from impl to placeholder.
166170
let impl_to_placeholder_substs = InternalSubsts::identity_for_item(tcx, impl_m.def_id);
@@ -266,6 +270,13 @@ fn compare_predicate_entailment<'tcx>(
266270

267271
let trait_sig = tcx.bound_fn_sig(trait_m.def_id).subst(tcx, trait_to_placeholder_substs);
268272
let trait_sig = tcx.liberate_late_bound_regions(impl_m.def_id, trait_sig);
273+
let mut collector =
274+
ImplTraitInTraitCollector::new(&ocx, return_span, param_env, impl_m_hir_id);
275+
// FIXME(RPITIT): This should only be needed on the output type, but
276+
// RPITIT placeholders shouldn't show up anywhere except for there,
277+
// so I think this is fine.
278+
let trait_sig = trait_sig.fold_with(&mut collector);
279+
269280
// Next, add all inputs and output as well-formed tys. Importantly,
270281
// we have to do this before normalization, since the normalized ty may
271282
// not contain the input parameters. See issue #87748.
@@ -391,30 +402,6 @@ fn compare_predicate_entailment<'tcx>(
391402
return Err(diag.emit());
392403
}
393404

394-
// Check that an impl's fn return satisfies the bounds of the
395-
// FIXME(RPITIT): Generalize this to nested impl traits
396-
if let ty::Projection(proj) = tcx.fn_sig(trait_m.def_id).skip_binder().output().kind()
397-
&& tcx.def_kind(proj.item_def_id) == DefKind::ImplTraitPlaceholder
398-
{
399-
let return_span = tcx.hir().fn_decl_by_hir_id(impl_m_hir_id).unwrap().output.span();
400-
401-
for (predicate, span) in tcx
402-
.bound_explicit_item_bounds(proj.item_def_id)
403-
.transpose_iter()
404-
.map(|pred| pred.map_bound(|pred| *pred).subst(tcx, trait_to_placeholder_substs))
405-
{
406-
ocx.register_obligation(traits::Obligation::new(
407-
traits::ObligationCause::new(
408-
return_span,
409-
impl_m_hir_id,
410-
ObligationCauseCode::BindingObligation(proj.item_def_id, span),
411-
),
412-
param_env,
413-
predicate,
414-
));
415-
}
416-
}
417-
418405
// Check that all obligations are satisfied by the implementation's
419406
// version.
420407
let errors = ocx.select_all_or_error();
@@ -435,10 +422,96 @@ fn compare_predicate_entailment<'tcx>(
435422
&outlives_environment,
436423
);
437424

438-
Ok(())
425+
let mut collected_tys = FxHashMap::default();
426+
for (def_id, ty) in collector.types {
427+
match infcx.fully_resolve(ty) {
428+
Ok(ty) => {
429+
collected_tys.insert(def_id, ty);
430+
}
431+
Err(err) => {
432+
tcx.sess.delay_span_bug(
433+
return_span,
434+
format!("could not fully resolve: {ty} => {err:?}"),
435+
);
436+
collected_tys.insert(def_id, tcx.ty_error());
437+
}
438+
}
439+
}
440+
441+
Ok(&*tcx.arena.alloc(collected_tys))
439442
})
440443
}
441444

445+
struct ImplTraitInTraitCollector<'a, 'tcx> {
446+
ocx: &'a ObligationCtxt<'a, 'tcx>,
447+
types: FxHashMap<DefId, Ty<'tcx>>,
448+
span: Span,
449+
param_env: ty::ParamEnv<'tcx>,
450+
body_id: hir::HirId,
451+
}
452+
453+
impl<'a, 'tcx> ImplTraitInTraitCollector<'a, 'tcx> {
454+
fn new(
455+
ocx: &'a ObligationCtxt<'a, 'tcx>,
456+
span: Span,
457+
param_env: ty::ParamEnv<'tcx>,
458+
body_id: hir::HirId,
459+
) -> Self {
460+
ImplTraitInTraitCollector { ocx, types: FxHashMap::default(), span, param_env, body_id }
461+
}
462+
}
463+
464+
impl<'tcx> TypeFolder<'tcx> for ImplTraitInTraitCollector<'_, 'tcx> {
465+
fn tcx<'a>(&'a self) -> TyCtxt<'tcx> {
466+
self.ocx.infcx.tcx
467+
}
468+
469+
fn fold_ty(&mut self, ty: Ty<'tcx>) -> Ty<'tcx> {
470+
if let ty::Projection(proj) = ty.kind()
471+
&& self.tcx().def_kind(proj.item_def_id) == DefKind::ImplTraitPlaceholder
472+
{
473+
if let Some(ty) = self.types.get(&proj.item_def_id) {
474+
return *ty;
475+
}
476+
//FIXME(RPITIT): Deny nested RPITIT in substs too
477+
if proj.substs.has_escaping_bound_vars() {
478+
bug!("FIXME(RPITIT): error here");
479+
}
480+
// Replace with infer var
481+
let infer_ty = self.ocx.infcx.next_ty_var(TypeVariableOrigin {
482+
span: self.span,
483+
kind: TypeVariableOriginKind::MiscVariable,
484+
});
485+
self.types.insert(proj.item_def_id, infer_ty);
486+
// Recurse into bounds
487+
for pred in self.tcx().bound_explicit_item_bounds(proj.item_def_id).transpose_iter() {
488+
let pred_span = pred.0.1;
489+
490+
let pred = pred.map_bound(|(pred, _)| *pred).subst(self.tcx(), proj.substs);
491+
let pred = pred.fold_with(self);
492+
let pred = self.ocx.normalize(
493+
ObligationCause::misc(self.span, self.body_id),
494+
self.param_env,
495+
pred,
496+
);
497+
498+
self.ocx.register_obligation(traits::Obligation::new(
499+
ObligationCause::new(
500+
self.span,
501+
self.body_id,
502+
ObligationCauseCode::BindingObligation(proj.item_def_id, pred_span),
503+
),
504+
self.param_env,
505+
pred,
506+
));
507+
}
508+
infer_ty
509+
} else {
510+
ty.super_fold_with(self)
511+
}
512+
}
513+
}
514+
442515
fn check_region_bounds_on_impl_item<'tcx>(
443516
tcx: TyCtxt<'tcx>,
444517
impl_m: &ty::AssocItem,

compiler/rustc_typeck/src/check/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ use crate::require_c_abi_if_c_variadic;
132132
use crate::util::common::indenter;
133133

134134
use self::coercion::DynamicCoerceMany;
135+
use self::compare_method::compare_predicates_and_trait_impl_trait_tys;
135136
use self::region::region_scope_tree;
136137
pub use self::Expectation::*;
137138

@@ -249,6 +250,7 @@ pub fn provide(providers: &mut Providers) {
249250
used_trait_imports,
250251
check_mod_item_types,
251252
region_scope_tree,
253+
compare_predicates_and_trait_impl_trait_tys,
252254
..*providers
253255
};
254256
}

compiler/rustc_typeck/src/collect.rs

+4-1
Original file line numberDiff line numberDiff line change
@@ -1602,7 +1602,10 @@ fn generics_of(tcx: TyCtxt<'_>, def_id: DefId) -> ty::Generics {
16021602
}
16031603
ItemKind::ImplTraitPlaceholder(_) => {
16041604
let parent_id = tcx.hir().get_parent_item(hir_id).to_def_id();
1605-
assert_eq!(tcx.def_kind(parent_id), DefKind::AssocFn);
1605+
assert!(matches!(
1606+
tcx.def_kind(parent_id),
1607+
DefKind::AssocFn | DefKind::ImplTraitPlaceholder
1608+
));
16061609
Some(parent_id)
16071610
}
16081611
_ => None,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
// check-pass
2+
3+
#![feature(return_position_impl_trait_in_trait)]
4+
#![allow(incomplete_features)]
5+
6+
struct Wrapper<T>(T);
7+
8+
trait Foo {
9+
fn bar() -> Wrapper<impl Sized>;
10+
}
11+
12+
impl Foo for () {
13+
fn bar() -> Wrapper<i32> { Wrapper(0) }
14+
}
15+
16+
fn main() {}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#![feature(return_position_impl_trait_in_trait)]
2+
#![allow(incomplete_features)]
3+
4+
struct Wrapper<T>(T);
5+
6+
trait Foo {
7+
fn bar() -> Wrapper<impl Sized>;
8+
}
9+
10+
impl Foo for () {
11+
fn bar() -> i32 { 0 }
12+
//~^ ERROR method `bar` has an incompatible type for trait
13+
}
14+
15+
fn main() {}

0 commit comments

Comments
 (0)