Skip to content

Commit b24a957

Browse files
committed
Lower return types for gen fn to impl Iterator
1 parent 1f845ac commit b24a957

File tree

7 files changed

+167
-80
lines changed

7 files changed

+167
-80
lines changed

compiler/rustc_ast_lowering/src/item.rs

Lines changed: 100 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use crate::FnReturnTransformation;
2+
13
use super::errors::{InvalidAbi, InvalidAbiReason, InvalidAbiSuggestion, MisplacedRelaxTraitBound};
24
use super::ResolverAstLoweringExt;
35
use super::{AstOwner, ImplTraitContext, ImplTraitPosition};
@@ -276,13 +278,33 @@ impl<'hir> LoweringContext<'_, 'hir> {
276278
// only cares about the input argument patterns in the function
277279
// declaration (decl), not the return types.
278280
let asyncness = header.asyncness;
279-
let body_id =
280-
this.lower_maybe_async_body(span, hir_id, decl, asyncness, body.as_deref());
281+
let genness = header.genness;
282+
let body_id = this.lower_maybe_coroutine_body(
283+
span,
284+
hir_id,
285+
decl,
286+
asyncness,
287+
genness,
288+
body.as_deref(),
289+
);
281290

282291
let itctx = ImplTraitContext::Universal;
283292
let (generics, decl) =
284293
this.lower_generics(generics, header.constness, id, &itctx, |this| {
285-
let ret_id = asyncness.opt_return_id();
294+
let ret_id = asyncness
295+
.opt_return_id()
296+
.map(|(node_id, span)| {
297+
crate::FnReturnTransformation::Async(node_id, span)
298+
})
299+
.or_else(|| match genness {
300+
Gen::Yes { span, closure_id: _, return_impl_trait_id } => {
301+
Some(crate::FnReturnTransformation::Iterator(
302+
return_impl_trait_id,
303+
span,
304+
))
305+
}
306+
_ => None,
307+
});
286308
this.lower_fn_decl(decl, id, *fn_sig_span, FnDeclKind::Fn, ret_id)
287309
});
288310
let sig = hir::FnSig {
@@ -765,20 +787,31 @@ impl<'hir> LoweringContext<'_, 'hir> {
765787
sig,
766788
i.id,
767789
FnDeclKind::Trait,
768-
asyncness.opt_return_id(),
790+
asyncness
791+
.opt_return_id()
792+
.map(|(node_id, span)| crate::FnReturnTransformation::Async(node_id, span)),
769793
);
770794
(generics, hir::TraitItemKind::Fn(sig, hir::TraitFn::Required(names)), false)
771795
}
772796
AssocItemKind::Fn(box Fn { sig, generics, body: Some(body), .. }) => {
773797
let asyncness = sig.header.asyncness;
774-
let body_id =
775-
self.lower_maybe_async_body(i.span, hir_id, &sig.decl, asyncness, Some(body));
798+
let genness = sig.header.genness;
799+
let body_id = self.lower_maybe_coroutine_body(
800+
i.span,
801+
hir_id,
802+
&sig.decl,
803+
asyncness,
804+
genness,
805+
Some(body),
806+
);
776807
let (generics, sig) = self.lower_method_sig(
777808
generics,
778809
sig,
779810
i.id,
780811
FnDeclKind::Trait,
781-
asyncness.opt_return_id(),
812+
asyncness
813+
.opt_return_id()
814+
.map(|(node_id, span)| crate::FnReturnTransformation::Async(node_id, span)),
782815
);
783816
(generics, hir::TraitItemKind::Fn(sig, hir::TraitFn::Provided(body_id)), true)
784817
}
@@ -869,19 +902,23 @@ impl<'hir> LoweringContext<'_, 'hir> {
869902
AssocItemKind::Fn(box Fn { sig, generics, body, .. }) => {
870903
self.current_item = Some(i.span);
871904
let asyncness = sig.header.asyncness;
872-
let body_id = self.lower_maybe_async_body(
905+
let genness = sig.header.genness;
906+
let body_id = self.lower_maybe_coroutine_body(
873907
i.span,
874908
hir_id,
875909
&sig.decl,
876910
asyncness,
911+
genness,
877912
body.as_deref(),
878913
);
879914
let (generics, sig) = self.lower_method_sig(
880915
generics,
881916
sig,
882917
i.id,
883918
if self.is_in_trait_impl { FnDeclKind::Impl } else { FnDeclKind::Inherent },
884-
asyncness.opt_return_id(),
919+
asyncness
920+
.opt_return_id()
921+
.map(|(node_id, span)| crate::FnReturnTransformation::Async(node_id, span)),
885922
);
886923

887924
(generics, hir::ImplItemKind::Fn(sig, body_id))
@@ -1045,16 +1082,22 @@ impl<'hir> LoweringContext<'_, 'hir> {
10451082
})
10461083
}
10471084

