Skip to content

Commit 05116c5

Browse files
Only split by-ref/by-move futures for async closures
1 parent e760daa commit 05116c5

33 files changed

+119
-432
lines changed

compiler/rustc_borrowck/src/type_check/input_output.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
8787
self.tcx(),
8888
ty::CoroutineArgsParts {
8989
parent_args: args.parent_args(),
90-
kind_ty: Ty::from_closure_kind(self.tcx(), args.kind()),
90+
kind_ty: Ty::from_coroutine_closure_kind(self.tcx(), args.kind()),
9191
return_ty: user_provided_sig.output(),
9292
tupled_upvars_ty,
9393
// For async closures, none of these can be annotated, so just fill

compiler/rustc_hir_typeck/src/callee.rs

+9-5
Original file line numberDiff line numberDiff line change
@@ -184,16 +184,20 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
184184
kind: TypeVariableOriginKind::TypeInference,
185185
span: callee_expr.span,
186186
});
187+
// We may actually receive a coroutine back whose kind is different
188+
// from the closure that this dispatched from. This is because when
189+
// we have no captures, we automatically implement `FnOnce`. This
190+
// impl forces the closure kind to `FnOnce` i.e. `u8`.
191+
let kind_ty = self.next_ty_var(TypeVariableOrigin {
192+
kind: TypeVariableOriginKind::TypeInference,
193+
span: callee_expr.span,
194+
});
187195
let call_sig = self.tcx.mk_fn_sig(
188196
[coroutine_closure_sig.tupled_inputs_ty],
189197
coroutine_closure_sig.to_coroutine(
190198
self.tcx,
191199
closure_args.parent_args(),
192-
// Inherit the kind ty of the closure, since we're calling this
193-
// coroutine with the most relaxed `AsyncFn*` trait that we can.
194-
// We don't necessarily need to do this here, but it saves us
195-
// computing one more infer var that will get constrained later.
196-
closure_args.kind_ty(),
200+
kind_ty,
197201
self.tcx.coroutine_for_closure(def_id),
198202
tupled_upvars_ty,
199203
),

compiler/rustc_hir_typeck/src/closure.rs

+5-1
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,10 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
262262
},
263263
);
264264

