Skip to content

Commit 5e2b66f

Browse files
Don't populate yield and resume types after the fact
1 parent 9212108 commit 5e2b66f

File tree

5 files changed

+85
-78
lines changed

5 files changed

+85
-78
lines changed

compiler/rustc_const_eval/src/transform/promote_consts.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -969,7 +969,7 @@ pub fn promote_candidates<'tcx>(
969969
0,
970970
vec![],
971971
body.span,
972-
body.coroutine_kind(),
972+
None,
973973
body.tainted_by_errors,
974974
);
975975
promoted.phase = MirPhase::Analysis(AnalysisPhase::Initial);

compiler/rustc_middle/src/mir/mod.rs

+19-10
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,23 @@ pub struct CoroutineInfo<'tcx> {
263263
pub coroutine_kind: CoroutineKind,
264264
}
265265

266+
impl<'tcx> CoroutineInfo<'tcx> {
267+
// Sets up `CoroutineInfo` for a pre-coroutine-transform MIR body.
268+
pub fn initial(
269+
coroutine_kind: CoroutineKind,
270+
yield_ty: Ty<'tcx>,
271+
resume_ty: Ty<'tcx>,
272+
) -> CoroutineInfo<'tcx> {
273+
CoroutineInfo {
274+
coroutine_kind,
275+
yield_ty: Some(yield_ty),
276+
resume_ty: Some(resume_ty),
277+
coroutine_drop: None,
278+
coroutine_layout: None,
279+
}
280+
}
281+
}
282+
266283
/// The lowered representation of a single function.
267284
#[derive(Clone, TyEncodable, TyDecodable, Debug, HashStable, TypeFoldable, TypeVisitable)]
268285
pub struct Body<'tcx> {
@@ -367,7 +384,7 @@ impl<'tcx> Body<'tcx> {
367384
arg_count: usize,
368385
var_debug_info: Vec<VarDebugInfo<'tcx>>,
369386
span: Span,
370-
coroutine_kind: Option<CoroutineKind>,
387+
coroutine: Option<Box<CoroutineInfo<'tcx>>>,
371388
tainted_by_errors: Option<ErrorGuaranteed>,
372389
) -> Self {
373390
// We need `arg_count` locals, and one for the return place.
@@ -384,15 +401,7 @@ impl<'tcx> Body<'tcx> {
384401
source,
385402
basic_blocks: BasicBlocks::new(basic_blocks),
386403
source_scopes,
387-
coroutine: coroutine_kind.map(|coroutine_kind| {
388-
Box::new(CoroutineInfo {
389-
yield_ty: None,
390-
resume_ty: None,
391-
coroutine_drop: None,
392-
coroutine_layout: None,
393-
coroutine_kind,
394-
})
395-
}),
404+
coroutine,
396405
local_decls,
397406
user_type_annotations,
398407
arg_count,

compiler/rustc_middle/src/thir.rs

-2
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,6 @@ macro_rules! thir_with_elements {
8686
}
8787
}
8888

89-
pub const UPVAR_ENV_PARAM: ParamId = ParamId::from_u32(0);
90-
9189
thir_with_elements! {
9290
body_type: BodyTy<'tcx>,
9391

compiler/rustc_mir_build/src/build/mod.rs

+60-60
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use rustc_errors::ErrorGuaranteed;
99
use rustc_hir as hir;
1010
use rustc_hir::def::DefKind;
1111
use rustc_hir::def_id::{DefId, LocalDefId};
12-
use rustc_hir::{CoroutineKind, Node};
12+
use rustc_hir::Node;
1313
use rustc_index::bit_set::GrowableBitSet;
1414
use rustc_index::{Idx, IndexSlice, IndexVec};
1515
use rustc_infer::infer::{InferCtxt, TyCtxtInferExt};
@@ -177,7 +177,7 @@ struct Builder<'a, 'tcx> {
177177
check_overflow: bool,
178178
fn_span: Span,
179179
arg_count: usize,
180-
coroutine_kind: Option<CoroutineKind>,
180+
coroutine: Option<Box<CoroutineInfo<'tcx>>>,
181181

182182
/// The current set of scopes, updated as we traverse;
183183
/// see the `scope` module for more details.
@@ -458,7 +458,6 @@ fn construct_fn<'tcx>(
458458
) -> Body<'tcx> {
459459
let span = tcx.def_span(fn_def);
460460
let fn_id = tcx.local_def_id_to_hir_id(fn_def);
461-
let coroutine_kind = tcx.coroutine_kind(fn_def);
462461

463462
// The representation of thir for `-Zunpretty=thir-tree` relies on
464463
// the entry expression being the last element of `thir.exprs`.
@@ -488,17 +487,15 @@ fn construct_fn<'tcx>(
488487

489488
let arguments = &thir.params;
490489

491-
let (resume_ty, yield_ty, return_ty) = if coroutine_kind.is_some() {
492-
let coroutine_ty = arguments[thir::UPVAR_ENV_PARAM].ty;
493-
let coroutine_sig = match coroutine_ty.kind() {
494-
ty::Coroutine(_, gen_args, ..) => gen_args.as_coroutine().sig(),
495-
_ => {
496-
span_bug!(span, "coroutine w/o coroutine type: {:?}", coroutine_ty)
497-
}
498-
};
499-
(Some(coroutine_sig.resume_ty), Some(coroutine_sig.yield_ty), coroutine_sig.return_ty)
500-
} else {
501-
(None, None, fn_sig.output())
490+
let return_ty = fn_sig.output();
491+
let coroutine = match tcx.type_of(fn_def).instantiate_identity().kind() {
492+
ty::Coroutine(_, args) => Some(Box::new(CoroutineInfo::initial(
493+
tcx.coroutine_kind(fn_def).unwrap(),
494+
args.as_coroutine().yield_ty(),
495+
args.as_coroutine().resume_ty(),
496+
))),
497+
ty::Closure(..) | ty::FnDef(..) => None,
498+
ty => span_bug!(span_with_body, "unexpected type of body: {ty:?}"),
502499
};
503500

504501
if let Some(custom_mir_attr) =
@@ -529,7 +526,7 @@ fn construct_fn<'tcx>(
529526
safety,
530527
return_ty,
531528
return_ty_span,
532-
coroutine_kind,
529+
coroutine,
533530
);
534531

535532
let call_site_scope =
@@ -563,11 +560,6 @@ fn construct_fn<'tcx>(
563560
None
564561
};
565562

566-
if coroutine_kind.is_some() {
567-
body.coroutine.as_mut().unwrap().yield_ty = yield_ty;
568-
body.coroutine.as_mut().unwrap().resume_ty = resume_ty;
569-
}
570-
571563
body
572564
}
573565

