Skip to content

Commit

Permalink
Make Windows UDS work with tests and clean implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
KolbyML committed May 3, 2023
1 parent 457f753 commit 12c4013
Show file tree
Hide file tree
Showing 20 changed files with 290 additions and 270 deletions.
4 changes: 4 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ pub mod windows {
//! Windows only extensions.
pub use crate::sys::named_pipe::NamedPipe;
// blocking windows uds which mimick std implementation used for tests
cfg_net! {
pub use crate::sys::windows::stdnet;
}
}

pub mod features {
Expand Down
20 changes: 10 additions & 10 deletions src/net/tcp/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -269,49 +269,49 @@ impl TcpStream {

impl Read for TcpStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.inner.do_io(|inner| (&*inner).read(buf))
self.inner.do_io(|mut inner| inner.read(buf))
}

fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
self.inner.do_io(|inner| (&*inner).read_vectored(bufs))
self.inner.do_io(|mut inner| inner.read_vectored(bufs))
}
}

impl<'a> Read for &'a TcpStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.inner.do_io(|inner| (&*inner).read(buf))
self.inner.do_io(|mut inner| inner.read(buf))
}

fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
self.inner.do_io(|inner| (&*inner).read_vectored(bufs))
self.inner.do_io(|mut inner| inner.read_vectored(bufs))
}
}

impl Write for TcpStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.inner.do_io(|inner| (&*inner).write(buf))
self.inner.do_io(|mut inner| inner.write(buf))
}

fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
self.inner.do_io(|inner| (&*inner).write_vectored(bufs))
self.inner.do_io(|mut inner| inner.write_vectored(bufs))
}

fn flush(&mut self) -> io::Result<()> {
self.inner.do_io(|inner| (&*inner).flush())
self.inner.do_io(|mut inner| inner.flush())
}
}

impl<'a> Write for &'a TcpStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.inner.do_io(|inner| (&*inner).write(buf))
self.inner.do_io(|mut inner| inner.write(buf))
}

fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
self.inner.do_io(|inner| (&*inner).write_vectored(bufs))
self.inner.do_io(|mut inner| inner.write_vectored(bufs))
}

fn flush(&mut self) -> io::Result<()> {
self.inner.do_io(|inner| (&*inner).flush())
self.inner.do_io(|mut inner| inner.flush())
}
}

