From bddcaae578e361f7e6383a1b1c85f312a8e0b134 Mon Sep 17 00:00:00 2001 From: Sean Sullivan Date: Mon, 15 Aug 2022 06:17:27 -0700 Subject: [PATCH 1/3] add first pass implementation of windows uds modify src/net for windows compatibility fix tests add docs back in cleanup remove log statements clean up selector clean up stream and listener sys logic fix re-registration add test for serial calls to listener.accept fix serial calls to accept remove tempfile dependency and fix doc tests revert change in draining behavior re-organize stdnet files to mirror std::os::unix::net use single syscall vectored approach from rust-lang/socket2 lint improve support across feature matrix fix doc tests use bcrypt instead of rand add -_ to random char set to avoid rejection sampling optimize rng syscall logic fix lint and fmt remove unused functions fmt simplify windows mod clean up tests fix indentation, imports, address other comments fmt remove unrelated code changes fix lint remove explicit SetHandleInformation calls abstract socketaddr behind common API in net fix lint add comment clarifying inheritance during calls to accept --- src/net/mod.rs | 11 +- src/net/tcp/stream.rs | 20 +-- src/net/uds/addr.rs | 97 +++++++++++ src/net/uds/datagram.rs | 12 +- src/net/uds/listener.rs | 53 +++++- src/net/uds/mod.rs | 7 +- src/net/uds/stream.rs | 155 +++++++++++++++-- src/sys/mod.rs | 14 +- src/sys/shell/mod.rs | 1 - src/sys/shell/uds.rs | 15 +- src/sys/unix/mod.rs | 4 +- src/sys/unix/pipe.rs | 20 +-- src/sys/unix/uds/listener.rs | 3 +- src/sys/unix/uds/mod.rs | 4 +- src/sys/unix/uds/socketaddr.rs | 23 +-- src/sys/windows/mod.rs | 257 ++++++++++++++++------------- src/sys/windows/net.rs | 1 - src/sys/windows/stdnet/addr.rs | 124 ++++++++++++++ src/sys/windows/stdnet/listener.rs | 83 ++++++++++ src/sys/windows/stdnet/mod.rs | 26 +++ src/sys/windows/stdnet/socket.rs | 186 +++++++++++++++++++++ src/sys/windows/stdnet/stream.rs | 151 +++++++++++++++++ src/sys/windows/uds/listener.rs | 23 +++ src/sys/windows/uds/mod.rs | 29 ++++ src/sys/windows/uds/stream.rs | 19 +++ tests/unix_listener.rs | 8 +- tests/unix_pipe.rs | 4 +- tests/unix_stream.rs | 139 +++++++++++----- 28 files changed, 1245 insertions(+), 244 deletions(-) create mode 100644 src/net/uds/addr.rs create mode 100644 src/sys/windows/stdnet/addr.rs create mode 100644 src/sys/windows/stdnet/listener.rs create mode 100644 src/sys/windows/stdnet/mod.rs create mode 100644 src/sys/windows/stdnet/socket.rs create mode 100644 src/sys/windows/stdnet/stream.rs create mode 100644 src/sys/windows/uds/listener.rs create mode 100644 src/sys/windows/uds/mod.rs create mode 100644 src/sys/windows/uds/stream.rs diff --git a/src/net/mod.rs b/src/net/mod.rs index 7d714ca00..41d81a2d4 100644 --- a/src/net/mod.rs +++ b/src/net/mod.rs @@ -32,8 +32,13 @@ pub use self::tcp::{TcpListener, TcpStream}; mod udp; #[cfg(not(target_os = "wasi"))] pub use self::udp::UdpSocket; - -#[cfg(unix)] +#[cfg(not(target_os = "wasi"))] mod uds; +#[cfg(not(target_os = "wasi"))] +pub use self::uds::{SocketAddr, UnixListener, UnixStream}; + +#[cfg(not(target_os = "wasi"))] +pub(crate) use self::uds::AddressKind; + #[cfg(unix)] -pub use self::uds::{SocketAddr, UnixDatagram, UnixListener, UnixStream}; +pub use self::uds::UnixDatagram; diff --git a/src/net/tcp/stream.rs b/src/net/tcp/stream.rs index 8a3f6a2f2..3264904f5 100644 --- a/src/net/tcp/stream.rs +++ b/src/net/tcp/stream.rs @@ -269,49 +269,49 @@ impl TcpStream { impl Read for TcpStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.inner.do_io(|mut inner| inner.read(buf)) + self.inner.do_io(|inner| (&*inner).read(buf)) } fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { - self.inner.do_io(|mut inner| inner.read_vectored(bufs)) + self.inner.do_io(|inner| (&*inner).read_vectored(bufs)) } } impl<'a> Read for &'a TcpStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.inner.do_io(|mut inner| inner.read(buf)) + self.inner.do_io(|inner| (&*inner).read(buf)) } fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { - self.inner.do_io(|mut inner| inner.read_vectored(bufs)) + self.inner.do_io(|inner| (&*inner).read_vectored(bufs)) } } impl Write for TcpStream { fn write(&mut self, buf: &[u8]) -> io::Result { - self.inner.do_io(|mut inner| inner.write(buf)) + self.inner.do_io(|inner| (&*inner).write(buf)) } fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { - self.inner.do_io(|mut inner| inner.write_vectored(bufs)) + self.inner.do_io(|inner| (&*inner).write_vectored(bufs)) } fn flush(&mut self) -> io::Result<()> { - self.inner.do_io(|mut inner| inner.flush()) + self.inner.do_io(|inner| (&*inner).flush()) } } impl<'a> Write for &'a TcpStream { fn write(&mut self, buf: &[u8]) -> io::Result { - self.inner.do_io(|mut inner| inner.write(buf)) + self.inner.do_io(|inner| (&*inner).write(buf)) } fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { - self.inner.do_io(|mut inner| inner.write_vectored(bufs)) + self.inner.do_io(|inner| (&*inner).write_vectored(bufs)) } fn flush(&mut self) -> io::Result<()> { - self.inner.do_io(|mut inner| inner.flush()) + self.inner.do_io(|inner| (&*inner).flush()) } } diff --git a/src/net/uds/addr.rs b/src/net/uds/addr.rs new file mode 100644 index 000000000..9fb4c9c88 --- /dev/null +++ b/src/net/uds/addr.rs @@ -0,0 +1,97 @@ +use crate::sys; +use std::path::Path; +use std::{ascii, fmt}; + +/// An address associated with a `mio` specific Unix socket. +/// +/// This is implemented instead of imported from [`net::SocketAddr`] because +/// there is no way to create a [`net::SocketAddr`]. One must be returned by +/// [`accept`], so this is returned instead. +/// +/// [`net::SocketAddr`]: std::os::unix::net::SocketAddr +/// [`accept`]: #method.accept +pub struct SocketAddr { + inner: sys::SocketAddr, +} + +struct AsciiEscaped<'a>(&'a [u8]); + +pub(crate) enum AddressKind<'a> { + Unnamed, + Pathname(&'a Path), + Abstract(&'a [u8]), +} + +impl SocketAddr { + pub(crate) fn new(inner: sys::SocketAddr) -> Self { + SocketAddr { inner } + } + + fn address(&self) -> AddressKind<'_> { + self.inner.address() + } +} + +cfg_os_poll! { + impl SocketAddr { + /// Returns `true` if the address is unnamed. + /// + /// Documentation reflected in [`SocketAddr`] + /// + /// [`SocketAddr`]: std::os::unix::net::SocketAddr + pub fn is_unnamed(&self) -> bool { + matches!(self.address(), AddressKind::Unnamed) + } + + /// Returns the contents of this address if it is a `pathname` address. + /// + /// Documentation reflected in [`SocketAddr`] + /// + /// [`SocketAddr`]: std::os::unix::net::SocketAddr + pub fn as_pathname(&self) -> Option<&Path> { + if let AddressKind::Pathname(path) = self.address() { + Some(path) + } else { + None + } + } + + /// Returns the contents of this address if it is an abstract namespace + /// without the leading null byte. + // Link to std::os::unix::net::SocketAddr pending + // https://github.com/rust-lang/rust/issues/85410. + pub fn as_abstract_namespace(&self) -> Option<&[u8]> { + if let AddressKind::Abstract(path) = self.address() { + Some(path) + } else { + None + } + } + } +} + +impl fmt::Debug for SocketAddr { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "{:?}", self.address()) + } +} + +impl fmt::Debug for AddressKind<'_> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + AddressKind::Unnamed => write!(fmt, "(unnamed)"), + AddressKind::Abstract(name) => write!(fmt, "{} (abstract)", AsciiEscaped(name)), + AddressKind::Pathname(path) => write!(fmt, "{:?} (pathname)", path), + } + } +} + +impl<'a> fmt::Display for AsciiEscaped<'a> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "\"")?; + for byte in self.0.iter().cloned().flat_map(ascii::escape_default) { + write!(fmt, "{}", byte as char)?; + } + write!(fmt, "\"") + } +} diff --git a/src/net/uds/datagram.rs b/src/net/uds/datagram.rs index e963d6e2f..7bc1b7b1f 100644 --- a/src/net/uds/datagram.rs +++ b/src/net/uds/datagram.rs @@ -1,4 +1,5 @@ use crate::io_source::IoSource; +use crate::net::SocketAddr; use crate::{event, sys, Interest, Registry, Token}; use std::net::Shutdown; @@ -54,24 +55,25 @@ impl UnixDatagram { } /// Returns the address of this socket. - pub fn local_addr(&self) -> io::Result { - sys::uds::datagram::local_addr(&self.inner) + pub fn local_addr(&self) -> io::Result { + sys::uds::datagram::local_addr(&self.inner).map(SocketAddr::new) } /// Returns the address of this socket's peer. /// /// The `connect` method will connect the socket to a peer. - pub fn peer_addr(&self) -> io::Result { - sys::uds::datagram::peer_addr(&self.inner) + pub fn peer_addr(&self) -> io::Result { + sys::uds::datagram::peer_addr(&self.inner).map(SocketAddr::new) } /// Receives data from the socket. /// /// On success, returns the number of bytes read and the address from /// whence the data came. - pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, sys::SocketAddr)> { + pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { self.inner .do_io(|inner| sys::uds::datagram::recv_from(inner, buf)) + .map(|(nread, addr)| (nread, SocketAddr::new(addr))) } /// Receives data from the socket. diff --git a/src/net/uds/listener.rs b/src/net/uds/listener.rs index eeffe042e..fba8ccf37 100644 --- a/src/net/uds/listener.rs +++ b/src/net/uds/listener.rs @@ -2,8 +2,14 @@ use crate::io_source::IoSource; use crate::net::{SocketAddr, UnixStream}; use crate::{event, sys, Interest, Registry, Token}; +#[cfg(windows)] +use crate::sys::windows::stdnet as net; +#[cfg(unix)] use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; +#[cfg(unix)] use std::os::unix::net; +#[cfg(windows)] +use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket}; use std::path::Path; use std::{fmt, io}; @@ -29,23 +35,34 @@ impl UnixListener { /// standard library in the Mio equivalent. The conversion assumes nothing /// about the underlying listener; it is left up to the user to set it in /// non-blocking mode. + #[cfg(unix)] + #[cfg_attr(docsrs, doc(cfg(unix)))] pub fn from_std(listener: net::UnixListener) -> UnixListener { UnixListener { inner: IoSource::new(listener), } } + #[cfg(windows)] + pub(crate) fn from_std(listener: net::UnixListener) -> UnixListener { + UnixListener { + inner: IoSource::new(listener), + } + } + /// Accepts a new incoming connection to this listener. /// /// The call is responsible for ensuring that the listening socket is in /// non-blocking mode. pub fn accept(&self) -> io::Result<(UnixStream, SocketAddr)> { - sys::uds::listener::accept(&self.inner) + self.inner + .do_io(sys::uds::listener::accept) + .map(|(stream, addr)| (stream, SocketAddr::new(addr))) } /// Returns the local socket address of this listener. - pub fn local_addr(&self) -> io::Result { - sys::uds::listener::local_addr(&self.inner) + pub fn local_addr(&self) -> io::Result { + sys::uds::listener::local_addr(&self.inner).map(SocketAddr::new) } /// Returns the value of the `SO_ERROR` option. @@ -84,18 +101,24 @@ impl fmt::Debug for UnixListener { } } +#[cfg(unix)] +#[cfg_attr(docsrs, doc(cfg(unix)))] impl IntoRawFd for UnixListener { fn into_raw_fd(self) -> RawFd { self.inner.into_inner().into_raw_fd() } } +#[cfg(unix)] +#[cfg_attr(docsrs, doc(cfg(unix)))] impl AsRawFd for UnixListener { fn as_raw_fd(&self) -> RawFd { self.inner.as_raw_fd() } } +#[cfg(unix)] +#[cfg_attr(docsrs, doc(cfg(unix)))] impl FromRawFd for UnixListener { /// Converts a `RawFd` to a `UnixListener`. /// @@ -107,3 +130,27 @@ impl FromRawFd for UnixListener { UnixListener::from_std(FromRawFd::from_raw_fd(fd)) } } + +#[cfg(windows)] +#[cfg_attr(docsrs, doc(cfg(windows)))] +impl IntoRawSocket for UnixListener { + fn into_raw_socket(self) -> RawSocket { + self.inner.into_inner().into_raw_socket() + } +} + +#[cfg(windows)] +#[cfg_attr(docsrs, doc(cfg(windows)))] +impl AsRawSocket for UnixListener { + fn as_raw_socket(&self) -> RawSocket { + self.inner.as_raw_socket() + } +} + +#[cfg(windows)] +#[cfg_attr(docsrs, doc(cfg(windows)))] +impl FromRawSocket for UnixListener { + unsafe fn from_raw_socket(sock: RawSocket) -> Self { + UnixListener::from_std(FromRawSocket::from_raw_socket(sock)) + } +} diff --git a/src/net/uds/mod.rs b/src/net/uds/mod.rs index 6b4ffdc43..2a12f965e 100644 --- a/src/net/uds/mod.rs +++ b/src/net/uds/mod.rs @@ -1,4 +1,7 @@ +#[cfg(unix)] mod datagram; +#[cfg(unix)] +#[cfg_attr(docsrs, doc(cfg(unix)))] pub use self::datagram::UnixDatagram; mod listener; @@ -7,4 +10,6 @@ pub use self::listener::UnixListener; mod stream; pub use self::stream::UnixStream; -pub use crate::sys::SocketAddr; +mod addr; +pub(crate) use self::addr::AddressKind; +pub use self::addr::SocketAddr; diff --git a/src/net/uds/stream.rs b/src/net/uds/stream.rs index 1c17d84a1..7b6b5f728 100644 --- a/src/net/uds/stream.rs +++ b/src/net/uds/stream.rs @@ -2,11 +2,17 @@ use crate::io_source::IoSource; use crate::net::SocketAddr; use crate::{event, sys, Interest, Registry, Token}; +#[cfg(windows)] +use crate::sys::windows::stdnet as net; use std::fmt; use std::io::{self, IoSlice, IoSliceMut, Read, Write}; use std::net::Shutdown; +#[cfg(unix)] use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; +#[cfg(unix)] use std::os::unix::net; +#[cfg(windows)] +use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket}; use std::path::Path; /// A non-blocking Unix stream socket. @@ -43,15 +49,26 @@ impl UnixStream { /// The Unix stream here will not have `connect` called on it, so it /// should already be connected via some other means (be it manually, or /// the standard library). + #[cfg(unix)] + #[cfg_attr(docsrs, doc(cfg(unix)))] pub fn from_std(stream: net::UnixStream) -> UnixStream { UnixStream { inner: IoSource::new(stream), } } + #[cfg(windows)] + pub(crate) fn from_std(stream: net::UnixStream) -> UnixStream { + UnixStream { + inner: IoSource::new(stream), + } + } + /// Creates an unnamed pair of connected sockets. /// /// Returns two `UnixStream`s which are connected to each other. + #[cfg(unix)] + #[cfg_attr(docsrs, doc(cfg(unix)))] pub fn pair() -> io::Result<(UnixStream, UnixStream)> { sys::uds::stream::pair().map(|(stream1, stream2)| { (UnixStream::from_std(stream1), UnixStream::from_std(stream2)) @@ -59,13 +76,13 @@ impl UnixStream { } /// Returns the socket address of the local half of this connection. - pub fn local_addr(&self) -> io::Result { - sys::uds::stream::local_addr(&self.inner) + pub fn local_addr(&self) -> io::Result { + sys::uds::stream::local_addr(&self.inner).map(SocketAddr::new) } /// Returns the socket address of the remote half of this connection. - pub fn peer_addr(&self) -> io::Result { - sys::uds::stream::peer_addr(&self.inner) + pub fn peer_addr(&self) -> io::Result { + sys::uds::stream::peer_addr(&self.inner).map(SocketAddr::new) } /// Returns the value of the `SO_ERROR` option. @@ -95,7 +112,8 @@ impl UnixStream { /// /// # Examples /// - /// ``` + #[cfg_attr(unix, doc = "```")] + #[cfg_attr(windows, doc = "```ignore")] /// # use std::error::Error; /// # /// # fn main() -> Result<(), Box> { @@ -143,6 +161,83 @@ impl UnixStream { /// # Ok(()) /// # } /// ``` + /// + #[cfg_attr(windows, doc = "```")] + #[cfg_attr(unix, doc = "```ignore")] + /// # use std::error::Error; + /// # + /// # fn main() -> Result<(), Box> { + /// use std::io; + /// use std::os::windows::io::AsRawSocket; + /// use std::os::raw::c_int; + /// use mio::net::{UnixStream, UnixListener}; + /// use windows_sys::Win32::Networking::WinSock; + /// use std::convert::TryInto; + /// + /// let file_path = std::env::temp_dir().join("server.sock"); + /// # let _ = std::fs::remove_file(&file_path); + /// let server = UnixListener::bind(&file_path).unwrap(); + /// + /// let handle = std::thread::spawn(move || { + /// if let Ok((stream2, _)) = server.accept() { + /// // Wait until the stream is readable... + /// + /// // Read from the stream using a direct WinSock call, of course the + /// // `io::Read` implementation would be easier to use. + /// let mut buf = [0; 512]; + /// let n = stream2.try_io(|| { + /// let res = unsafe { + /// WinSock::recv( + /// stream2.as_raw_socket().try_into().unwrap(), + /// &mut buf as *mut _ as *mut _, + /// buf.len() as c_int, + /// 0 + /// ) + /// }; + /// if res != WinSock::SOCKET_ERROR { + /// Ok(res as usize) + /// } else { + /// // If EAGAIN or EWOULDBLOCK is set by WinSock::recv, the closure + /// // should return `WouldBlock` error. + /// Err(io::Error::last_os_error()) + /// } + /// }).unwrap(); + /// eprintln!("read {} bytes", n); + /// } + /// }); + /// + /// let stream1 = UnixStream::connect(&file_path).unwrap(); + /// + /// // Wait until the stream is writable... + /// + /// // Write to the stream using a direct WinSock call, of course the + /// // `io::Write` implementation would be easier to use. + /// let buf = b"hello"; + /// let n = stream1.try_io(|| { + /// let res = unsafe { + /// WinSock::send( + /// stream1.as_raw_socket().try_into().unwrap(), + /// &buf as *const _ as *const _, + /// buf.len() as c_int, + /// 0 + /// ) + /// }; + /// if res != WinSock::SOCKET_ERROR { + /// Ok(res as usize) + /// } else { + /// // If EAGAIN or EWOULDBLOCK is set by WinSock::send, the closure + /// // should return `WouldBlock` error. + /// Err(io::Error::from_raw_os_error(unsafe { + /// WinSock::WSAGetLastError() + /// })) + /// } + /// })?; + /// eprintln!("write {} bytes", n); + /// + /// # handle.join().unwrap(); + /// # Ok(()) + /// # } + /// ``` pub fn try_io(&self, f: F) -> io::Result where F: FnOnce() -> io::Result, @@ -153,49 +248,49 @@ impl UnixStream { impl Read for UnixStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.inner.do_io(|mut inner| inner.read(buf)) + self.inner.do_io(|inner| (&*inner).read(buf)) } fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { - self.inner.do_io(|mut inner| inner.read_vectored(bufs)) + self.inner.do_io(|inner| (&*inner).read_vectored(bufs)) } } impl<'a> Read for &'a UnixStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.inner.do_io(|mut inner| inner.read(buf)) + self.inner.do_io(|inner| (&*inner).read(buf)) } fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { - self.inner.do_io(|mut inner| inner.read_vectored(bufs)) + self.inner.do_io(|inner| (&*inner).read_vectored(bufs)) } } impl Write for UnixStream { fn write(&mut self, buf: &[u8]) -> io::Result { - self.inner.do_io(|mut inner| inner.write(buf)) + self.inner.do_io(|inner| (&*inner).write(buf)) } fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { - self.inner.do_io(|mut inner| inner.write_vectored(bufs)) + self.inner.do_io(|inner| (&*inner).write_vectored(bufs)) } fn flush(&mut self) -> io::Result<()> { - self.inner.do_io(|mut inner| inner.flush()) + self.inner.do_io(|inner| (&*inner).flush()) } } impl<'a> Write for &'a UnixStream { fn write(&mut self, buf: &[u8]) -> io::Result { - self.inner.do_io(|mut inner| inner.write(buf)) + self.inner.do_io(|inner| (&*inner).write(buf)) } fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { - self.inner.do_io(|mut inner| inner.write_vectored(bufs)) + self.inner.do_io(|inner| (&*inner).write_vectored(bufs)) } fn flush(&mut self) -> io::Result<()> { - self.inner.do_io(|mut inner| inner.flush()) + self.inner.do_io(|inner| (&*inner).flush()) } } @@ -229,18 +324,24 @@ impl fmt::Debug for UnixStream { } } +#[cfg(unix)] +#[cfg_attr(docsrs, doc(cfg(unix)))] impl IntoRawFd for UnixStream { fn into_raw_fd(self) -> RawFd { self.inner.into_inner().into_raw_fd() } } +#[cfg(unix)] +#[cfg_attr(docsrs, doc(cfg(unix)))] impl AsRawFd for UnixStream { fn as_raw_fd(&self) -> RawFd { self.inner.as_raw_fd() } } +#[cfg(unix)] +#[cfg_attr(docsrs, doc(cfg(unix)))] impl FromRawFd for UnixStream { /// Converts a `RawFd` to a `UnixStream`. /// @@ -252,3 +353,27 @@ impl FromRawFd for UnixStream { UnixStream::from_std(FromRawFd::from_raw_fd(fd)) } } + +#[cfg(windows)] +#[cfg_attr(docsrs, doc(cfg(windows)))] +impl IntoRawSocket for UnixStream { + fn into_raw_socket(self) -> RawSocket { + self.inner.into_inner().into_raw_socket() + } +} + +#[cfg(windows)] +#[cfg_attr(docsrs, doc(cfg(windows)))] +impl AsRawSocket for UnixStream { + fn as_raw_socket(&self) -> RawSocket { + self.inner.as_raw_socket() + } +} + +#[cfg(windows)] +#[cfg_attr(docsrs, doc(cfg(windows)))] +impl FromRawSocket for UnixStream { + unsafe fn from_raw_socket(sock: RawSocket) -> Self { + UnixStream::from_std(FromRawSocket::from_raw_socket(sock)) + } +} diff --git a/src/sys/mod.rs b/src/sys/mod.rs index 2a968b265..13b180c4c 100644 --- a/src/sys/mod.rs +++ b/src/sys/mod.rs @@ -59,7 +59,7 @@ cfg_os_poll! { #[cfg(windows)] cfg_os_poll! { - mod windows; + pub(crate) mod windows; pub use self::windows::*; } @@ -81,6 +81,16 @@ cfg_not_os_poll! { #[cfg(unix)] cfg_net! { - pub use self::unix::SocketAddr; + pub(crate) use self::unix::SocketAddr; + } + + #[cfg(windows)] + cfg_any_os_ext! { + pub(crate) mod windows; + } + + #[cfg(windows)] + cfg_net! { + pub(crate) use self::windows::SocketAddr; } } diff --git a/src/sys/shell/mod.rs b/src/sys/shell/mod.rs index 76085b8a2..979aeb44a 100644 --- a/src/sys/shell/mod.rs +++ b/src/sys/shell/mod.rs @@ -15,7 +15,6 @@ pub(crate) use self::waker::Waker; cfg_net! { pub(crate) mod tcp; pub(crate) mod udp; - #[cfg(unix)] pub(crate) mod uds; } diff --git a/src/sys/shell/uds.rs b/src/sys/shell/uds.rs index bac547b03..e04b262d1 100644 --- a/src/sys/shell/uds.rs +++ b/src/sys/shell/uds.rs @@ -1,5 +1,6 @@ +#[cfg(unix)] pub(crate) mod datagram { - use crate::net::SocketAddr; + use crate::sys::SocketAddr; use std::io; use std::os::unix::net; use std::path::Path; @@ -33,8 +34,12 @@ pub(crate) mod datagram { } pub(crate) mod listener { - use crate::net::{SocketAddr, UnixStream}; + use crate::net::UnixStream; + #[cfg(windows)] + use crate::sys::windows::stdnet as net; + use crate::sys::SocketAddr; use std::io; + #[cfg(unix)] use std::os::unix::net; use std::path::Path; @@ -56,8 +61,11 @@ pub(crate) mod listener { } pub(crate) mod stream { - use crate::net::SocketAddr; + #[cfg(windows)] + use crate::sys::windows::stdnet as net; + use crate::sys::SocketAddr; use std::io; + #[cfg(unix)] use std::os::unix::net; use std::path::Path; @@ -69,6 +77,7 @@ pub(crate) mod stream { os_required!() } + #[cfg(unix)] pub(crate) fn pair() -> io::Result<(net::UnixStream, net::UnixStream)> { os_required!() } diff --git a/src/sys/unix/mod.rs b/src/sys/unix/mod.rs index eb268b9f4..0e3fcc6f3 100644 --- a/src/sys/unix/mod.rs +++ b/src/sys/unix/mod.rs @@ -29,7 +29,7 @@ cfg_os_poll! { pub(crate) mod tcp; pub(crate) mod udp; pub(crate) mod uds; - pub use self::uds::SocketAddr; + pub(crate) use self::uds::SocketAddr; } cfg_io_source! { @@ -112,7 +112,7 @@ cfg_os_poll! { cfg_not_os_poll! { cfg_net! { mod uds; - pub use self::uds::SocketAddr; + pub(crate) use self::uds::SocketAddr; } cfg_any_os_ext! { diff --git a/src/sys/unix/pipe.rs b/src/sys/unix/pipe.rs index 8e92dd37d..c71855bc4 100644 --- a/src/sys/unix/pipe.rs +++ b/src/sys/unix/pipe.rs @@ -333,29 +333,29 @@ impl event::Source for Sender { impl Write for Sender { fn write(&mut self, buf: &[u8]) -> io::Result { - self.inner.do_io(|mut sender| sender.write(buf)) + self.inner.do_io(|sender| (&*sender).write(buf)) } fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { - self.inner.do_io(|mut sender| sender.write_vectored(bufs)) + self.inner.do_io(|sender| (&*sender).write_vectored(bufs)) } fn flush(&mut self) -> io::Result<()> { - self.inner.do_io(|mut sender| sender.flush()) + self.inner.do_io(|sender| (&*sender).flush()) } } impl Write for &Sender { fn write(&mut self, buf: &[u8]) -> io::Result { - self.inner.do_io(|mut sender| sender.write(buf)) + self.inner.do_io(|sender| (&*sender).write(buf)) } fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { - self.inner.do_io(|mut sender| sender.write_vectored(bufs)) + self.inner.do_io(|sender| (&*sender).write_vectored(bufs)) } fn flush(&mut self) -> io::Result<()> { - self.inner.do_io(|mut sender| sender.flush()) + self.inner.do_io(|sender| (&*sender).flush()) } } @@ -498,21 +498,21 @@ impl event::Source for Receiver { impl Read for Receiver { fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.inner.do_io(|mut sender| sender.read(buf)) + self.inner.do_io(|sender| (&*sender).read(buf)) } fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { - self.inner.do_io(|mut sender| sender.read_vectored(bufs)) + self.inner.do_io(|sender| (&*sender).read_vectored(bufs)) } } impl Read for &Receiver { fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.inner.do_io(|mut sender| sender.read(buf)) + self.inner.do_io(|sender| (&*sender).read(buf)) } fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { - self.inner.do_io(|mut sender| sender.read_vectored(bufs)) + self.inner.do_io(|sender| (&*sender).read_vectored(bufs)) } } diff --git a/src/sys/unix/uds/listener.rs b/src/sys/unix/uds/listener.rs index ff77c53bd..794a9f7bb 100644 --- a/src/sys/unix/uds/listener.rs +++ b/src/sys/unix/uds/listener.rs @@ -1,5 +1,6 @@ use super::socket_addr; -use crate::net::{SocketAddr, UnixStream}; +use super::SocketAddr; +use crate::net::UnixStream; use crate::sys::unix::net::new_socket; use std::os::unix::ffi::OsStrExt; use std::os::unix::io::{AsRawFd, FromRawFd}; diff --git a/src/sys/unix/uds/mod.rs b/src/sys/unix/uds/mod.rs index 20b668f80..e75412f65 100644 --- a/src/sys/unix/uds/mod.rs +++ b/src/sys/unix/uds/mod.rs @@ -1,5 +1,5 @@ mod socketaddr; -pub use self::socketaddr::SocketAddr; +pub(crate) use self::socketaddr::SocketAddr; /// Get the `sun_path` field offset of `sockaddr_un` for the target OS. /// @@ -61,7 +61,7 @@ cfg_os_poll! { let offset = path_offset(&sockaddr); let mut socklen = offset + bytes.len(); - match bytes.first() { + match bytes.get(0) { // The struct has already been zeroes so the null byte for pathname // addresses is already there. Some(&0) | None => {} diff --git a/src/sys/unix/uds/socketaddr.rs b/src/sys/unix/uds/socketaddr.rs index 8e0ef53a4..978f6c00b 100644 --- a/src/sys/unix/uds/socketaddr.rs +++ b/src/sys/unix/uds/socketaddr.rs @@ -1,32 +1,15 @@ use super::path_offset; +use crate::net::AddressKind; use std::ffi::OsStr; use std::os::unix::ffi::OsStrExt; -use std::path::Path; -use std::{ascii, fmt}; -/// An address associated with a `mio` specific Unix socket. -/// -/// This is implemented instead of imported from [`net::SocketAddr`] because -/// there is no way to create a [`net::SocketAddr`]. One must be returned by -/// [`accept`], so this is returned instead. -/// -/// [`net::SocketAddr`]: std::os::unix::net::SocketAddr -/// [`accept`]: #method.accept -pub struct SocketAddr { +pub(crate) struct SocketAddr { sockaddr: libc::sockaddr_un, socklen: libc::socklen_t, } -struct AsciiEscaped<'a>(&'a [u8]); - -enum AddressKind<'a> { - Unnamed, - Pathname(&'a Path), - Abstract(&'a [u8]), -} - impl SocketAddr { - fn address(&self) -> AddressKind<'_> { + pub(crate) fn address(&self) -> AddressKind<'_> { let offset = path_offset(&self.sockaddr); // Don't underflow in `len` below. if (self.socklen as usize) < offset { diff --git a/src/sys/windows/mod.rs b/src/sys/windows/mod.rs index f8b72fc49..07f7dda6c 100644 --- a/src/sys/windows/mod.rs +++ b/src/sys/windows/mod.rs @@ -1,151 +1,174 @@ -mod afd; +// Macro must be defined before any modules that uses them. +/// Helper macro to execute a system call that returns an `io::Result`. +#[allow(unused_macros)] +macro_rules! syscall { + ($fn: ident ( $($arg: expr),* $(,)* ), $err_test: path, $err_value: expr) => {{ + let res = unsafe { $fn($($arg, )*) }; + if $err_test(&res, &$err_value) { + Err(io::Error::last_os_error()) + } else { + Ok(res) + } + }}; +} -pub mod event; -pub use event::{Event, Events}; +/// Helper macro to execute a WinSock system call that returns an `io::Result`. +#[allow(unused_macros)] +macro_rules! wsa_syscall { + ($fn: ident ( $($arg: expr),* $(,)* ), $err_value: expr) => {{ + let res = unsafe { windows_sys::Win32::Networking::WinSock::$fn($($arg, )*) }; + if PartialEq::eq(&res, &$err_value) { + Err(std::io::Error::from_raw_os_error(unsafe { + windows_sys::Win32::Networking::WinSock::WSAGetLastError() + })) + } else { + Ok(res) + } + }}; +} -mod handle; -use handle::Handle; +cfg_net! { + pub(crate) mod stdnet; + pub(crate) mod uds; + pub(crate) use self::uds::SocketAddr; +} -mod io_status_block; -mod iocp; +cfg_os_poll! { + mod afd; -mod overlapped; -use overlapped::Overlapped; + pub mod event; + pub use event::{Event, Events}; -mod selector; -pub use selector::{Selector, SelectorInner, SockState}; + mod handle; + use handle::Handle; -// Macros must be defined before the modules that use them -cfg_net! { - /// Helper macro to execute a system call that returns an `io::Result`. - // - // Macro must be defined before any modules that uses them. - macro_rules! syscall { - ($fn: ident ( $($arg: expr),* $(,)* ), $err_test: path, $err_value: expr) => {{ - let res = unsafe { $fn($($arg, )*) }; - if $err_test(&res, &$err_value) { - Err(io::Error::last_os_error()) - } else { - Ok(res) - } - }}; - } + mod io_status_block; + mod iocp; - mod net; + mod overlapped; + use overlapped::Overlapped; - pub(crate) mod tcp; - pub(crate) mod udp; -} + mod selector; + pub use selector::{Selector, SelectorInner, SockState}; -cfg_os_ext! { - pub(crate) mod named_pipe; -} + // Macros must be defined before the modules that use them + cfg_net! { + mod net; -mod waker; -pub(crate) use waker::Waker; + pub(crate) mod tcp; + pub(crate) mod udp; + } -cfg_io_source! { - use std::io; - use std::os::windows::io::RawSocket; - use std::pin::Pin; - use std::sync::{Arc, Mutex}; + cfg_os_ext! { + pub(crate) mod named_pipe; + } - use crate::{Interest, Registry, Token}; + mod waker; + pub(crate) use waker::Waker; - struct InternalState { - selector: Arc, - token: Token, - interests: Interest, - sock_state: Pin>>, - } + cfg_io_source! { + use std::io; + use std::os::windows::io::RawSocket; + use std::pin::Pin; + use std::sync::{Arc, Mutex}; - impl Drop for InternalState { - fn drop(&mut self) { - let mut sock_state = self.sock_state.lock().unwrap(); - sock_state.mark_delete(); + use crate::{Interest, Registry, Token}; + + struct InternalState { + selector: Arc, + token: Token, + interests: Interest, + sock_state: Pin>>, } - } - pub struct IoSourceState { - // This is `None` if the socket has not yet been registered. - // - // We box the internal state to not increase the size on the stack as the - // type might move around a lot. - inner: Option>, - } + impl Drop for InternalState { + fn drop(&mut self) { + let mut sock_state = self.sock_state.lock().unwrap(); + sock_state.mark_delete(); + } + } - impl IoSourceState { - pub fn new() -> IoSourceState { - IoSourceState { inner: None } + pub struct IoSourceState { + // This is `None` if the socket has not yet been registered. + // + // We box the internal state to not increase the size on the stack as the + // type might move around a lot. + inner: Option>, } - pub fn do_io(&self, f: F, io: &T) -> io::Result - where - F: FnOnce(&T) -> io::Result, - { - let result = f(io); - if let Err(ref e) = result { - if e.kind() == io::ErrorKind::WouldBlock { - self.inner.as_ref().map_or(Ok(()), |state| { - state - .selector - .reregister(state.sock_state.clone(), state.token, state.interests) - })?; - } + impl IoSourceState { + pub fn new() -> IoSourceState { + IoSourceState { inner: None } } - result - } - pub fn register( - &mut self, - registry: &Registry, - token: Token, - interests: Interest, - socket: RawSocket, - ) -> io::Result<()> { - if self.inner.is_some() { - Err(io::ErrorKind::AlreadyExists.into()) - } else { - registry - .selector() - .register(socket, token, interests) - .map(|state| { - self.inner = Some(Box::new(state)); - }) + pub fn do_io(&self, f: F, io: &T) -> io::Result + where + F: FnOnce(&T) -> io::Result, + { + let result = f(io); + if let Err(ref e) = result { + if e.kind() == io::ErrorKind::WouldBlock { + self.inner.as_ref().map_or(Ok(()), |state| { + state + .selector + .reregister(state.sock_state.clone(), state.token, state.interests) + })?; + } + } + result } - } - pub fn reregister( - &mut self, - registry: &Registry, - token: Token, - interests: Interest, - ) -> io::Result<()> { - match self.inner.as_mut() { - Some(state) => { + pub fn register( + &mut self, + registry: &Registry, + token: Token, + interests: Interest, + socket: RawSocket, + ) -> io::Result<()> { + if self.inner.is_some() { + Err(io::ErrorKind::AlreadyExists.into()) + } else { registry .selector() - .reregister(state.sock_state.clone(), token, interests) - .map(|()| { - state.token = token; - state.interests = interests; + .register(socket, token, interests) + .map(|state| { + self.inner = Some(Box::new(state)); }) } - None => Err(io::ErrorKind::NotFound.into()), } - } - pub fn deregister(&mut self) -> io::Result<()> { - match self.inner.as_mut() { - Some(state) => { - { - let mut sock_state = state.sock_state.lock().unwrap(); - sock_state.mark_delete(); + pub fn reregister( + &mut self, + registry: &Registry, + token: Token, + interests: Interest, + ) -> io::Result<()> { + match self.inner.as_mut() { + Some(state) => { + registry + .selector() + .reregister(state.sock_state.clone(), token, interests) + .map(|()| { + state.token = token; + state.interests = interests; + }) + } + None => Err(io::ErrorKind::NotFound.into()), + } + } + + pub fn deregister(&mut self) -> io::Result<()> { + match self.inner.as_mut() { + Some(state) => { + { + let mut sock_state = state.sock_state.lock().unwrap(); + sock_state.mark_delete(); + } + self.inner = None; + Ok(()) } - self.inner = None; - Ok(()) + None => Err(io::ErrorKind::NotFound.into()), } - None => Err(io::ErrorKind::NotFound.into()), } } } diff --git a/src/sys/windows/net.rs b/src/sys/windows/net.rs index 5cc235335..38b17492c 100644 --- a/src/sys/windows/net.rs +++ b/src/sys/windows/net.rs @@ -1,7 +1,6 @@ use std::io; use std::mem; use std::net::SocketAddr; -use std::sync::Once; use windows_sys::Win32::Networking::WinSock::{ closesocket, ioctlsocket, socket, AF_INET, AF_INET6, FIONBIO, IN6_ADDR, IN6_ADDR_0, diff --git a/src/sys/windows/stdnet/addr.rs b/src/sys/windows/stdnet/addr.rs new file mode 100644 index 000000000..26b1fddde --- /dev/null +++ b/src/sys/windows/stdnet/addr.rs @@ -0,0 +1,124 @@ +use crate::net::AddressKind; +use std::os::raw::c_int; +use std::path::Path; +use std::{fmt, io, mem}; + +use windows_sys::Win32::Networking::WinSock::{sockaddr_un, SOCKADDR}; + +fn path_offset(addr: &sockaddr_un) -> usize { + // Work with an actual instance of the type since using a null pointer is UB + let base = addr as *const _ as usize; + let path = &addr.sun_path as *const _ as usize; + path - base +} + +cfg_os_poll! { + use windows_sys::Win32::Networking::WinSock::AF_UNIX; + pub(super) fn socket_addr(path: &Path) -> io::Result<(sockaddr_un, c_int)> { + let sockaddr = mem::MaybeUninit::::zeroed(); + + // This is safe to assume because a `sockaddr_un` filled with `0` + // bytes is properly initialized. + // + // `0` is a valid value for `sockaddr_un::sun_family`; it is + // `WinSock::AF_UNSPEC`. + // + // `[0; 108]` is a valid value for `sockaddr_un::sun_path`; it begins an + // abstract path. + let mut sockaddr = unsafe { sockaddr.assume_init() }; + sockaddr.sun_family = AF_UNIX; + + // Winsock2 expects 'sun_path' to be a Win32 UTF-8 file system path + let bytes = path.to_str().map(|s| s.as_bytes()).ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + "path contains invalid characters", + ) + })?; + + if bytes.contains(&0) { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "paths may not contain interior null bytes", + )); + } + + if bytes.len() >= sockaddr.sun_path.len() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "path must be shorter than SUN_LEN", + )); + } + + sockaddr.sun_path[..bytes.len()].copy_from_slice(bytes); + + let offset = path_offset(&sockaddr); + let mut socklen = offset + bytes.len(); + + match bytes.first() { + // The struct has already been zeroes so the null byte for pathname + // addresses is already there. + Some(&0) | None => {} + Some(_) => socklen += 1, + } + + Ok((sockaddr, socklen as c_int)) + } +} + +pub(crate) struct SocketAddr { + addr: sockaddr_un, + len: c_int, +} + +impl SocketAddr { + pub(crate) fn init(f: F) -> io::Result<(T, SocketAddr)> + where + F: FnOnce(*mut SOCKADDR, *mut c_int) -> io::Result, + { + let mut sockaddr = { + let sockaddr = mem::MaybeUninit::::zeroed(); + unsafe { sockaddr.assume_init() } + }; + + let mut len = mem::size_of::() as c_int; + let result = f(&mut sockaddr as *mut _ as *mut _, &mut len)?; + Ok(( + result, + SocketAddr { + addr: sockaddr, + len, + }, + )) + } + + pub(crate) fn new(f: F) -> io::Result + where + F: FnOnce(*mut SOCKADDR, *mut c_int) -> io::Result, + { + SocketAddr::init(f).map(|(_, addr)| addr) + } + + pub(crate) fn address(&self) -> AddressKind<'_> { + let len = self.len as usize - path_offset(&self.addr); + // sockaddr_un::sun_path on Windows is a Win32 UTF-8 file system path + + // macOS seems to return a len of 16 and a zeroed sun_path for unnamed addresses + if len == 0 { + AddressKind::Unnamed + } else if self.addr.sun_path[0] == 0 { + AddressKind::Abstract(&self.addr.sun_path[1..len]) + } else { + use std::ffi::CStr; + let pathname = + unsafe { CStr::from_bytes_with_nul_unchecked(&self.addr.sun_path[..len]) }; + AddressKind::Pathname(Path::new(pathname.to_str().unwrap())) + } + } +} + +impl fmt::Debug for SocketAddr { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "{:?}", self.address()) + } +} diff --git a/src/sys/windows/stdnet/listener.rs b/src/sys/windows/stdnet/listener.rs new file mode 100644 index 000000000..214167276 --- /dev/null +++ b/src/sys/windows/stdnet/listener.rs @@ -0,0 +1,83 @@ +use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket}; +use std::{fmt, io, mem}; + +use windows_sys::Win32::Networking::WinSock::SOCKET_ERROR; + +use super::{socket::Socket, SocketAddr}; + +pub(crate) struct UnixListener(Socket); + +impl UnixListener { + pub(crate) fn local_addr(&self) -> io::Result { + SocketAddr::new(|addr, len| { + wsa_syscall!( + getsockname(self.0.as_raw_socket() as _, addr, len), + SOCKET_ERROR + ) + }) + } + + pub(crate) fn take_error(&self) -> io::Result> { + self.0.take_error() + } +} + +cfg_os_poll! { + use std::path::Path; + + use super::{socket_addr, UnixStream}; + + impl UnixListener { + pub(crate) fn bind>(path: P) -> io::Result { + let inner = Socket::new()?; + let (addr, len) = socket_addr(path.as_ref())?; + + wsa_syscall!( + bind(inner.as_raw_socket() as _, &addr as *const _ as *const _, len as _), + SOCKET_ERROR + )?; + wsa_syscall!(listen(inner.as_raw_socket() as _, 1024), SOCKET_ERROR)?; + Ok(UnixListener(inner)) + } + + pub(crate) fn accept(&self) -> io::Result<(UnixStream, SocketAddr)> { + SocketAddr::init(|addr, len| self.0.accept(addr, len)) + .map(|(sock, addr)| (UnixStream(sock), addr)) + } + + pub(crate) fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { + self.0.set_nonblocking(nonblocking) + } + } +} + +impl fmt::Debug for UnixListener { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut builder = fmt.debug_struct("UnixListener"); + builder.field("socket", &self.0.as_raw_socket()); + if let Ok(addr) = self.local_addr() { + builder.field("local", &addr); + } + builder.finish() + } +} + +impl AsRawSocket for UnixListener { + fn as_raw_socket(&self) -> RawSocket { + self.0.as_raw_socket() + } +} + +impl FromRawSocket for UnixListener { + unsafe fn from_raw_socket(sock: RawSocket) -> Self { + UnixListener(Socket::from_raw_socket(sock)) + } +} + +impl IntoRawSocket for UnixListener { + fn into_raw_socket(self) -> RawSocket { + let ret = self.0.as_raw_socket(); + mem::forget(self); + ret + } +} diff --git a/src/sys/windows/stdnet/mod.rs b/src/sys/windows/stdnet/mod.rs new file mode 100644 index 000000000..0eb5130d4 --- /dev/null +++ b/src/sys/windows/stdnet/mod.rs @@ -0,0 +1,26 @@ +//! Implementation of blocking UDS types for windows, mirrors std::os::unix::net. +mod addr; +mod listener; +mod socket; +mod stream; + +pub(crate) use self::addr::SocketAddr; +pub(crate) use self::listener::UnixListener; +pub(crate) use self::stream::UnixStream; + +cfg_os_poll! { + pub(self) use self::addr::socket_addr; + + use std::sync::Once; + + /// Initialise the network stack for Windows. + pub(crate) fn init() { + static INIT: Once = Once::new(); + INIT.call_once(|| { + // Let standard library call `WSAStartup` for us, we can't do it + // ourselves because otherwise using any type in `std::net` would panic + // when it tries to call `WSAStartup` a second time. + drop(std::net::UdpSocket::bind("127.0.0.1:0")); + }); + } +} diff --git a/src/sys/windows/stdnet/socket.rs b/src/sys/windows/stdnet/socket.rs new file mode 100644 index 000000000..9212c1e04 --- /dev/null +++ b/src/sys/windows/stdnet/socket.rs @@ -0,0 +1,186 @@ +use std::cmp::min; +use std::convert::TryInto; +use std::io::{self, IoSlice, IoSliceMut}; +use std::mem; +use std::net::Shutdown; +use std::os::raw::c_int; +use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket}; +use std::ptr; + +use windows_sys::Win32::Networking::WinSock::{self, closesocket, SOCKET, SOCKET_ERROR, WSABUF}; + +/// Maximum size of a buffer passed to system call like `recv` and `send`. +const MAX_BUF_LEN: usize = c_int::MAX as usize; + +#[derive(Debug)] +pub(crate) struct Socket(SOCKET); + +impl Socket { + pub fn recv(&self, buf: &mut [u8]) -> io::Result { + let ret = wsa_syscall!( + recv(self.0, buf.as_mut_ptr() as *mut _, buf.len() as c_int, 0,), + SOCKET_ERROR + )?; + Ok(ret as usize) + } + + pub fn recv_vectored(&self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { + let mut total = 0; + let mut flags: u32 = 0; + let bufs = unsafe { &mut *(bufs as *mut [IoSliceMut<'_>] as *mut [WSABUF]) }; + let res = wsa_syscall!( + WSARecv( + self.0, + bufs.as_mut_ptr().cast(), + min(bufs.len(), u32::MAX as usize) as u32, + &mut total, + &mut flags, + ptr::null_mut(), + None, + ), + SOCKET_ERROR + ); + match res { + Ok(_) => Ok(total as usize), + Err(ref err) if err.raw_os_error() == Some(WinSock::WSAESHUTDOWN as i32) => Ok(0), + Err(err) => Err(err), + } + } + + pub fn send(&self, buf: &[u8]) -> io::Result { + wsa_syscall!( + send( + self.0, + buf.as_ptr().cast(), + min(buf.len(), MAX_BUF_LEN) as c_int, + 0, + ), + SOCKET_ERROR + ) + .map(|n| n as usize) + } + + pub fn send_vectored(&self, bufs: &[IoSlice<'_>]) -> io::Result { + let mut total = 0; + wsa_syscall!( + WSASend( + self.0, + // FIXME: From the `WSASend` docs [1]: + // > For a Winsock application, once the WSASend function is called, + // > the system owns these buffers and the application may not + // > access them. + // + // So what we're doing is actually UB as `bufs` needs to be `&mut + // [IoSlice<'_>]`. + // + // See: https://github.com/rust-lang/socket2-rs/issues/129. + // + // [1] https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-wsasend + bufs.as_ptr() as *mut _, + min(bufs.len(), u32::MAX as usize) as u32, + &mut total, + 0, + std::ptr::null_mut(), + None, + ), + SOCKET_ERROR + ) + .map(|_| total as usize) + } + + pub fn shutdown(&self, how: Shutdown) -> io::Result<()> { + let how = match how { + Shutdown::Write => WinSock::SD_SEND, + Shutdown::Read => WinSock::SD_RECEIVE, + Shutdown::Both => WinSock::SD_BOTH, + }; + wsa_syscall!(shutdown(self.0, how.try_into().unwrap()), SOCKET_ERROR)?; + Ok(()) + } + + pub fn take_error(&self) -> io::Result> { + let mut val: mem::MaybeUninit = mem::MaybeUninit::uninit(); + let mut len = mem::size_of::() as i32; + wsa_syscall!( + getsockopt( + self.0 as _, + WinSock::SOL_SOCKET.try_into().unwrap(), + WinSock::SO_ERROR.try_into().unwrap(), + &mut val as *mut _ as *mut _, + &mut len, + ), + SOCKET_ERROR + )?; + assert_eq!(len as usize, mem::size_of::()); + let val = unsafe { val.assume_init() }; + if val == 0 { + Ok(None) + } else { + Ok(Some(io::Error::from_raw_os_error(val as i32))) + } + } +} + +cfg_os_poll! { + use windows_sys::Win32::Networking::WinSock::{INVALID_SOCKET, SOCKADDR}; + use super::init; + + impl Socket { + pub fn new() -> io::Result { + init(); + wsa_syscall!( + WSASocketW( + WinSock::AF_UNIX.into(), + WinSock::SOCK_STREAM.into(), + 0, + ptr::null_mut(), + 0, + WinSock::WSA_FLAG_OVERLAPPED | WinSock::WSA_FLAG_NO_HANDLE_INHERIT, + ), + INVALID_SOCKET + ).map(Socket) + } + + pub fn accept(&self, storage: *mut SOCKADDR, len: *mut c_int) -> io::Result { + // WinSock's accept returns a socket with the same properties as the listener. it is + // called on. In particular, the WSA_FLAG_NO_HANDLE_INHERIT will be inherited from the + // listener. + wsa_syscall!(accept(self.0, storage, len), INVALID_SOCKET).map(Socket) + } + + pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { + let mut nonblocking = if nonblocking { 1 } else { 0 }; + wsa_syscall!( + ioctlsocket(self.0, WinSock::FIONBIO, &mut nonblocking), + SOCKET_ERROR + )?; + Ok(()) + } + } +} + +impl Drop for Socket { + fn drop(&mut self) { + let _ = unsafe { closesocket(self.0) }; + } +} + +impl AsRawSocket for Socket { + fn as_raw_socket(&self) -> RawSocket { + self.0 as RawSocket + } +} + +impl FromRawSocket for Socket { + unsafe fn from_raw_socket(sock: RawSocket) -> Self { + Socket(sock as SOCKET) + } +} + +impl IntoRawSocket for Socket { + fn into_raw_socket(self) -> RawSocket { + let ret = self.0 as RawSocket; + mem::forget(self); + ret + } +} diff --git a/src/sys/windows/stdnet/stream.rs b/src/sys/windows/stdnet/stream.rs new file mode 100644 index 000000000..ce1da2f54 --- /dev/null +++ b/src/sys/windows/stdnet/stream.rs @@ -0,0 +1,151 @@ +use std::io::{self, IoSlice, IoSliceMut}; +use std::net::Shutdown; +use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket}; +use std::{fmt, mem}; + +use windows_sys::Win32::Networking::WinSock::SOCKET_ERROR; + +use super::{socket::Socket, SocketAddr}; + +pub(crate) struct UnixStream(pub(super) Socket); + +impl UnixStream { + pub(crate) fn local_addr(&self) -> io::Result { + SocketAddr::new(|addr, len| { + wsa_syscall!( + getsockname(self.0.as_raw_socket() as _, addr, len), + SOCKET_ERROR + ) + }) + } + + pub(crate) fn peer_addr(&self) -> io::Result { + SocketAddr::new(|addr, len| { + wsa_syscall!( + getpeername(self.0.as_raw_socket() as _, addr, len), + SOCKET_ERROR + ) + }) + } + + pub(crate) fn take_error(&self) -> io::Result> { + self.0.take_error() + } + + pub(crate) fn shutdown(&self, how: Shutdown) -> io::Result<()> { + self.0.shutdown(how) + } +} + +cfg_os_poll! { + use std::path::Path; + use windows_sys::Win32::Networking::WinSock::WSAEINPROGRESS; + use super::socket_addr; + + impl UnixStream { + pub(crate) fn connect>(path: P) -> io::Result { + let inner = Socket::new()?; + let (addr, len) = socket_addr(path.as_ref())?; + + match wsa_syscall!( + connect( + inner.as_raw_socket() as _, + &addr as *const _ as *const _, + len as i32, + ), + SOCKET_ERROR + ) { + Ok(_) => {} + Err(ref err) if err.raw_os_error() == Some(WSAEINPROGRESS) => {} + Err(e) => return Err(e), + } + Ok(UnixStream(inner)) + } + + pub(crate) fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { + self.0.set_nonblocking(nonblocking) + } + } +} + +impl fmt::Debug for UnixStream { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut builder = fmt.debug_struct("UnixStream"); + builder.field("socket", &self.0.as_raw_socket()); + if let Ok(addr) = self.local_addr() { + builder.field("local", &addr); + } + if let Ok(addr) = self.peer_addr() { + builder.field("peer", &addr); + } + builder.finish() + } +} + +impl io::Read for UnixStream { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + io::Read::read(&mut &*self, buf) + } + + fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { + io::Read::read_vectored(&mut &*self, bufs) + } +} + +impl<'a> io::Read for &'a UnixStream { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.0.recv(buf) + } + + fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { + self.0.recv_vectored(bufs) + } +} + +impl io::Write for UnixStream { + fn write(&mut self, buf: &[u8]) -> io::Result { + io::Write::write(&mut &*self, buf) + } + + fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { + io::Write::write_vectored(&mut &*self, bufs) + } + + fn flush(&mut self) -> io::Result<()> { + io::Write::flush(&mut &*self) + } +} + +impl<'a> io::Write for &'a UnixStream { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.0.send(buf) + } + + fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { + self.0.send_vectored(bufs) + } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} + +impl AsRawSocket for UnixStream { + fn as_raw_socket(&self) -> RawSocket { + self.0.as_raw_socket() + } +} + +impl FromRawSocket for UnixStream { + unsafe fn from_raw_socket(sock: RawSocket) -> Self { + UnixStream(Socket::from_raw_socket(sock)) + } +} + +impl IntoRawSocket for UnixStream { + fn into_raw_socket(self) -> RawSocket { + let ret = self.0.as_raw_socket(); + mem::forget(self); + ret + } +} diff --git a/src/sys/windows/uds/listener.rs b/src/sys/windows/uds/listener.rs new file mode 100644 index 000000000..4ba4395e5 --- /dev/null +++ b/src/sys/windows/uds/listener.rs @@ -0,0 +1,23 @@ +use std::io; +use std::os::windows::io::AsRawSocket; +use std::path::Path; + +use super::SocketAddr; +use crate::net::UnixStream; +use crate::sys::windows::stdnet as net; + +pub(crate) fn bind(path: &Path) -> io::Result { + let listener = net::UnixListener::bind(path)?; + listener.set_nonblocking(true)?; + Ok(listener) +} + +pub(crate) fn accept(listener: &net::UnixListener) -> io::Result<(UnixStream, SocketAddr)> { + listener + .accept() + .map(|(stream, addr)| (UnixStream::from_std(stream), addr)) +} + +pub(crate) fn local_addr(listener: &net::UnixListener) -> io::Result { + super::local_addr(listener.as_raw_socket()) +} diff --git a/src/sys/windows/uds/mod.rs b/src/sys/windows/uds/mod.rs new file mode 100644 index 000000000..b99c01e42 --- /dev/null +++ b/src/sys/windows/uds/mod.rs @@ -0,0 +1,29 @@ +pub(crate) use super::stdnet::SocketAddr; + +cfg_os_poll! { + use std::convert::TryInto; + use windows_sys::Win32::Networking::WinSock::SOCKET_ERROR; + use std::os::windows::io::RawSocket; + use std::io; + + pub(crate) mod listener; + pub(crate) mod stream; + + pub(crate) fn local_addr(socket: RawSocket) -> io::Result { + SocketAddr::new(|sockaddr, socklen| { + wsa_syscall!( + getsockname(socket.try_into().unwrap(), sockaddr, socklen), + SOCKET_ERROR + ) + }) + } + + pub(crate) fn peer_addr(socket: RawSocket) -> io::Result { + SocketAddr::new(|sockaddr, socklen| { + wsa_syscall!( + getpeername(socket.try_into().unwrap(), sockaddr, socklen), + SOCKET_ERROR + ) + }) + } +} diff --git a/src/sys/windows/uds/stream.rs b/src/sys/windows/uds/stream.rs new file mode 100644 index 000000000..b02f32e8f --- /dev/null +++ b/src/sys/windows/uds/stream.rs @@ -0,0 +1,19 @@ +use super::SocketAddr; +use crate::sys::windows::stdnet as net; +use std::io; +use std::os::windows::io::AsRawSocket; +use std::path::Path; + +pub(crate) fn connect(path: &Path) -> io::Result { + let socket = net::UnixStream::connect(path)?; + socket.set_nonblocking(true)?; + Ok(socket) +} + +pub(crate) fn local_addr(socket: &net::UnixStream) -> io::Result { + super::local_addr(socket.as_raw_socket()) +} + +pub(crate) fn peer_addr(socket: &net::UnixStream) -> io::Result { + super::peer_addr(socket.as_raw_socket()) +} diff --git a/tests/unix_listener.rs b/tests/unix_listener.rs index d1d9cf07d..c131497cc 100644 --- a/tests/unix_listener.rs +++ b/tests/unix_listener.rs @@ -1,8 +1,11 @@ -#![cfg(all(unix, feature = "os-poll", feature = "net"))] +#![cfg(all(feature = "os-poll", feature = "net"))] +#[cfg(windows)] +use mio::net; use mio::net::UnixListener; use mio::{Interest, Token}; use std::io::{self, Read}; +#[cfg(unix)] use std::os::unix::net; use std::path::{Path, PathBuf}; use std::sync::{Arc, Barrier}; @@ -30,6 +33,7 @@ fn unix_listener_smoke() { smoke_test(|path| UnixListener::bind(path), "unix_listener_smoke"); } +#[cfg(unix)] #[test] fn unix_listener_from_std() { smoke_test( @@ -135,7 +139,7 @@ fn unix_listener_deregister() { #[cfg(target_os = "linux")] #[test] -fn unix_listener_abstract_namespace() { +fn unix_listener_abstract_namesapce() { use rand::Rng; let num: u64 = rand::thread_rng().gen(); let name = format!("\u{0000}-mio-abstract-uds-{}", num); diff --git a/tests/unix_pipe.rs b/tests/unix_pipe.rs index f8e6464c9..a83e3833b 100644 --- a/tests/unix_pipe.rs +++ b/tests/unix_pipe.rs @@ -49,7 +49,7 @@ fn smoke() { ); let n = receiver.read(&mut buf).unwrap(); assert_eq!(n, DATA1.len()); - assert_eq!(&buf[..n], DATA1); + assert_eq!(&buf[..n], &*DATA1); } #[test] @@ -162,7 +162,7 @@ fn from_child_process_io() { let mut buf = [0; 20]; let n = receiver.read(&mut buf).unwrap(); assert_eq!(n, DATA1.len()); - assert_eq!(&buf[..n], DATA1); + assert_eq!(&buf[..n], &*DATA1); drop(sender); diff --git a/tests/unix_stream.rs b/tests/unix_stream.rs index 9bb9d52fd..93da5b597 100644 --- a/tests/unix_stream.rs +++ b/tests/unix_stream.rs @@ -1,14 +1,19 @@ -#![cfg(all(unix, feature = "os-poll", feature = "net"))] +#![cfg(all(feature = "os-poll", feature = "net"))] +#[cfg(windows)] +use mio::net; use mio::net::UnixStream; use mio::{Interest, Token}; use std::io::{self, IoSlice, IoSliceMut, Read, Write}; use std::net::Shutdown; +#[cfg(unix)] use std::os::unix::net; use std::path::Path; use std::sync::mpsc::channel; use std::sync::{Arc, Barrier}; use std::thread; +#[cfg(windows)] +use std::time::Duration; #[macro_use] mod util; @@ -24,6 +29,7 @@ const DATA1_LEN: usize = 16; const DATA2_LEN: usize = 14; const DEFAULT_BUF_SIZE: usize = 64; const TOKEN_1: Token = Token(0); +#[cfg(unix)] const TOKEN_2: Token = Token(1); #[test] @@ -77,6 +83,7 @@ fn unix_stream_connect() { handle.join().unwrap(); } +#[cfg(unix)] #[test] fn unix_stream_connect_addr() { let (mut poll, mut events) = init_with_poll(); @@ -133,6 +140,7 @@ fn unix_stream_from_std() { ) } +#[cfg(unix)] #[test] fn unix_stream_pair() { let (mut poll, mut events) = init_with_poll(); @@ -287,7 +295,13 @@ fn unix_stream_shutdown_write() { ); let err = stream.write(DATA2).unwrap_err(); + #[cfg(unix)] assert_eq!(err.kind(), io::ErrorKind::BrokenPipe); + #[cfg(windows)] + { + use windows_sys::Win32::Networking::WinSock::WSAESHUTDOWN; + assert_eq!(err.raw_os_error(), Some(WSAESHUTDOWN)); + } // Read should be ok let mut buf = [0; DEFAULT_BUF_SIZE]; @@ -352,8 +366,8 @@ fn unix_stream_shutdown_both() { let err = stream.write(DATA2).unwrap_err(); #[cfg(unix)] assert_eq!(err.kind(), io::ErrorKind::BrokenPipe); - #[cfg(window)] - assert_eq!(err.kind(), io::ErrorKind::ConnectionAbroted); + #[cfg(windows)] + assert_eq!(err.kind(), io::ErrorKind::ConnectionAborted); // Close the connection to allow the remote to shutdown drop(stream); @@ -520,70 +534,107 @@ where handle.join().unwrap(); } -fn new_echo_listener( +#[cfg(windows)] +fn new_listener( connections: usize, test_name: &'static str, -) -> (thread::JoinHandle<()>, net::SocketAddr) { + handle_stream: F, +) -> (thread::JoinHandle<()>, net::SocketAddr) +where + F: Fn(net::UnixStream) + std::marker::Send + 'static, +{ let (addr_sender, addr_receiver) = channel(); let handle = thread::spawn(move || { let path = temp_file(test_name); - let listener = net::UnixListener::bind(path).unwrap(); + // We use mio's non-blocking listener here for windows, since there is no listener in std + // yet. We must be sure to poll before listener I/O. + let mut listener = net::UnixListener::bind(path).unwrap(); + let (mut poll, mut events) = init_with_poll(); + poll.registry() + .register(&mut listener, TOKEN_1, Interest::READABLE) + .unwrap(); + let local_addr = listener.local_addr().unwrap(); addr_sender.send(local_addr).unwrap(); for _ in 0..connections { - let (mut stream, _) = listener.accept().unwrap(); - - // On Linux based system it will cause a connection reset - // error when the reading side of the peer connection is - // shutdown, we don't consider it an actual here. - let (mut read, mut written) = (0, 0); - let mut buf = [0; DEFAULT_BUF_SIZE]; - loop { - let n = match stream.read(&mut buf) { - Ok(amount) => { - read += amount; - amount - } - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => continue, - Err(ref err) if err.kind() == io::ErrorKind::ConnectionReset => break, - Err(err) => panic!("{}", err), - }; - if n == 0 { - break; - } - match stream.write(&buf[..n]) { - Ok(amount) => written += amount, - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => continue, - Err(ref err) if err.kind() == io::ErrorKind::BrokenPipe => break, - Err(err) => panic!("{}", err), - }; - } - assert_eq!(read, written, "unequal reads and writes"); + poll.poll(&mut events, Some(Duration::from_millis(500))) + .unwrap(); + let (stream, _) = listener.accept().unwrap(); + assert_would_block(listener.accept()); + handle_stream(stream); } }); (handle, addr_receiver.recv().unwrap()) } -fn new_noop_listener( +#[cfg(unix)] +fn new_listener( connections: usize, - barrier: Arc, test_name: &'static str, -) -> (thread::JoinHandle<()>, net::SocketAddr) { - let (sender, receiver) = channel(); + handle_stream: F, +) -> (thread::JoinHandle<()>, net::SocketAddr) +where + F: Fn(net::UnixStream) + std::marker::Send + 'static, +{ + let (addr_sender, addr_receiver) = channel(); let handle = thread::spawn(move || { let path = temp_file(test_name); let listener = net::UnixListener::bind(path).unwrap(); let local_addr = listener.local_addr().unwrap(); - sender.send(local_addr).unwrap(); + addr_sender.send(local_addr).unwrap(); for _ in 0..connections { let (stream, _) = listener.accept().unwrap(); - barrier.wait(); - stream.shutdown(Shutdown::Write).unwrap(); - barrier.wait(); - drop(stream); + handle_stream(stream); } }); - (handle, receiver.recv().unwrap()) + (handle, addr_receiver.recv().unwrap()) +} + +fn new_echo_listener( + connections: usize, + test_name: &'static str, +) -> (thread::JoinHandle<()>, net::SocketAddr) { + new_listener(connections, test_name, |mut stream| { + // On Linux based system it will cause a connection reset + // error when the reading side of the peer connection is + // shutdown, we don't consider it an actual here. + let (mut read, mut written) = (0, 0); + let mut buf = [0; DEFAULT_BUF_SIZE]; + loop { + let n = match stream.read(&mut buf) { + Ok(amount) => { + read += amount; + amount + } + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => continue, + Err(ref err) if err.kind() == io::ErrorKind::ConnectionReset => break, + Err(err) => panic!("{}", err), + }; + if n == 0 { + break; + } + match stream.write(&buf[..n]) { + Ok(amount) => written += amount, + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => continue, + Err(ref err) if err.kind() == io::ErrorKind::BrokenPipe => break, + Err(err) => panic!("{}", err), + }; + } + assert_eq!(read, written, "unequal reads and writes"); + }) +} + +fn new_noop_listener( + connections: usize, + barrier: Arc, + test_name: &'static str, +) -> (thread::JoinHandle<()>, net::SocketAddr) { + new_listener(connections, test_name, move |stream| { + barrier.wait(); + stream.shutdown(Shutdown::Write).unwrap(); + barrier.wait(); + drop(stream); + }) } From 48db06fdd4e23ebc4d53c3d83f40c18395fa6eac Mon Sep 17 00:00:00 2001 From: Kolby ML <31669092+KolbyML@users.noreply.github.com> Date: Mon, 1 May 2023 12:36:45 -0600 Subject: [PATCH 2/3] Make Windows UDS work with tests and clean implementation --- src/lib.rs | 4 + src/net/tcp/stream.rs | 20 ++--- src/net/uds/listener.rs | 9 -- src/net/uds/stream.rs | 84 ++++++++---------- src/sys/unix/pipe.rs | 20 ++--- src/sys/unix/uds/mod.rs | 2 +- src/sys/windows/iocp.rs | 4 +- src/sys/windows/mod.rs | 2 +- src/sys/windows/net.rs | 5 +- src/sys/windows/stdnet/addr.rs | 55 +++++++++--- src/sys/windows/stdnet/listener.rs | 64 +++++++++----- src/sys/windows/stdnet/mod.rs | 6 +- src/sys/windows/stdnet/socket.rs | 91 +++++++++++--------- src/sys/windows/stdnet/stream.rs | 34 +++++--- src/sys/windows/udp.rs | 4 +- src/sys/windows/uds/listener.rs | 6 +- tests/unix_listener.rs | 6 +- tests/unix_pipe.rs | 4 +- tests/unix_stream.rs | 134 ++++++++++------------------- tests/util/mod.rs | 4 +- 20 files changed, 289 insertions(+), 269 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 56a7160be..aabd716b0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -91,6 +91,10 @@ pub mod windows { //! Windows only extensions. pub use crate::sys::named_pipe::NamedPipe; + // blocking windows uds which mimick std implementation used for tests + cfg_net! { + pub use crate::sys::windows::stdnet; + } } pub mod features { diff --git a/src/net/tcp/stream.rs b/src/net/tcp/stream.rs index 3264904f5..8a3f6a2f2 100644 --- a/src/net/tcp/stream.rs +++ b/src/net/tcp/stream.rs @@ -269,49 +269,49 @@ impl TcpStream { impl Read for TcpStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.inner.do_io(|inner| (&*inner).read(buf)) + self.inner.do_io(|mut inner| inner.read(buf)) } fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { - self.inner.do_io(|inner| (&*inner).read_vectored(bufs)) + self.inner.do_io(|mut inner| inner.read_vectored(bufs)) } } impl<'a> Read for &'a TcpStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.inner.do_io(|inner| (&*inner).read(buf)) + self.inner.do_io(|mut inner| inner.read(buf)) } fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { - self.inner.do_io(|inner| (&*inner).read_vectored(bufs)) + self.inner.do_io(|mut inner| inner.read_vectored(bufs)) } } impl Write for TcpStream { fn write(&mut self, buf: &[u8]) -> io::Result { - self.inner.do_io(|inner| (&*inner).write(buf)) + self.inner.do_io(|mut inner| inner.write(buf)) } fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { - self.inner.do_io(|inner| (&*inner).write_vectored(bufs)) + self.inner.do_io(|mut inner| inner.write_vectored(bufs)) } fn flush(&mut self) -> io::Result<()> { - self.inner.do_io(|inner| (&*inner).flush()) + self.inner.do_io(|mut inner| inner.flush()) } } impl<'a> Write for &'a TcpStream { fn write(&mut self, buf: &[u8]) -> io::Result { - self.inner.do_io(|inner| (&*inner).write(buf)) + self.inner.do_io(|mut inner| inner.write(buf)) } fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { - self.inner.do_io(|inner| (&*inner).write_vectored(bufs)) + self.inner.do_io(|mut inner| inner.write_vectored(bufs)) } fn flush(&mut self) -> io::Result<()> { - self.inner.do_io(|inner| (&*inner).flush()) + self.inner.do_io(|mut inner| inner.flush()) } } diff --git a/src/net/uds/listener.rs b/src/net/uds/listener.rs index fba8ccf37..ad27810cf 100644 --- a/src/net/uds/listener.rs +++ b/src/net/uds/listener.rs @@ -35,21 +35,12 @@ impl UnixListener { /// standard library in the Mio equivalent. The conversion assumes nothing /// about the underlying listener; it is left up to the user to set it in /// non-blocking mode. - #[cfg(unix)] - #[cfg_attr(docsrs, doc(cfg(unix)))] pub fn from_std(listener: net::UnixListener) -> UnixListener { UnixListener { inner: IoSource::new(listener), } } - #[cfg(windows)] - pub(crate) fn from_std(listener: net::UnixListener) -> UnixListener { - UnixListener { - inner: IoSource::new(listener), - } - } - /// Accepts a new incoming connection to this listener. /// /// The call is responsible for ensuring that the listening socket is in diff --git a/src/net/uds/stream.rs b/src/net/uds/stream.rs index 7b6b5f728..bfcfe3a9b 100644 --- a/src/net/uds/stream.rs +++ b/src/net/uds/stream.rs @@ -49,21 +49,12 @@ impl UnixStream { /// The Unix stream here will not have `connect` called on it, so it /// should already be connected via some other means (be it manually, or /// the standard library). - #[cfg(unix)] - #[cfg_attr(docsrs, doc(cfg(unix)))] pub fn from_std(stream: net::UnixStream) -> UnixStream { UnixStream { inner: IoSource::new(stream), } } - #[cfg(windows)] - pub(crate) fn from_std(stream: net::UnixStream) -> UnixStream { - UnixStream { - inner: IoSource::new(stream), - } - } - /// Creates an unnamed pair of connected sockets. /// /// Returns two `UnixStream`s which are connected to each other. @@ -178,34 +169,6 @@ impl UnixStream { /// # let _ = std::fs::remove_file(&file_path); /// let server = UnixListener::bind(&file_path).unwrap(); /// - /// let handle = std::thread::spawn(move || { - /// if let Ok((stream2, _)) = server.accept() { - /// // Wait until the stream is readable... - /// - /// // Read from the stream using a direct WinSock call, of course the - /// // `io::Read` implementation would be easier to use. - /// let mut buf = [0; 512]; - /// let n = stream2.try_io(|| { - /// let res = unsafe { - /// WinSock::recv( - /// stream2.as_raw_socket().try_into().unwrap(), - /// &mut buf as *mut _ as *mut _, - /// buf.len() as c_int, - /// 0 - /// ) - /// }; - /// if res != WinSock::SOCKET_ERROR { - /// Ok(res as usize) - /// } else { - /// // If EAGAIN or EWOULDBLOCK is set by WinSock::recv, the closure - /// // should return `WouldBlock` error. - /// Err(io::Error::last_os_error()) - /// } - /// }).unwrap(); - /// eprintln!("read {} bytes", n); - /// } - /// }); - /// /// let stream1 = UnixStream::connect(&file_path).unwrap(); /// /// // Wait until the stream is writable... @@ -234,6 +197,33 @@ impl UnixStream { /// })?; /// eprintln!("write {} bytes", n); /// + /// let handle = std::thread::spawn(move || { + /// if let Ok((stream2, _)) = server.accept() { + /// // Wait until the stream is readable... + /// + /// // Read from the stream using a direct WinSock call, of course the + /// // `io::Read` implementation would be easier to use. + /// let mut buf = [0; 512]; + /// let n = stream2.try_io(|| { + /// let res = unsafe { + /// WinSock::recv( + /// stream2.as_raw_socket().try_into().unwrap(), + /// &mut buf as *mut _ as *mut _, + /// buf.len() as c_int, + /// 0 + /// ) + /// }; + /// if res != WinSock::SOCKET_ERROR { + /// Ok(res as usize) + /// } else { + /// // If EAGAIN or EWOULDBLOCK is set by WinSock::recv, the closure + /// // should return `WouldBlock` error. + /// Err(io::Error::last_os_error()) + /// } + /// }).unwrap(); + /// eprintln!("read {} bytes", n); + /// } + /// }); /// # handle.join().unwrap(); /// # Ok(()) /// # } @@ -248,49 +238,49 @@ impl UnixStream { impl Read for UnixStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.inner.do_io(|inner| (&*inner).read(buf)) + self.inner.do_io(|mut inner| inner.read(buf)) } fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { - self.inner.do_io(|inner| (&*inner).read_vectored(bufs)) + self.inner.do_io(|mut inner| inner.read_vectored(bufs)) } } impl<'a> Read for &'a UnixStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.inner.do_io(|inner| (&*inner).read(buf)) + self.inner.do_io(|mut inner| inner.read(buf)) } fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { - self.inner.do_io(|inner| (&*inner).read_vectored(bufs)) + self.inner.do_io(|mut inner| inner.read_vectored(bufs)) } } impl Write for UnixStream { fn write(&mut self, buf: &[u8]) -> io::Result { - self.inner.do_io(|inner| (&*inner).write(buf)) + self.inner.do_io(|mut inner| inner.write(buf)) } fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { - self.inner.do_io(|inner| (&*inner).write_vectored(bufs)) + self.inner.do_io(|mut inner| inner.write_vectored(bufs)) } fn flush(&mut self) -> io::Result<()> { - self.inner.do_io(|inner| (&*inner).flush()) + self.inner.do_io(|mut inner| inner.flush()) } } impl<'a> Write for &'a UnixStream { fn write(&mut self, buf: &[u8]) -> io::Result { - self.inner.do_io(|inner| (&*inner).write(buf)) + self.inner.do_io(|mut inner| inner.write(buf)) } fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { - self.inner.do_io(|inner| (&*inner).write_vectored(bufs)) + self.inner.do_io(|mut inner| inner.write_vectored(bufs)) } fn flush(&mut self) -> io::Result<()> { - self.inner.do_io(|inner| (&*inner).flush()) + self.inner.do_io(|mut inner| inner.flush()) } } diff --git a/src/sys/unix/pipe.rs b/src/sys/unix/pipe.rs index c71855bc4..8e92dd37d 100644 --- a/src/sys/unix/pipe.rs +++ b/src/sys/unix/pipe.rs @@ -333,29 +333,29 @@ impl event::Source for Sender { impl Write for Sender { fn write(&mut self, buf: &[u8]) -> io::Result { - self.inner.do_io(|sender| (&*sender).write(buf)) + self.inner.do_io(|mut sender| sender.write(buf)) } fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { - self.inner.do_io(|sender| (&*sender).write_vectored(bufs)) + self.inner.do_io(|mut sender| sender.write_vectored(bufs)) } fn flush(&mut self) -> io::Result<()> { - self.inner.do_io(|sender| (&*sender).flush()) + self.inner.do_io(|mut sender| sender.flush()) } } impl Write for &Sender { fn write(&mut self, buf: &[u8]) -> io::Result { - self.inner.do_io(|sender| (&*sender).write(buf)) + self.inner.do_io(|mut sender| sender.write(buf)) } fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { - self.inner.do_io(|sender| (&*sender).write_vectored(bufs)) + self.inner.do_io(|mut sender| sender.write_vectored(bufs)) } fn flush(&mut self) -> io::Result<()> { - self.inner.do_io(|sender| (&*sender).flush()) + self.inner.do_io(|mut sender| sender.flush()) } } @@ -498,21 +498,21 @@ impl event::Source for Receiver { impl Read for Receiver { fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.inner.do_io(|sender| (&*sender).read(buf)) + self.inner.do_io(|mut sender| sender.read(buf)) } fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { - self.inner.do_io(|sender| (&*sender).read_vectored(bufs)) + self.inner.do_io(|mut sender| sender.read_vectored(bufs)) } } impl Read for &Receiver { fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.inner.do_io(|sender| (&*sender).read(buf)) + self.inner.do_io(|mut sender| sender.read(buf)) } fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { - self.inner.do_io(|sender| (&*sender).read_vectored(bufs)) + self.inner.do_io(|mut sender| sender.read_vectored(bufs)) } } diff --git a/src/sys/unix/uds/mod.rs b/src/sys/unix/uds/mod.rs index e75412f65..96f4e938f 100644 --- a/src/sys/unix/uds/mod.rs +++ b/src/sys/unix/uds/mod.rs @@ -61,7 +61,7 @@ cfg_os_poll! { let offset = path_offset(&sockaddr); let mut socklen = offset + bytes.len(); - match bytes.get(0) { + match bytes.first() { // The struct has already been zeroes so the null byte for pathname // addresses is already there. Some(&0) | None => {} diff --git a/src/sys/windows/iocp.rs b/src/sys/windows/iocp.rs index c71b695d4..01aeb9fb9 100644 --- a/src/sys/windows/iocp.rs +++ b/src/sys/windows/iocp.rs @@ -206,7 +206,7 @@ impl CompletionStatus { /// A completion key is a per-handle key that is specified when it is added /// to an I/O completion port via `add_handle` or `add_socket`. pub fn token(&self) -> usize { - self.0.lpCompletionKey as usize + self.0.lpCompletionKey } /// Returns a pointer to the `Overlapped` structure that was specified when @@ -268,6 +268,6 @@ mod tests { } assert_eq!(s[2].bytes_transferred(), 0); assert_eq!(s[2].token(), 0); - assert_eq!(s[2].overlapped(), 0 as *mut _); + assert_eq!(s[2].overlapped(), std::ptr::null_mut()); } } diff --git a/src/sys/windows/mod.rs b/src/sys/windows/mod.rs index 07f7dda6c..889278ed7 100644 --- a/src/sys/windows/mod.rs +++ b/src/sys/windows/mod.rs @@ -28,7 +28,7 @@ macro_rules! wsa_syscall { } cfg_net! { - pub(crate) mod stdnet; + pub mod stdnet; pub(crate) mod uds; pub(crate) use self::uds::SocketAddr; } diff --git a/src/sys/windows/net.rs b/src/sys/windows/net.rs index 38b17492c..c2b540560 100644 --- a/src/sys/windows/net.rs +++ b/src/sys/windows/net.rs @@ -1,6 +1,7 @@ use std::io; use std::mem; use std::net::SocketAddr; +use std::sync::Once; use windows_sys::Win32::Networking::WinSock::{ closesocket, ioctlsocket, socket, AF_INET, AF_INET6, FIONBIO, IN6_ADDR, IN6_ADDR_0, @@ -74,7 +75,7 @@ pub(crate) fn socket_addr(addr: &SocketAddr) -> (SocketAddrCRepr, i32) { }; let sockaddr_in = SOCKADDR_IN { - sin_family: AF_INET as u16, // 1 + sin_family: AF_INET, // 1 sin_port: addr.port().to_be(), sin_addr, sin_zero: [0; 8], @@ -96,7 +97,7 @@ pub(crate) fn socket_addr(addr: &SocketAddr) -> (SocketAddrCRepr, i32) { }; let sockaddr_in6 = SOCKADDR_IN6 { - sin6_family: AF_INET6 as u16, // 23 + sin6_family: AF_INET6, // 23 sin6_port: addr.port().to_be(), sin6_addr, sin6_flowinfo: addr.flowinfo(), diff --git a/src/sys/windows/stdnet/addr.rs b/src/sys/windows/stdnet/addr.rs index 26b1fddde..88a96a024 100644 --- a/src/sys/windows/stdnet/addr.rs +++ b/src/sys/windows/stdnet/addr.rs @@ -3,9 +3,9 @@ use std::os::raw::c_int; use std::path::Path; use std::{fmt, io, mem}; -use windows_sys::Win32::Networking::WinSock::{sockaddr_un, SOCKADDR}; +use windows_sys::Win32::Networking::WinSock::{SOCKADDR, SOCKADDR_UN}; -fn path_offset(addr: &sockaddr_un) -> usize { +fn path_offset(addr: &SOCKADDR_UN) -> usize { // Work with an actual instance of the type since using a null pointer is UB let base = addr as *const _ as usize; let path = &addr.sun_path as *const _ as usize; @@ -14,16 +14,16 @@ fn path_offset(addr: &sockaddr_un) -> usize { cfg_os_poll! { use windows_sys::Win32::Networking::WinSock::AF_UNIX; - pub(super) fn socket_addr(path: &Path) -> io::Result<(sockaddr_un, c_int)> { - let sockaddr = mem::MaybeUninit::::zeroed(); + pub(super) fn socket_addr(path: &Path) -> io::Result<(SOCKADDR_UN, c_int)> { + let sockaddr = mem::MaybeUninit::::zeroed(); - // This is safe to assume because a `sockaddr_un` filled with `0` + // This is safe to assume because a `SOCKADDR_UN` filled with `0` // bytes is properly initialized. // - // `0` is a valid value for `sockaddr_un::sun_family`; it is + // `0` is a valid value for `SOCKADDR_UN::sun_family`; it is // `WinSock::AF_UNSPEC`. // - // `[0; 108]` is a valid value for `sockaddr_un::sun_path`; it begins an + // `[0; 108]` is a valid value for `SOCKADDR_UN::sun_path`; it begins an // abstract path. let mut sockaddr = unsafe { sockaddr.assume_init() }; sockaddr.sun_family = AF_UNIX; @@ -66,8 +66,9 @@ cfg_os_poll! { } } -pub(crate) struct SocketAddr { - addr: sockaddr_un, +/// An address associated with a Unix socket. +pub struct SocketAddr { + addr: SOCKADDR_UN, len: c_int, } @@ -77,11 +78,11 @@ impl SocketAddr { F: FnOnce(*mut SOCKADDR, *mut c_int) -> io::Result, { let mut sockaddr = { - let sockaddr = mem::MaybeUninit::::zeroed(); + let sockaddr = mem::MaybeUninit::::zeroed(); unsafe { sockaddr.assume_init() } }; - let mut len = mem::size_of::() as c_int; + let mut len = mem::size_of::() as c_int; let result = f(&mut sockaddr as *mut _ as *mut _, &mut len)?; Ok(( result, @@ -99,6 +100,38 @@ impl SocketAddr { SocketAddr::init(f).map(|(_, addr)| addr) } + cfg_os_poll! { + pub(crate) fn from_parts(sockaddr: SOCKADDR_UN, mut len: c_int) -> io::Result { + if len == 0 { + // When there is a datagram from unnamed unix socket + // linux returns zero bytes of address + len = path_offset(&sockaddr) as c_int; // i.e. zero-length address + } else if sockaddr.sun_family != windows_sys::Win32::Networking::WinSock::AF_UNIX as _ { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!( + "file descriptor did not correspond to a Unix socket: {}", + sockaddr.sun_family + ), + )); + } + + Ok(SocketAddr { + addr: sockaddr, + len, + }) + } + } + + /// Returns the contents of this address if it is a `pathname` address. + pub fn as_pathname(&self) -> Option<&Path> { + if let AddressKind::Pathname(path) = self.address() { + Some(path) + } else { + None + } + } + pub(crate) fn address(&self) -> AddressKind<'_> { let len = self.len as usize - path_offset(&self.addr); // sockaddr_un::sun_path on Windows is a Win32 UTF-8 file system path diff --git a/src/sys/windows/stdnet/listener.rs b/src/sys/windows/stdnet/listener.rs index 214167276..487aac9eb 100644 --- a/src/sys/windows/stdnet/listener.rs +++ b/src/sys/windows/stdnet/listener.rs @@ -1,34 +1,39 @@ +use super::{socket::Socket, SocketAddr}; use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket}; use std::{fmt, io, mem}; - use windows_sys::Win32::Networking::WinSock::SOCKET_ERROR; -use super::{socket::Socket, SocketAddr}; - -pub(crate) struct UnixListener(Socket); +/// A structure representing a Unix domain socket server. +pub struct UnixListener { + inner: Socket, +} impl UnixListener { - pub(crate) fn local_addr(&self) -> io::Result { + /// Returns the local socket address of this listener. + pub fn local_addr(&self) -> io::Result { SocketAddr::new(|addr, len| { wsa_syscall!( - getsockname(self.0.as_raw_socket() as _, addr, len), + getsockname(self.inner.as_raw_socket() as _, addr, len), SOCKET_ERROR ) }) } - pub(crate) fn take_error(&self) -> io::Result> { - self.0.take_error() + /// Returns the value of the `SO_ERROR` option. + pub fn take_error(&self) -> io::Result> { + self.inner.take_error() } } cfg_os_poll! { + use std::os::raw::c_int; use std::path::Path; - use super::{socket_addr, UnixStream}; + use windows_sys::Win32::Networking::WinSock::SOCKADDR_UN; impl UnixListener { - pub(crate) fn bind>(path: P) -> io::Result { + /// Creates a new `UnixListener` bound to the specified socket. + pub fn bind>(path: P) -> io::Result { let inner = Socket::new()?; let (addr, len) = socket_addr(path.as_ref())?; @@ -37,16 +42,33 @@ cfg_os_poll! { SOCKET_ERROR )?; wsa_syscall!(listen(inner.as_raw_socket() as _, 1024), SOCKET_ERROR)?; - Ok(UnixListener(inner)) + Ok(UnixListener { + inner + }) } - pub(crate) fn accept(&self) -> io::Result<(UnixStream, SocketAddr)> { - SocketAddr::init(|addr, len| self.0.accept(addr, len)) - .map(|(sock, addr)| (UnixStream(sock), addr)) + /// Accepts a new incoming connection to this listener. + /// + /// This function will block the calling thread until a new Unix connection + /// is established. When established, the corresponding [`UnixStream`] and + /// the remote peer's address will be returned. + pub fn accept(&self) -> io::Result<(UnixStream, SocketAddr)> { + let mut storage: SOCKADDR_UN = unsafe { mem::zeroed() }; + let mut len = mem::size_of_val(&storage) as c_int; + let sock = self.inner.accept(&mut storage as *mut _ as *mut _, &mut len)?; + let addr = SocketAddr::from_parts(storage, len)?; + Ok((UnixStream(sock), addr)) } - pub(crate) fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { - self.0.set_nonblocking(nonblocking) + /// Moves the socket into or out of nonblocking mode. + /// + /// This will result in the `accept` operation becoming nonblocking, + /// i.e., immediately returning from their calls. If the IO operation is + /// successful, `Ok` is returned and no further action is required. If the + /// IO operation could not be completed and needs to be retried, an error + /// with kind [`io::ErrorKind::WouldBlock`] is returned. + pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { + self.inner.set_nonblocking(nonblocking) } } } @@ -54,7 +76,7 @@ cfg_os_poll! { impl fmt::Debug for UnixListener { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { let mut builder = fmt.debug_struct("UnixListener"); - builder.field("socket", &self.0.as_raw_socket()); + builder.field("socket", &self.inner.as_raw_socket()); if let Ok(addr) = self.local_addr() { builder.field("local", &addr); } @@ -64,19 +86,21 @@ impl fmt::Debug for UnixListener { impl AsRawSocket for UnixListener { fn as_raw_socket(&self) -> RawSocket { - self.0.as_raw_socket() + self.inner.as_raw_socket() } } impl FromRawSocket for UnixListener { unsafe fn from_raw_socket(sock: RawSocket) -> Self { - UnixListener(Socket::from_raw_socket(sock)) + UnixListener { + inner: Socket::from_raw_socket(sock), + } } } impl IntoRawSocket for UnixListener { fn into_raw_socket(self) -> RawSocket { - let ret = self.0.as_raw_socket(); + let ret = self.inner.as_raw_socket(); mem::forget(self); ret } diff --git a/src/sys/windows/stdnet/mod.rs b/src/sys/windows/stdnet/mod.rs index 0eb5130d4..c1e1cc748 100644 --- a/src/sys/windows/stdnet/mod.rs +++ b/src/sys/windows/stdnet/mod.rs @@ -4,9 +4,9 @@ mod listener; mod socket; mod stream; -pub(crate) use self::addr::SocketAddr; -pub(crate) use self::listener::UnixListener; -pub(crate) use self::stream::UnixStream; +pub use self::addr::SocketAddr; +pub use self::listener::UnixListener; +pub use self::stream::UnixStream; cfg_os_poll! { pub(self) use self::addr::socket_addr; diff --git a/src/sys/windows/stdnet/socket.rs b/src/sys/windows/stdnet/socket.rs index 9212c1e04..d34d04037 100644 --- a/src/sys/windows/stdnet/socket.rs +++ b/src/sys/windows/stdnet/socket.rs @@ -1,12 +1,10 @@ use std::cmp::min; -use std::convert::TryInto; use std::io::{self, IoSlice, IoSliceMut}; use std::mem; use std::net::Shutdown; use std::os::raw::c_int; use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket}; use std::ptr; - use windows_sys::Win32::Networking::WinSock::{self, closesocket, SOCKET, SOCKET_ERROR, WSABUF}; /// Maximum size of a buffer passed to system call like `recv` and `send`. @@ -42,42 +40,38 @@ impl Socket { ); match res { Ok(_) => Ok(total as usize), - Err(ref err) if err.raw_os_error() == Some(WinSock::WSAESHUTDOWN as i32) => Ok(0), + Err(ref err) if err.raw_os_error() == Some(WinSock::WSAESHUTDOWN) => Ok(0), Err(err) => Err(err), } } - pub fn send(&self, buf: &[u8]) -> io::Result { - wsa_syscall!( - send( + pub fn write(&self, buf: &[u8]) -> io::Result { + let response = unsafe { + windows_sys::Win32::Networking::WinSock::send( self.0, buf.as_ptr().cast(), min(buf.len(), MAX_BUF_LEN) as c_int, 0, - ), - SOCKET_ERROR - ) - .map(|n| n as usize) + ) + }; + if response == SOCKET_ERROR { + return match unsafe { windows_sys::Win32::Networking::WinSock::WSAGetLastError() } { + windows_sys::Win32::Networking::WinSock::WSAESHUTDOWN => { + Err(io::Error::new(io::ErrorKind::BrokenPipe, "brokenpipe")) + } + e => Err(std::io::Error::from_raw_os_error(e)), + }; + } + Ok(response as usize) } - pub fn send_vectored(&self, bufs: &[IoSlice<'_>]) -> io::Result { + pub fn write_vectored(&self, bufs: &[IoSlice<'_>]) -> io::Result { let mut total = 0; wsa_syscall!( WSASend( self.0, - // FIXME: From the `WSASend` docs [1]: - // > For a Winsock application, once the WSASend function is called, - // > the system owns these buffers and the application may not - // > access them. - // - // So what we're doing is actually UB as `bufs` needs to be `&mut - // [IoSlice<'_>]`. - // - // See: https://github.com/rust-lang/socket2-rs/issues/129. - // - // [1] https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-wsasend - bufs.as_ptr() as *mut _, - min(bufs.len(), u32::MAX as usize) as u32, + bufs.as_ptr() as *mut WSABUF, + bufs.len().min(u32::MAX as usize) as u32, &mut total, 0, std::ptr::null_mut(), @@ -94,7 +88,7 @@ impl Socket { Shutdown::Read => WinSock::SD_RECEIVE, Shutdown::Both => WinSock::SD_BOTH, }; - wsa_syscall!(shutdown(self.0, how.try_into().unwrap()), SOCKET_ERROR)?; + wsa_syscall!(shutdown(self.0, how), SOCKET_ERROR)?; Ok(()) } @@ -104,8 +98,8 @@ impl Socket { wsa_syscall!( getsockopt( self.0 as _, - WinSock::SOL_SOCKET.try_into().unwrap(), - WinSock::SO_ERROR.try_into().unwrap(), + WinSock::SOL_SOCKET, + WinSock::SO_ERROR, &mut val as *mut _ as *mut _, &mut len, ), @@ -116,36 +110,48 @@ impl Socket { if val == 0 { Ok(None) } else { - Ok(Some(io::Error::from_raw_os_error(val as i32))) + Ok(Some(io::Error::from_raw_os_error(val))) } } } cfg_os_poll! { + use windows_sys::Win32::Foundation::{HANDLE, HANDLE_FLAG_INHERIT, SetHandleInformation}; use windows_sys::Win32::Networking::WinSock::{INVALID_SOCKET, SOCKADDR}; use super::init; impl Socket { pub fn new() -> io::Result { init(); - wsa_syscall!( - WSASocketW( - WinSock::AF_UNIX.into(), - WinSock::SOCK_STREAM.into(), - 0, - ptr::null_mut(), - 0, - WinSock::WSA_FLAG_OVERLAPPED | WinSock::WSA_FLAG_NO_HANDLE_INHERIT, - ), - INVALID_SOCKET - ).map(Socket) + match wsa_syscall!(WSASocketW( + WinSock::AF_UNIX.into(), + WinSock::SOCK_STREAM, + 0, + ptr::null_mut(), + 0, + WinSock::WSA_FLAG_OVERLAPPED, + ), INVALID_SOCKET) { + Ok(res) => { + let socket = Socket(res); + socket.set_no_inherit()?; + Ok(socket) + }, + Err(e) => Err(e), + } } pub fn accept(&self, storage: *mut SOCKADDR, len: *mut c_int) -> io::Result { // WinSock's accept returns a socket with the same properties as the listener. it is // called on. In particular, the WSA_FLAG_NO_HANDLE_INHERIT will be inherited from the // listener. - wsa_syscall!(accept(self.0, storage, len), INVALID_SOCKET).map(Socket) + match wsa_syscall!(accept(self.0, storage, len), INVALID_SOCKET) { + Ok(res) => { + let socket = Socket(res); + socket.set_no_inherit()?; + Ok(socket) + }, + Err(e) => Err(e), + } } pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { @@ -156,6 +162,11 @@ cfg_os_poll! { )?; Ok(()) } + + pub fn set_no_inherit(&self) -> io::Result<()> { + syscall!(SetHandleInformation(self.0 as HANDLE, HANDLE_FLAG_INHERIT, 0), PartialEq::eq, -1)?; + Ok(()) + } } } diff --git a/src/sys/windows/stdnet/stream.rs b/src/sys/windows/stdnet/stream.rs index ce1da2f54..9d78f5383 100644 --- a/src/sys/windows/stdnet/stream.rs +++ b/src/sys/windows/stdnet/stream.rs @@ -7,10 +7,12 @@ use windows_sys::Win32::Networking::WinSock::SOCKET_ERROR; use super::{socket::Socket, SocketAddr}; -pub(crate) struct UnixStream(pub(super) Socket); +/// A Unix stream socket. +pub struct UnixStream(pub(super) Socket); impl UnixStream { - pub(crate) fn local_addr(&self) -> io::Result { + /// Connects to the socket specified by [`address`]. + pub fn local_addr(&self) -> io::Result { SocketAddr::new(|addr, len| { wsa_syscall!( getsockname(self.0.as_raw_socket() as _, addr, len), @@ -19,7 +21,8 @@ impl UnixStream { }) } - pub(crate) fn peer_addr(&self) -> io::Result { + /// Returns the socket address of the remote half of this connection. + pub fn peer_addr(&self) -> io::Result { SocketAddr::new(|addr, len| { wsa_syscall!( getpeername(self.0.as_raw_socket() as _, addr, len), @@ -28,22 +31,28 @@ impl UnixStream { }) } - pub(crate) fn take_error(&self) -> io::Result> { + /// Returns the value of the `SO_ERROR` option. + pub fn take_error(&self) -> io::Result> { self.0.take_error() } - pub(crate) fn shutdown(&self, how: Shutdown) -> io::Result<()> { + /// Shuts down the read, write, or both halves of this connection. + /// + /// This function will cause all pending and future I/O calls on the + /// specified portions to immediately return with an appropriate value + /// (see the documentation of [`Shutdown`]). + pub fn shutdown(&self, how: Shutdown) -> io::Result<()> { self.0.shutdown(how) } } cfg_os_poll! { use std::path::Path; - use windows_sys::Win32::Networking::WinSock::WSAEINPROGRESS; use super::socket_addr; impl UnixStream { - pub(crate) fn connect>(path: P) -> io::Result { + /// Connects to the socket named by `path`. + pub fn connect>(path: P) -> io::Result { let inner = Socket::new()?; let (addr, len) = socket_addr(path.as_ref())?; @@ -51,18 +60,19 @@ cfg_os_poll! { connect( inner.as_raw_socket() as _, &addr as *const _ as *const _, - len as i32, + len, ), SOCKET_ERROR ) { Ok(_) => {} - Err(ref err) if err.raw_os_error() == Some(WSAEINPROGRESS) => {} + Err(ref err) if err.kind() == std::io::ErrorKind::Other => {} Err(e) => return Err(e), } Ok(UnixStream(inner)) } - pub(crate) fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { + /// Moves the socket into or out of nonblocking mode. + pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { self.0.set_nonblocking(nonblocking) } } @@ -118,11 +128,11 @@ impl io::Write for UnixStream { impl<'a> io::Write for &'a UnixStream { fn write(&mut self, buf: &[u8]) -> io::Result { - self.0.send(buf) + self.0.write(buf) } fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { - self.0.send_vectored(bufs) + self.0.write_vectored(bufs) } fn flush(&mut self) -> io::Result<()> { diff --git a/src/sys/windows/udp.rs b/src/sys/windows/udp.rs index 87e269fa3..3e975973f 100644 --- a/src/sys/windows/udp.rs +++ b/src/sys/windows/udp.rs @@ -30,8 +30,8 @@ pub(crate) fn only_v6(socket: &net::UdpSocket) -> io::Result { syscall!( getsockopt( socket.as_raw_socket() as usize, - IPPROTO_IPV6 as i32, - IPV6_V6ONLY as i32, + IPPROTO_IPV6, + IPV6_V6ONLY, optval.as_mut_ptr().cast(), &mut optlen, ), diff --git a/src/sys/windows/uds/listener.rs b/src/sys/windows/uds/listener.rs index 4ba4395e5..f33312d50 100644 --- a/src/sys/windows/uds/listener.rs +++ b/src/sys/windows/uds/listener.rs @@ -13,9 +13,9 @@ pub(crate) fn bind(path: &Path) -> io::Result { } pub(crate) fn accept(listener: &net::UnixListener) -> io::Result<(UnixStream, SocketAddr)> { - listener - .accept() - .map(|(stream, addr)| (UnixStream::from_std(stream), addr)) + listener.set_nonblocking(true)?; + let (stream, addr) = listener.accept()?; + Ok((UnixStream::from_std(stream), addr)) } pub(crate) fn local_addr(listener: &net::UnixListener) -> io::Result { diff --git a/tests/unix_listener.rs b/tests/unix_listener.rs index c131497cc..2303ca530 100644 --- a/tests/unix_listener.rs +++ b/tests/unix_listener.rs @@ -1,8 +1,8 @@ #![cfg(all(feature = "os-poll", feature = "net"))] -#[cfg(windows)] -use mio::net; use mio::net::UnixListener; +#[cfg(windows)] +use mio::windows::stdnet as net; use mio::{Interest, Token}; use std::io::{self, Read}; #[cfg(unix)] @@ -139,7 +139,7 @@ fn unix_listener_deregister() { #[cfg(target_os = "linux")] #[test] -fn unix_listener_abstract_namesapce() { +fn unix_listener_abstract_namespace() { use rand::Rng; let num: u64 = rand::thread_rng().gen(); let name = format!("\u{0000}-mio-abstract-uds-{}", num); diff --git a/tests/unix_pipe.rs b/tests/unix_pipe.rs index a83e3833b..f8e6464c9 100644 --- a/tests/unix_pipe.rs +++ b/tests/unix_pipe.rs @@ -49,7 +49,7 @@ fn smoke() { ); let n = receiver.read(&mut buf).unwrap(); assert_eq!(n, DATA1.len()); - assert_eq!(&buf[..n], &*DATA1); + assert_eq!(&buf[..n], DATA1); } #[test] @@ -162,7 +162,7 @@ fn from_child_process_io() { let mut buf = [0; 20]; let n = receiver.read(&mut buf).unwrap(); assert_eq!(n, DATA1.len()); - assert_eq!(&buf[..n], &*DATA1); + assert_eq!(&buf[..n], DATA1); drop(sender); diff --git a/tests/unix_stream.rs b/tests/unix_stream.rs index 93da5b597..c198c09e5 100644 --- a/tests/unix_stream.rs +++ b/tests/unix_stream.rs @@ -1,8 +1,8 @@ #![cfg(all(feature = "os-poll", feature = "net"))] -#[cfg(windows)] -use mio::net; use mio::net::UnixStream; +#[cfg(windows)] +use mio::windows::stdnet as net; use mio::{Interest, Token}; use std::io::{self, IoSlice, IoSliceMut, Read, Write}; use std::net::Shutdown; @@ -12,8 +12,6 @@ use std::path::Path; use std::sync::mpsc::channel; use std::sync::{Arc, Barrier}; use std::thread; -#[cfg(windows)] -use std::time::Duration; #[macro_use] mod util; @@ -83,7 +81,6 @@ fn unix_stream_connect() { handle.join().unwrap(); } -#[cfg(unix)] #[test] fn unix_stream_connect_addr() { let (mut poll, mut events) = init_with_poll(); @@ -295,13 +292,7 @@ fn unix_stream_shutdown_write() { ); let err = stream.write(DATA2).unwrap_err(); - #[cfg(unix)] assert_eq!(err.kind(), io::ErrorKind::BrokenPipe); - #[cfg(windows)] - { - use windows_sys::Win32::Networking::WinSock::WSAESHUTDOWN; - assert_eq!(err.raw_os_error(), Some(WSAESHUTDOWN)); - } // Read should be ok let mut buf = [0; DEFAULT_BUF_SIZE]; @@ -507,6 +498,8 @@ where assert!(stream.take_error().unwrap().is_none()); + // To comply with draining behavior on windows we have to check assert_would_block() + // https://github.com/tokio-rs/mio/issues/1611 assert_would_block(stream.read(&mut buf)); let bufs = [IoSlice::new(DATA1), IoSlice::new(DATA2)]; @@ -534,107 +527,70 @@ where handle.join().unwrap(); } -#[cfg(windows)] -fn new_listener( +fn new_echo_listener( connections: usize, test_name: &'static str, - handle_stream: F, -) -> (thread::JoinHandle<()>, net::SocketAddr) -where - F: Fn(net::UnixStream) + std::marker::Send + 'static, -{ +) -> (thread::JoinHandle<()>, net::SocketAddr) { let (addr_sender, addr_receiver) = channel(); let handle = thread::spawn(move || { let path = temp_file(test_name); - // We use mio's non-blocking listener here for windows, since there is no listener in std - // yet. We must be sure to poll before listener I/O. - let mut listener = net::UnixListener::bind(path).unwrap(); - let (mut poll, mut events) = init_with_poll(); - poll.registry() - .register(&mut listener, TOKEN_1, Interest::READABLE) - .unwrap(); - + let listener = net::UnixListener::bind(path).unwrap(); let local_addr = listener.local_addr().unwrap(); addr_sender.send(local_addr).unwrap(); for _ in 0..connections { - poll.poll(&mut events, Some(Duration::from_millis(500))) - .unwrap(); - let (stream, _) = listener.accept().unwrap(); - assert_would_block(listener.accept()); - handle_stream(stream); + let (mut stream, _) = listener.accept().unwrap(); + + // On Linux based system it will cause a connection reset + // error when the reading side of the peer connection is + // shutdown, we don't consider it an actual here. + let (mut read, mut written) = (0, 0); + let mut buf = [0; DEFAULT_BUF_SIZE]; + loop { + let n = match stream.read(&mut buf) { + Ok(amount) => { + read += amount; + amount + } + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => continue, + Err(ref err) if err.kind() == io::ErrorKind::ConnectionReset => break, + Err(err) => panic!("{}", err), + }; + if n == 0 { + break; + } + match stream.write(&buf[..n]) { + Ok(amount) => written += amount, + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => continue, + Err(ref err) if err.kind() == io::ErrorKind::BrokenPipe => break, + Err(err) => panic!("{}", err), + }; + } + assert_eq!(read, written, "unequal reads and writes"); } }); (handle, addr_receiver.recv().unwrap()) } -#[cfg(unix)] -fn new_listener( +fn new_noop_listener( connections: usize, + barrier: Arc, test_name: &'static str, - handle_stream: F, -) -> (thread::JoinHandle<()>, net::SocketAddr) -where - F: Fn(net::UnixStream) + std::marker::Send + 'static, -{ - let (addr_sender, addr_receiver) = channel(); +) -> (thread::JoinHandle<()>, net::SocketAddr) { + let (sender, receiver) = channel(); let handle = thread::spawn(move || { let path = temp_file(test_name); let listener = net::UnixListener::bind(path).unwrap(); let local_addr = listener.local_addr().unwrap(); - addr_sender.send(local_addr).unwrap(); + sender.send(local_addr).unwrap(); for _ in 0..connections { let (stream, _) = listener.accept().unwrap(); - handle_stream(stream); + barrier.wait(); + stream.shutdown(Shutdown::Write).unwrap(); + barrier.wait(); + drop(stream); } }); - (handle, addr_receiver.recv().unwrap()) -} - -fn new_echo_listener( - connections: usize, - test_name: &'static str, -) -> (thread::JoinHandle<()>, net::SocketAddr) { - new_listener(connections, test_name, |mut stream| { - // On Linux based system it will cause a connection reset - // error when the reading side of the peer connection is - // shutdown, we don't consider it an actual here. - let (mut read, mut written) = (0, 0); - let mut buf = [0; DEFAULT_BUF_SIZE]; - loop { - let n = match stream.read(&mut buf) { - Ok(amount) => { - read += amount; - amount - } - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => continue, - Err(ref err) if err.kind() == io::ErrorKind::ConnectionReset => break, - Err(err) => panic!("{}", err), - }; - if n == 0 { - break; - } - match stream.write(&buf[..n]) { - Ok(amount) => written += amount, - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => continue, - Err(ref err) if err.kind() == io::ErrorKind::BrokenPipe => break, - Err(err) => panic!("{}", err), - }; - } - assert_eq!(read, written, "unequal reads and writes"); - }) -} - -fn new_noop_listener( - connections: usize, - barrier: Arc, - test_name: &'static str, -) -> (thread::JoinHandle<()>, net::SocketAddr) { - new_listener(connections, test_name, move |stream| { - barrier.wait(); - stream.shutdown(Shutdown::Write).unwrap(); - barrier.wait(); - drop(stream); - }) + (handle, receiver.recv().unwrap()) } diff --git a/tests/util/mod.rs b/tests/util/mod.rs index 7a192d9b0..7fcb9fe5e 100644 --- a/tests/util/mod.rs +++ b/tests/util/mod.rs @@ -285,8 +285,8 @@ pub fn set_linger_zero(socket: &TcpStream) { let res = unsafe { setsockopt( socket.as_raw_socket() as _, - SOL_SOCKET as i32, - SO_LINGER as i32, + SOL_SOCKET, + SO_LINGER, &mut val as *mut _ as *mut _, size_of::() as _, ) From 149095f42f1b9f25e7057fc8011452a2f818e53a Mon Sep 17 00:00:00 2001 From: Kolby ML <31669092+KolbyML@users.noreply.github.com> Date: Wed, 20 Sep 2023 14:41:45 -0600 Subject: [PATCH 3/3] wip changes to fix ci --- src/sys/unix/uds/listener.rs | 3 +-- src/sys/windows/stdnet/addr.rs | 10 +++++++++- src/sys/windows/stdnet/listener.rs | 14 ++++++++++++++ src/sys/windows/stdnet/mod.rs | 2 +- src/sys/windows/stdnet/stream.rs | 19 +++++++++++++++++++ src/sys/windows/uds/listener.rs | 9 +++++++-- src/sys/windows/uds/stream.rs | 6 ++++++ 7 files changed, 57 insertions(+), 6 deletions(-) diff --git a/src/sys/unix/uds/listener.rs b/src/sys/unix/uds/listener.rs index 794a9f7bb..ff77c53bd 100644 --- a/src/sys/unix/uds/listener.rs +++ b/src/sys/unix/uds/listener.rs @@ -1,6 +1,5 @@ use super::socket_addr; -use super::SocketAddr; -use crate::net::UnixStream; +use crate::net::{SocketAddr, UnixStream}; use crate::sys::unix::net::new_socket; use std::os::unix::ffi::OsStrExt; use std::os::unix::io::{AsRawFd, FromRawFd}; diff --git a/src/sys/windows/stdnet/addr.rs b/src/sys/windows/stdnet/addr.rs index 88a96a024..3df201d93 100644 --- a/src/sys/windows/stdnet/addr.rs +++ b/src/sys/windows/stdnet/addr.rs @@ -67,7 +67,7 @@ cfg_os_poll! { } /// An address associated with a Unix socket. -pub struct SocketAddr { +pub(crate) struct SocketAddr { addr: SOCKADDR_UN, len: c_int, } @@ -121,6 +121,14 @@ impl SocketAddr { len, }) } + + pub(crate) fn raw_sockaddr(&self) -> &SOCKADDR_UN { + &self.addr + } + + pub(crate) fn raw_socklen(&self) -> c_int { + self.len + } } /// Returns the contents of this address if it is a `pathname` address. diff --git a/src/sys/windows/stdnet/listener.rs b/src/sys/windows/stdnet/listener.rs index 487aac9eb..a971cd2cb 100644 --- a/src/sys/windows/stdnet/listener.rs +++ b/src/sys/windows/stdnet/listener.rs @@ -47,6 +47,20 @@ cfg_os_poll! { }) } + /// Creates a new `UnixListener` bound to the specified address. + pub fn bind_addr(socket_addr: &SocketAddr) -> io::Result { + let inner = Socket::new()?; + + wsa_syscall!( + bind(inner.as_raw_socket() as _, &socket_addr.raw_sockaddr() as *const _ as *const _, socket_addr.raw_socklen() as _), + SOCKET_ERROR + )?; + wsa_syscall!(listen(inner.as_raw_socket() as _, 1024), SOCKET_ERROR)?; + Ok(UnixListener { + inner + }) + } + /// Accepts a new incoming connection to this listener. /// /// This function will block the calling thread until a new Unix connection diff --git a/src/sys/windows/stdnet/mod.rs b/src/sys/windows/stdnet/mod.rs index c1e1cc748..72c97e8d9 100644 --- a/src/sys/windows/stdnet/mod.rs +++ b/src/sys/windows/stdnet/mod.rs @@ -4,7 +4,7 @@ mod listener; mod socket; mod stream; -pub use self::addr::SocketAddr; +pub(crate) use self::addr::SocketAddr; pub use self::listener::UnixListener; pub use self::stream::UnixStream; diff --git a/src/sys/windows/stdnet/stream.rs b/src/sys/windows/stdnet/stream.rs index 9d78f5383..f433772a5 100644 --- a/src/sys/windows/stdnet/stream.rs +++ b/src/sys/windows/stdnet/stream.rs @@ -71,6 +71,25 @@ cfg_os_poll! { Ok(UnixStream(inner)) } + /// Connects to the socket named by `socker_addr`. + pub fn connect_addr(socket_addr: &SocketAddr) -> io::Result { + let inner = Socket::new()?; + + match wsa_syscall!( + connect( + inner.as_raw_socket() as _, + &socket_addr.raw_sockaddr() as *const _ as *const _, + socket_addr.raw_socklen(), + ), + SOCKET_ERROR + ) { + Ok(_) => {} + Err(ref err) if err.kind() == std::io::ErrorKind::Other => {} + Err(e) => return Err(e), + } + Ok(UnixStream(inner)) + } + /// Moves the socket into or out of nonblocking mode. pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { self.0.set_nonblocking(nonblocking) diff --git a/src/sys/windows/uds/listener.rs b/src/sys/windows/uds/listener.rs index f33312d50..84d96c179 100644 --- a/src/sys/windows/uds/listener.rs +++ b/src/sys/windows/uds/listener.rs @@ -2,8 +2,7 @@ use std::io; use std::os::windows::io::AsRawSocket; use std::path::Path; -use super::SocketAddr; -use crate::net::UnixStream; +use crate::net::{SocketAddr, UnixStream}; use crate::sys::windows::stdnet as net; pub(crate) fn bind(path: &Path) -> io::Result { @@ -12,6 +11,12 @@ pub(crate) fn bind(path: &Path) -> io::Result { Ok(listener) } +pub(crate) fn bind_addr(socket_addr: &SocketAddr) -> io::Result { + let listener = net::UnixListener::bind_addr(socket_addr)?; + listener.set_nonblocking(true)?; + Ok(listener) +} + pub(crate) fn accept(listener: &net::UnixListener) -> io::Result<(UnixStream, SocketAddr)> { listener.set_nonblocking(true)?; let (stream, addr) = listener.accept()?; diff --git a/src/sys/windows/uds/stream.rs b/src/sys/windows/uds/stream.rs index b02f32e8f..ae1a98e78 100644 --- a/src/sys/windows/uds/stream.rs +++ b/src/sys/windows/uds/stream.rs @@ -10,6 +10,12 @@ pub(crate) fn connect(path: &Path) -> io::Result { Ok(socket) } +pub(crate) fn connect_addr(socker_addr: &SocketAddr) -> io::Result { + let socket = net::UnixStream::connect_addr(socker_addr)?; + socket.set_nonblocking(true)?; + Ok(socket) +} + pub(crate) fn local_addr(socket: &net::UnixStream) -> io::Result { super::local_addr(socket.as_raw_socket()) }