@@ -632,47 +624,62 @@ fn construct_const<'a, 'tcx>(
632624
fn construct_error(tcx: TyCtxt<'_>, def_id: LocalDefId, guar: ErrorGuaranteed) -> Body<'_> {
633625
let span = tcx.def_span(def_id);
634626
let hir_id = tcx.local_def_id_to_hir_id(def_id);
635-
let coroutine_kind = tcx.coroutine_kind(def_id);
636627

637-
let (inputs, output, resume_ty, yield_ty) = match tcx.def_kind(def_id) {
628+
let (inputs, output, coroutine) = match tcx.def_kind(def_id) {
638629
DefKind::Const
639630
| DefKind::AssocConst
640631
| DefKind::AnonConst
641632
| DefKind::InlineConst
642-
| DefKind::Static(_) => (vec![], tcx.type_of(def_id).instantiate_identity(), None, None),
633+
| DefKind::Static(_) => (vec![], tcx.type_of(def_id).instantiate_identity(), None),
643634
DefKind::Ctor(..) | DefKind::Fn | DefKind::AssocFn => {
644635
let sig = tcx.liberate_late_bound_regions(
645636
def_id.to_def_id(),
646637
tcx.fn_sig(def_id).instantiate_identity(),
647638
);
648-
(sig.inputs().to_vec(), sig.output(), None, None)
649-
}
650-
DefKind::Closure if coroutine_kind.is_some() => {
651-
let coroutine_ty = tcx.type_of(def_id).instantiate_identity();
652-
let ty::Coroutine(_, args) = coroutine_ty.kind() else {
653-
bug!("expected type of coroutine-like closure to be a coroutine")
654-
};
655-
let args = args.as_coroutine();
656-
let resume_ty = args.resume_ty();
657-
let yield_ty = args.yield_ty();
658-
let return_ty = args.return_ty();
659-
(vec![coroutine_ty, args.resume_ty()], return_ty, Some(resume_ty), Some(yield_ty))
639+
(sig.inputs().to_vec(), sig.output(), None)
660640
}
661641
DefKind::Closure => {
662642
let closure_ty = tcx.type_of(def_id).instantiate_identity();
663-
let ty::Closure(_, args) = closure_ty.kind() else {
664-
bug!("expected type of closure to be a closure")
665-
};
666-
let args = args.as_closure();
667-
let sig = tcx.liberate_late_bound_regions(def_id.to_def_id(), args.sig());
668-
let self_ty = match args.kind() {
669-
ty::ClosureKind::Fn => Ty::new_imm_ref(tcx, tcx.lifetimes.re_erased, closure_ty),
670-
ty::ClosureKind::FnMut => Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, closure_ty),
671-
ty::ClosureKind::FnOnce => closure_ty,
672-
};
673-
([self_ty].into_iter().chain(sig.inputs().to_vec()).collect(), sig.output(), None, None)
643+
match closure_ty.kind() {
644+
ty::Closure(_, args) => {
645+
let args = args.as_closure();
646+
let sig = tcx.liberate_late_bound_regions(def_id.to_def_id(), args.sig());
647+
let self_ty = match args.kind() {
648+
ty::ClosureKind::Fn => {
649+
Ty::new_imm_ref(tcx, tcx.lifetimes.re_erased, closure_ty)
650+
}
651+
ty::ClosureKind::FnMut => {
652+
Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, closure_ty)
653+
}
654+
ty::ClosureKind::FnOnce => closure_ty,
655+
};
656+
(
657+
[self_ty].into_iter().chain(sig.inputs().to_vec()).collect(),
658+
sig.output(),
659+
None,
660+
)
661+
}
662+
ty::Coroutine(_, args) => {
663+
let args = args.as_coroutine();
664+
let resume_ty = args.resume_ty();
665+
let yield_ty = args.yield_ty();
666+
let return_ty = args.return_ty();
667+
(
668+
vec![closure_ty, args.resume_ty()],
669+
return_ty,
670+
Some(Box::new(CoroutineInfo::initial(
671+
tcx.coroutine_kind(def_id).unwrap(),
672+
yield_ty,
673+
resume_ty,
674+
))),
675+
)
676+
}
677+
_ => {
678+
span_bug!(span, "expected type of closure body to be a closure or coroutine");
679+
}
680+
}
674681
}
675-
dk => bug!("{:?} is not a body: {:?}", def_id, dk),
682+
dk => span_bug!(span, "{:?} is not a body: {:?}", def_id, dk),
676683
};
677684

678685
let source_info = SourceInfo { span, scope: OUTERMOST_SOURCE_SCOPE };
@@ -696,7 +703,7 @@ fn construct_error(tcx: TyCtxt<'_>, def_id: LocalDefId, guar: ErrorGuaranteed) -
696703

697704
cfg.terminate(START_BLOCK, source_info, TerminatorKind::Unreachable);
698705

699-
let mut body = Body::new(
706+
Body::new(
700707
MirSource::item(def_id.to_def_id()),
701708
cfg.basic_blocks,
702709
source_scopes,
@@ -705,16 +712,9 @@ fn construct_error(tcx: TyCtxt<'_>, def_id: LocalDefId, guar: ErrorGuaranteed) -
705712
inputs.len(),
706713
vec![],
707714
span,
708-
coroutine_kind,
715+
coroutine,
709716
Some(guar),
710-
);
711-
712-
body.coroutine.as_mut().map(|gen| {
713-
gen.yield_ty = yield_ty;
714-
gen.resume_ty = resume_ty;
715-
});
716-
717-
body
717+
)
718718
}
719719

720720
impl<'a, 'tcx> Builder<'a, 'tcx> {
@@ -728,7 +728,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
728728
safety: Safety,
729729
return_ty: Ty<'tcx>,
730730
return_span: Span,
731-
coroutine_kind: Option<CoroutineKind>,
731+
coroutine: Option<Box<CoroutineInfo<'tcx>>>,
732732
) -> Builder<'a, 'tcx> {
733733
let tcx = infcx.tcx;
734734
let attrs = tcx.hir().attrs(hir_id);
@@ -759,7 +759,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
759759
cfg: CFG { basic_blocks: IndexVec::new() },
760760
fn_span: span,
761761
arg_count,
762-
coroutine_kind,
762+
coroutine,
763763
scopes: scope::Scopes::new(),
764764
block_context: BlockContext::new(),
765765
source_scopes: IndexVec::new(),
@@ -803,7 +803,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
803803
self.arg_count,
804804
self.var_debug_info,
805805
self.fn_span,
806-
self.coroutine_kind,
806+
self.coroutine,
807807
None,
808808
)
809809
}

compiler/rustc_mir_build/src/build/scope.rs

+5-5
Original file line numberDiff line numberDiff line change
@@ -706,7 +706,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
706706
// If we are emitting a `drop` statement, we need to have the cached
707707
// diverge cleanup pads ready in case that drop panics.
708708
let needs_cleanup = self.scopes.scopes.last().is_some_and(|scope| scope.needs_cleanup());
709-
let is_coroutine = self.coroutine_kind.is_some();
709+
let is_coroutine = self.coroutine.is_some();
710710
let unwind_to = if needs_cleanup { self.diverge_cleanup() } else { DropIdx::MAX };
711711

712712
let scope = self.scopes.scopes.last().expect("leave_top_scope called with no scopes");
@@ -960,7 +960,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
960960
// path, we only need to invalidate the cache for drops that happen on
961961
// the unwind or coroutine drop paths. This means that for
962962
// non-coroutines we don't need to invalidate caches for `DropKind::Storage`.
963-
let invalidate_caches = needs_drop || self.coroutine_kind.is_some();
963+
let invalidate_caches = needs_drop || self.coroutine.is_some();
964964
for scope in self.scopes.scopes.iter_mut().rev() {
965965
if invalidate_caches {
966966
scope.invalidate_cache();
@@ -1073,7 +1073,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
10731073
return cached_drop;
10741074
}
10751075

1076-
let is_coroutine = self.coroutine_kind.is_some();
1076+
let is_coroutine = self.coroutine.is_some();
10771077
for scope in &mut self.scopes.scopes[uncached_scope..=target] {
10781078
for drop in &scope.drops {
10791079
if is_coroutine || drop.kind == DropKind::Value {
@@ -1318,7 +1318,7 @@ impl<'a, 'tcx: 'a> Builder<'a, 'tcx> {
13181318
blocks[ROOT_NODE] = continue_block;
13191319

13201320
drops.build_mir::<ExitScopes>(&mut self.cfg, &mut blocks);
1321-
let is_coroutine = self.coroutine_kind.is_some();
1321+
let is_coroutine = self.coroutine.is_some();
13221322

13231323
// Link the exit drop tree to unwind drop tree.
13241324
if drops.drops.iter().any(|(drop, _)| drop.kind == DropKind::Value) {
@@ -1355,7 +1355,7 @@ impl<'a, 'tcx: 'a> Builder<'a, 'tcx> {
13551355

13561356
/// Build the unwind and coroutine drop trees.
13571357
pub(crate) fn build_drop_trees(&mut self) {
1358-
if self.coroutine_kind.is_some() {
1358+
if self.coroutine.is_some() {
13591359
self.build_coroutine_drop_trees();
13601360
} else {
13611361
Self::build_unwind_tree(

0 commit comments

Comments
 (0)