diff --git a/tokio/src/io/util/empty.rs b/tokio/src/io/util/empty.rs index 06be4ff3073..289725ce49f 100644 --- a/tokio/src/io/util/empty.rs +++ b/tokio/src/io/util/empty.rs @@ -1,3 +1,4 @@ +use crate::io::util::poll_proceed_and_make_progress; use crate::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf}; use std::fmt; @@ -138,20 +139,6 @@ impl fmt::Debug for Empty { } } -cfg_coop! { - fn poll_proceed_and_make_progress(cx: &mut Context<'_>) -> Poll<()> { - let coop = ready!(crate::runtime::coop::poll_proceed(cx)); - coop.made_progress(); - Poll::Ready(()) - } -} - -cfg_not_coop! { - fn poll_proceed_and_make_progress(_: &mut Context<'_>) -> Poll<()> { - Poll::Ready(()) - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/tokio/src/io/util/mod.rs b/tokio/src/io/util/mod.rs index 21199d0be84..47b951f2b83 100644 --- a/tokio/src/io/util/mod.rs +++ b/tokio/src/io/util/mod.rs @@ -85,6 +85,20 @@ cfg_io_util! { // used by `BufReader` and `BufWriter` // https://github.com/rust-lang/rust/blob/master/library/std/src/sys_common/io.rs#L1 const DEFAULT_BUF_SIZE: usize = 8 * 1024; + + cfg_coop! { + fn poll_proceed_and_make_progress(cx: &mut std::task::Context<'_>) -> std::task::Poll<()> { + let coop = ready!(crate::runtime::coop::poll_proceed(cx)); + coop.made_progress(); + std::task::Poll::Ready(()) + } + } + + cfg_not_coop! { + fn poll_proceed_and_make_progress(_: &mut std::task::Context<'_>) -> std::task::Poll<()> { + std::task::Poll::Ready(()) + } + } } cfg_not_io_util! { diff --git a/tokio/src/io/util/repeat.rs b/tokio/src/io/util/repeat.rs index 1142765df5c..4a3ac78e49e 100644 --- a/tokio/src/io/util/repeat.rs +++ b/tokio/src/io/util/repeat.rs @@ -1,3 +1,4 @@ +use crate::io::util::poll_proceed_and_make_progress; use crate::io::{AsyncRead, ReadBuf}; use std::io; @@ -50,9 +51,11 @@ impl AsyncRead for Repeat { #[inline] fn poll_read( self: Pin<&mut Self>, - _: &mut Context<'_>, + cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { + ready!(crate::trace::trace_leaf(cx)); + ready!(poll_proceed_and_make_progress(cx)); // TODO: could be faster, but should we unsafe it? while buf.remaining() != 0 { buf.put_slice(&[self.byte]); diff --git a/tokio/src/io/util/sink.rs b/tokio/src/io/util/sink.rs index 05ee773fa38..1c0102d4b2f 100644 --- a/tokio/src/io/util/sink.rs +++ b/tokio/src/io/util/sink.rs @@ -1,3 +1,4 @@ +use crate::io::util::poll_proceed_and_make_progress; use crate::io::AsyncWrite; use std::fmt; @@ -53,19 +54,25 @@ impl AsyncWrite for Sink { #[inline] fn poll_write( self: Pin<&mut Self>, - _: &mut Context<'_>, + cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { + ready!(crate::trace::trace_leaf(cx)); + ready!(poll_proceed_and_make_progress(cx)); Poll::Ready(Ok(buf.len())) } #[inline] - fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + ready!(crate::trace::trace_leaf(cx)); + ready!(poll_proceed_and_make_progress(cx)); Poll::Ready(Ok(())) } #[inline] - fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + ready!(crate::trace::trace_leaf(cx)); + ready!(poll_proceed_and_make_progress(cx)); Poll::Ready(Ok(())) } } diff --git a/tokio/tests/io_repeat.rs b/tokio/tests/io_repeat.rs new file mode 100644 index 00000000000..76d192cc257 --- /dev/null +++ b/tokio/tests/io_repeat.rs @@ -0,0 +1,18 @@ +#![warn(rust_2018_idioms)] +#![cfg(all(feature = "full"))] + +use tokio::io::AsyncReadExt; + +#[tokio::test] +async fn repeat_is_cooperative() { + tokio::select! { + biased; + _ = async { + loop { + let mut buf = [0u8; 4096]; + tokio::io::repeat(0b101).read_exact(&mut buf).await.unwrap(); + } + } => {}, + _ = tokio::task::yield_now() => {} + } +} diff --git a/tokio/tests/io_sink.rs b/tokio/tests/io_sink.rs new file mode 100644 index 00000000000..9b4fb31f30f --- /dev/null +++ b/tokio/tests/io_sink.rs @@ -0,0 +1,44 @@ +#![warn(rust_2018_idioms)] +#![cfg(all(feature = "full"))] + +use tokio::io::AsyncWriteExt; + +#[tokio::test] +async fn sink_poll_write_is_cooperative() { + tokio::select! { + biased; + _ = async { + loop { + let buf = vec![1, 2, 3]; + tokio::io::sink().write_all(&buf).await.unwrap(); + } + } => {}, + _ = tokio::task::yield_now() => {} + } +} + +#[tokio::test] +async fn sink_poll_flush_is_cooperative() { + tokio::select! { + biased; + _ = async { + loop { + tokio::io::sink().flush().await.unwrap(); + } + } => {}, + _ = tokio::task::yield_now() => {} + } +} + +#[tokio::test] +async fn sink_poll_shutdown_is_cooperative() { + tokio::select! { + biased; + _ = async { + loop { + tokio::io::sink().shutdown().await.unwrap(); + } + } => {}, + _ = tokio::task::yield_now() => {} + } +}