1048-
fn lower_maybe_async_body(
1085+
/// Takes what may be the body of an `async fn` or a `gen fn` and wraps it in an `async {}` or
1086+
/// `gen {}` block as appropriate.
1087+
fn lower_maybe_coroutine_body(
10491088
&mut self,
10501089
span: Span,
10511090
fn_id: hir::HirId,
10521091
decl: &FnDecl,
10531092
asyncness: Async,
1093+
genness: Gen,
10541094
body: Option<&Block>,
10551095
) -> hir::BodyId {
1056-
let (closure_id, body) = match (asyncness, body) {
1057-
(Async::Yes { closure_id, .. }, Some(body)) => (closure_id, body),
1096+
let (closure_id, body) = match (asyncness, genness, body) {
1097+
// FIXME(eholk): do something reasonable for `async gen fn`. Probably that's an error
1098+
// for now since it's not supported.
1099+
(Async::Yes { closure_id, .. }, _, Some(body))
1100+
| (_, Gen::Yes { closure_id, .. }, Some(body)) => (closure_id, body),
10581101
_ => return self.lower_fn_body_block(span, decl, body),
10591102
};
10601103

@@ -1197,44 +1240,55 @@ impl<'hir> LoweringContext<'_, 'hir> {
11971240
parameters.push(new_parameter);
11981241
}
11991242

1200-
let async_expr = this.make_async_expr(
1201-
CaptureBy::Value { move_kw: rustc_span::DUMMY_SP },
1202-
closure_id,
1203-
None,
1204-
body.span,
1205-
hir::CoroutineSource::Fn,
1206-
|this| {
1207-
// Create a block from the user's function body:
1208-
let user_body = this.lower_block_expr(body);
1243+
let mkbody = |this: &mut LoweringContext<'_, 'hir>| {
1244+
// Create a block from the user's function body:
1245+
let user_body = this.lower_block_expr(body);
12091246

1210-
// Transform into `drop-temps { <user-body> }`, an expression:
1211-
let desugared_span =
1212-
this.mark_span_with_reason(DesugaringKind::Async, user_body.span, None);
1213-
let user_body =
1214-
this.expr_drop_temps(desugared_span, this.arena.alloc(user_body));
1247+
// Transform into `drop-temps { <user-body> }`, an expression:
1248+
let desugared_span =
1249+
this.mark_span_with_reason(DesugaringKind::Async, user_body.span, None);
1250+
let user_body = this.expr_drop_temps(desugared_span, this.arena.alloc(user_body));
12151251

1216-
// As noted above, create the final block like
1217-
//
1218-
// ```
1219-
// {
1220-
// let $param_pattern = $raw_param;
1221-
// ...
1222-
// drop-temps { <user-body> }
1223-
// }
1224-
// ```
1225-
let body = this.block_all(
1226-
desugared_span,
1227-
this.arena.alloc_from_iter(statements),
1228-
Some(user_body),
1229-
);
1252+
// As noted above, create the final block like
1253+
//
1254+
// ```
1255+
// {
1256+
// let $param_pattern = $raw_param;
1257+
// ...
1258+
// drop-temps { <user-body> }
1259+
// }
1260+
// ```
1261+
let body = this.block_all(
1262+
desugared_span,
1263+
this.arena.alloc_from_iter(statements),
1264+
Some(user_body),
1265+
);
12301266

1231-
this.expr_block(body)
1232-
},
1233-
);
1267+
this.expr_block(body)
1268+
};
1269+
let coroutine_expr = match (asyncness, genness) {
1270+
(Async::Yes { .. }, _) => this.make_async_expr(
1271+
CaptureBy::Value { move_kw: rustc_span::DUMMY_SP },
1272+
closure_id,
1273+
None,
1274+
body.span,
1275+
hir::CoroutineSource::Fn,
1276+
mkbody,
1277+
),
1278+
(_, Gen::Yes { .. }) => this.make_gen_expr(
1279+
CaptureBy::Value { move_kw: rustc_span::DUMMY_SP },
1280+
closure_id,
1281+
None,
1282+
body.span,
1283+
hir::CoroutineSource::Fn,
1284+
mkbody,
1285+
),
1286+
_ => unreachable!("we must have either an async fn or a gen fn"),
1287+
};
12341288

12351289
let hir_id = this.lower_node_id(closure_id);
12361290
this.maybe_forward_track_caller(body.span, fn_id, hir_id);
1237-
let expr = hir::Expr { hir_id, kind: async_expr, span: this.lower_span(body.span) };
1291+
let expr = hir::Expr { hir_id, kind: coroutine_expr, span: this.lower_span(body.span) };
12381292

12391293
(this.arena.alloc_from_iter(parameters), expr)
12401294
})
@@ -1246,13 +1300,13 @@ impl<'hir> LoweringContext<'_, 'hir> {
12461300
sig: &FnSig,
12471301
id: NodeId,
12481302
kind: FnDeclKind,
1249-
is_async: Option<(NodeId, Span)>,
1303+
transform_return_type: Option<FnReturnTransformation>,
12501304
) -> (&'hir hir::Generics<'hir>, hir::FnSig<'hir>) {
12511305
let header = self.lower_fn_header(sig.header);
12521306
let itctx = ImplTraitContext::Universal;
12531307
let (generics, decl) =
12541308
self.lower_generics(generics, sig.header.constness, id, &itctx, |this| {
1255-
this.lower_fn_decl(&sig.decl, id, sig.span, kind, is_async)
1309+
this.lower_fn_decl(&sig.decl, id, sig.span, kind, transform_return_type)
12561310
});
12571311
(generics, hir::FnSig { header, decl, span: self.lower_span(sig.span) })
12581312
}

