@@ -60,6 +60,7 @@ pub use rustc_target::abi::{ReprFlags, ReprOptions};
60
60
pub use rustc_type_ir:: { DebugWithInfcx , InferCtxtLike , WithInfcx } ;
61
61
pub use vtable:: * ;
62
62
63
+ use std:: assert_matches:: assert_matches;
63
64
use std:: fmt:: Debug ;
64
65
use std:: hash:: { Hash , Hasher } ;
65
66
use std:: marker:: PhantomData ;
@@ -1826,8 +1827,40 @@ impl<'tcx> TyCtxt<'tcx> {
1826
1827
1827
1828
/// Returns layout of a coroutine. Layout might be unavailable if the
1828
1829
/// coroutine is tainted by errors.
1829
- pub fn coroutine_layout ( self , def_id : DefId ) -> Option < & ' tcx CoroutineLayout < ' tcx > > {
1830
- self . optimized_mir ( def_id) . coroutine_layout ( )
1830
+ ///
1831
+ /// Takes `coroutine_kind` which can be acquired from the `CoroutineArgs::kind_ty`,
1832
+ /// e.g. `args.as_coroutine().kind_ty()`.
1833
+ pub fn coroutine_layout (
1834
+ self ,
1835
+ def_id : DefId ,
1836
+ coroutine_kind_ty : Ty < ' tcx > ,
1837
+ ) -> Option < & ' tcx CoroutineLayout < ' tcx > > {
1838
+ let mir = self . optimized_mir ( def_id) ;
1839
+ // Regular coroutine
1840
+ if coroutine_kind_ty. is_unit ( ) {
1841
+ mir. coroutine_layout_raw ( )
1842
+ } else {
1843
+ // If we have a `Coroutine` that comes from an coroutine-closure,
1844
+ // then it may be a by-move or by-ref body.
1845
+ let ty:: Coroutine ( _, identity_args) =
1846
+ * self . type_of ( def_id) . instantiate_identity ( ) . kind ( )
1847
+ else {
1848
+ unreachable ! ( ) ;
1849
+ } ;
1850
+ let identity_kind_ty = identity_args. as_coroutine ( ) . kind_ty ( ) ;
1851
+ // If the types differ, then we must be getting the by-move body of
1852
+ // a by-ref coroutine.
1853
+ if identity_kind_ty == coroutine_kind_ty {
1854
+ mir. coroutine_layout_raw ( )
1855
+ } else {
1856
+ assert_matches ! ( coroutine_kind_ty. to_opt_closure_kind( ) , Some ( ClosureKind :: FnOnce ) ) ;
1857
+ assert_matches ! (
1858
+ identity_kind_ty. to_opt_closure_kind( ) ,
1859
+ Some ( ClosureKind :: Fn | ClosureKind :: FnMut )
1860
+ ) ;
1861
+ mir. coroutine_by_move_body ( ) . unwrap ( ) . coroutine_layout_raw ( )
1862
+ }
1863
+ }
1831
1864
}
1832
1865
1833
1866
/// Given the `DefId` of an impl, returns the `DefId` of the trait it implements.
0 commit comments