From 21a83adb63b8385859a10bd63868eb2cec29fbcf Mon Sep 17 00:00:00 2001 From: John Nunley Date: Wed, 6 Dec 2023 17:22:15 -0800 Subject: [PATCH] Add test and fix a bug Signed-off-by: John Nunley --- src/lib.rs | 47 +++++++++++++++++++++++++++----------------- tests/async.rs | 53 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+), 18 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 03917ad..cf35fd4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1847,24 +1847,7 @@ impl Async { /// # std::io::Result::Ok(()) }); /// ``` pub async fn connect>(path: P) -> io::Result> { - // SocketAddrUnix::new() will throw EINVAL when a path with a zero in it is passed in. - // However, some users expect to be able to pass in paths to abstract sockets, which - // triggers this error as it has a zero in it. Therefore, if a path starts with a zero, - // make it an abstract socket. - #[cfg(any(target_os = "linux", target_os = "android"))] - let address = { - use std::os::unix::ffi::OsStrExt; - - let path = path.as_ref().as_os_str(); - match path.as_bytes().first() { - Some(0) => rn::SocketAddrUnix::new_abstract_name(path.as_bytes())?, - _ => rn::SocketAddrUnix::new(path)?, - } - }; - - // Only Linux and Android support abstract sockets. - #[cfg(not(any(target_os = "linux", target_os = "android")))] - let address = rn::SocketAddrUnix::new(path.as_ref())?; + let address = convert_path_to_socket_address(path.as_ref())?; // Begin async connect. let socket = connect(address.into(), rn::AddressFamily::UNIX, None)?; @@ -2210,3 +2193,31 @@ fn set_nonblocking( Ok(()) } + +/// Converts a `Path` to its socket address representation. +/// +/// This function is abstract socket-aware. +#[cfg(unix)] +#[inline] +fn convert_path_to_socket_address(path: &Path) -> io::Result { + // SocketAddrUnix::new() will throw EINVAL when a path with a zero in it is passed in. + // However, some users expect to be able to pass in paths to abstract sockets, which + // triggers this error as it has a zero in it. Therefore, if a path starts with a zero, + // make it an abstract socket. + #[cfg(any(target_os = "linux", target_os = "android"))] + let address = { + use std::os::unix::ffi::OsStrExt; + + let path = path.as_os_str(); + match path.as_bytes().first() { + Some(0) => rn::SocketAddrUnix::new_abstract_name(path.as_bytes().get(1..).unwrap())?, + _ => rn::SocketAddrUnix::new(path)?, + } + }; + + // Only Linux and Android support abstract sockets. + #[cfg(not(any(target_os = "linux", target_os = "android")))] + let address = rn::SocketAddrUnix::new(path)?; + + Ok(address) +} diff --git a/tests/async.rs b/tests/async.rs index 9218e71..bea3a33 100644 --- a/tests/async.rs +++ b/tests/async.rs @@ -383,3 +383,56 @@ fn duplicate_socket_insert() -> io::Result<()> { Ok(()) }) } + +#[cfg(any(target_os = "linux", target_os = "android"))] +#[test] +fn abstract_socket() -> io::Result<()> { + use std::ffi::OsStr; + use std::os::linux::net::SocketAddrExt; + use std::os::unix::ffi::OsStrExt; + use std::os::unix::net::{SocketAddr, UnixListener, UnixStream}; + + future::block_on(async { + // Bind a listener to a socket. + let path = OsStr::from_bytes(b"\0smolabstract"); + let addr = SocketAddr::from_abstract_name(b"smolabstract")?; + let listener = Async::new(UnixListener::bind_addr(&addr)?)?; + + // Future that connects to the listener. + let connector = async { + // Connect to the socket. + let mut stream = Async::::connect(path).await?; + + // Write some bytes to the stream. + stream.write_all(LOREM_IPSUM).await?; + + // Read some bytes from the stream. + let mut buf = vec![0; LOREM_IPSUM.len()]; + stream.read_exact(&mut buf).await?; + assert_eq!(buf.as_slice(), LOREM_IPSUM); + + io::Result::Ok(()) + }; + + // Future that drives the listener. + let driver = async { + // Wait for a new connection. + let (mut stream, _) = listener.accept().await?; + + // Read some bytes from the stream. + let mut buf = vec![0; LOREM_IPSUM.len()]; + stream.read_exact(&mut buf).await?; + assert_eq!(buf.as_slice(), LOREM_IPSUM); + + // Write some bytes to the stream. + stream.write_all(LOREM_IPSUM).await?; + + io::Result::Ok(()) + }; + + // Run both in parallel. + future::try_zip(connector, driver).await?; + + Ok(()) + }) +}