Skip to content

Commit

Permalink
Fix FnMut/Fn shim for coroutine-closures that capture references
Browse files Browse the repository at this point in the history
  • Loading branch information
compiler-errors committed Jun 29, 2024
1 parent b66d708 commit edda2c7
Show file tree
Hide file tree
Showing 18 changed files with 214 additions and 51 deletions.
2 changes: 1 addition & 1 deletion compiler/rustc_codegen_ssa/src/mir/locals.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
let expected_ty = self.monomorphize(self.mir.local_decls[local].ty);
if expected_ty != op.layout.ty {
warn!(
"Unexpected initial operand type: expected {expected_ty:?}, found {:?}.\
"Unexpected initial operand type:\nexpected {expected_ty:?},\nfound {:?}.\n\
See <https://github.com/rust-lang/rust/issues/114858>.",
op.layout.ty
);
Expand Down
49 changes: 31 additions & 18 deletions compiler/rustc_mir_transform/src/shim.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
use rustc_hir as hir;
use rustc_hir::def_id::DefId;
use rustc_hir::lang_items::LangItem;
use rustc_index::{Idx, IndexVec};
use rustc_middle::mir::*;
use rustc_middle::query::Providers;
use rustc_middle::ty::GenericArgs;
use rustc_middle::ty::{self, CoroutineArgs, CoroutineArgsExt, EarlyBinder, Ty, TyCtxt};
use rustc_middle::{bug, span_bug};
use rustc_target::abi::{FieldIdx, VariantIdx, FIRST_VARIANT};

use rustc_index::{Idx, IndexVec};

use rustc_span::{source_map::Spanned, Span, DUMMY_SP};
use rustc_target::abi::{FieldIdx, VariantIdx, FIRST_VARIANT};
use rustc_target::spec::abi::Abi;

use std::assert_matches::assert_matches;
use std::fmt;
use std::iter;

Expand Down Expand Up @@ -1020,21 +1019,19 @@ fn build_construct_coroutine_by_move_shim<'tcx>(
receiver_by_ref: bool,
) -> Body<'tcx> {
let mut self_ty = tcx.type_of(coroutine_closure_def_id).instantiate_identity();
let mut self_local: Place<'tcx> = Local::from_usize(1).into();
let ty::CoroutineClosure(_, args) = *self_ty.kind() else {
bug!();
};

// We use `&mut Self` here because we only need to emit an ABI-compatible shim body,
// rather than match the signature exactly (which might take `&self` instead).
// We use `&Self` here because we only need to emit an ABI-compatible shim body,
// rather than match the signature exactly (which might take `&mut self` instead).
//
// The self type here is a coroutine-closure, not a coroutine, and we never read from
// it because it never has any captures, because this is only true in the Fn/FnMut
// implementation, not the AsyncFn/AsyncFnMut implementation, which is implemented only
// if the coroutine-closure has no captures.
// We adjust the `self_local` to be a deref since we want to copy fields out of
// a reference to the closure.
if receiver_by_ref {
// Triple-check that there's no captures here.
assert_eq!(args.as_coroutine_closure().tupled_upvars_ty(), tcx.types.unit);
self_ty = Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, self_ty);
self_local = tcx.mk_place_deref(self_local);
self_ty = Ty::new_imm_ref(tcx, tcx.lifetimes.re_erased, self_ty);
}

