Skip to content

Commit 998a816

Browse files
committed
Make gen blocks implement the Iterator trait
1 parent 6214943 commit 998a816

File tree

17 files changed

+286
-7
lines changed

17 files changed

+286
-7
lines changed

compiler/rustc_hir_typeck/src/closure.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,9 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
652652
},
653653
)
654654
}
655+
Some(hir::CoroutineKind::Gen(hir::CoroutineSource::Fn)) => {
656+
todo!("gen closures do not exist yet")
657+
}
655658

656659
_ => astconv.ty_infer(None, decl.output.span()),
657660
},

compiler/rustc_middle/src/traits/select.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,10 @@ pub enum SelectionCandidate<'tcx> {
144144
/// generated for an async construct.
145145
FutureCandidate,
146146

147+
/// Implementation of an `Iterator` trait by one of the generator types
148+
/// generated for a gen construct.
149+
IteratorCandidate,
150+
147151
/// Implementation of a `Fn`-family trait by one of the anonymous
148152
/// types generated for a fn pointer type (e.g., `fn(int) -> int`)
149153
FnPointerCandidate {

compiler/rustc_middle/src/ty/context.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -782,6 +782,17 @@ impl<'tcx> TyCtxt<'tcx> {
782782
matches!(self.coroutine_kind(def_id), Some(hir::CoroutineKind::Async(_)))
783783
}
784784

785+
/// Returns `true` if the node pointed to by `def_id` is a general coroutine that implements `Coroutine`.
786+
/// This means it is neither an `async` or `gen` construct.
787+
pub fn is_general_coroutine(self, def_id: DefId) -> bool {
788+
matches!(self.coroutine_kind(def_id), Some(hir::CoroutineKind::Coroutine))
789+
}
790+
791+
/// Returns `true` if the node pointed to by `def_id` is a coroutine for a gen construct.
792+
pub fn coroutine_is_gen(self, def_id: DefId) -> bool {
793+
matches!(self.coroutine_kind(def_id), Some(hir::CoroutineKind::Gen(_)))
794+
}
795+
785796
pub fn stability(self) -> &'tcx stability::Index {
786797
self.stability_index(())
787798
}

compiler/rustc_span/src/symbol.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ symbols! {
226226
IpAddr,
227227
IrTyKind,
228228
Is,
229+
Item,
229230
ItemContext,
230231
IterEmpty,
231232
IterOnce,

compiler/rustc_trait_selection/src/solve/assembly/mod.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,15 @@ pub(super) trait GoalKind<'tcx>:
199199
goal: Goal<'tcx, Self>,
200200
) -> QueryResult<'tcx>;
201201

202-
/// A coroutine (that doesn't come from an `async` desugaring) is known to
202+
/// A coroutine (that comes from a `gen` desugaring) is known to implement
203+
/// `Iterator<Item = O>`, where `O` is given by the generator's yield type
204+
/// that was computed during type-checking.
205+
fn consider_builtin_iterator_candidate(
206+
ecx: &mut EvalCtxt<'_, 'tcx>,
207+
goal: Goal<'tcx, Self>,
208+
) -> QueryResult<'tcx>;
209+
210+
/// A coroutine (that doesn't come from an `async` or `gen` desugaring) is known to
203211
/// implement `Coroutine<R, Yield = Y, Return = O>`, given the resume, yield,
204212
/// and return types of the coroutine computed during type-checking.
205213
fn consider_builtin_coroutine_candidate(
@@ -552,6 +560,8 @@ impl<'tcx> EvalCtxt<'_, 'tcx> {
552560
G::consider_builtin_pointee_candidate(self, goal)
553561
} else if lang_items.future_trait() == Some(trait_def_id) {
554562
G::consider_builtin_future_candidate(self, goal)
563+
} else if lang_items.iterator_trait() == Some(trait_def_id) {
564+
G::consider_builtin_iterator_candidate(self, goal)
555565
} else if lang_items.gen_trait() == Some(trait_def_id) {
556566
G::consider_builtin_coroutine_candidate(self, goal)
557567
} else if lang_items.discriminant_kind_trait() == Some(trait_def_id) {

compiler/rustc_trait_selection/src/solve/project_goals/mod.rs

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,37 @@ impl<'tcx> assembly::GoalKind<'tcx> for ProjectionPredicate<'tcx> {
485485
)
486486
}
487487

