Skip to content

Commit 990bb59

Browse files
Rework coroutine transform to be more flexible in preparation for async generators
1 parent 5facb42 commit 990bb59

File tree

1 file changed

+123
-80
lines changed

1 file changed

+123
-80
lines changed

compiler/rustc_mir_transform/src/coroutine.rs

Lines changed: 123 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,9 @@ use rustc_index::{Idx, IndexVec};
6767
use rustc_middle::mir::dump_mir;
6868
use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor};
6969
use rustc_middle::mir::*;
70+
use rustc_middle::ty::CoroutineArgs;
7071
use rustc_middle::ty::InstanceDef;
71-
use rustc_middle::ty::{self, AdtDef, Ty, TyCtxt};
72-
use rustc_middle::ty::{CoroutineArgs, GenericArgsRef};
72+
use rustc_middle::ty::{self, Ty, TyCtxt};
7373
use rustc_mir_dataflow::impls::{
7474
MaybeBorrowedLocals, MaybeLiveLocals, MaybeRequiresStorage, MaybeStorageLive,
7575
};
@@ -226,8 +226,6 @@ struct SuspensionPoint<'tcx> {
226226
struct TransformVisitor<'tcx> {
227227
tcx: TyCtxt<'tcx>,
228228
coroutine_kind: hir::CoroutineKind,
229-
state_adt_ref: AdtDef<'tcx>,
230-
state_args: GenericArgsRef<'tcx>,
231229

232230
// The type of the discriminant in the coroutine struct
233231
discr_ty: Ty<'tcx>,
@@ -246,21 +244,34 @@ struct TransformVisitor<'tcx> {
246244
always_live_locals: BitSet<Local>,
247245

248246
// The original RETURN_PLACE local
249-
new_ret_local: Local,
247+
old_ret_local: Local,
248+
249+
old_yield_ty: Ty<'tcx>,
250+
251+
old_ret_ty: Ty<'tcx>,
250252
}
251253