compiler/rustc_ast_lowering/src/lib.rs

Lines changed: 53 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,21 @@ enum ParenthesizedGenericArgs {
455455
Err,
456456
}
457457

458+
/// Describes a return type transformation that can be performed by `LoweringContext::lower_fn_decl`
459+
#[derive(Debug)]
460+
enum FnReturnTransformation {
461+
/// Replaces a return type `T` with `impl Future<Output = T>`.
462+
///
463+
/// The `NodeId` is the ID of the return type `impl Trait` item, and the `Span` points to the
464+
/// `async` keyword.
465+
Async(NodeId, Span),
466+
/// Replaces a return type `T` with `impl Iterator<Item = T>`.
467+
///
468+
/// The `NodeId` is the ID of the return type `impl Trait` item, and the `Span` points to the
469+
/// `gen` keyword.
470+
Iterator(NodeId, Span),
471+
}
472+
458473
impl<'a, 'hir> LoweringContext<'a, 'hir> {
459474
fn create_def(
460475
&mut self,
@@ -1739,21 +1754,23 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
17391754
}))
17401755
}
17411756

1742-
// Lowers a function declaration.
1743-
//
1744-
// `decl`: the unlowered (AST) function declaration.
1745-
// `fn_node_id`: `impl Trait` arguments are lowered into generic parameters on the given `NodeId`.
1746-
// `make_ret_async`: if `Some`, converts `-> T` into `-> impl Future<Output = T>` in the
1747-
// return type. This is used for `async fn` declarations. The `NodeId` is the ID of the
1748-
// return type `impl Trait` item, and the `Span` points to the `async` keyword.
1757+
/// Lowers a function declaration.
1758+
///
1759+
/// `decl`: the unlowered (AST) function declaration.
1760+
///
1761+
/// `fn_node_id`: `impl Trait` arguments are lowered into generic parameters on the given
1762+
/// `NodeId`.
1763+
///
1764+
/// `transform_return_type`: if `Some`, applies some conversion to the return type, such as is
1765+
/// needed for `async fn` and `gen fn`. See [`FnReturnTransformation`] for more details.
17491766
#[instrument(level = "debug", skip(self))]
17501767
fn lower_fn_decl(
17511768
&mut self,
17521769
decl: &FnDecl,
17531770
fn_node_id: NodeId,
17541771
fn_span: Span,
17551772
kind: FnDeclKind,
1756-
make_ret_async: Option<(NodeId, Span)>,
1773+
transform_return_type: Option<FnReturnTransformation>,
17571774
) -> &'hir hir::FnDecl<'hir> {
17581775
let c_variadic = decl.c_variadic();
17591776

@@ -1782,11 +1799,12 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
17821799
self.lower_ty_direct(&param.ty, &itctx)
17831800
}));
17841801