488+
fn consider_builtin_iterator_candidate(
489+
ecx: &mut EvalCtxt<'_, 'tcx>,
490+
goal: Goal<'tcx, Self>,
491+
) -> QueryResult<'tcx> {
492+
let self_ty = goal.predicate.self_ty();
493+
let ty::Coroutine(def_id, args, _) = *self_ty.kind() else {
494+
return Err(NoSolution);
495+
};
496+
497+
// Coroutines are not Iterators unless they come from `gen` desugaring
498+
let tcx = ecx.tcx();
499+
if !tcx.coroutine_is_gen(def_id) {
500+
return Err(NoSolution);
501+
}
502+
503+
let term = args.as_coroutine().yield_ty().into();
504+
505+
Self::consider_implied_clause(
506+
ecx,
507+
goal,
508+
ty::ProjectionPredicate {
509+
projection_ty: ty::AliasTy::new(ecx.tcx(), goal.predicate.def_id(), [self_ty]),
510+
term,
511+
}
512+
.to_predicate(tcx),
513+
// Technically, we need to check that the iterator type is Sized,
514+
// but that's already proven by the generator being WF.
515+
[],
516+
)
517+
}
518+
488519
fn consider_builtin_coroutine_candidate(
489520
ecx: &mut EvalCtxt<'_, 'tcx>,
490521
goal: Goal<'tcx, Self>,
@@ -496,7 +527,7 @@ impl<'tcx> assembly::GoalKind<'tcx> for ProjectionPredicate<'tcx> {
496527

497528
// `async`-desugared coroutines do not implement the coroutine trait
498529
let tcx = ecx.tcx();
499-
if tcx.coroutine_is_async(def_id) {
530+
if !tcx.is_general_coroutine(def_id) {
500531
return Err(NoSolution);
501532
}
502533

@@ -523,7 +554,7 @@ impl<'tcx> assembly::GoalKind<'tcx> for ProjectionPredicate<'tcx> {
523554
term,
524555
}
525556
.to_predicate(tcx),
526-
// Technically, we need to check that the future type is Sized,
557+
// Technically, we need to check that the coroutine type is Sized,
527558
// but that's already proven by the coroutine being WF.
528559
[],
529560
)

compiler/rustc_trait_selection/src/solve/trait_goals.rs

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,30 @@ impl<'tcx> assembly::GoalKind<'tcx> for TraitPredicate<'tcx> {
335335
ecx.evaluate_added_goals_and_make_canonical_response(Certainty::Yes)
336336
}
337337

