1
1
use super :: potentially_plural_count;
2
2
use crate :: errors:: LifetimesOrBoundsMismatchOnTrait ;
3
- use rustc_data_structures:: fx:: FxHashSet ;
3
+ use hir:: def_id:: DefId ;
4
+ use rustc_data_structures:: fx:: { FxHashMap , FxHashSet } ;
4
5
use rustc_errors:: { pluralize, struct_span_err, Applicability , DiagnosticId , ErrorGuaranteed } ;
5
6
use rustc_hir as hir;
6
7
use rustc_hir:: def:: { DefKind , Res } ;
7
8
use rustc_hir:: intravisit;
8
9
use rustc_hir:: { GenericParamKind , ImplItemKind , TraitItemKind } ;
9
10
use rustc_infer:: infer:: outlives:: env:: OutlivesEnvironment ;
11
+ use rustc_infer:: infer:: type_variable:: { TypeVariableOrigin , TypeVariableOriginKind } ;
10
12
use rustc_infer:: infer:: { self , TyCtxtInferExt } ;
11
13
use rustc_infer:: traits:: util;
12
14
use rustc_middle:: ty:: error:: { ExpectedFound , TypeError } ;
13
15
use rustc_middle:: ty:: subst:: { InternalSubsts , Subst } ;
14
16
use rustc_middle:: ty:: util:: ExplicitSelf ;
15
- use rustc_middle:: ty:: { self , DefIdTree } ;
17
+ use rustc_middle:: ty:: {
18
+ self , DefIdTree , Ty , TypeFoldable , TypeFolder , TypeSuperFoldable , TypeVisitable ,
19
+ } ;
16
20
use rustc_middle:: ty:: { GenericParamDefKind , ToPredicate , TyCtxt } ;
17
21
use rustc_span:: Span ;
18
22
use rustc_trait_selection:: traits:: error_reporting:: InferCtxtExt ;
@@ -64,10 +68,7 @@ pub(crate) fn compare_impl_method<'tcx>(
64
68
return ;
65
69
}
66
70
67
- if let Err ( _) = compare_predicate_entailment ( tcx, impl_m, impl_m_span, trait_m, impl_trait_ref)
68
- {
69
- return ;
70
- }
71
+ tcx. ensure ( ) . compare_predicates_and_trait_impl_trait_tys ( impl_m. def_id ) ;
71
72
}
72
73
73
74
/// This function is best explained by example. Consider a trait:
@@ -136,13 +137,15 @@ pub(crate) fn compare_impl_method<'tcx>(
136
137
///
137
138
/// Finally we register each of these predicates as an obligation and check that
138
139
/// they hold.
139
- fn compare_predicate_entailment < ' tcx > (
140
+ pub ( super ) fn compare_predicates_and_trait_impl_trait_tys < ' tcx > (
140
141
tcx : TyCtxt < ' tcx > ,
141
- impl_m : & ty:: AssocItem ,
142
- impl_m_span : Span ,
143
- trait_m : & ty:: AssocItem ,
144
- impl_trait_ref : ty:: TraitRef < ' tcx > ,
145
- ) -> Result < ( ) , ErrorGuaranteed > {
142
+ def_id : DefId ,
143
+ ) -> Result < & ' tcx FxHashMap < DefId , Ty < ' tcx > > , ErrorGuaranteed > {
144
+ let impl_m = tcx. opt_associated_item ( def_id) . unwrap ( ) ;
145
+ let impl_m_span = tcx. def_span ( def_id) ;
146
+ let trait_m = tcx. opt_associated_item ( impl_m. trait_item_def_id . unwrap ( ) ) . unwrap ( ) ;
147
+ let impl_trait_ref = tcx. impl_trait_ref ( impl_m. impl_container ( tcx) . unwrap ( ) ) . unwrap ( ) ;
148
+
146
149
let trait_to_impl_substs = impl_trait_ref. substs ;
147
150
148
151
// This node-id should be used for the `body_id` field on each
@@ -161,6 +164,7 @@ fn compare_predicate_entailment<'tcx>(
161
164
kind : impl_m. kind ,
162
165
} ,
163
166
) ;
167
+ let return_span = tcx. hir ( ) . fn_decl_by_hir_id ( impl_m_hir_id) . unwrap ( ) . output . span ( ) ;
164
168
165
169
// Create mapping from impl to placeholder.
166
170
let impl_to_placeholder_substs = InternalSubsts :: identity_for_item ( tcx, impl_m. def_id ) ;
@@ -266,6 +270,13 @@ fn compare_predicate_entailment<'tcx>(
266
270
267
271
let trait_sig = tcx. bound_fn_sig ( trait_m. def_id ) . subst ( tcx, trait_to_placeholder_substs) ;
268
272
let trait_sig = tcx. liberate_late_bound_regions ( impl_m. def_id , trait_sig) ;
273
+ let mut collector =
274
+ ImplTraitInTraitCollector :: new ( & ocx, return_span, param_env, impl_m_hir_id) ;
275
+ // FIXME(RPITIT): This should only be needed on the output type, but
276
+ // RPITIT placeholders shouldn't show up anywhere except for there,
277
+ // so I think this is fine.
278
+ let trait_sig = trait_sig. fold_with ( & mut collector) ;
279
+
269
280
// Next, add all inputs and output as well-formed tys. Importantly,
270
281
// we have to do this before normalization, since the normalized ty may
271
282
// not contain the input parameters. See issue #87748.
@@ -391,30 +402,6 @@ fn compare_predicate_entailment<'tcx>(
391
402
return Err ( diag. emit ( ) ) ;
392
403
}
393
404
394
- // Check that an impl's fn return satisfies the bounds of the
395
- // FIXME(RPITIT): Generalize this to nested impl traits
396
- if let ty:: Projection ( proj) = tcx. fn_sig ( trait_m. def_id ) . skip_binder ( ) . output ( ) . kind ( )
397
- && tcx. def_kind ( proj. item_def_id ) == DefKind :: ImplTraitPlaceholder
398
- {
399
- let return_span = tcx. hir ( ) . fn_decl_by_hir_id ( impl_m_hir_id) . unwrap ( ) . output . span ( ) ;
400
-
401
- for ( predicate, span) in tcx
402
- . bound_explicit_item_bounds ( proj. item_def_id )
403
- . transpose_iter ( )
404
- . map ( |pred| pred. map_bound ( |pred| * pred) . subst ( tcx, trait_to_placeholder_substs) )
405
- {
406
- ocx. register_obligation ( traits:: Obligation :: new (
407
- traits:: ObligationCause :: new (
408
- return_span,
409
- impl_m_hir_id,
410
- ObligationCauseCode :: BindingObligation ( proj. item_def_id , span) ,
411
- ) ,
412
- param_env,
413
- predicate,
414
- ) ) ;
415
- }
416
- }
417
-
418
405
// Check that all obligations are satisfied by the implementation's
419
406
// version.
420
407
let errors = ocx. select_all_or_error ( ) ;
@@ -435,10 +422,96 @@ fn compare_predicate_entailment<'tcx>(
435
422
& outlives_environment,
436
423
) ;
437
424
438
- Ok ( ( ) )
425
+ let mut collected_tys = FxHashMap :: default ( ) ;
426
+ for ( def_id, ty) in collector. types {
427
+ match infcx. fully_resolve ( ty) {
428
+ Ok ( ty) => {
429
+ collected_tys. insert ( def_id, ty) ;
430
+ }
431
+ Err ( err) => {
432
+ tcx. sess . delay_span_bug (
433
+ return_span,
434
+ format ! ( "could not fully resolve: {ty} => {err:?}" ) ,
435
+ ) ;
436
+ collected_tys. insert ( def_id, tcx. ty_error ( ) ) ;
437
+ }
438
+ }
439
+ }
440
+
441
+ Ok ( & * tcx. arena . alloc ( collected_tys) )
439
442
} )
440
443
}
441
444
445
+ struct ImplTraitInTraitCollector < ' a , ' tcx > {
446
+ ocx : & ' a ObligationCtxt < ' a , ' tcx > ,
447
+ types : FxHashMap < DefId , Ty < ' tcx > > ,
448
+ span : Span ,
449
+ param_env : ty:: ParamEnv < ' tcx > ,
450
+ body_id : hir:: HirId ,
451
+ }
452
+
453
+ impl < ' a , ' tcx > ImplTraitInTraitCollector < ' a , ' tcx > {
454
+ fn new (
455
+ ocx : & ' a ObligationCtxt < ' a , ' tcx > ,
456
+ span : Span ,
457
+ param_env : ty:: ParamEnv < ' tcx > ,
458
+ body_id : hir:: HirId ,
459
+ ) -> Self {
460
+ ImplTraitInTraitCollector { ocx, types : FxHashMap :: default ( ) , span, param_env, body_id }
461
+ }
462
+ }
463
+
464
+ impl < ' tcx > TypeFolder < ' tcx > for ImplTraitInTraitCollector < ' _ , ' tcx > {
465
+ fn tcx < ' a > ( & ' a self ) -> TyCtxt < ' tcx > {
466
+ self . ocx . infcx . tcx
467
+ }
468
+
469
+ fn fold_ty ( & mut self , ty : Ty < ' tcx > ) -> Ty < ' tcx > {
470
+ if let ty:: Projection ( proj) = ty. kind ( )
471
+ && self . tcx ( ) . def_kind ( proj. item_def_id ) == DefKind :: ImplTraitPlaceholder
472
+ {
473
+ if let Some ( ty) = self . types . get ( & proj. item_def_id ) {
474
+ return * ty;
475
+ }
476
+ //FIXME(RPITIT): Deny nested RPITIT in substs too
477
+ if proj. substs . has_escaping_bound_vars ( ) {
478
+ bug ! ( "FIXME(RPITIT): error here" ) ;
479
+ }
480
+ // Replace with infer var
481
+ let infer_ty = self . ocx . infcx . next_ty_var ( TypeVariableOrigin {
482
+ span : self . span ,
483
+ kind : TypeVariableOriginKind :: MiscVariable ,
484
+ } ) ;
485
+ self . types . insert ( proj. item_def_id , infer_ty) ;
486
+ // Recurse into bounds
487
+ for pred in self . tcx ( ) . bound_explicit_item_bounds ( proj. item_def_id ) . transpose_iter ( ) {
488
+ let pred_span = pred. 0 . 1 ;
489
+
490
+ let pred = pred. map_bound ( |( pred, _) | * pred) . subst ( self . tcx ( ) , proj. substs ) ;
491
+ let pred = pred. fold_with ( self ) ;
492
+ let pred = self . ocx . normalize (
493
+ ObligationCause :: misc ( self . span , self . body_id ) ,
494
+ self . param_env ,
495
+ pred,
496
+ ) ;
497
+
498
+ self . ocx . register_obligation ( traits:: Obligation :: new (
499
+ ObligationCause :: new (
500
+ self . span ,
501
+ self . body_id ,
502
+ ObligationCauseCode :: BindingObligation ( proj. item_def_id , pred_span) ,
503
+ ) ,
504
+ self . param_env ,
505
+ pred,
506
+ ) ) ;
507
+ }
508
+ infer_ty
509
+ } else {
510
+ ty. super_fold_with ( self )
511
+ }
512
+ }
513
+ }
514
+
442
515
fn check_region_bounds_on_impl_item < ' tcx > (
443
516
tcx : TyCtxt < ' tcx > ,
444
517
impl_m : & ty:: AssocItem ,
0 commit comments