let poly_sig = args.as_coroutine_closure().coroutine_closure_sig().map_bound(|sig| {
Expand Down Expand Up @@ -1067,11 +1064,27 @@ fn build_construct_coroutine_by_move_shim<'tcx>(
fields.push(Operand::Move(Local::from_usize(idx + 1).into()));
}
for (idx, ty) in args.as_coroutine_closure().upvar_tys().iter().enumerate() {
fields.push(Operand::Move(tcx.mk_place_field(
Local::from_usize(1).into(),
FieldIdx::from_usize(idx),
ty,
)));
if receiver_by_ref {
// The only situation where it's possible is when we capture immuatable references,
// since those don't need to be reborrowed with the closure's env lifetime. Since
// references are always `Copy`, just emit a copy.
assert_matches!(
ty.kind(),
ty::Ref(_, _, hir::Mutability::Not),
"field should be captured by immutable ref if we have an `Fn` instance"
);
fields.push(Operand::Copy(tcx.mk_place_field(
self_local,
FieldIdx::from_usize(idx),
ty,
)));
} else {
fields.push(Operand::Move(tcx.mk_place_field(
self_local,
FieldIdx::from_usize(idx),
ty,
)));
}
}

let source_info = SourceInfo::outermost(span);
Expand Down
10 changes: 7 additions & 3 deletions compiler/rustc_symbol_mangling/src/legacy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,13 @@ pub(super) fn mangle<'tcx>(
}
// FIXME(async_closures): This shouldn't be needed when we fix
// `Instance::ty`/`Instance::def_id`.
ty::InstanceKind::ConstructCoroutineInClosureShim { .. }
| ty::InstanceKind::CoroutineKindShim { .. } => {
printer.write_str("{{fn-once-shim}}").unwrap();
ty::InstanceKind::ConstructCoroutineInClosureShim { receiver_by_ref, .. } => {
printer
.write_str(if receiver_by_ref { "{{by-move-shim}}" } else { "{{by-ref-shim}}" })
.unwrap();
}
ty::InstanceKind::CoroutineKindShim { .. } => {
printer.write_str("{{by-move-body-shim}}").unwrap();
}
_ => {}
}
Expand Down
11 changes: 9 additions & 2 deletions compiler/rustc_symbol_mangling/src/v0.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,15 @@ pub(super) fn mangle<'tcx>(
ty::InstanceKind::ReifyShim(_, Some(ReifyReason::FnPtr)) => Some("reify_fnptr"),
ty::InstanceKind::ReifyShim(_, Some(ReifyReason::Vtable)) => Some("reify_vtable"),

ty::InstanceKind::ConstructCoroutineInClosureShim { .. }
| ty::InstanceKind::CoroutineKindShim { .. } => Some("fn_once"),
// FIXME(async_closures): This shouldn't be needed when we fix
// `Instance::ty`/`Instance::def_id`.
ty::InstanceKind::ConstructCoroutineInClosureShim { receiver_by_ref: true, .. } => {
Some("by_move")
}
ty::InstanceKind::ConstructCoroutineInClosureShim { receiver_by_ref: false, .. } => {
Some("by_ref")
}
ty::InstanceKind::CoroutineKindShim { .. } => Some("by_move_body"),

_ => None,
};
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_ty_utils/src/abi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,9 @@ fn fn_sig_for_fn_abi<'tcx>(
coroutine_kind = ty::ClosureKind::FnOnce;

// Implementations of `FnMut` and `Fn` for coroutine-closures
// still take their receiver by (mut) ref.
// still take their receiver by ref.
if receiver_by_ref {
Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, coroutine_ty)
Ty::new_imm_ref(tcx, tcx.lifetimes.re_erased, coroutine_ty)
} else {
coroutine_ty
}
Expand Down
21 changes: 14 additions & 7 deletions src/tools/miri/tests/pass/async-closure.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#![feature(async_closure, noop_waker, async_fn_traits)]
#![allow(unused)]

use std::future::Future;
use std::ops::{AsyncFnMut, AsyncFnOnce};
use std::ops::{AsyncFn, AsyncFnMut, AsyncFnOnce};
use std::pin::pin;
use std::task::*;

Expand All @@ -17,6 +18,10 @@ pub fn block_on<T>(fut: impl Future<Output = T>) -> T {
}
}

async fn call(f: &mut impl AsyncFn(i32)) {
f(0).await;
}

async fn call_mut(f: &mut impl AsyncFnMut(i32)) {
f(0).await;
}
Expand All @@ -26,10 +31,10 @@ async fn call_once(f: impl AsyncFnOnce(i32)) {
}

async fn call_normal<F: Future<Output = ()>>(f: &impl Fn(i32) -> F) {
f(0).await;
f(1).await;
}

async fn call_normal_once<F: Future<Output = ()>>(f: impl FnOnce(i32) -> F) {
async fn call_normal_mut<F: Future<Output = ()>>(f: &mut impl FnMut(i32) -> F) {
f(1).await;
}

Expand All @@ -39,14 +44,16 @@ pub fn main() {
let mut async_closure = async move |a: i32| {
println!("{a} {b}");
};
call(&mut async_closure).await;
call_mut(&mut async_closure).await;
call_once(async_closure).await;

// No-capture closures implement `Fn`.
let async_closure = async move |a: i32| {
println!("{a}");
let b = 2i32;
let mut async_closure = async |a: i32| {
println!("{a} {b}");
};
call_normal(&async_closure).await;
call_normal_once(async_closure).await;
call_normal_mut(&mut async_closure).await;
call_once(async_closure).await;
});
}
6 changes: 4 additions & 2 deletions src/tools/miri/tests/pass/async-closure.stdout
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
0 2
0 2
1 2
1 2
1 2
1 2
0
1
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// MIR for `main::{closure#0}::{closure#0}::{closure#0}` 0 coroutine_by_move