252254
impl<'tcx> TransformVisitor<'tcx> {
253255
fn insert_none_ret_block(&self, body: &mut Body<'tcx>) -> BasicBlock {
254-
let block = BasicBlock::new(body.basic_blocks.len());
256+
assert!(matches!(self.coroutine_kind, CoroutineKind::Gen(_)));
255257

258+
let block = BasicBlock::new(body.basic_blocks.len());
256259
let source_info = SourceInfo::outermost(body.span);
260+
let option_def_id = self.tcx.require_lang_item(LangItem::Option, None);
257261

258-
let (kind, idx) = self.coroutine_state_adt_and_variant_idx(true);
259-
assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0);
260262
let statements = vec![Statement {
261263
kind: StatementKind::Assign(Box::new((
262264
Place::return_place(),
263-
Rvalue::Aggregate(Box::new(kind), IndexVec::new()),
265+
Rvalue::Aggregate(
266+
Box::new(AggregateKind::Adt(
267+
option_def_id,
268+
VariantIdx::from_usize(0),
269+
self.tcx.mk_args(&[self.old_yield_ty.into()]),
270+
None,
271+
None,
272+
)),
273+
IndexVec::new(),
274+
),
264275
))),
265276
source_info,
266277
}];
@@ -274,23 +285,6 @@ impl<'tcx> TransformVisitor<'tcx> {
274285
block
275286
}
276287

277-
fn coroutine_state_adt_and_variant_idx(
278-
&self,
279-
is_return: bool,
280-
) -> (AggregateKind<'tcx>, VariantIdx) {
281-
let idx = VariantIdx::new(match (is_return, self.coroutine_kind) {
282-
(true, hir::CoroutineKind::Coroutine) => 1, // CoroutineState::Complete
283-
(false, hir::CoroutineKind::Coroutine) => 0, // CoroutineState::Yielded
284-
(true, hir::CoroutineKind::Async(_)) => 0, // Poll::Ready
285-
(false, hir::CoroutineKind::Async(_)) => 1, // Poll::Pending
286-
(true, hir::CoroutineKind::Gen(_)) => 0, // Option::None
287-
(false, hir::CoroutineKind::Gen(_)) => 1, // Option::Some
288-
});
289-
290-
let kind = AggregateKind::Adt(self.state_adt_ref.did(), idx, self.state_args, None, None);
291-
(kind, idx)
292-
}
293-
294288
// Make a `CoroutineState` or `Poll` variant assignment.
295289
//
296290
// `core::ops::CoroutineState` only has single element tuple variants,
@@ -303,51 +297,99 @@ impl<'tcx> TransformVisitor<'tcx> {
303297
is_return: bool,
304298
statements: &mut Vec<Statement<'tcx>>,
305299
) {
306-
let (kind, idx) = self.coroutine_state_adt_and_variant_idx(is_return);
307-
308-
match self.coroutine_kind {
309-
// `Poll::Pending`
300+
let rvalue = match self.coroutine_kind {
310301
CoroutineKind::Async(_) => {
311-
if !is_return {
312-
assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0);
313-
314-
// FIXME(swatinem): assert that `val` is indeed unit?
315-
statements.push(Statement {
316-
kind: StatementKind::Assign(Box::new((
317-
Place::return_place(),
318-
Rvalue::Aggregate(Box::new(kind), IndexVec::new()),
319-
))),
320-
source_info,
321-
});
322-
return;
302+
let poll_def_id = self.tcx.require_lang_item(LangItem::Poll, None);
303+
let args = self.tcx.mk_args(&[self.old_ret_ty.into()]);
304+
if is_return {
305+
// Poll::Ready(val)
306+
Rvalue::Aggregate(
307+
Box::new(AggregateKind::Adt(
308+
poll_def_id,
309+
VariantIdx::from_usize(0),
310+
args,
311+
None,
312+
None,
313+
)),
314+
IndexVec::from_raw(vec![val]),
315+
)
316+
} else {
317+
// Poll::Pending
318+
Rvalue::Aggregate(
319+
Box::new(AggregateKind::Adt(
320+
poll_def_id,
321+
VariantIdx::from_usize(1),
322+
args,
323+
None,
324+
None,
325+
)),
326+
IndexVec::new(),
327+
)
323328
}
324329
}
325-
// `Option::None`
326330
CoroutineKind::Gen(_) => {
331+
let option_def_id = self.tcx.require_lang_item(LangItem::Option, None);
332+
let args = self.tcx.mk_args(&[self.old_yield_ty.into()]);
327333
if is_return {
328-
assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0);
329-
330-
statements.push(Statement {
331-
kind: StatementKind::Assign(Box::new((
332-
Place::return_place(),
333-
Rvalue::Aggregate(Box::new(kind), IndexVec::new()),
334-
))),
335-
source_info,
336-
});
337-
return;
334+
// None
335+
Rvalue::Aggregate(
336+
Box::new(AggregateKind::Adt(
337+
option_def_id,
338+
VariantIdx::from_usize(0),
339+
args,
340+
None,
341+
None,
342+
)),
343+
IndexVec::new(),
344+
)
345+
} else {
346+
// Some(val)
347+
Rvalue::Aggregate(
348+
Box::new(AggregateKind::Adt(
349+
option_def_id,
350+
VariantIdx::from_usize(1),
351+
args,
352+
None,
353+
None,
354+
)),
355+
IndexVec::from_raw(vec![val]),
356+
)
338357
}
339358
}
340-
CoroutineKind::Coroutine => {}
341-
}
342-
343-
// else: `Poll::Ready(x)`, `CoroutineState::Yielded(x)`, `CoroutineState::Complete(x)`, or `Option::Some(x)`
344-
assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 1);
359+
CoroutineKind::Coroutine => {
360+
let coroutine_state_def_id =
361+
self.tcx.require_lang_item(LangItem::CoroutineState, None);
362+
let args = self.tcx.mk_args(&[self.old_yield_ty.into(), self.old_ret_ty.into()]);
363+
if is_return {
364+
// CoroutineState::Complete(val)
365+
Rvalue::Aggregate(
366+
Box::new(AggregateKind::Adt(
367+
coroutine_state_def_id,
368+
VariantIdx::from_usize(1),
369+
args,
370+
None,
371+
None,
372+
)),
373+
IndexVec::from_raw(vec![val]),
374+
)
375+
} else {
376+
// CoroutineState::Yielded(val)
377+
Rvalue::Aggregate(
378+
Box::new(AggregateKind::Adt(
379+
coroutine_state_def_id,
380+
VariantIdx::from_usize(0),
381+
args,
382+
None,
383+
None,
384+
)),
385+
IndexVec::from_raw(vec![val]),
386+
)
387+
}
388+
}
389+
};
345390

