diff --git a/src/lib.rs b/src/lib.rs index b04d1d9..cf35fd4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1847,12 +1847,10 @@ impl Async { /// # std::io::Result::Ok(()) }); /// ``` pub async fn connect>(path: P) -> io::Result> { + let address = convert_path_to_socket_address(path.as_ref())?; + // Begin async connect. - let socket = connect( - rn::SocketAddrUnix::new(path.as_ref())?.into(), - rn::AddressFamily::UNIX, - None, - )?; + let socket = connect(address.into(), rn::AddressFamily::UNIX, None)?; // Use new_nonblocking because connect already sets socket to non-blocking mode. let stream = Async::new_nonblocking(UnixStream::from(socket))?; @@ -2195,3 +2193,31 @@ fn set_nonblocking( Ok(()) } + +/// Converts a `Path` to its socket address representation. +/// +/// This function is abstract socket-aware. +#[cfg(unix)] +#[inline] +fn convert_path_to_socket_address(path: &Path) -> io::Result { + // SocketAddrUnix::new() will throw EINVAL when a path with a zero in it is passed in. + // However, some users expect to be able to pass in paths to abstract sockets, which + // triggers this error as it has a zero in it. Therefore, if a path starts with a zero, + // make it an abstract socket. + #[cfg(any(target_os = "linux", target_os = "android"))] + let address = { + use std::os::unix::ffi::OsStrExt; + + let path = path.as_os_str(); + match path.as_bytes().first() { + Some(0) => rn::SocketAddrUnix::new_abstract_name(path.as_bytes().get(1..).unwrap())?, + _ => rn::SocketAddrUnix::new(path)?, + } + }; + + // Only Linux and Android support abstract sockets. + #[cfg(not(any(target_os = "linux", target_os = "android")))] + let address = rn::SocketAddrUnix::new(path)?; + + Ok(address) +} diff --git a/tests/async.rs b/tests/async.rs index 9218e71..ea9f11b 100644 --- a/tests/async.rs +++ b/tests/async.rs @@ -383,3 +383,59 @@ fn duplicate_socket_insert() -> io::Result<()> { Ok(()) }) } + +#[cfg(any(target_os = "linux", target_os = "android"))] +#[test] +fn abstract_socket() -> io::Result<()> { + use std::ffi::OsStr; + #[cfg(target_os = "android")] + use std::os::android::net::SocketAddrExt; + #[cfg(target_os = "linux")] + use std::os::linux::net::SocketAddrExt; + use std::os::unix::ffi::OsStrExt; + use std::os::unix::net::{SocketAddr, UnixListener, UnixStream}; + + future::block_on(async { + // Bind a listener to a socket. + let path = OsStr::from_bytes(b"\0smolabstract"); + let addr = SocketAddr::from_abstract_name(b"smolabstract")?; + let listener = Async::new(UnixListener::bind_addr(&addr)?)?; + + // Future that connects to the listener. + let connector = async { + // Connect to the socket. + let mut stream = Async::::connect(path).await?; + + // Write some bytes to the stream. + stream.write_all(LOREM_IPSUM).await?; + + // Read some bytes from the stream. + let mut buf = vec![0; LOREM_IPSUM.len()]; + stream.read_exact(&mut buf).await?; + assert_eq!(buf.as_slice(), LOREM_IPSUM); + + io::Result::Ok(()) + }; + + // Future that drives the listener. + let driver = async { + // Wait for a new connection. + let (mut stream, _) = listener.accept().await?; + + // Read some bytes from the stream. + let mut buf = vec![0; LOREM_IPSUM.len()]; + stream.read_exact(&mut buf).await?; + assert_eq!(buf.as_slice(), LOREM_IPSUM); + + // Write some bytes to the stream. + stream.write_all(LOREM_IPSUM).await?; + + io::Result::Ok(()) + }; + + // Run both in parallel. + future::try_zip(connector, driver).await?; + + Ok(()) + }) +}