fn main::{closure#0}::{closure#0}::{closure#0}(_1: {async closure body@$DIR/async_closure_shims.rs:42:53: 45:10}, _2: ResumeTy) -> ()
fn main::{closure#0}::{closure#0}::{closure#0}(_1: {async closure body@$DIR/async_closure_shims.rs:53:53: 56:10}, _2: ResumeTy) -> ()
yields ()
{
debug _task_context => _2;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// MIR for `main::{closure#0}::{closure#0}::{closure#0}` 0 coroutine_by_move

fn main::{closure#0}::{closure#0}::{closure#0}(_1: {async closure body@$DIR/async_closure_shims.rs:42:53: 45:10}, _2: ResumeTy) -> ()
fn main::{closure#0}::{closure#0}::{closure#0}(_1: {async closure body@$DIR/async_closure_shims.rs:53:53: 56:10}, _2: ResumeTy) -> ()
yields ()
{
debug _task_context => _2;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
// MIR for `main::{closure#0}::{closure#0}` 0 coroutine_closure_by_move

fn main::{closure#0}::{closure#0}(_1: {async closure@$DIR/async_closure_shims.rs:42:33: 42:52}, _2: i32) -> {async closure body@$DIR/async_closure_shims.rs:42:53: 45:10} {
let mut _0: {async closure body@$DIR/async_closure_shims.rs:42:53: 45:10};
fn main::{closure#0}::{closure#0}(_1: {async closure@$DIR/async_closure_shims.rs:53:33: 53:52}, _2: i32) -> {async closure body@$DIR/async_closure_shims.rs:53:53: 56:10} {
let mut _0: {async closure body@$DIR/async_closure_shims.rs:53:53: 56:10};

bb0: {
_0 = {coroutine@$DIR/async_closure_shims.rs:42:53: 45:10 (#0)} { a: move _2, b: move (_1.0: i32) };
_0 = {coroutine@$DIR/async_closure_shims.rs:53:53: 56:10 (#0)} { a: move _2, b: move (_1.0: i32) };
return;
}
}
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
// MIR for `main::{closure#0}::{closure#0}` 0 coroutine_closure_by_move

fn main::{closure#0}::{closure#0}(_1: {async closure@$DIR/async_closure_shims.rs:42:33: 42:52}, _2: i32) -> {async closure body@$DIR/async_closure_shims.rs:42:53: 45:10} {
let mut _0: {async closure body@$DIR/async_closure_shims.rs:42:53: 45:10};
fn main::{closure#0}::{closure#0}(_1: {async closure@$DIR/async_closure_shims.rs:53:33: 53:52}, _2: i32) -> {async closure body@$DIR/async_closure_shims.rs:53:53: 56:10} {
let mut _0: {async closure body@$DIR/async_closure_shims.rs:53:53: 56:10};

bb0: {
_0 = {coroutine@$DIR/async_closure_shims.rs:42:53: 45:10 (#0)} { a: move _2, b: move (_1.0: i32) };
_0 = {coroutine@$DIR/async_closure_shims.rs:53:53: 56:10 (#0)} { a: move _2, b: move (_1.0: i32) };
return;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// MIR for `main::{closure#0}::{closure#1}::{closure#0}` 0 coroutine_by_move

fn main::{closure#0}::{closure#1}::{closure#0}(_1: {async closure body@$DIR/async_closure_shims.rs:62:48: 65:10}, _2: ResumeTy) -> ()
yields ()
{
debug _task_context => _2;
debug a => (_1.0: i32);
debug b => (*(_1.1: &i32));
let mut _0: ();
let _3: i32;
scope 1 {
debug a => _3;
let _4: &i32;
scope 2 {
debug a => _4;
let _5: &i32;
scope 3 {
debug b => _5;
}
}
}

bb0: {
StorageLive(_3);
_3 = (_1.0: i32);
FakeRead(ForLet(None), _3);
StorageLive(_4);
_4 = &_3;
FakeRead(ForLet(None), _4);
StorageLive(_5);
_5 = &(*(_1.1: &i32));
FakeRead(ForLet(None), _5);
_0 = const ();
StorageDead(_5);
StorageDead(_4);
StorageDead(_3);
drop(_1) -> [return: bb1, unwind: bb2];
}

bb1: {
return;
}

bb2 (cleanup): {
resume;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// MIR for `main::{closure#0}::{closure#1}::{closure#0}` 0 coroutine_by_move

fn main::{closure#0}::{closure#1}::{closure#0}(_1: {async closure body@$DIR/async_closure_shims.rs:62:48: 65:10}, _2: ResumeTy) -> ()
yields ()
{
debug _task_context => _2;
debug a => (_1.0: i32);
debug b => (*(_1.1: &i32));
let mut _0: ();
let _3: i32;
scope 1 {
debug a => _3;
let _4: &i32;
scope 2 {
debug a => _4;
let _5: &i32;
scope 3 {
debug b => _5;
}
}
}

bb0: {
StorageLive(_3);
_3 = (_1.0: i32);
FakeRead(ForLet(None), _3);
StorageLive(_4);
_4 = &_3;
FakeRead(ForLet(None), _4);
StorageLive(_5);
_5 = &(*(_1.1: &i32));
FakeRead(ForLet(None), _5);
_0 = const ();
StorageDead(_5);
StorageDead(_4);
StorageDead(_3);
drop(_1) -> [return: bb1, unwind: bb2];
}

bb1: {
return;
}

bb2 (cleanup): {
resume;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// MIR for `main::{closure#0}::{closure#1}` 0 coroutine_closure_by_move

fn main::{closure#0}::{closure#1}(_1: {async closure@$DIR/async_closure_shims.rs:62:33: 62:47}, _2: i32) -> {async closure body@$DIR/async_closure_shims.rs:62:48: 65:10} {
let mut _0: {async closure body@$DIR/async_closure_shims.rs:62:48: 65:10};

bb0: {
_0 = {coroutine@$DIR/async_closure_shims.rs:62:48: 65:10 (#0)} { a: move _2, b: move (_1.0: &i32) };
return;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// MIR for `main::{closure#0}::{closure#1}` 0 coroutine_closure_by_move

fn main::{closure#0}::{closure#1}(_1: {async closure@$DIR/async_closure_shims.rs:62:33: 62:47}, _2: i32) -> {async closure body@$DIR/async_closure_shims.rs:62:48: 65:10} {
let mut _0: {async closure body@$DIR/async_closure_shims.rs:62:48: 65:10};

bb0: {
_0 = {coroutine@$DIR/async_closure_shims.rs:62:48: 65:10 (#0)} { a: move _2, b: move (_1.0: &i32) };
return;
}
}
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
// MIR for `main::{closure#0}::{closure#1}` 0 coroutine_closure_by_ref

fn main::{closure#0}::{closure#1}(_1: &mut {async closure@$DIR/async_closure_shims.rs:49:29: 49:48}, _2: i32) -> {async closure body@$DIR/async_closure_shims.rs:49:49: 51:10} {
let mut _0: {async closure body@$DIR/async_closure_shims.rs:49:49: 51:10};
fn main::{closure#0}::{closure#1}(_1: &{async closure@$DIR/async_closure_shims.rs:62:33: 62:47}, _2: i32) -> {async closure body@$DIR/async_closure_shims.rs:62:48: 65:10} {
let mut _0: {async closure body@$DIR/async_closure_shims.rs:62:48: 65:10};

bb0: {
_0 = {coroutine@$DIR/async_closure_shims.rs:49:49: 51:10 (#0)} { a: move _2 };
_0 = {coroutine@$DIR/async_closure_shims.rs:62:48: 65:10 (#0)} { a: move _2, b: ((*_1).0: &i32) };
return;
}
}
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
// MIR for `main::{closure#0}::{closure#1}` 0 coroutine_closure_by_ref

fn main::{closure#0}::{closure#1}(_1: &mut {async closure@$DIR/async_closure_shims.rs:49:29: 49:48}, _2: i32) -> {async closure body@$DIR/async_closure_shims.rs:49:49: 51:10} {
let mut _0: {async closure body@$DIR/async_closure_shims.rs:49:49: 51:10};
fn main::{closure#0}::{closure#1}(_1: &{async closure@$DIR/async_closure_shims.rs:62:33: 62:47}, _2: i32) -> {async closure body@$DIR/async_closure_shims.rs:62:48: 65:10} {
let mut _0: {async closure body@$DIR/async_closure_shims.rs:62:48: 65:10};

bb0: {
_0 = {coroutine@$DIR/async_closure_shims.rs:49:49: 51:10 (#0)} { a: move _2 };
_0 = {coroutine@$DIR/async_closure_shims.rs:62:48: 65:10 (#0)} { a: move _2, b: ((*_1).0: &i32) };
return;
}
}
Loading

0 comments on commit edda2c7

Please sign in to comment.