346391
statements.push(Statement {
347-
kind: StatementKind::Assign(Box::new((
348-
Place::return_place(),
349-
Rvalue::Aggregate(Box::new(kind), [val].into()),
350-
))),
392+
kind: StatementKind::Assign(Box::new((Place::return_place(), rvalue))),
351393
source_info,
352394
});
353395
}
@@ -421,7 +463,7 @@ impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> {
421463

422464
let ret_val = match data.terminator().kind {
423465
TerminatorKind::Return => {
424-
Some((true, None, Operand::Move(Place::from(self.new_ret_local)), None))
466+
Some((true, None, Operand::Move(Place::from(self.old_ret_local)), None))
425467
}
426468
TerminatorKind::Yield { ref value, resume, resume_arg, drop } => {
427469
Some((false, Some((resume, resume_arg)), value.clone(), drop))
@@ -1503,10 +1545,11 @@ pub(crate) fn mir_coroutine_witnesses<'tcx>(
15031545

15041546
impl<'tcx> MirPass<'tcx> for StateTransform {
15051547
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
1506-
let Some(yield_ty) = body.yield_ty() else {
1548+
let Some(old_yield_ty) = body.yield_ty() else {
15071549
// This only applies to coroutines
15081550
return;
15091551
};
1552+
let old_ret_ty = body.return_ty();
15101553

15111554
assert!(body.coroutine_drop().is_none());
15121555

@@ -1528,34 +1571,33 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
15281571

15291572
let is_async_kind = matches!(body.coroutine_kind(), Some(CoroutineKind::Async(_)));
15301573
let is_gen_kind = matches!(body.coroutine_kind(), Some(CoroutineKind::Gen(_)));
1531-
let (state_adt_ref, state_args) = match body.coroutine_kind().unwrap() {
1574+
let new_ret_ty = match body.coroutine_kind().unwrap() {
15321575
CoroutineKind::Async(_) => {
15331576
// Compute Poll<return_ty>
15341577
let poll_did = tcx.require_lang_item(LangItem::Poll, None);
15351578
let poll_adt_ref = tcx.adt_def(poll_did);
1536-
let poll_args = tcx.mk_args(&[body.return_ty().into()]);
1537-
(poll_adt_ref, poll_args)
1579+
let poll_args = tcx.mk_args(&[old_ret_ty.into()]);
1580+
Ty::new_adt(tcx, poll_adt_ref, poll_args)
15381581
}
15391582
CoroutineKind::Gen(_) => {
15401583
// Compute Option<yield_ty>
15411584
let option_did = tcx.require_lang_item(LangItem::Option, None);
15421585
let option_adt_ref = tcx.adt_def(option_did);
1543-
let option_args = tcx.mk_args(&[body.yield_ty().unwrap().into()]);
1544-
(option_adt_ref, option_args)
1586+
let option_args = tcx.mk_args(&[old_yield_ty.into()]);
1587+
Ty::new_adt(tcx, option_adt_ref, option_args)
15451588
}
15461589
CoroutineKind::Coroutine => {
15471590
// Compute CoroutineState<yield_ty, return_ty>
15481591
let state_did = tcx.require_lang_item(LangItem::CoroutineState, None);
15491592
let state_adt_ref = tcx.adt_def(state_did);
1550-
let state_args = tcx.mk_args(&[yield_ty.into(), body.return_ty().into()]);
1551-
(state_adt_ref, state_args)
1593+
let state_args = tcx.mk_args(&[old_yield_ty.into(), old_ret_ty.into()]);
1594+
Ty::new_adt(tcx, state_adt_ref, state_args)
15521595
}
15531596
};
1554-
let ret_ty = Ty::new_adt(tcx, state_adt_ref, state_args);
15551597

1556-
// We rename RETURN_PLACE which has type mir.return_ty to new_ret_local
1598+
// We rename RETURN_PLACE which has type mir.return_ty to old_ret_local
15571599
// RETURN_PLACE then is a fresh unused local with type ret_ty.
1558-
let new_ret_local = replace_local(RETURN_PLACE, ret_ty, body, tcx);
1600+
let old_ret_local = replace_local(RETURN_PLACE, new_ret_ty, body, tcx);
15591601

15601602
// Replace all occurrences of `ResumeTy` with `&mut Context<'_>` within async bodies.
15611603
if is_async_kind {
@@ -1572,17 +1614,18 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
15721614
} else {
15731615
body.local_decls[resume_local].ty
15741616
};
1575-
let new_resume_local = replace_local(resume_local, resume_ty, body, tcx);
1617+
let old_resume_local = replace_local(resume_local, resume_ty, body, tcx);
15761618

1577-
// When first entering the coroutine, move the resume argument into its new local.
1619+
// When first entering the coroutine, move the resume argument into its old local
1620+
// (which is now a generator interior).
15781621
let source_info = SourceInfo::outermost(body.span);
15791622
let stmts = &mut body.basic_blocks_mut()[START_BLOCK].statements;
15801623
stmts.insert(
15811624
0,
15821625
Statement {
15831626
source_info,
15841627
kind: StatementKind::Assign(Box::new((
1585-
new_resume_local.into(),
1628+
old_resume_local.into(),
15861629
Rvalue::Use(Operand::Move(resume_local.into())),
15871630
))),
15881631
},
@@ -1618,14 +1661,14 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
16181661
let mut transform = TransformVisitor {
16191662
tcx,
16201663
coroutine_kind: body.coroutine_kind().unwrap(),
1621-
state_adt_ref,
1622-
state_args,
16231664
remap,
16241665
storage_liveness,
16251666
always_live_locals,
16261667
suspension_points: Vec::new(),
1627-
new_ret_local,
1668+
old_ret_local,
16281669
discr_ty,
1670+
old_ret_ty,
1671+
old_yield_ty,
16291672
};
16301673
transform.visit_body(body);
16311674

0 commit comments

Comments
 (0)