Skip to content

Commit bffeb05

Browse files
Rollup merge of #123021 - compiler-errors:coroutine-layout-lol, r=oli-obk
Make `TyCtxt::coroutine_layout` take coroutine's kind parameter For coroutines that come from coroutine-closures (i.e. async closures), we may have two kinds of bodies stored in the coroutine; one that takes the closure's captures by reference, and one that takes the captures by move. These currently have identical layouts, but if we do any optimization for these layouts that are related to the upvars, then they will diverge -- e.g. #120168 (comment). This PR relaxes the assertion I added in #121122, and instead make the `TyCtxt::coroutine_layout` method take the `coroutine_kind_ty` argument from the coroutine, which will allow us to differentiate these by-move and by-ref bodies.
2 parents 8a7f285 + 9bda9ac commit bffeb05

File tree

8 files changed

+61
-23
lines changed

8 files changed

+61
-23
lines changed

compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/cpp_like.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -683,7 +683,8 @@ fn build_union_fields_for_direct_tag_coroutine<'ll, 'tcx>(
683683
_ => unreachable!(),
684684
};
685685

686-
let coroutine_layout = cx.tcx.optimized_mir(coroutine_def_id).coroutine_layout().unwrap();
686+
let coroutine_layout =
687+
cx.tcx.coroutine_layout(coroutine_def_id, coroutine_args.kind_ty()).unwrap();
687688

688689
let common_upvar_names = cx.tcx.closure_saved_names_of_captured_variables(coroutine_def_id);
689690
let variant_range = coroutine_args.variant_range(coroutine_def_id, cx.tcx);

compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/native.rs

