Skip to content

Commit 5df00fc

Browse files
Fix ABI for FnMut/Fn impls for async closures
1 parent 0451ff4 commit 5df00fc

12 files changed

+81
-18
lines changed

compiler/rustc_middle/src/mir/visit.rs

+1
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,7 @@ macro_rules! make_mir_visitor {
347347
ty::InstanceDef::ClosureOnceShim { call_once: _def_id, track_caller: _ } |
348348
ty::InstanceDef::ConstructCoroutineInClosureShim {
349349
coroutine_closure_def_id: _def_id,
350+
receiver_by_ref: _,
350351
} |
351352
ty::InstanceDef::CoroutineKindShim { coroutine_def_id: _def_id } |
352353
ty::InstanceDef::DropGlue(_def_id, None) => {}

compiler/rustc_middle/src/ty/instance.rs

+10-1
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,15 @@ pub enum InstanceDef<'tcx> {
9595
/// The body generated here differs significantly from the `ClosureOnceShim`,
9696
/// since we need to generate a distinct coroutine type that will move the
9797
/// closure's upvars *out* of the closure.
98-
ConstructCoroutineInClosureShim { coroutine_closure_def_id: DefId },
98+
ConstructCoroutineInClosureShim {
99+
coroutine_closure_def_id: DefId,
100+
// Whether the generated MIR body takes the coroutine by-ref. This is
101+
// because the signature of `<{async fn} as FnMut>::call_mut` is:
102+
// `fn(&mut self, args: A) -> <Self as FnOnce>::Output`, that is to say
103+
// that it returns the `FnOnce`-flavored coroutine but takes the closure
104+
// by ref (and similarly for `Fn::call`).
105+
receiver_by_ref: bool,
106+
},
99107

100108
/// `<[coroutine] as Future>::poll`, but for coroutines produced when `AsyncFnOnce`
101109
/// is called on a coroutine-closure whose closure kind greater than `FnOnce`, or
@@ -188,6 +196,7 @@ impl<'tcx> InstanceDef<'tcx> {
188196
| InstanceDef::ClosureOnceShim { call_once: def_id, track_caller: _ }
189197
| ty::InstanceDef::ConstructCoroutineInClosureShim {
190198
coroutine_closure_def_id: def_id,
199+
receiver_by_ref: _,
191200
}
192201
| ty::InstanceDef::CoroutineKindShim { coroutine_def_id: def_id }
193202
| InstanceDef::DropGlue(def_id, _)

compiler/rustc_mir_transform/src/shim.rs

+19-5
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,10 @@ fn make_shim<'tcx>(tcx: TyCtxt<'tcx>, instance: ty::InstanceDef<'tcx>) -> Body<'
7070
build_call_shim(tcx, instance, Some(Adjustment::RefMut), CallKind::Direct(call_mut))
7171
}
7272

73-
ty::InstanceDef::ConstructCoroutineInClosureShim { coroutine_closure_def_id } => {
74-
build_construct_coroutine_by_move_shim(tcx, coroutine_closure_def_id)
75-
}
73+
ty::InstanceDef::ConstructCoroutineInClosureShim {
74+
coroutine_closure_def_id,
75+
receiver_by_ref,
76+
} => build_construct_coroutine_by_move_shim(tcx, coroutine_closure_def_id, receiver_by_ref),
7677

7778
ty::InstanceDef::CoroutineKindShim { coroutine_def_id } => {
7879
return tcx.optimized_mir(coroutine_def_id).coroutine_by_move_body().unwrap().clone();
@@ -1015,12 +1016,17 @@ fn build_fn_ptr_addr_shim<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId, self_ty: Ty<'t
10151016
fn build_construct_coroutine_by_move_shim<'tcx>(
10161017
tcx: TyCtxt<'tcx>,
10171018
coroutine_closure_def_id: DefId,
1019+
receiver_by_ref: bool,
10181020
) -> Body<'tcx> {
1019-
let self_ty = tcx.type_of(coroutine_closure_def_id).instantiate_identity();
1021+
let mut self_ty = tcx.type_of(coroutine_closure_def_id).instantiate_identity();
10201022
let ty::CoroutineClosure(_, args) = *self_ty.kind() else {
10211023
bug!();
10221024
};
10231025

1026+
if receiver_by_ref {
1027+
self_ty = Ty::new_mut_ptr(tcx, self_ty);
1028+
}
1029+
10241030
let poly_sig = args.as_coroutine_closure().coroutine_closure_sig().map_bound(|sig| {
10251031
tcx.mk_fn_sig(
10261032
[self_ty].into_iter().chain(sig.tupled_inputs_ty.tuple_fields()),
@@ -1076,11 +1082,19 @@ fn build_construct_coroutine_by_move_shim<'tcx>(
10761082

10771083
let source = MirSource::from_instance(ty::InstanceDef::ConstructCoroutineInClosureShim {
10781084
coroutine_closure_def_id,
1085+
receiver_by_ref,
10791086
});
10801087

10811088
let body =
10821089
new_body(source, IndexVec::from_elem_n(start_block, 1), locals, sig.inputs().len(), span);
1083-
dump_mir(tcx, false, "coroutine_closure_by_move", &0, &body, |_, _| Ok(()));
1090+
dump_mir(
1091+
tcx,
1092+
false,
1093+
if receiver_by_ref { "coroutine_closure_by_ref" } else { "coroutine_closure_by_move" },
1094+
&0,
1095+
&body,
1096+
|_, _| Ok(()),
1097+
);
10841098

10851099
body
10861100
}

compiler/rustc_ty_utils/src/abi.rs

+11-4
Original file line numberDiff line numberDiff line change
@@ -118,11 +118,18 @@ fn fn_sig_for_fn_abi<'tcx>(
118118
// a separate def-id for these bodies.
119119
let mut coroutine_kind = args.as_coroutine_closure().kind();
120120

121-
if let InstanceDef::ConstructCoroutineInClosureShim { .. } = instance.def {
122-
coroutine_kind = ty::ClosureKind::FnOnce;
123-
}
121+
let env_ty =
122+
if let InstanceDef::ConstructCoroutineInClosureShim { receiver_by_ref, .. } =
123+
instance.def
124+
{
125+
coroutine_kind = ty::ClosureKind::FnOnce;
124126

125-
let env_ty = tcx.closure_env_ty(coroutine_ty, coroutine_kind, env_region);
127+
// Implementations of `FnMut` and `Fn` for coroutine-closures
128+
// still take their receiver by ref.
129+
if receiver_by_ref { Ty::new_mut_ptr(tcx, coroutine_ty) } else { coroutine_ty }
130+
} else {
131+
tcx.closure_env_ty(coroutine_ty, coroutine_kind, env_region)
132+
};
126133

127134
let sig = sig.skip_binder();
128135
ty::Binder::bind_with_vars(

compiler/rustc_ty_utils/src/instance.rs

+2
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,7 @@ fn resolve_associated_item<'tcx>(
283283
Some(Instance {
284284
def: ty::InstanceDef::ConstructCoroutineInClosureShim {
285285
coroutine_closure_def_id,
286+
receiver_by_ref: target_kind != ty::ClosureKind::FnOnce,
286287
},
287288
args,
288289
})
@@ -310,6 +311,7 @@ fn resolve_associated_item<'tcx>(
310311
Some(Instance {
311312
def: ty::InstanceDef::ConstructCoroutineInClosureShim {
312313
coroutine_closure_def_id,
314+
receiver_by_ref: false,
313315
},
314316
args,
315317
})

tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}-{closure#0}.coroutine_by_move.0.panic-abort.mir

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
// MIR for `main::{closure#0}::{closure#0}::{closure#0}` 0 coroutine_by_move
22

3-
fn main::{closure#0}::{closure#0}::{closure#0}(_1: {async closure body@$DIR/async_closure_shims.rs:37:53: 40:10}, _2: ResumeTy) -> ()
3+
fn main::{closure#0}::{closure#0}::{closure#0}(_1: {async closure body@$DIR/async_closure_shims.rs:42:53: 45:10}, _2: ResumeTy) -> ()
44
yields ()
55
{
66
debug _task_context => _2;

tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}-{closure#0}.coroutine_by_move.0.panic-unwind.mir

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
// MIR for `main::{closure#0}::{closure#0}::{closure#0}` 0 coroutine_by_move
22

3-
fn main::{closure#0}::{closure#0}::{closure#0}(_1: {async closure body@$DIR/async_closure_shims.rs:37:53: 40:10}, _2: ResumeTy) -> ()
3+
fn main::{closure#0}::{closure#0}::{closure#0}(_1: {async closure body@$DIR/async_closure_shims.rs:42:53: 45:10}, _2: ResumeTy) -> ()
44
yields ()
55
{
66
debug _task_context => _2;
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
// MIR for `main::{closure#0}::{closure#0}` 0 coroutine_closure_by_move
22

3-
fn main::{closure#0}::{closure#0}(_1: {async closure@$DIR/async_closure_shims.rs:37:33: 37:52}, _2: i32) -> {async closure body@$DIR/async_closure_shims.rs:37:53: 40:10} {
4-
let mut _0: {async closure body@$DIR/async_closure_shims.rs:37:53: 40:10};
3+
fn main::{closure#0}::{closure#0}(_1: {async closure@$DIR/async_closure_shims.rs:42:33: 42:52}, _2: i32) -> {async closure body@$DIR/async_closure_shims.rs:42:53: 45:10} {
4+
let mut _0: {async closure body@$DIR/async_closure_shims.rs:42:53: 45:10};
55

66
bb0: {
7-
_0 = {coroutine@$DIR/async_closure_shims.rs:37:53: 40:10 (#0)} { a: move _2, b: move (_1.0: i32) };
7+
_0 = {coroutine@$DIR/async_closure_shims.rs:42:53: 45:10 (#0)} { a: move _2, b: move (_1.0: i32) };
88
return;
99
}
1010
}
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
// MIR for `main::{closure#0}::{closure#0}` 0 coroutine_closure_by_move
22

3-
fn main::{closure#0}::{closure#0}(_1: {async closure@$DIR/async_closure_shims.rs:37:33: 37:52}, _2: i32) -> {async closure body@$DIR/async_closure_shims.rs:37:53: 40:10} {
4-
let mut _0: {async closure body@$DIR/async_closure_shims.rs:37:53: 40:10};
3+
fn main::{closure#0}::{closure#0}(_1: {async closure@$DIR/async_closure_shims.rs:42:33: 42:52}, _2: i32) -> {async closure body@$DIR/async_closure_shims.rs:42:53: 45:10} {
4+
let mut _0: {async closure body@$DIR/async_closure_shims.rs:42:53: 45:10};
55

66
bb0: {
7-
_0 = {coroutine@$DIR/async_closure_shims.rs:37:53: 40:10 (#0)} { a: move _2, b: move (_1.0: i32) };
7+
_0 = {coroutine@$DIR/async_closure_shims.rs:42:53: 45:10 (#0)} { a: move _2, b: move (_1.0: i32) };
88
return;
99
}
1010
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
// MIR for `main::{closure#0}::{closure#1}` 0 coroutine_closure_by_ref
2+
3+
fn main::{closure#0}::{closure#1}(_1: *mut {async closure@$DIR/async_closure_shims.rs:49:29: 49:48}, _2: i32) -> {async closure body@$DIR/async_closure_shims.rs:49:49: 51:10} {
4+
let mut _0: {async closure body@$DIR/async_closure_shims.rs:49:49: 51:10};
5+
6+
bb0: {
7+
_0 = {coroutine@$DIR/async_closure_shims.rs:49:49: 51:10 (#0)} { a: move _2 };
8+
return;
9+
}
10+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
// MIR for `main::{closure#0}::{closure#1}` 0 coroutine_closure_by_ref
2+
3+
fn main::{closure#0}::{closure#1}(_1: *mut {async closure@$DIR/async_closure_shims.rs:49:29: 49:48}, _2: i32) -> {async closure body@$DIR/async_closure_shims.rs:49:49: 51:10} {
4+
let mut _0: {async closure body@$DIR/async_closure_shims.rs:49:49: 51:10};
5+
6+
bb0: {
7+
_0 = {coroutine@$DIR/async_closure_shims.rs:49:49: 51:10 (#0)} { a: move _2 };
8+
return;
9+
}
10+
}

tests/mir-opt/async_closure_shims.rs

+10
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,13 @@ async fn call_once(f: impl AsyncFnOnce(i32)) {
2929
f(1).await;
3030
}
3131

32+
async fn call_normal<F: Future<Output = ()>>(f: &impl Fn(i32) -> F) {
33+
f(1).await;
34+
}
35+
3236
// EMIT_MIR async_closure_shims.main-{closure#0}-{closure#0}.coroutine_closure_by_move.0.mir
3337
// EMIT_MIR async_closure_shims.main-{closure#0}-{closure#0}-{closure#0}.coroutine_by_move.0.mir
38+
// EMIT_MIR async_closure_shims.main-{closure#0}-{closure#1}.coroutine_closure_by_ref.0.mir
3439
pub fn main() {
3540
block_on(async {
3641
let b = 2i32;
@@ -40,5 +45,10 @@ pub fn main() {
4045
};
4146
call_mut(&mut async_closure).await;
4247
call_once(async_closure).await;
48+
49+
let async_closure = async move |a: i32| {
50+
let a = &a;
51+
};
52+
call_normal(&async_closure).await;
4353
});
4454
}

0 commit comments

Comments
 (0)