265+
let coroutine_kind_ty = self.next_ty_var(TypeVariableOrigin {
266+
kind: TypeVariableOriginKind::ClosureSynthetic,
267+
span: expr_span,
268+
});
265269
let coroutine_upvars_ty = self.next_ty_var(TypeVariableOrigin {
266270
kind: TypeVariableOriginKind::ClosureSynthetic,
267271
span: expr_span,
@@ -279,7 +283,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
279283
sig.to_coroutine(
280284
tcx,
281285
parent_args,
282-
closure_kind_ty,
286+
coroutine_kind_ty,
283287
tcx.coroutine_for_closure(expr_def_id),
284288
coroutine_upvars_ty,
285289
)

compiler/rustc_hir_typeck/src/upvar.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
410410
self.demand_eqtype(
411411
span,
412412
coroutine_args.as_coroutine().kind_ty(),
413-
Ty::from_closure_kind(self.tcx, closure_kind),
413+
Ty::from_coroutine_closure_kind(self.tcx, closure_kind),
414414
);
415415
}
416416

compiler/rustc_middle/src/mir/mod.rs

-12
Original file line numberDiff line numberDiff line change
@@ -278,13 +278,6 @@ pub struct CoroutineInfo<'tcx> {
278278
/// using `run_passes`.
279279
pub by_move_body: Option<Body<'tcx>>,
280280

281-
/// The body of the coroutine, modified to take its upvars by mutable ref rather than by
282-
/// immutable ref.
283-
///
284-
/// FIXME(async_closures): This is literally the same body as the parent body. Find a better
285-
/// way to represent the by-mut signature (or cap the closure-kind of the coroutine).
286-
pub by_mut_body: Option<Body<'tcx>>,
287-
288281
/// The layout of a coroutine. This field is populated after the state transform pass.
289282
pub coroutine_layout: Option<CoroutineLayout<'tcx>>,
290283

@@ -305,7 +298,6 @@ impl<'tcx> CoroutineInfo<'tcx> {
305298
yield_ty: Some(yield_ty),
306299
resume_ty: Some(resume_ty),
307300
by_move_body: None,
308-
by_mut_body: None,
309301
coroutine_drop: None,
310302
coroutine_layout: None,
311303
}
@@ -628,10 +620,6 @@ impl<'tcx> Body<'tcx> {
628620
self.coroutine.as_ref()?.by_move_body.as_ref()
629621
}
630622

631-
pub fn coroutine_by_mut_body(&self) -> Option<&Body<'tcx>> {
632-
self.coroutine.as_ref()?.by_mut_body.as_ref()
633-
}
634-
635623
#[inline]
636624
pub fn coroutine_kind(&self) -> Option<CoroutineKind> {
637625
self.coroutine.as_ref().map(|coroutine| coroutine.coroutine_kind)

compiler/rustc_middle/src/mir/visit.rs

+4-2
Original file line numberDiff line numberDiff line change
@@ -345,8 +345,10 @@ macro_rules! make_mir_visitor {
345345
ty::InstanceDef::Virtual(_def_id, _) |
346346
ty::InstanceDef::ThreadLocalShim(_def_id) |
347347
ty::InstanceDef::ClosureOnceShim { call_once: _def_id, track_caller: _ } |
348-
ty::InstanceDef::ConstructCoroutineInClosureShim { coroutine_closure_def_id: _def_id, target_kind: _ } |
349-
ty::InstanceDef::CoroutineKindShim { coroutine_def_id: _def_id, target_kind: _ } |
348+
ty::InstanceDef::ConstructCoroutineInClosureShim {
349+
coroutine_closure_def_id: _def_id,
350+
} |
351+
ty::InstanceDef::CoroutineKindShim { coroutine_def_id: _def_id } |
350352
ty::InstanceDef::DropGlue(_def_id, None) => {}
351353

352354
ty::InstanceDef::FnPtrShim(_def_id, ty) |

compiler/rustc_middle/src/ty/instance.rs

+5-13
Original file line numberDiff line numberDiff line change
@@ -90,24 +90,20 @@ pub enum InstanceDef<'tcx> {
9090
/// and dispatch to the `FnMut::call_mut` instance for the closure.
9191
ClosureOnceShim { call_once: DefId, track_caller: bool },
9292

93-
/// `<[FnMut/Fn coroutine-closure] as FnOnce>::call_once` or
94-
/// `<[Fn coroutine-closure] as FnMut>::call_mut`.
93+
/// `<[FnMut/Fn coroutine-closure] as FnOnce>::call_once`
9594
///
9695
/// The body generated here differs significantly from the `ClosureOnceShim`,
9796
/// since we need to generate a distinct coroutine type that will move the
9897
/// closure's upvars *out* of the closure.
99-
ConstructCoroutineInClosureShim {
100-
coroutine_closure_def_id: DefId,
101-
target_kind: ty::ClosureKind,
102-
},
98+
ConstructCoroutineInClosureShim { coroutine_closure_def_id: DefId },
10399

104100
/// `<[coroutine] as Future>::poll`, but for coroutines produced when `AsyncFnOnce`
105101
/// is called on a coroutine-closure whose closure kind greater than `FnOnce`, or
106102
/// similarly for `AsyncFnMut`.
107103
///
108104
/// This will select the body that is produced by the `ByMoveBody` transform, and thus
109105
/// take and use all of its upvars by-move rather than by-ref.
110-
CoroutineKindShim { coroutine_def_id: DefId, target_kind: ty::ClosureKind },
106+
CoroutineKindShim { coroutine_def_id: DefId },
111107

112108
/// Compiler-generated accessor for thread locals which returns a reference to the thread local
113109
/// the `DefId` defines. This is used to export thread locals from dylibs on platforms lacking
@@ -192,9 +188,8 @@ impl<'tcx> InstanceDef<'tcx> {
192188
| InstanceDef::ClosureOnceShim { call_once: def_id, track_caller: _ }
193189
| ty::InstanceDef::ConstructCoroutineInClosureShim {
194190
coroutine_closure_def_id: def_id,
195-
target_kind: _,
196191
}
197-
| ty::InstanceDef::CoroutineKindShim { coroutine_def_id: def_id, target_kind: _ }
192+
| ty::InstanceDef::CoroutineKindShim { coroutine_def_id: def_id }
198193
| InstanceDef::DropGlue(def_id, _)
199194
| InstanceDef::CloneShim(def_id, _)
200195
| InstanceDef::FnPtrAddrShim(def_id, _) => def_id,
@@ -651,10 +646,7 @@ impl<'tcx> Instance<'tcx> {
651646
Some(Instance { def: ty::InstanceDef::Item(coroutine_def_id), args })
652647
} else {
653648
Some(Instance {
654-
def: ty::InstanceDef::CoroutineKindShim {
655-
coroutine_def_id,
656-
target_kind: args.as_coroutine().kind_ty().to_opt_closure_kind().unwrap(),
657-
},
649+
def: ty::InstanceDef::CoroutineKindShim { coroutine_def_id },
658650
args,
659651
})
660652
}

compiler/rustc_middle/src/ty/sty.rs

+16-1
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,7 @@ impl<'tcx> CoroutineClosureSignature<'tcx> {
483483
self.to_coroutine(
484484
tcx,
485485
parent_args,
486-
Ty::from_closure_kind(tcx, goal_kind),
486+
Ty::from_coroutine_closure_kind(tcx, goal_kind),
487487
coroutine_def_id,
488488
tupled_upvars_ty,
489489
)
@@ -2456,6 +2456,21 @@ impl<'tcx> Ty<'tcx> {
24562456
}
24572457
}
24582458

2459+
/// Like [`Ty::to_opt_closure_kind`], but it caps the "maximum" closure kind
2460+
/// to `FnMut`. This is because although we have three capability states,
2461+
/// `AsyncFn`/`AsyncFnMut`/`AsyncFnOnce`, we only need to distinguish two coroutine
2462+
/// bodies: by-ref and by-value.
2463+
///
2464+
/// This method should be used when constructing a `Coroutine` out of a
2465+
/// `CoroutineClosure`, when the `Coroutine`'s `kind` field is being populated
2466+
/// directly from the `CoroutineClosure`'s `kind`.
2467+
pub fn from_coroutine_closure_kind(tcx: TyCtxt<'tcx>, kind: ty::ClosureKind) -> Ty<'tcx> {
2468+
match kind {
2469+
ty::ClosureKind::Fn | ty::ClosureKind::FnMut => tcx.types.i16,
2470+
ty::ClosureKind::FnOnce => tcx.types.i32,
2471+
}
2472+
}
2473+
24592474
/// Fast path helper for testing if a type is `Sized`.
24602475
///
24612476
/// Returning true means the type is known to be sized. Returning

compiler/rustc_mir_transform/src/coroutine/by_move_body.rs

-35
Original file line numberDiff line numberDiff line change
@@ -67,45 +67,10 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody {
6767
by_move_body.source = mir::MirSource {
6868
instance: InstanceDef::CoroutineKindShim {
6969
coroutine_def_id: coroutine_def_id.to_def_id(),
70-
target_kind: ty::ClosureKind::FnOnce,
7170
},
7271
promoted: None,
7372
};
7473
body.coroutine.as_mut().unwrap().by_move_body = Some(by_move_body);
75-
76-
// If this is coming from an `AsyncFn` coroutine-closure, we must also create a by-mut body.
77-
// This is actually just a copy of the by-ref body, but with a different self type.
78-
// FIXME(async_closures): We could probably unify this with the by-ref body somehow.
79-
if coroutine_kind == ty::ClosureKind::Fn {
80-
let by_mut_coroutine_ty = Ty::new_coroutine(
81-
tcx,
82-
coroutine_def_id.to_def_id(),
83-
ty::CoroutineArgs::new(
84-
tcx,
85-
ty::CoroutineArgsParts {
86-
parent_args: args.as_coroutine().parent_args(),
87-
kind_ty: Ty::from_closure_kind(tcx, ty::ClosureKind::FnMut),
88-
resume_ty: args.as_coroutine().resume_ty(),
89-
yield_ty: args.as_coroutine().yield_ty(),
90-
return_ty: args.as_coroutine().return_ty(),
91-
witness: args.as_coroutine().witness(),
92-
tupled_upvars_ty: args.as_coroutine().tupled_upvars_ty(),
93-
},
94-
)
95-
.args,
96-
);
97-
let mut by_mut_body = body.clone();
98-
by_mut_body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty = by_mut_coroutine_ty;
99-
dump_mir(tcx, false, "coroutine_by_mut", &0, &by_mut_body, |_, _| Ok(()));
100-
by_mut_body.source = mir::MirSource {
101-
instance: InstanceDef::CoroutineKindShim {
102-
coroutine_def_id: coroutine_def_id.to_def_id(),
103-
target_kind: ty::ClosureKind::FnMut,
104-
},
105-
promoted: None,
106-
};
107-
body.coroutine.as_mut().unwrap().by_mut_body = Some(by_mut_body);
108-
}
10974
}
11075
}
11176

compiler/rustc_mir_transform/src/pass_manager.rs

-3
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,6 @@ fn run_passes_inner<'tcx>(
186186
if let Some(by_move_body) = coroutine.by_move_body.as_mut() {
187187
run_passes_inner(tcx, by_move_body, passes, phase_change, validate_each);
188188
}
189-
if let Some(by_mut_body) = coroutine.by_mut_body.as_mut() {
190-
run_passes_inner(tcx, by_mut_body, passes, phase_change, validate_each);
191-
}
192189
}
193190
}
194191

compiler/rustc_mir_transform/src/shim.rs

+12-86
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ use rustc_hir::def_id::DefId;
33
use rustc_hir::lang_items::LangItem;
44
use rustc_middle::mir::*;
55
use rustc_middle::query::Providers;
6+
use rustc_middle::ty::GenericArgs;
67
use rustc_middle::ty::{self, CoroutineArgs, EarlyBinder, Ty, TyCtxt};
7-
use rustc_middle::ty::{GenericArgs, CAPTURE_STRUCT_LOCAL};
88
use rustc_target::abi::{FieldIdx, VariantIdx, FIRST_VARIANT};
99

1010
use rustc_index::{Idx, IndexVec};
@@ -70,39 +70,13 @@ fn make_shim<'tcx>(tcx: TyCtxt<'tcx>, instance: ty::InstanceDef<'tcx>) -> Body<'
7070
build_call_shim(tcx, instance, Some(Adjustment::RefMut), CallKind::Direct(call_mut))
7171
}
7272

