From 1eee420c64e6f7236e45ea92d4671886d704e23c Mon Sep 17 00:00:00 2001 From: mox692 Date: Fri, 29 Dec 2023 18:25:22 +0900 Subject: [PATCH] make `repeat`, `sink` cooperative --- tokio/src/io/util/repeat.rs | 18 +++++++++++++++++- tokio/src/io/util/sink.rs | 19 ++++++++++++++++++- tokio/tests/io_repeat.rs | 18 ++++++++++++++++++ tokio/tests/io_sink.rs | 18 ++++++++++++++++++ 4 files changed, 71 insertions(+), 2 deletions(-) create mode 100644 tokio/tests/io_repeat.rs create mode 100644 tokio/tests/io_sink.rs diff --git a/tokio/src/io/util/repeat.rs b/tokio/src/io/util/repeat.rs index 1142765df5c..f651f9e8926 100644 --- a/tokio/src/io/util/repeat.rs +++ b/tokio/src/io/util/repeat.rs @@ -50,9 +50,11 @@ impl AsyncRead for Repeat { #[inline] fn poll_read( self: Pin<&mut Self>, - _: &mut Context<'_>, + cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { + ready!(crate::trace::trace_leaf(cx)); + ready!(poll_proceed_and_make_progress(cx)); // TODO: could be faster, but should we unsafe it? while buf.remaining() != 0 { buf.put_slice(&[self.byte]); @@ -61,6 +63,20 @@ impl AsyncRead for Repeat { } } +cfg_coop! { + fn poll_proceed_and_make_progress(cx: &mut Context<'_>) -> Poll<()> { + let coop = ready!(crate::runtime::coop::poll_proceed(cx)); + coop.made_progress(); + Poll::Ready(()) + } +} + +cfg_not_coop! { + fn poll_proceed_and_make_progress(_: &mut Context<'_>) -> Poll<()> { + Poll::Ready(()) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/tokio/src/io/util/sink.rs b/tokio/src/io/util/sink.rs index 05ee773fa38..8e799db0a76 100644 --- a/tokio/src/io/util/sink.rs +++ b/tokio/src/io/util/sink.rs @@ -53,9 +53,12 @@ impl AsyncWrite for Sink { #[inline] fn poll_write( self: Pin<&mut Self>, - _: &mut Context<'_>, + cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { + ready!(crate::trace::trace_leaf(cx)); + ready!(poll_proceed_and_make_progress(cx)); + Poll::Ready(Ok(buf.len())) } @@ -76,6 +79,20 @@ impl fmt::Debug for Sink { } } +cfg_coop! { + fn poll_proceed_and_make_progress(cx: &mut Context<'_>) -> Poll<()> { + let coop = ready!(crate::runtime::coop::poll_proceed(cx)); + coop.made_progress(); + Poll::Ready(()) + } +} + +cfg_not_coop! { + fn poll_proceed_and_make_progress(_: &mut Context<'_>) -> Poll<()> { + Poll::Ready(()) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/tokio/tests/io_repeat.rs b/tokio/tests/io_repeat.rs new file mode 100644 index 00000000000..76d192cc257 --- /dev/null +++ b/tokio/tests/io_repeat.rs @@ -0,0 +1,18 @@ +#![warn(rust_2018_idioms)] +#![cfg(all(feature = "full"))] + +use tokio::io::AsyncReadExt; + +#[tokio::test] +async fn repeat_is_cooperative() { + tokio::select! { + biased; + _ = async { + loop { + let mut buf = [0u8; 4096]; + tokio::io::repeat(0b101).read_exact(&mut buf).await.unwrap(); + } + } => {}, + _ = tokio::task::yield_now() => {} + } +} diff --git a/tokio/tests/io_sink.rs b/tokio/tests/io_sink.rs new file mode 100644 index 00000000000..a4a8eb5f89c --- /dev/null +++ b/tokio/tests/io_sink.rs @@ -0,0 +1,18 @@ +#![warn(rust_2018_idioms)] +#![cfg(all(feature = "full"))] + +use tokio::io::AsyncWriteExt; + +#[tokio::test] +async fn sink_is_cooperative() { + tokio::select! { + biased; + _ = async { + loop { + let buf= vec![1, 2, 3]; + tokio::io::sink().write(&buf).await.unwrap(); + } + } => {}, + _ = tokio::task::yield_now() => {} + } +}