From 5781b2faa2b7eab3f63408542fb45f44e8b1fa1a Mon Sep 17 00:00:00 2001 From: mox692 Date: Mon, 26 Feb 2024 21:44:21 +0900 Subject: [PATCH] init work --- tokio/src/net/tcp/listener.rs | 14 ++++- tokio/src/net/tcp/stream.rs | 44 +++++++++++++-- tokio/tests/tcp_stream.rs | 101 ++++++++++++++++++++++++++++++++++ 3 files changed, 151 insertions(+), 8 deletions(-) diff --git a/tokio/src/net/tcp/listener.rs b/tokio/src/net/tcp/listener.rs index 3f6592abe19..a737e55d32f 100644 --- a/tokio/src/net/tcp/listener.rs +++ b/tokio/src/net/tcp/listener.rs @@ -158,13 +158,23 @@ impl TcpListener { /// } /// ``` pub async fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> { + self.accept_with_interest(Interest::READABLE | Interest::WRITABLE) + .await + } + + /// comment + pub async fn accept_with_interest( + &self, + interest: Interest, + ) -> io::Result<(TcpStream, SocketAddr)> { let (mio, addr) = self .io .registration() - .async_io(Interest::READABLE, || self.io.accept()) + .async_io(interest, || self.io.accept()) .await?; - let stream = TcpStream::new(mio)?; + // TODO: clear here + let stream = TcpStream::new_with_interest(mio, interest)?; Ok((stream, addr)) } diff --git a/tokio/src/net/tcp/stream.rs b/tokio/src/net/tcp/stream.rs index e20473e5cc3..113d142f873 100644 --- a/tokio/src/net/tcp/stream.rs +++ b/tokio/src/net/tcp/stream.rs @@ -117,7 +117,28 @@ impl TcpStream { let mut last_err = None; for addr in addrs { - match TcpStream::connect_addr(addr).await { + match TcpStream::connect_addr_with_interest(addr, Interest::READABLE | Interest::WRITABLE).await { + Ok(stream) => return Ok(stream), + Err(e) => last_err = Some(e), + } + } + + Err(last_err.unwrap_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + "could not resolve to any address", + ) + })) + } + + /// foooo + pub async fn connect_with_interest(addr: A, interest: Interest) -> io::Result { + let addrs = to_socket_addrs(addr).await?; + + let mut last_err = None; + + for addr in addrs { + match TcpStream::connect_addr_with_interest(addr, interest).await { Ok(stream) => return Ok(stream), Err(e) => last_err = Some(e), } @@ -132,13 +153,17 @@ impl TcpStream { } /// Establishes a connection to the specified `addr`. - async fn connect_addr(addr: SocketAddr) -> io::Result { + async fn connect_addr_with_interest(addr: SocketAddr, interest: Interest) -> io::Result { let sys = mio::net::TcpStream::connect(addr)?; - TcpStream::connect_mio(sys).await + TcpStream::connect_mio_with_interest(sys, interest).await } pub(crate) async fn connect_mio(sys: mio::net::TcpStream) -> io::Result { - let stream = TcpStream::new(sys)?; + Self::connect_mio_with_interest(sys, Interest::READABLE | Interest::WRITABLE).await + } + + pub(crate) async fn connect_mio_with_interest(sys: mio::net::TcpStream, interest: Interest) -> io::Result { + let stream = TcpStream::new_with_interest(sys, interest)?; // Once we've connected, wait for the stream to be writable as // that's when the actual connection has been initiated. Once we're @@ -157,8 +182,7 @@ impl TcpStream { } pub(crate) fn new(connected: mio::net::TcpStream) -> io::Result { - let io = PollEvented::new(connected)?; - Ok(TcpStream { io }) + Self::new_with_interest(connected, Interest::READABLE | Interest::WRITABLE) } /// Creates new `TcpStream` from a `std::net::TcpStream`. @@ -205,6 +229,14 @@ impl TcpStream { Ok(TcpStream { io }) } + pub(crate) fn new_with_interest( + connected: mio::net::TcpStream, + interest: Interest, + ) -> io::Result { + let io = PollEvented::new_with_interest(connected, interest)?; + Ok(TcpStream { io }) + } + /// Turns a [`tokio::net::TcpStream`] into a [`std::net::TcpStream`]. /// /// The returned [`std::net::TcpStream`] will have nonblocking mode set as `true`. diff --git a/tokio/tests/tcp_stream.rs b/tokio/tests/tcp_stream.rs index 725a60169ea..16474e8459c 100644 --- a/tokio/tests/tcp_stream.rs +++ b/tokio/tests/tcp_stream.rs @@ -398,3 +398,104 @@ async fn write_closed() { assert!(!ready_event.is_write_closed()); } + +#[cfg(any(target_os = "linux", target_os = "android"))] +#[tokio::test] +async fn priority_interest() { + use std::os::fd::AsRawFd; + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let stream = TcpStream::connect(listener.local_addr().unwrap()) + .await + .unwrap(); + + tokio::spawn(async move { + let (socket, _) = listener + .accept_with_interest(Interest::PRIORITY) + .await + .unwrap(); + let ready = socket.ready(Interest::PRIORITY).await.unwrap(); + assert!(ready.is_priority()); + }); + + let ready = stream + .ready(Interest::READABLE | Interest::WRITABLE) + .await + .unwrap(); + if ready.is_writable() { + fn send_oob_data(stream: &TcpStream, data: &[u8]) -> io::Result { + unsafe { + let res = libc::send( + stream.as_raw_fd(), + data.as_ptr().cast(), + data.len(), + libc::MSG_OOB, + ); + if res == -1 { + Err(io::Error::last_os_error()) + } else { + Ok(res as usize) + } + } + } + send_oob_data(&stream, b"hello").unwrap(); + } +} + +#[cfg(any(target_os = "linux", target_os = "android"))] +#[tokio::test] +async fn connect_with_interest() { + use std::os::fd::AsRawFd; + // TODO: should be minimized. + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let stream = TcpStream::connect_with_interest( + listener.local_addr().unwrap(), + Interest::PRIORITY | Interest::READABLE | Interest::WRITABLE, + ) + .await + .unwrap(); + + tokio::spawn(async move { + let (socket, _) = listener + .accept_with_interest(Interest::READABLE | Interest::WRITABLE) + .await + .unwrap(); + + loop { + let ready = socket + .ready(Interest::READABLE | Interest::WRITABLE) + .await + .unwrap(); + if ready.is_writable() { + fn send_oob_data(stream: &TcpStream, data: &[u8]) -> io::Result { + unsafe { + let res = libc::send( + stream.as_raw_fd(), + data.as_ptr().cast(), + data.len(), + libc::MSG_OOB, + ); + if res == -1 { + Err(io::Error::last_os_error()) + } else { + Ok(res as usize) + } + } + } + send_oob_data(&socket, b"hello").unwrap(); + break; + } + if ready.is_readable() { + continue; + } + } + }); + + let ready = stream.ready(Interest::WRITABLE).await.unwrap(); + if ready.is_writable() { + stream.try_write(&[1, 2, 3]).unwrap(); + } + let ready = stream.ready(Interest::PRIORITY).await.unwrap(); + assert!(ready.is_priority()); +}