73-
ty::InstanceDef::ConstructCoroutineInClosureShim {
74-
coroutine_closure_def_id,
75-
target_kind,
76-
} => match target_kind {
77-
ty::ClosureKind::Fn => unreachable!("shouldn't be building shim for Fn"),
78-
ty::ClosureKind::FnMut => {
79-
// No need to optimize the body, it has already been optimized
80-
// since we steal it from the `AsyncFn::call` body and just fix
81-
// the return type.
82-
return build_construct_coroutine_by_mut_shim(tcx, coroutine_closure_def_id);
83-
}
84-
ty::ClosureKind::FnOnce => {
85-
build_construct_coroutine_by_move_shim(tcx, coroutine_closure_def_id)
86-
}
87-
},
73+
ty::InstanceDef::ConstructCoroutineInClosureShim { coroutine_closure_def_id } => {
74+
build_construct_coroutine_by_move_shim(tcx, coroutine_closure_def_id)
75+
}
8876

89-
ty::InstanceDef::CoroutineKindShim { coroutine_def_id, target_kind } => match target_kind {
90-
ty::ClosureKind::Fn => unreachable!(),
91-
ty::ClosureKind::FnMut => {
92-
return tcx
93-
.optimized_mir(coroutine_def_id)
94-
.coroutine_by_mut_body()
95-
.unwrap()
96-
.clone();
97-
}
98-
ty::ClosureKind::FnOnce => {
99-
return tcx
100-
.optimized_mir(coroutine_def_id)
101-
.coroutine_by_move_body()
102-
.unwrap()
103-
.clone();
104-
}
105-
},
77+
ty::InstanceDef::CoroutineKindShim { coroutine_def_id } => {
78+
return tcx.optimized_mir(coroutine_def_id).coroutine_by_move_body().unwrap().clone();
79+
}
10680