338+
fn consider_builtin_iterator_candidate(
339+
ecx: &mut EvalCtxt<'_, 'tcx>,
340+
goal: Goal<'tcx, Self>,
341+
) -> QueryResult<'tcx> {
342+
if goal.predicate.polarity != ty::ImplPolarity::Positive {
343+
return Err(NoSolution);
344+
}
345+
346+
let ty::Coroutine(def_id, _, _) = *goal.predicate.self_ty().kind() else {
347+
return Err(NoSolution);
348+
};
349+
350+
// Coroutines are not iterators unless they come from `gen` desugaring
351+
let tcx = ecx.tcx();
352+
if !tcx.coroutine_is_gen(def_id) {
353+
return Err(NoSolution);
354+
}
355+
356+
// Gen coroutines unconditionally implement `Iterator`
357+
// Technically, we need to check that the iterator output type is Sized,
358+
// but that's already proven by the coroutines being WF.
359+
ecx.evaluate_added_goals_and_make_canonical_response(Certainty::Yes)
360+
}
361+
338362
fn consider_builtin_coroutine_candidate(
339363
ecx: &mut EvalCtxt<'_, 'tcx>,
340364
goal: Goal<'tcx, Self>,
@@ -350,7 +374,7 @@ impl<'tcx> assembly::GoalKind<'tcx> for TraitPredicate<'tcx> {
350374

351375
// `async`-desugared coroutines do not implement the coroutine trait
352376
let tcx = ecx.tcx();
353-
if tcx.coroutine_is_async(def_id) {
377+
if !tcx.is_general_coroutine(def_id) {
354378
return Err(NoSolution);
355379
}
356380

compiler/rustc_trait_selection/src/traits/project.rs

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1798,7 +1798,7 @@ fn assemble_candidates_from_impls<'cx, 'tcx>(
17981798
let self_ty = selcx.infcx.shallow_resolve(obligation.predicate.self_ty());
17991799

18001800
let lang_items = selcx.tcx().lang_items();
1801-
if [lang_items.gen_trait(), lang_items.future_trait()].contains(&Some(trait_ref.def_id))
1801+
if [lang_items.gen_trait(), lang_items.future_trait(), lang_items.iterator_trait()].contains(&Some(trait_ref.def_id))
18021802
|| selcx.tcx().fn_trait_kind_from_def_id(trait_ref.def_id).is_some()
18031803
{
18041804
true
@@ -2015,6 +2015,8 @@ fn confirm_select_candidate<'cx, 'tcx>(
20152015
confirm_coroutine_candidate(selcx, obligation, data)
20162016
} else if lang_items.future_trait() == Some(trait_def_id) {
20172017
confirm_future_candidate(selcx, obligation, data)
2018+
} else if lang_items.iterator_trait() == Some(trait_def_id) {
2019+
confirm_iterator_candidate(selcx, obligation, data)
20182020
} else if selcx.tcx().fn_trait_kind_from_def_id(trait_def_id).is_some() {
20192021
if obligation.predicate.self_ty().is_closure() {
20202022
confirm_closure_candidate(selcx, obligation, data)
@@ -2135,6 +2137,50 @@ fn confirm_future_candidate<'cx, 'tcx>(
21352137
.with_addl_obligations(obligations)
21362138
}
21372139

2140+
fn confirm_iterator_candidate<'cx, 'tcx>(
2141+
selcx: &mut SelectionContext<'cx, 'tcx>,
2142+
obligation: &ProjectionTyObligation<'tcx>,
2143+
nested: Vec<PredicateObligation<'tcx>>,
2144+
) -> Progress<'tcx> {
2145+
let ty::Coroutine(_, args, _) =
2146+
selcx.infcx.shallow_resolve(obligation.predicate.self_ty()).kind()
2147+
else {
2148+
unreachable!()
2149+
};
2150+
let gen_sig = args.as_coroutine().poly_sig();
2151+
let Normalized { value: gen_sig, obligations } = normalize_with_depth(
2152+
selcx,
2153+
obligation.param_env,
2154+
obligation.cause.clone(),
2155+
obligation.recursion_depth + 1,
2156+
gen_sig,
2157+
);
2158+
2159+
debug!(?obligation, ?gen_sig, ?obligations, "confirm_future_candidate");
2160+
2161+
let tcx = selcx.tcx();
2162+
let iter_def_id = tcx.require_lang_item(LangItem::Iterator, None);
2163+
2164+
let predicate = super::util::iterator_trait_ref_and_outputs(
2165+
tcx,
2166+
iter_def_id,
2167+
obligation.predicate.self_ty(),
2168+
gen_sig,
2169+
)
2170+
.map_bound(|(trait_ref, yield_ty)| {
2171+
debug_assert_eq!(tcx.associated_item(obligation.predicate.def_id).name, sym::Item);
2172+
2173+
ty::ProjectionPredicate {
2174+
projection_ty: ty::AliasTy::new(tcx, obligation.predicate.def_id, trait_ref.args),
2175+
term: yield_ty.into(),
2176+
}
2177+
});
2178+
2179+
confirm_param_env_candidate(selcx, obligation, predicate, false)
2180+
.with_addl_obligations(nested)
2181+
.with_addl_obligations(obligations)
2182+
}
2183+
21382184
fn confirm_builtin_candidate<'cx, 'tcx>(
21392185
selcx: &mut SelectionContext<'cx, 'tcx>,
21402186
obligation: &ProjectionTyObligation<'tcx>,

compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,8 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
114114
self.assemble_coroutine_candidates(obligation, &mut candidates);
115115
} else if lang_items.future_trait() == Some(def_id) {
116116
self.assemble_future_candidates(obligation, &mut candidates);
117+
} else if lang_items.iterator_trait() == Some(def_id) {
118+
self.assemble_iterator_candidates(obligation, &mut candidates);
117119
}
118120

119121
self.assemble_closure_candidates(obligation, &mut candidates);
@@ -211,9 +213,9 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
211213
// type/region parameters.
212214
let self_ty = obligation.self_ty().skip_binder();
213215
match self_ty.kind() {
214-
// async constructs get lowered to a special kind of coroutine that
216+
// `async`/`gen` constructs get lowered to a special kind of coroutine that
215217
// should *not* `impl Coroutine`.
216-
ty::Coroutine(did, ..) if !self.tcx().coroutine_is_async(*did) => {
218+
ty::Coroutine(did, ..) if self.tcx().is_general_coroutine(*did) => {
217219
debug!(?self_ty, ?obligation, "assemble_coroutine_candidates",);
218220

219221
candidates.vec.push(CoroutineCandidate);
@@ -243,6 +245,23 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
243245
}
244246
}
245247

248+
fn assemble_iterator_candidates(
249+
&mut self,
250+
obligation: &PolyTraitObligation<'tcx>,
251+
candidates: &mut SelectionCandidateSet<'tcx>,
252+
) {
253+
let self_ty = obligation.self_ty().skip_binder();
254+
if let ty::Coroutine(did, ..) = self_ty.kind() {
255+
// gen constructs get lowered to a special kind of coroutine that
256+
// should directly `impl Iterator`.
257+
if self.tcx().coroutine_is_gen(*did) {
258+
debug!(?self_ty, ?obligation, "assemble_iterator_candidates",);
259+
260+
candidates.vec.push(IteratorCandidate);
261+
}
262+
}
263+
}
264+
246265
/// Checks for the artificial impl that the compiler will create for an obligation like `X :
247266
/// FnMut<..>` where `X` is a closure type.
248267
///

compiler/rustc_trait_selection/src/traits/select/confirmation.rs

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,11 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
9393
ImplSource::Builtin(BuiltinImplSource::Misc, vtable_future)
9494
}
9595

96+
IteratorCandidate => {
97+
let vtable_iterator = self.confirm_iterator_candidate(obligation)?;
98+
ImplSource::Builtin(BuiltinImplSource::Misc, vtable_iterator)
99+
}
100+
96101
FnPointerCandidate { is_const } => {
97102
let data = self.confirm_fn_pointer_candidate(obligation, is_const)?;
98103
ImplSource::Builtin(BuiltinImplSource::Misc, data)
@@ -780,6 +785,36 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
780785
Ok(nested)
781786
}
782787

788+
fn confirm_iterator_candidate(
789+
&mut self,
790+
obligation: &PolyTraitObligation<'tcx>,
791+
) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
792+
// Okay to skip binder because the args on coroutine types never
793+
// touch bound regions, they just capture the in-scope
794+
// type/region parameters.
795+
let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder());
796+
let ty::Coroutine(coroutine_def_id, args, _) = *self_ty.kind() else {
797+
bug!("closure candidate for non-closure {:?}", obligation);
798+
};
799+
800+
debug!(?obligation, ?coroutine_def_id, ?args, "confirm_iterator_candidate");
801+
802+
let gen_sig = args.as_coroutine().poly_sig();
803+
804+
let trait_ref = super::util::iterator_trait_ref_and_outputs(
805+
self.tcx(),
806+
obligation.predicate.def_id(),
807+
obligation.predicate.no_bound_vars().expect("iterator has no bound vars").self_ty(),
808+
gen_sig,
809+
)
810+
.map_bound(|(trait_ref, ..)| trait_ref);
811+
812+
let nested = self.confirm_poly_trait_refs(obligation, trait_ref)?;
813+
debug!(?trait_ref, ?nested, "iterator candidate obligations");
814+
815+
Ok(nested)
816+
}
817+
783818
#[instrument(skip(self), level = "debug")]
784819
fn confirm_closure_candidate(
785820
&mut self,

compiler/rustc_trait_selection/src/traits/select/mod.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1888,6 +1888,7 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
18881888
| ClosureCandidate { .. }
18891889
| CoroutineCandidate
18901890
| FutureCandidate
1891+
| IteratorCandidate
18911892
| FnPointerCandidate { .. }
18921893
| BuiltinObjectCandidate
18931894
| BuiltinUnsizeCandidate
@@ -1916,6 +1917,7 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
19161917
| ClosureCandidate { .. }
19171918
| CoroutineCandidate
19181919
| FutureCandidate
1920+
| IteratorCandidate
19191921
| FnPointerCandidate { .. }
19201922
| BuiltinObjectCandidate
19211923
| BuiltinUnsizeCandidate
@@ -1950,6 +1952,7 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
19501952
| ClosureCandidate { .. }
19511953
| CoroutineCandidate
19521954
| FutureCandidate
1955+
| IteratorCandidate
19531956
| FnPointerCandidate { .. }
19541957
| BuiltinObjectCandidate
19551958
| BuiltinUnsizeCandidate
@@ -1964,6 +1967,7 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
19641967
| ClosureCandidate { .. }
19651968
| CoroutineCandidate
19661969
| FutureCandidate
1970+
| IteratorCandidate
19671971
| FnPointerCandidate { .. }
19681972
| BuiltinObjectCandidate
19691973
| BuiltinUnsizeCandidate
@@ -2070,6 +2074,7 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
20702074
| ClosureCandidate { .. }
20712075
| CoroutineCandidate
20722076
| FutureCandidate
2077+
| IteratorCandidate
20732078
| FnPointerCandidate { .. }
20742079
| BuiltinObjectCandidate
20752080
| BuiltinUnsizeCandidate
@@ -2080,6 +2085,7 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
20802085
| ClosureCandidate { .. }
20812086
| CoroutineCandidate
20822087
| FutureCandidate
2088+
| IteratorCandidate
20832089
| FnPointerCandidate { .. }
20842090
| BuiltinObjectCandidate
20852091
| BuiltinUnsizeCandidate

compiler/rustc_trait_selection/src/traits/util.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,17 @@ pub fn future_trait_ref_and_outputs<'tcx>(
297297
sig.map_bound(|sig| (trait_ref, sig.return_ty))
298298
}
299299

300+
pub fn iterator_trait_ref_and_outputs<'tcx>(
301+
tcx: TyCtxt<'tcx>,
302+
iterator_def_id: DefId,
303+
self_ty: Ty<'tcx>,
304+
sig: ty::PolyGenSig<'tcx>,
305+
) -> ty::Binder<'tcx, (ty::TraitRef<'tcx>, Ty<'tcx>)> {
306+
assert!(!self_ty.has_escaping_bound_vars());
307+
let trait_ref = ty::TraitRef::new(tcx, iterator_def_id, [self_ty]);
308+
sig.map_bound(|sig| (trait_ref, sig.yield_ty))
309+
}
310+
300311
pub fn impl_item_is_final(tcx: TyCtxt<'_>, assoc_item: &ty::AssocItem) -> bool {
301312
assoc_item.defaultness(tcx).is_final()
302313
&& tcx.defaultness(assoc_item.container_id(tcx)).is_final()

0 commit comments

Comments
 (0)