1785-
let output = if let Some((ret_id, _span)) = make_ret_async {
1786-
let fn_def_id = self.local_def_id(fn_node_id);
1787-
self.lower_async_fn_ret_ty(&decl.output, fn_def_id, ret_id, kind, fn_span)
1788-
} else {
1789-
match &decl.output {
1802+
let output = match transform_return_type {
1803+
Some(transform) => {
1804+
let fn_def_id = self.local_def_id(fn_node_id);
1805+
self.lower_coroutine_fn_ret_ty(&decl.output, fn_def_id, transform, kind, fn_span)
1806+
}
1807+
None => match &decl.output {
17901808
FnRetTy::Ty(ty) => {
17911809
let context = if kind.return_impl_trait_allowed() {
17921810
let fn_def_id = self.local_def_id(fn_node_id);
@@ -1810,7 +1828,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
18101828
hir::FnRetTy::Return(self.lower_ty(ty, &context))
18111829
}
18121830
FnRetTy::Default(span) => hir::FnRetTy::DefaultReturn(self.lower_span(*span)),
1813-
}
1831+
},
18141832
};
18151833

18161834
self.arena.alloc(hir::FnDecl {
@@ -1849,17 +1867,22 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
18491867
// `fn_node_id`: `NodeId` of the parent function (used to create child impl trait definition)
18501868
// `opaque_ty_node_id`: `NodeId` of the opaque `impl Trait` type that should be created
18511869
#[instrument(level = "debug", skip(self))]
1852-
fn lower_async_fn_ret_ty(
1870+
fn lower_coroutine_fn_ret_ty(
18531871
&mut self,
18541872
output: &FnRetTy,
18551873
fn_def_id: LocalDefId,
1856-
opaque_ty_node_id: NodeId,
1874+
transform: FnReturnTransformation,
18571875
fn_kind: FnDeclKind,
18581876
fn_span: Span,
18591877
) -> hir::FnRetTy<'hir> {
18601878
let span = self.lower_span(fn_span);
18611879
let opaque_ty_span = self.mark_span_with_reason(DesugaringKind::Async, span, None);
18621880

1881+
let opaque_ty_node_id = match transform {
1882+
FnReturnTransformation::Async(opaque_ty_node_id, _)
1883+
| FnReturnTransformation::Iterator(opaque_ty_node_id, _) => opaque_ty_node_id,
1884+
};
1885+
18631886
let captured_lifetimes: Vec<_> = self
18641887
.resolver
18651888
.take_extra_lifetime_params(opaque_ty_node_id)
@@ -1875,8 +1898,9 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
18751898
span,
18761899
opaque_ty_span,
18771900
|this| {
1878-
let future_bound = this.lower_async_fn_output_type_to_future_bound(
1901+
let future_bound = this.lower_coroutine_fn_output_type_to_future_bound(
18791902
output,
1903+
transform,
18801904
span,
18811905
ImplTraitContext::ReturnPositionOpaqueTy {
18821906
origin: hir::OpaqueTyOrigin::FnReturn(fn_def_id),
@@ -1892,9 +1916,10 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
18921916
}
18931917

18941918
/// Transforms `-> T` into `Future<Output = T>`.
1895-
fn lower_async_fn_output_type_to_future_bound(
1919+
fn lower_coroutine_fn_output_type_to_future_bound(
18961920
&mut self,
18971921
output: &FnRetTy,
1922+
transform: FnReturnTransformation,
18981923
span: Span,
18991924
nested_impl_trait_context: ImplTraitContext,
19001925
) -> hir::GenericBound<'hir> {
@@ -1909,17 +1934,23 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
19091934
FnRetTy::Default(ret_ty_span) => self.arena.alloc(self.ty_tup(*ret_ty_span, &[])),
19101935
};
19111936

1912-
// "<Output = T>"
1937+
// "<Output|Item = T>"
1938+
let (symbol, lang_item) = match transform {
1939+
FnReturnTransformation::Async(..) => (hir::FN_OUTPUT_NAME, hir::LangItem::Future),
1940+
FnReturnTransformation::Iterator(..) => {
1941+
(hir::ITERATOR_ITEM_NAME, hir::LangItem::Iterator)
1942+
}
1943+
};
1944+
19131945
let future_args = self.arena.alloc(hir::GenericArgs {
19141946
args: &[],
1915-
bindings: arena_vec![self; self.output_ty_binding(span, output_ty)],
1947+
bindings: arena_vec![self; self.assoc_ty_binding(symbol, span, output_ty)],
19161948
parenthesized: hir::GenericArgsParentheses::No,
19171949
span_ext: DUMMY_SP,
19181950
});
19191951

19201952
hir::GenericBound::LangItemTrait(
1921-
// ::std::future::Future<future_params>
1922-
hir::LangItem::Future,
1953+
lang_item,
19231954
self.lower_span(span),
19241955
self.next_id(),
19251956
future_args,

0 commit comments

Comments
 (0)