+5-3
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ pub(super) fn build_coroutine_di_node<'ll, 'tcx>(
135135
unique_type_id: UniqueTypeId<'tcx>,
136136
) -> DINodeCreationResult<'ll> {
137137
let coroutine_type = unique_type_id.expect_ty();
138-
let &ty::Coroutine(coroutine_def_id, _) = coroutine_type.kind() else {
138+
let &ty::Coroutine(coroutine_def_id, coroutine_args) = coroutine_type.kind() else {
139139
bug!("build_coroutine_di_node() called with non-coroutine type: `{:?}`", coroutine_type)
140140
};
141141

@@ -158,8 +158,10 @@ pub(super) fn build_coroutine_di_node<'ll, 'tcx>(
158158
DIFlags::FlagZero,
159159
),
160160
|cx, coroutine_type_di_node| {
161-
let coroutine_layout =
162-
cx.tcx.optimized_mir(coroutine_def_id).coroutine_layout().unwrap();
161+
let coroutine_layout = cx
162+
.tcx
163+
.coroutine_layout(coroutine_def_id, coroutine_args.as_coroutine().kind_ty())
164+
.unwrap();
163165

164166
let Variants::Multiple { tag_encoding: TagEncoding::Direct, ref variants, .. } =
165167
coroutine_type_and_layout.variants

compiler/rustc_const_eval/src/transform/validate.rs

+11-11
Original file line numberDiff line numberDiff line change
@@ -101,18 +101,17 @@ impl<'tcx> MirPass<'tcx> for Validator {
101101
}
102102

103103
// Enforce that coroutine-closure layouts are identical.
104-
if let Some(layout) = body.coroutine_layout()
104+
if let Some(layout) = body.coroutine_layout_raw()
105105
&& let Some(by_move_body) = body.coroutine_by_move_body()
106-
&& let Some(by_move_layout) = by_move_body.coroutine_layout()
106+
&& let Some(by_move_layout) = by_move_body.coroutine_layout_raw()
107107
{
108-
if layout != by_move_layout {
109-
// If this turns out not to be true, please let compiler-errors know.
110-
// It is possible to support, but requires some changes to the layout
111-
// computation code.
108+
// FIXME(async_closures): We could do other validation here?
109+
if layout.variant_fields.len() != by_move_layout.variant_fields.len() {
112110
cfg_checker.fail(
113111
Location::START,
114112
format!(
115-
"Coroutine layout differs from by-move coroutine layout:\n\
113+
"Coroutine layout has different number of variant fields from \
114+
by-move coroutine layout:\n\
116115
layout: {layout:#?}\n\
117116
by_move_layout: {by_move_layout:#?}",
118117
),
@@ -715,13 +714,14 @@ impl<'a, 'tcx> Visitor<'tcx> for TypeChecker<'a, 'tcx> {
715714
// args of the coroutine. Otherwise, we prefer to use this body
716715
// since we may be in the process of computing this MIR in the
717716
// first place.
718-
let gen_body = if def_id == self.caller_body.source.def_id() {
719-
self.caller_body
717+
let layout = if def_id == self.caller_body.source.def_id() {
718+
// FIXME: This is not right for async closures.
719+
self.caller_body.coroutine_layout_raw()
720720
} else {
721-
self.tcx.optimized_mir(def_id)
721+
self.tcx.coroutine_layout(def_id, args.as_coroutine().kind_ty())
722722
};
723723

724-
let Some(layout) = gen_body.coroutine_layout() else {
724+
let Some(layout) = layout else {
725725
self.fail(
726726
location,
727727
format!("No coroutine layout for {parent_ty:?}"),

compiler/rustc_middle/src/mir/mod.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -652,8 +652,9 @@ impl<'tcx> Body<'tcx> {
652652
self.coroutine.as_ref().and_then(|coroutine| coroutine.resume_ty)
653653
}
654654

655+
/// Prefer going through [`TyCtxt::coroutine_layout`] rather than using this directly.
655656
#[inline]
656-
pub fn coroutine_layout(&self) -> Option<&CoroutineLayout<'tcx>> {
657+
pub fn coroutine_layout_raw(&self) -> Option<&CoroutineLayout<'tcx>> {
657658
self.coroutine.as_ref().and_then(|coroutine| coroutine.coroutine_layout.as_ref())
658659
}
659660

compiler/rustc_middle/src/mir/pretty.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ fn dump_matched_mir_node<'tcx, F>(
126126
Some(promoted) => write!(file, "::{promoted:?}`")?,
127127
}
128128
writeln!(file, " {disambiguator} {pass_name}")?;
129-
if let Some(ref layout) = body.coroutine_layout() {
129+
if let Some(ref layout) = body.coroutine_layout_raw() {
130130
writeln!(file, "/* coroutine_layout = {layout:#?} */")?;
131131
}
132132
writeln!(file)?;

compiler/rustc_middle/src/ty/mod.rs

+35-2
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ pub use rustc_target::abi::{ReprFlags, ReprOptions};
6060
pub use rustc_type_ir::{DebugWithInfcx, InferCtxtLike, WithInfcx};
6161
pub use vtable::*;
6262

63+
use std::assert_matches::assert_matches;
6364
use std::fmt::Debug;
6465
use std::hash::{Hash, Hasher};
6566
use std::marker::PhantomData;
@@ -1826,8 +1827,40 @@ impl<'tcx> TyCtxt<'tcx> {
18261827

18271828
/// Returns layout of a coroutine. Layout might be unavailable if the
18281829
/// coroutine is tainted by errors.
1829-
pub fn coroutine_layout(self, def_id: DefId) -> Option<&'tcx CoroutineLayout<'tcx>> {
1830-
self.optimized_mir(def_id).coroutine_layout()
1830+
///
1831+
/// Takes `coroutine_kind` which can be acquired from the `CoroutineArgs::kind_ty`,
1832+
/// e.g. `args.as_coroutine().kind_ty()`.
1833+
pub fn coroutine_layout(
1834+
self,
1835+
def_id: DefId,
1836+
coroutine_kind_ty: Ty<'tcx>,
1837+
) -> Option<&'tcx CoroutineLayout<'tcx>> {
1838+
let mir = self.optimized_mir(def_id);
1839+
// Regular coroutine
1840+
if coroutine_kind_ty.is_unit() {
1841+
mir.coroutine_layout_raw()
1842+
} else {
1843+
// If we have a `Coroutine` that comes from an coroutine-closure,
1844+
// then it may be a by-move or by-ref body.
1845+
let ty::Coroutine(_, identity_args) =
1846+
*self.type_of(def_id).instantiate_identity().kind()
1847+
else {
1848+
unreachable!();
1849+
};
1850+
let identity_kind_ty = identity_args.as_coroutine().kind_ty();
1851+
// If the types differ, then we must be getting the by-move body of
1852+
// a by-ref coroutine.
1853+
if identity_kind_ty == coroutine_kind_ty {
1854+
mir.coroutine_layout_raw()
1855+
} else {
1856+
assert_matches!(coroutine_kind_ty.to_opt_closure_kind(), Some(ClosureKind::FnOnce));
1857+
assert_matches!(
1858+
identity_kind_ty.to_opt_closure_kind(),
1859+
Some(ClosureKind::Fn | ClosureKind::FnMut)
1860+
);
1861+
mir.coroutine_by_move_body().unwrap().coroutine_layout_raw()
1862+
}
1863+
}
18311864
}
18321865

18331866
/// Given the `DefId` of an impl, returns the `DefId` of the trait it implements.

compiler/rustc_middle/src/ty/sty.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -694,7 +694,8 @@ impl<'tcx> CoroutineArgs<'tcx> {
694694
#[inline]
695695
pub fn variant_range(&self, def_id: DefId, tcx: TyCtxt<'tcx>) -> Range<VariantIdx> {
696696
// FIXME requires optimized MIR
697-
FIRST_VARIANT..tcx.coroutine_layout(def_id).unwrap().variant_fields.next_index()
697+
FIRST_VARIANT
698+
..tcx.coroutine_layout(def_id, tcx.types.unit).unwrap().variant_fields.next_index()
698699
}
699700

700701
/// The discriminant for the given variant. Panics if the `variant_index` is
@@ -754,7 +755,7 @@ impl<'tcx> CoroutineArgs<'tcx> {
754755
def_id: DefId,
755756
tcx: TyCtxt<'tcx>,
756757
) -> impl Iterator<Item: Iterator<Item = Ty<'tcx>> + Captures<'tcx>> {
757-
let layout = tcx.coroutine_layout(def_id).unwrap();
758+
let layout = tcx.coroutine_layout(def_id, self.kind_ty()).unwrap();
758759
layout.variant_fields.iter().map(move |variant| {
759760
variant.iter().map(move |field| {
760761
ty::EarlyBinder::bind(layout.field_tys[*field].ty).instantiate(tcx, self.args)

compiler/rustc_ty_utils/src/layout.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -745,7 +745,7 @@ fn coroutine_layout<'tcx>(
745745
let tcx = cx.tcx;
746746
let instantiate_field = |ty: Ty<'tcx>| EarlyBinder::bind(ty).instantiate(tcx, args);
747747

748-
let Some(info) = tcx.coroutine_layout(def_id) else {
748+
let Some(info) = tcx.coroutine_layout(def_id, args.as_coroutine().kind_ty()) else {
749749
return Err(error(cx, LayoutError::Unknown(ty)));
750750
};
751751
let (ineligible_locals, assignments) = coroutine_saved_local_eligibility(info);
@@ -1072,7 +1072,7 @@ fn variant_info_for_coroutine<'tcx>(
10721072
return (vec![], None);
10731073
};
10741074

1075-
let coroutine = cx.tcx.optimized_mir(def_id).coroutine_layout().unwrap();
1075+
let coroutine = cx.tcx.coroutine_layout(def_id, args.as_coroutine().kind_ty()).unwrap();
10761076
let upvar_names = cx.tcx.closure_saved_names_of_captured_variables(def_id);
10771077

10781078
let mut upvars_size = Size::ZERO;

0 commit comments

Comments
 (0)