Skip to content

Commit b7d67ea

Browse files
Require coroutine kind type to be passed to TyCtxt::coroutine_layout
1 parent 847fd88 commit b7d67ea

File tree

8 files changed

+59
-17
lines changed

8 files changed

+59
-17
lines changed

compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/cpp_like.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -683,7 +683,8 @@ fn build_union_fields_for_direct_tag_coroutine<'ll, 'tcx>(
683683
_ => unreachable!(),
684684
};
685685

686-
let coroutine_layout = cx.tcx.coroutine_layout(coroutine_def_id).unwrap();
686+
let coroutine_layout =
687+
cx.tcx.coroutine_layout(coroutine_def_id, coroutine_args.kind_ty()).unwrap();
687688

688689
let common_upvar_names = cx.tcx.closure_saved_names_of_captured_variables(coroutine_def_id);
689690
let variant_range = coroutine_args.variant_range(coroutine_def_id, cx.tcx);

compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/native.rs

+5-2
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ pub(super) fn build_coroutine_di_node<'ll, 'tcx>(
135135
unique_type_id: UniqueTypeId<'tcx>,
136136
) -> DINodeCreationResult<'ll> {
137137
let coroutine_type = unique_type_id.expect_ty();
138-
let &ty::Coroutine(coroutine_def_id, _) = coroutine_type.kind() else {
138+
let &ty::Coroutine(coroutine_def_id, coroutine_args) = coroutine_type.kind() else {
139139
bug!("build_coroutine_di_node() called with non-coroutine type: `{:?}`", coroutine_type)
140140
};
141141

@@ -158,7 +158,10 @@ pub(super) fn build_coroutine_di_node<'ll, 'tcx>(
158158
DIFlags::FlagZero,
159159
),
160160
|cx, coroutine_type_di_node| {
161-
let coroutine_layout = cx.tcx.coroutine_layout(coroutine_def_id).unwrap();
161+
let coroutine_layout = cx
162+
.tcx
163+
.coroutine_layout(coroutine_def_id, coroutine_args.as_coroutine().kind_ty())
164+
.unwrap();
162165

163166
let Variants::Multiple { tag_encoding: TagEncoding::Direct, ref variants, .. } =
164167
coroutine_type_and_layout.variants

compiler/rustc_const_eval/src/transform/validate.rs

+7-6
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,9 @@ impl<'tcx> MirPass<'tcx> for Validator {
101101
}
102102

103103
// Enforce that coroutine-closure layouts are identical.
104-
if let Some(layout) = body.coroutine_layout()
104+
if let Some(layout) = body.coroutine_layout_raw()
105105
&& let Some(by_move_body) = body.coroutine_by_move_body()
106-
&& let Some(by_move_layout) = by_move_body.coroutine_layout()
106+
&& let Some(by_move_layout) = by_move_body.coroutine_layout_raw()
107107
{
108108
if layout != by_move_layout {
109109
// If this turns out not to be true, please let compiler-errors know.
@@ -715,13 +715,14 @@ impl<'a, 'tcx> Visitor<'tcx> for TypeChecker<'a, 'tcx> {
715715
// args of the coroutine. Otherwise, we prefer to use this body
716716
// since we may be in the process of computing this MIR in the
717717
// first place.
718-
let gen_body = if def_id == self.caller_body.source.def_id() {
719-
self.caller_body
718+
let layout = if def_id == self.caller_body.source.def_id() {
719+
// FIXME: This is not right for async closures.
720+
self.caller_body.coroutine_layout_raw()
720721
} else {
721-
self.tcx.optimized_mir(def_id)
722+
self.tcx.coroutine_layout(def_id, args.as_coroutine().kind_ty())
722723
};
723724

724-
let Some(layout) = gen_body.coroutine_layout() else {
725+
let Some(layout) = layout else {
725726
self.fail(
726727
location,
727728
format!("No coroutine layout for {parent_ty:?}"),

compiler/rustc_middle/src/mir/mod.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -652,8 +652,9 @@ impl<'tcx> Body<'tcx> {
652652
self.coroutine.as_ref().and_then(|coroutine| coroutine.resume_ty)
653653
}
654654

655+
/// Prefer going through [`TyCtxt::coroutine_layout`] rather than using this directly.
655656
#[inline]
656-
pub fn coroutine_layout(&self) -> Option<&CoroutineLayout<'tcx>> {
657+
pub fn coroutine_layout_raw(&self) -> Option<&CoroutineLayout<'tcx>> {
657658
self.coroutine.as_ref().and_then(|coroutine| coroutine.coroutine_layout.as_ref())
658659
}
659660

compiler/rustc_middle/src/mir/pretty.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ fn dump_matched_mir_node<'tcx, F>(
126126
Some(promoted) => write!(file, "::{promoted:?}`")?,
127127
}
128128
writeln!(file, " {disambiguator} {pass_name}")?;
129-
if let Some(ref layout) = body.coroutine_layout() {
129+
if let Some(ref layout) = body.coroutine_layout_raw() {
130130
writeln!(file, "/* coroutine_layout = {layout:#?} */")?;
131131
}
132132
writeln!(file)?;

compiler/rustc_middle/src/ty/mod.rs

+35-2
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ pub use rustc_target::abi::{ReprFlags, ReprOptions};
6060
pub use rustc_type_ir::{DebugWithInfcx, InferCtxtLike, WithInfcx};
6161
pub use vtable::*;
6262

63+
use std::assert_matches::assert_matches;
6364
use std::fmt::Debug;
6465
use std::hash::{Hash, Hasher};
6566
use std::marker::PhantomData;
@@ -1826,8 +1827,40 @@ impl<'tcx> TyCtxt<'tcx> {
18261827

18271828
/// Returns layout of a coroutine. Layout might be unavailable if the
18281829
/// coroutine is tainted by errors.
1829-
pub fn coroutine_layout(self, def_id: DefId) -> Option<&'tcx CoroutineLayout<'tcx>> {
1830-
self.optimized_mir(def_id).coroutine_layout()
1830+
///
1831+
/// Takes `coroutine_kind` which can be acquired from the `CoroutineArgs::kind_ty`,
1832+
/// e.g. `args.as_coroutine().kind_ty()`.
1833+
pub fn coroutine_layout(
1834+
self,
1835+
def_id: DefId,
1836+
coroutine_kind_ty: Ty<'tcx>,
1837+
) -> Option<&'tcx CoroutineLayout<'tcx>> {
1838+
let mir = self.optimized_mir(def_id);
1839+
// Regular coroutine
1840+
if coroutine_kind_ty.is_unit() {
1841+
mir.coroutine_layout_raw()
1842+
} else {
1843+
// If we have a `Coroutine` that comes from an coroutine-closure,
1844+
// then it may be a by-move or by-ref body.
1845+
let ty::Coroutine(_, identity_args) =
1846+
*self.type_of(def_id).instantiate_identity().kind()
1847+
else {
1848+
unreachable!();
1849+
};
1850+
let identity_kind_ty = identity_args.as_coroutine().kind_ty();
1851+
// If the types differ, then we must be getting the by-move body of
1852+
// a by-ref coroutine.
1853+
if identity_kind_ty == coroutine_kind_ty {
1854+
mir.coroutine_layout_raw()
1855+
} else {
1856+
assert_matches!(coroutine_kind_ty.to_opt_closure_kind(), Some(ClosureKind::FnOnce));
1857+
assert_matches!(
1858+
identity_kind_ty.to_opt_closure_kind(),
1859+
Some(ClosureKind::Fn | ClosureKind::FnMut)
1860+
);
1861+
mir.coroutine_by_move_body().unwrap().coroutine_layout_raw()
1862+
}
1863+
}
18311864
}
18321865

18331866
/// Given the `DefId` of an impl, returns the `DefId` of the trait it implements.

compiler/rustc_middle/src/ty/sty.rs

+5-2
Original file line numberDiff line numberDiff line change
@@ -694,7 +694,10 @@ impl<'tcx> CoroutineArgs<'tcx> {
694694
#[inline]
695695
pub fn variant_range(&self, def_id: DefId, tcx: TyCtxt<'tcx>) -> Range<VariantIdx> {
696696
// FIXME requires optimized MIR
697-
FIRST_VARIANT..tcx.coroutine_layout(def_id).unwrap().variant_fields.next_index()
697+
// FIXME(async_closures): We should assert all coroutine layouts have
698+
// the same number of variants.
699+
FIRST_VARIANT
700+
..tcx.coroutine_layout(def_id, tcx.types.unit).unwrap().variant_fields.next_index()
698701
}
699702

700703
/// The discriminant for the given variant. Panics if the `variant_index` is
@@ -754,7 +757,7 @@ impl<'tcx> CoroutineArgs<'tcx> {
754757
def_id: DefId,
755758
tcx: TyCtxt<'tcx>,
756759
) -> impl Iterator<Item: Iterator<Item = Ty<'tcx>> + Captures<'tcx>> {
757-
let layout = tcx.coroutine_layout(def_id).unwrap();
760+
let layout = tcx.coroutine_layout(def_id, self.kind_ty()).unwrap();
758761
layout.variant_fields.iter().map(move |variant| {
759762
variant.iter().map(move |field| {
760763
ty::EarlyBinder::bind(layout.field_tys[*field].ty).instantiate(tcx, self.args)

compiler/rustc_ty_utils/src/layout.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -745,7 +745,7 @@ fn coroutine_layout<'tcx>(
745745
let tcx = cx.tcx;
746746
let instantiate_field = |ty: Ty<'tcx>| EarlyBinder::bind(ty).instantiate(tcx, args);
747747

748-
let Some(info) = tcx.coroutine_layout(def_id) else {
748+
let Some(info) = tcx.coroutine_layout(def_id, args.as_coroutine().kind_ty()) else {
749749
return Err(error(cx, LayoutError::Unknown(ty)));
750750
};
751751
let (ineligible_locals, assignments) = coroutine_saved_local_eligibility(info);
@@ -1072,7 +1072,7 @@ fn variant_info_for_coroutine<'tcx>(
10721072
return (vec![], None);
10731073
};
10741074

1075-
let coroutine = cx.tcx.coroutine_layout(def_id).unwrap();
1075+
let coroutine = cx.tcx.coroutine_layout(def_id, args.as_coroutine().kind_ty()).unwrap();
10761076
let upvar_names = cx.tcx.closure_saved_names_of_captured_variables(def_id);
10771077

10781078
let mut upvars_size = Size::ZERO;

0 commit comments

Comments
 (0)