10781
ty::InstanceDef::DropGlue(def_id, ty) => {
10882
// FIXME(#91576): Drop shims for coroutines aren't subject to the MIR passes at the end
@@ -123,21 +97,11 @@ fn make_shim<'tcx>(tcx: TyCtxt<'tcx>, instance: ty::InstanceDef<'tcx>) -> Body<'
12397
let body = if id_args.as_coroutine().kind_ty() == args.as_coroutine().kind_ty() {
12498
coroutine_body.coroutine_drop().unwrap()
12599
} else {
126-
match args.as_coroutine().kind_ty().to_opt_closure_kind().unwrap() {
127-
ty::ClosureKind::Fn => {
128-
unreachable!()
129-
}
130-
ty::ClosureKind::FnMut => coroutine_body
131-
.coroutine_by_mut_body()
132-
.unwrap()
133-
.coroutine_drop()
134-
.unwrap(),
135-
ty::ClosureKind::FnOnce => coroutine_body
136-
.coroutine_by_move_body()
137-
.unwrap()
138-
.coroutine_drop()
139-
.unwrap(),
140-
}
100+
assert_eq!(
101+
args.as_coroutine().kind_ty().to_opt_closure_kind().unwrap(),
102+
ty::ClosureKind::FnOnce
103+
);
104+
coroutine_body.coroutine_by_move_body().unwrap().coroutine_drop().unwrap()
141105
};
142106

143107
let mut body = EarlyBinder::bind(body.clone()).instantiate(tcx, args);
@@ -1112,7 +1076,6 @@ fn build_construct_coroutine_by_move_shim<'tcx>(
11121076

11131077
let source = MirSource::from_instance(ty::InstanceDef::ConstructCoroutineInClosureShim {
11141078
coroutine_closure_def_id,
1115-
target_kind: ty::ClosureKind::FnOnce,
11161079
});
11171080

11181081
let body =
@@ -1121,40 +1084,3 @@ fn build_construct_coroutine_by_move_shim<'tcx>(
11211084

11221085
body
11231086
}
1124-
1125-
fn build_construct_coroutine_by_mut_shim<'tcx>(
1126-
tcx: TyCtxt<'tcx>,
1127-
coroutine_closure_def_id: DefId,
1128-
) -> Body<'tcx> {
1129-
let mut body = tcx.optimized_mir(coroutine_closure_def_id).clone();
1130-
let coroutine_closure_ty = tcx.type_of(coroutine_closure_def_id).instantiate_identity();
1131-
let ty::CoroutineClosure(_, args) = *coroutine_closure_ty.kind() else {
1132-
bug!();
1133-
};
1134-
let args = args.as_coroutine_closure();
1135-
1136-
body.local_decls[RETURN_PLACE].ty =
1137-
tcx.instantiate_bound_regions_with_erased(args.coroutine_closure_sig().map_bound(|sig| {
1138-
sig.to_coroutine_given_kind_and_upvars(
1139-
tcx,
1140-
args.parent_args(),
1141-
tcx.coroutine_for_closure(coroutine_closure_def_id),
1142-
ty::ClosureKind::FnMut,
1143-
tcx.lifetimes.re_erased,
1144-
args.tupled_upvars_ty(),
1145-
args.coroutine_captures_by_ref_ty(),
1146-
)
1147-
}));
1148-
body.local_decls[CAPTURE_STRUCT_LOCAL].ty =
1149-
Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, coroutine_closure_ty);
1150-
1151-
body.source = MirSource::from_instance(ty::InstanceDef::ConstructCoroutineInClosureShim {
1152-
coroutine_closure_def_id,
1153-
target_kind: ty::ClosureKind::FnMut,
1154-
});
1155-
1156-
body.pass_count = 0;
1157-
dump_mir(tcx, false, "coroutine_closure_by_mut", &0, &body, |_, _| Ok(()));
1158-
1159-
body
1160-
}

compiler/rustc_span/src/symbol.rs

+1-2
Original file line numberDiff line numberDiff line change
@@ -166,9 +166,8 @@ symbols! {
166166
Break,
167167
C,
168168
CStr,
169-
CallFuture,
170-
CallMutFuture,
171169
CallOnceFuture,
170+
CallRefFuture,
172171
Capture,
173172
Center,
174173
Cleanup,

0 commit comments

Comments
 (0)