@@ -67,9 +67,9 @@ use rustc_index::{Idx, IndexVec};
67
67
use rustc_middle:: mir:: dump_mir;
68
68
use rustc_middle:: mir:: visit:: { MutVisitor , PlaceContext , Visitor } ;
69
69
use rustc_middle:: mir:: * ;
70
+ use rustc_middle:: ty:: CoroutineArgs ;
70
71
use rustc_middle:: ty:: InstanceDef ;
71
- use rustc_middle:: ty:: { self , AdtDef , Ty , TyCtxt } ;
72
- use rustc_middle:: ty:: { CoroutineArgs , GenericArgsRef } ;
72
+ use rustc_middle:: ty:: { self , Ty , TyCtxt } ;
73
73
use rustc_mir_dataflow:: impls:: {
74
74
MaybeBorrowedLocals , MaybeLiveLocals , MaybeRequiresStorage , MaybeStorageLive ,
75
75
} ;
@@ -226,8 +226,6 @@ struct SuspensionPoint<'tcx> {
226
226
struct TransformVisitor < ' tcx > {
227
227
tcx : TyCtxt < ' tcx > ,
228
228
coroutine_kind : hir:: CoroutineKind ,
229
- state_adt_ref : AdtDef < ' tcx > ,
230
- state_args : GenericArgsRef < ' tcx > ,
231
229
232
230
// The type of the discriminant in the coroutine struct
233
231
discr_ty : Ty < ' tcx > ,
@@ -246,21 +244,34 @@ struct TransformVisitor<'tcx> {
246
244
always_live_locals : BitSet < Local > ,
247
245
248
246
// The original RETURN_PLACE local
249
- new_ret_local : Local ,
247
+ old_ret_local : Local ,
248
+
249
+ old_yield_ty : Ty < ' tcx > ,
250
+
251
+ old_ret_ty : Ty < ' tcx > ,
250
252
}
251
253
252
254
impl < ' tcx > TransformVisitor < ' tcx > {
253
255
fn insert_none_ret_block ( & self , body : & mut Body < ' tcx > ) -> BasicBlock {
254
- let block = BasicBlock :: new ( body . basic_blocks . len ( ) ) ;
256
+ assert ! ( matches! ( self . coroutine_kind , CoroutineKind :: Gen ( _ ) ) ) ;
255
257
258
+ let block = BasicBlock :: new ( body. basic_blocks . len ( ) ) ;
256
259
let source_info = SourceInfo :: outermost ( body. span ) ;
260
+ let option_def_id = self . tcx . require_lang_item ( LangItem :: Option , None ) ;
257
261
258
- let ( kind, idx) = self . coroutine_state_adt_and_variant_idx ( true ) ;
259
- assert_eq ! ( self . state_adt_ref. variant( idx) . fields. len( ) , 0 ) ;
260
262
let statements = vec ! [ Statement {
261
263
kind: StatementKind :: Assign ( Box :: new( (
262
264
Place :: return_place( ) ,
263
- Rvalue :: Aggregate ( Box :: new( kind) , IndexVec :: new( ) ) ,
265
+ Rvalue :: Aggregate (
266
+ Box :: new( AggregateKind :: Adt (
267
+ option_def_id,
268
+ VariantIdx :: from_usize( 0 ) ,
269
+ self . tcx. mk_args( & [ self . old_yield_ty. into( ) ] ) ,
270
+ None ,
271
+ None ,
272
+ ) ) ,
273
+ IndexVec :: new( ) ,
274
+ ) ,
264
275
) ) ) ,
265
276
source_info,
266
277
} ] ;
@@ -274,23 +285,6 @@ impl<'tcx> TransformVisitor<'tcx> {
274
285
block
275
286
}
276
287
277
- fn coroutine_state_adt_and_variant_idx (
278
- & self ,
279
- is_return : bool ,
280
- ) -> ( AggregateKind < ' tcx > , VariantIdx ) {
281
- let idx = VariantIdx :: new ( match ( is_return, self . coroutine_kind ) {
282
- ( true , hir:: CoroutineKind :: Coroutine ) => 1 , // CoroutineState::Complete
283
- ( false , hir:: CoroutineKind :: Coroutine ) => 0 , // CoroutineState::Yielded
284
- ( true , hir:: CoroutineKind :: Async ( _) ) => 0 , // Poll::Ready
285
- ( false , hir:: CoroutineKind :: Async ( _) ) => 1 , // Poll::Pending
286
- ( true , hir:: CoroutineKind :: Gen ( _) ) => 0 , // Option::None
287
- ( false , hir:: CoroutineKind :: Gen ( _) ) => 1 , // Option::Some
288
- } ) ;
289
-
290
- let kind = AggregateKind :: Adt ( self . state_adt_ref . did ( ) , idx, self . state_args , None , None ) ;
291
- ( kind, idx)
292
- }
293
-
294
288
// Make a `CoroutineState` or `Poll` variant assignment.
295
289
//
296
290
// `core::ops::CoroutineState` only has single element tuple variants,
@@ -303,51 +297,99 @@ impl<'tcx> TransformVisitor<'tcx> {
303
297
is_return : bool ,
304
298
statements : & mut Vec < Statement < ' tcx > > ,
305
299
) {
306
- let ( kind, idx) = self . coroutine_state_adt_and_variant_idx ( is_return) ;
307
-
308
- match self . coroutine_kind {
309
- // `Poll::Pending`
300
+ let rvalue = match self . coroutine_kind {
310
301
CoroutineKind :: Async ( _) => {
311
- if !is_return {
312
- assert_eq ! ( self . state_adt_ref. variant( idx) . fields. len( ) , 0 ) ;
313
-
314
- // FIXME(swatinem): assert that `val` is indeed unit?
315
- statements. push ( Statement {
316
- kind : StatementKind :: Assign ( Box :: new ( (
317
- Place :: return_place ( ) ,
318
- Rvalue :: Aggregate ( Box :: new ( kind) , IndexVec :: new ( ) ) ,
319
- ) ) ) ,
320
- source_info,
321
- } ) ;
322
- return ;
302
+ let poll_def_id = self . tcx . require_lang_item ( LangItem :: Poll , None ) ;
303
+ let args = self . tcx . mk_args ( & [ self . old_ret_ty . into ( ) ] ) ;
304
+ if is_return {
305
+ // Poll::Ready(val)
306
+ Rvalue :: Aggregate (
307
+ Box :: new ( AggregateKind :: Adt (
308
+ poll_def_id,
309
+ VariantIdx :: from_usize ( 0 ) ,
310
+ args,
311
+ None ,
312
+ None ,
313
+ ) ) ,
314
+ IndexVec :: from_raw ( vec ! [ val] ) ,
315
+ )
316
+ } else {
317
+ // Poll::Pending
318
+ Rvalue :: Aggregate (
319
+ Box :: new ( AggregateKind :: Adt (
320
+ poll_def_id,
321
+ VariantIdx :: from_usize ( 1 ) ,
322
+ args,
323
+ None ,
324
+ None ,
325
+ ) ) ,
326
+ IndexVec :: new ( ) ,
327
+ )
323
328
}
324
329
}
325
- // `Option::None`
326
330
CoroutineKind :: Gen ( _) => {
331
+ let option_def_id = self . tcx . require_lang_item ( LangItem :: Option , None ) ;
332
+ let args = self . tcx . mk_args ( & [ self . old_yield_ty . into ( ) ] ) ;
327
333
if is_return {
328
- assert_eq ! ( self . state_adt_ref. variant( idx) . fields. len( ) , 0 ) ;
329
-
330
- statements. push ( Statement {
331
- kind : StatementKind :: Assign ( Box :: new ( (
332
- Place :: return_place ( ) ,
333
- Rvalue :: Aggregate ( Box :: new ( kind) , IndexVec :: new ( ) ) ,
334
- ) ) ) ,
335
- source_info,
336
- } ) ;
337
- return ;
334
+ // None
335
+ Rvalue :: Aggregate (
336
+ Box :: new ( AggregateKind :: Adt (
337
+ option_def_id,
338
+ VariantIdx :: from_usize ( 0 ) ,
339
+ args,
340
+ None ,
341
+ None ,
342
+ ) ) ,
343
+ IndexVec :: new ( ) ,
344
+ )
345
+ } else {
346
+ // Some(val)
347
+ Rvalue :: Aggregate (
348
+ Box :: new ( AggregateKind :: Adt (
349
+ option_def_id,
350
+ VariantIdx :: from_usize ( 1 ) ,
351
+ args,
352
+ None ,
353
+ None ,
354
+ ) ) ,
355
+ IndexVec :: from_raw ( vec ! [ val] ) ,
356
+ )
338
357
}
339
358
}
340
- CoroutineKind :: Coroutine => { }
341
- }
342
-
343
- // else: `Poll::Ready(x)`, `CoroutineState::Yielded(x)`, `CoroutineState::Complete(x)`, or `Option::Some(x)`
344
- assert_eq ! ( self . state_adt_ref. variant( idx) . fields. len( ) , 1 ) ;
359
+ CoroutineKind :: Coroutine => {
360
+ let coroutine_state_def_id =
361
+ self . tcx . require_lang_item ( LangItem :: CoroutineState , None ) ;
362
+ let args = self . tcx . mk_args ( & [ self . old_yield_ty . into ( ) , self . old_ret_ty . into ( ) ] ) ;
363
+ if is_return {
364
+ // CoroutineState::Complete(val)
365
+ Rvalue :: Aggregate (
366
+ Box :: new ( AggregateKind :: Adt (
367
+ coroutine_state_def_id,
368
+ VariantIdx :: from_usize ( 1 ) ,
369
+ args,
370
+ None ,
371
+ None ,
372
+ ) ) ,
373
+ IndexVec :: from_raw ( vec ! [ val] ) ,
374
+ )
375
+ } else {
376
+ // CoroutineState::Yielded(val)
377
+ Rvalue :: Aggregate (
378
+ Box :: new ( AggregateKind :: Adt (
379
+ coroutine_state_def_id,
380
+ VariantIdx :: from_usize ( 0 ) ,
381
+ args,
382
+ None ,
383
+ None ,
384
+ ) ) ,
385
+ IndexVec :: from_raw ( vec ! [ val] ) ,
386
+ )
387
+ }
388
+ }
389
+ } ;
345
390
346
391
statements. push ( Statement {
347
- kind : StatementKind :: Assign ( Box :: new ( (
348
- Place :: return_place ( ) ,
349
- Rvalue :: Aggregate ( Box :: new ( kind) , [ val] . into ( ) ) ,
350
- ) ) ) ,
392
+ kind : StatementKind :: Assign ( Box :: new ( ( Place :: return_place ( ) , rvalue) ) ) ,
351
393
source_info,
352
394
} ) ;
353
395
}
@@ -421,7 +463,7 @@ impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> {
421
463
422
464
let ret_val = match data. terminator ( ) . kind {
423
465
TerminatorKind :: Return => {
424
- Some ( ( true , None , Operand :: Move ( Place :: from ( self . new_ret_local ) ) , None ) )
466
+ Some ( ( true , None , Operand :: Move ( Place :: from ( self . old_ret_local ) ) , None ) )
425
467
}
426
468
TerminatorKind :: Yield { ref value, resume, resume_arg, drop } => {
427
469
Some ( ( false , Some ( ( resume, resume_arg) ) , value. clone ( ) , drop) )
@@ -1503,10 +1545,11 @@ pub(crate) fn mir_coroutine_witnesses<'tcx>(
1503
1545
1504
1546
impl < ' tcx > MirPass < ' tcx > for StateTransform {
1505
1547
fn run_pass ( & self , tcx : TyCtxt < ' tcx > , body : & mut Body < ' tcx > ) {
1506
- let Some ( yield_ty ) = body. yield_ty ( ) else {
1548
+ let Some ( old_yield_ty ) = body. yield_ty ( ) else {
1507
1549
// This only applies to coroutines
1508
1550
return ;
1509
1551
} ;
1552
+ let old_ret_ty = body. return_ty ( ) ;
1510
1553
1511
1554
assert ! ( body. coroutine_drop( ) . is_none( ) ) ;
1512
1555
@@ -1528,34 +1571,33 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
1528
1571
1529
1572
let is_async_kind = matches ! ( body. coroutine_kind( ) , Some ( CoroutineKind :: Async ( _) ) ) ;
1530
1573
let is_gen_kind = matches ! ( body. coroutine_kind( ) , Some ( CoroutineKind :: Gen ( _) ) ) ;
1531
- let ( state_adt_ref , state_args ) = match body. coroutine_kind ( ) . unwrap ( ) {
1574
+ let new_ret_ty = match body. coroutine_kind ( ) . unwrap ( ) {
1532
1575
CoroutineKind :: Async ( _) => {
1533
1576
// Compute Poll<return_ty>
1534
1577
let poll_did = tcx. require_lang_item ( LangItem :: Poll , None ) ;
1535
1578
let poll_adt_ref = tcx. adt_def ( poll_did) ;
1536
- let poll_args = tcx. mk_args ( & [ body . return_ty ( ) . into ( ) ] ) ;
1537
- ( poll_adt_ref, poll_args)
1579
+ let poll_args = tcx. mk_args ( & [ old_ret_ty . into ( ) ] ) ;
1580
+ Ty :: new_adt ( tcx , poll_adt_ref, poll_args)
1538
1581
}
1539
1582
CoroutineKind :: Gen ( _) => {
1540
1583
// Compute Option<yield_ty>
1541
1584
let option_did = tcx. require_lang_item ( LangItem :: Option , None ) ;
1542
1585
let option_adt_ref = tcx. adt_def ( option_did) ;
1543
- let option_args = tcx. mk_args ( & [ body . yield_ty ( ) . unwrap ( ) . into ( ) ] ) ;
1544
- ( option_adt_ref, option_args)
1586
+ let option_args = tcx. mk_args ( & [ old_yield_ty . into ( ) ] ) ;
1587
+ Ty :: new_adt ( tcx , option_adt_ref, option_args)
1545
1588
}
1546
1589
CoroutineKind :: Coroutine => {
1547
1590
// Compute CoroutineState<yield_ty, return_ty>
1548
1591
let state_did = tcx. require_lang_item ( LangItem :: CoroutineState , None ) ;
1549
1592
let state_adt_ref = tcx. adt_def ( state_did) ;
1550
- let state_args = tcx. mk_args ( & [ yield_ty . into ( ) , body . return_ty ( ) . into ( ) ] ) ;
1551
- ( state_adt_ref, state_args)
1593
+ let state_args = tcx. mk_args ( & [ old_yield_ty . into ( ) , old_ret_ty . into ( ) ] ) ;
1594
+ Ty :: new_adt ( tcx , state_adt_ref, state_args)
1552
1595
}
1553
1596
} ;
1554
- let ret_ty = Ty :: new_adt ( tcx, state_adt_ref, state_args) ;
1555
1597
1556
- // We rename RETURN_PLACE which has type mir.return_ty to new_ret_local
1598
+ // We rename RETURN_PLACE which has type mir.return_ty to old_ret_local
1557
1599
// RETURN_PLACE then is a fresh unused local with type ret_ty.
1558
- let new_ret_local = replace_local ( RETURN_PLACE , ret_ty , body, tcx) ;
1600
+ let old_ret_local = replace_local ( RETURN_PLACE , new_ret_ty , body, tcx) ;
1559
1601
1560
1602
// Replace all occurrences of `ResumeTy` with `&mut Context<'_>` within async bodies.
1561
1603
if is_async_kind {
@@ -1572,17 +1614,18 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
1572
1614
} else {
1573
1615
body. local_decls [ resume_local] . ty
1574
1616
} ;
1575
- let new_resume_local = replace_local ( resume_local, resume_ty, body, tcx) ;
1617
+ let old_resume_local = replace_local ( resume_local, resume_ty, body, tcx) ;
1576
1618
1577
- // When first entering the coroutine, move the resume argument into its new local.
1619
+ // When first entering the coroutine, move the resume argument into its old local
1620
+ // (which is now a generator interior).
1578
1621
let source_info = SourceInfo :: outermost ( body. span ) ;
1579
1622
let stmts = & mut body. basic_blocks_mut ( ) [ START_BLOCK ] . statements ;
1580
1623
stmts. insert (
1581
1624
0 ,
1582
1625
Statement {
1583
1626
source_info,
1584
1627
kind : StatementKind :: Assign ( Box :: new ( (
1585
- new_resume_local . into ( ) ,
1628
+ old_resume_local . into ( ) ,
1586
1629
Rvalue :: Use ( Operand :: Move ( resume_local. into ( ) ) ) ,
1587
1630
) ) ) ,
1588
1631
} ,
@@ -1618,14 +1661,14 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
1618
1661
let mut transform = TransformVisitor {
1619
1662
tcx,
1620
1663
coroutine_kind : body. coroutine_kind ( ) . unwrap ( ) ,
1621
- state_adt_ref,
1622
- state_args,
1623
1664
remap,
1624
1665
storage_liveness,
1625
1666
always_live_locals,
1626
1667
suspension_points : Vec :: new ( ) ,
1627
- new_ret_local ,
1668
+ old_ret_local ,
1628
1669
discr_ty,
1670
+ old_ret_ty,
1671
+ old_yield_ty,
1629
1672
} ;
1630
1673
transform. visit_body ( body) ;
1631
1674
0 commit comments