Expand Down
9 changes: 0 additions & 9 deletions src/net/uds/listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,12 @@ impl UnixListener {
/// standard library in the Mio equivalent. The conversion assumes nothing
/// about the underlying listener; it is left up to the user to set it in
/// non-blocking mode.
#[cfg(unix)]
#[cfg_attr(docsrs, doc(cfg(unix)))]
pub fn from_std(listener: net::UnixListener) -> UnixListener {
UnixListener {
inner: IoSource::new(listener),
}
}

#[cfg(windows)]
pub(crate) fn from_std(listener: net::UnixListener) -> UnixListener {
UnixListener {
inner: IoSource::new(listener),
}
}

/// Accepts a new incoming connection to this listener.
///
/// The call is responsible for ensuring that the listening socket is in
Expand Down
84 changes: 37 additions & 47 deletions src/net/uds/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,21 +41,12 @@ impl UnixStream {
/// The Unix stream here will not have `connect` called on it, so it
/// should already be connected via some other means (be it manually, or
/// the standard library).
#[cfg(unix)]
#[cfg_attr(docsrs, doc(cfg(unix)))]
pub fn from_std(stream: net::UnixStream) -> UnixStream {
UnixStream {
inner: IoSource::new(stream),
}
}

#[cfg(windows)]
pub(crate) fn from_std(stream: net::UnixStream) -> UnixStream {
UnixStream {
inner: IoSource::new(stream),
}
}

/// Creates an unnamed pair of connected sockets.
///
/// Returns two `UnixStream`s which are connected to each other.
Expand Down Expand Up @@ -170,34 +161,6 @@ impl UnixStream {
/// # let _ = std::fs::remove_file(&file_path);
/// let server = UnixListener::bind(&file_path).unwrap();
///
/// let handle = std::thread::spawn(move || {
/// if let Ok((stream2, _)) = server.accept() {
/// // Wait until the stream is readable...
///
/// // Read from the stream using a direct WinSock call, of course the
/// // `io::Read` implementation would be easier to use.
/// let mut buf = [0; 512];
/// let n = stream2.try_io(|| {
/// let res = unsafe {
/// WinSock::recv(
/// stream2.as_raw_socket().try_into().unwrap(),
/// &mut buf as *mut _ as *mut _,
/// buf.len() as c_int,
/// 0
/// )
/// };
/// if res != WinSock::SOCKET_ERROR {
/// Ok(res as usize)
/// } else {
/// // If EAGAIN or EWOULDBLOCK is set by WinSock::recv, the closure
/// // should return `WouldBlock` error.
/// Err(io::Error::last_os_error())
/// }
/// }).unwrap();
/// eprintln!("read {} bytes", n);
/// }
/// });
///
/// let stream1 = UnixStream::connect(&file_path).unwrap();
///
/// // Wait until the stream is writable...
Expand Down Expand Up @@ -226,6 +189,33 @@ impl UnixStream {
/// })?;
/// eprintln!("write {} bytes", n);
///
/// let handle = std::thread::spawn(move || {
/// if let Ok((stream2, _)) = server.accept() {
/// // Wait until the stream is readable...
///
/// // Read from the stream using a direct WinSock call, of course the
/// // `io::Read` implementation would be easier to use.
/// let mut buf = [0; 512];
/// let n = stream2.try_io(|| {
/// let res = unsafe {
/// WinSock::recv(
/// stream2.as_raw_socket().try_into().unwrap(),
/// &mut buf as *mut _ as *mut _,
/// buf.len() as c_int,
/// 0
/// )
/// };
/// if res != WinSock::SOCKET_ERROR {
/// Ok(res as usize)
/// } else {
/// // If EAGAIN or EWOULDBLOCK is set by WinSock::recv, the closure
/// // should return `WouldBlock` error.
/// Err(io::Error::last_os_error())
/// }
/// }).unwrap();
/// eprintln!("read {} bytes", n);
/// }
/// });
/// # handle.join().unwrap();
/// # Ok(())
/// # }
Expand All @@ -240,49 +230,49 @@ impl UnixStream {

impl Read for UnixStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.inner.do_io(|inner| (&*inner).read(buf))
self.inner.do_io(|mut inner| inner.read(buf))
}

fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
self.inner.do_io(|inner| (&*inner).read_vectored(bufs))
self.inner.do_io(|mut inner| inner.read_vectored(bufs))
}
}

impl<'a> Read for &'a UnixStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.inner.do_io(|inner| (&*inner).read(buf))
self.inner.do_io(|mut inner| inner.read(buf))
}

fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
self.inner.do_io(|inner| (&*inner).read_vectored(bufs))
self.inner.do_io(|mut inner| inner.read_vectored(bufs))
}
}

impl Write for UnixStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.inner.do_io(|inner| (&*inner).write(buf))
self.inner.do_io(|mut inner| inner.write(buf))
}

fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
self.inner.do_io(|inner| (&*inner).write_vectored(bufs))
self.inner.do_io(|mut inner| inner.write_vectored(bufs))
}

fn flush(&mut self) -> io::Result<()> {
self.inner.do_io(|inner| (&*inner).flush())
self.inner.do_io(|mut inner| inner.flush())
}
}

impl<'a> Write for &'a UnixStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.inner.do_io(|inner| (&*inner).write(buf))
self.inner.do_io(|mut inner| inner.write(buf))
}

fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
self.inner.do_io(|inner| (&*inner).write_vectored(bufs))
self.inner.do_io(|mut inner| inner.write_vectored(bufs))
}

fn flush(&mut self) -> io::Result<()> {
self.inner.do_io(|inner| (&*inner).flush())
self.inner.do_io(|mut inner| inner.flush())
}
}

Expand Down
20 changes: 10 additions & 10 deletions src/sys/unix/pipe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -321,29 +321,29 @@ impl event::Source for Sender {

impl Write for Sender {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.inner.do_io(|sender| (&*sender).write(buf))
self.inner.do_io(|mut sender| sender.write(buf))
}

fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
self.inner.do_io(|sender| (&*sender).write_vectored(bufs))
self.inner.do_io(|mut sender| sender.write_vectored(bufs))
}

fn flush(&mut self) -> io::Result<()> {
self.inner.do_io(|sender| (&*sender).flush())
self.inner.do_io(|mut sender| sender.flush())
}
}

impl Write for &Sender {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.inner.do_io(|sender| (&*sender).write(buf))
self.inner.do_io(|mut sender| sender.write(buf))
}

fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
self.inner.do_io(|sender| (&*sender).write_vectored(bufs))
self.inner.do_io(|mut sender| sender.write_vectored(bufs))
}

fn flush(&mut self) -> io::Result<()> {
self.inner.do_io(|sender| (&*sender).flush())
self.inner.do_io(|mut sender| sender.flush())
}
}

Expand Down Expand Up @@ -486,21 +486,21 @@ impl event::Source for Receiver {

impl Read for Receiver {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.inner.do_io(|sender| (&*sender).read(buf))
self.inner.do_io(|mut sender| sender.read(buf))
}

fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
self.inner.do_io(|sender| (&*sender).read_vectored(bufs))
self.inner.do_io(|mut sender| sender.read_vectored(bufs))
}
}

impl Read for &Receiver {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.inner.do_io(|sender| (&*sender).read(buf))
self.inner.do_io(|mut sender| sender.read(buf))
}

fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
self.inner.do_io(|sender| (&*sender).read_vectored(bufs))
self.inner.do_io(|mut sender| sender.read_vectored(bufs))
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/sys/unix/uds/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ cfg_os_poll! {
sockaddr.sun_family = libc::AF_UNIX as libc::sa_family_t;

let bytes = path.as_os_str().as_bytes();
match (bytes.get(0), bytes.len().cmp(&sockaddr.sun_path.len())) {
match (bytes.first(), bytes.len().cmp(&sockaddr.sun_path.len())) {
// Abstract paths don't need a null terminator
(Some(&0), Ordering::Greater) => {
return Err(io::Error::new(
Expand All @@ -64,7 +64,7 @@ cfg_os_poll! {
let offset = path_offset(&sockaddr);
let mut socklen = offset + bytes.len();

match bytes.get(0) {
match bytes.first() {
// The struct has already been zeroes so the null byte for pathname
// addresses is already there.
Some(&0) | None => {}
Expand Down
4 changes: 2 additions & 2 deletions src/sys/windows/iocp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ impl CompletionStatus {
/// A completion key is a per-handle key that is specified when it is added
/// to an I/O completion port via `add_handle` or `add_socket`.
pub fn token(&self) -> usize {
self.0.lpCompletionKey as usize
self.0.lpCompletionKey
}

/// Returns a pointer to the `Overlapped` structure that was specified when
Expand Down Expand Up @@ -268,6 +268,6 @@ mod tests {
}
assert_eq!(s[2].bytes_transferred(), 0);
assert_eq!(s[2].token(), 0);
assert_eq!(s[2].overlapped(), 0 as *mut _);
assert_eq!(s[2].overlapped(), std::ptr::null_mut());
}
}
2 changes: 1 addition & 1 deletion src/sys/windows/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ macro_rules! wsa_syscall {
}

cfg_net! {
pub(crate) mod stdnet;
pub mod stdnet;
pub(crate) mod uds;
pub(crate) use self::uds::SocketAddr;
}
Expand Down
5 changes: 3 additions & 2 deletions src/sys/windows/net.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::io;
use std::mem;
use std::net::SocketAddr;
use std::sync::Once;

use windows_sys::Win32::Networking::WinSock::{
closesocket, ioctlsocket, socket, AF_INET, AF_INET6, FIONBIO, IN6_ADDR, IN6_ADDR_0,
Expand Down Expand Up @@ -74,7 +75,7 @@ pub(crate) fn socket_addr(addr: &SocketAddr) -> (SocketAddrCRepr, i32) {
};

let sockaddr_in = SOCKADDR_IN {
sin_family: AF_INET as u16, // 1
sin_family: AF_INET, // 1
sin_port: addr.port().to_be(),
sin_addr,
sin_zero: [0; 8],
Expand All @@ -96,7 +97,7 @@ pub(crate) fn socket_addr(addr: &SocketAddr) -> (SocketAddrCRepr, i32) {
};

let sockaddr_in6 = SOCKADDR_IN6 {
sin6_family: AF_INET6 as u16, // 23
sin6_family: AF_INET6, // 23
sin6_port: addr.port().to_be(),
sin6_addr,
sin6_flowinfo: addr.flowinfo(),
Expand Down
Loading

0 comments on commit 12c4013

Please sign in to comment.