diff --git a/src/backend/libc/net/read_sockaddr.rs b/src/backend/libc/net/read_sockaddr.rs index 575102c27..4ce4f0ab2 100644 --- a/src/backend/libc/net/read_sockaddr.rs +++ b/src/backend/libc/net/read_sockaddr.rs @@ -186,23 +186,51 @@ unsafe fn inner_read_sockaddr_os( if len == offsetof_sun_path { SocketAddrAny::Unix(SocketAddrUnix::new(&[][..]).unwrap()) } else { + #[cfg(not(any(target_os = "android", target_os = "linux")))] + fn try_decode_abstract_socket( + _sockaddr: &c::sockaddr_un, + _len: usize, + ) -> Option { + None + } + #[cfg(any(target_os = "android", target_os = "linux"))] + fn try_decode_abstract_socket( + decode: &c::sockaddr_un, + len: usize, + ) -> Option { + if decode.sun_path[0] != 0 { + None + } else { + let offsetof_sun_path = super::addr::offsetof_sun_path(); + let address_bytes = &decode.sun_path[1..len - offsetof_sun_path]; + Some( + SocketAddrUnix::new_abstract_name( + &address_bytes.iter().map(|c| *c as u8).collect::>(), + ) + .unwrap(), + ) + } + } + let decode = *storage.cast::(); - assert_eq!( - decode.sun_path[len - 1 - offsetof_sun_path], - b'\0' as c::c_char - ); - let path_bytes = &decode.sun_path[..len - 1 - offsetof_sun_path]; - - // FreeBSD sometimes sets the length to longer than the length - // of the NUL-terminated string. Find the NUL and truncate the - // string accordingly. - #[cfg(target_os = "freebsd")] - let path_bytes = &path_bytes[..path_bytes.iter().position(|b| *b == 0).unwrap()]; - - SocketAddrAny::Unix( + let result = try_decode_abstract_socket(&decode, len).unwrap_or_else(|| { + assert_eq!( + decode.sun_path[len - 1 - offsetof_sun_path], + b'\0' as c::c_char + ); + let path_bytes = &decode.sun_path[..len - 1 - offsetof_sun_path]; + + // FreeBSD sometimes sets the length to longer than the length + // of the NUL-terminated string. Find the NUL and truncate the + // string accordingly. + #[cfg(target_os = "freebsd")] + let path_bytes = + &path_bytes[..path_bytes.iter().position(|b| *b == 0).unwrap()]; + SocketAddrUnix::new(path_bytes.iter().map(|c| *c as u8).collect::>()) - .unwrap(), - ) + .unwrap() + }); + SocketAddrAny::Unix(result) } } other => unimplemented!("{:?}", other), diff --git a/src/backend/linux_raw/net/read_sockaddr.rs b/src/backend/linux_raw/net/read_sockaddr.rs index b9bc09b96..fdab8d108 100644 --- a/src/backend/linux_raw/net/read_sockaddr.rs +++ b/src/backend/linux_raw/net/read_sockaddr.rs @@ -155,19 +155,31 @@ pub(crate) unsafe fn read_sockaddr_os(storage: *const c::sockaddr, len: usize) - SocketAddrAny::Unix(SocketAddrUnix::new(&[][..]).unwrap()) } else { let decode = *storage.cast::(); - assert_eq!( - decode.sun_path[len - 1 - offsetof_sun_path], - b'\0' as c::c_char - ); - SocketAddrAny::Unix( - SocketAddrUnix::new( - decode.sun_path[..len - 1 - offsetof_sun_path] - .iter() - .map(|c| *c as u8) - .collect::>(), + if decode.sun_path[0] == 0 { + SocketAddrAny::Unix( + SocketAddrUnix::new_abstract_name( + &decode.sun_path[1..len - offsetof_sun_path] + .iter() + .map(|c| *c as u8) + .collect::>(), + ) + .unwrap(), + ) + } else { + assert_eq!( + decode.sun_path[len - 1 - offsetof_sun_path], + b'\0' as c::c_char + ); + SocketAddrAny::Unix( + SocketAddrUnix::new( + decode.sun_path[..len - 1 - offsetof_sun_path] + .iter() + .map(|c| *c as u8) + .collect::>(), + ) + .unwrap(), ) - .unwrap(), - ) + } } } other => unimplemented!("{:?}", other), diff --git a/tests/net/unix.rs b/tests/net/unix.rs index b69fc8453..e63bfabc7 100644 --- a/tests/net/unix.rs +++ b/tests/net/unix.rs @@ -142,39 +142,21 @@ fn test_unix() { } #[cfg(not(any(target_os = "redox", target_os = "wasi")))] -#[test] -fn test_unix_msg() { +fn do_test_unix_msg(addr: SocketAddrUnix) { use rustix::io::{IoSlice, IoSliceMut}; use rustix::net::{recvmsg, sendmsg_noaddr, RecvFlags, SendFlags}; - use std::string::ToString; - - let tmpdir = tempfile::tempdir().unwrap(); - let path = tmpdir.path().join("scp_4804"); - let ready = Arc::new((Mutex::new(false), Condvar::new())); let server = { - let ready = ready.clone(); - let path = path.clone(); + let connection_socket = socket( + AddressFamily::UNIX, + SocketType::SEQPACKET, + Protocol::default(), + ) + .unwrap(); + bind_unix(&connection_socket, &addr).unwrap(); + listen(&connection_socket, 1).unwrap(); move || { - let connection_socket = socket( - AddressFamily::UNIX, - SocketType::SEQPACKET, - Protocol::default(), - ) - .unwrap(); - - let name = SocketAddrUnix::new(&path).unwrap(); - bind_unix(&connection_socket, &name).unwrap(); - listen(&connection_socket, 1).unwrap(); - - { - let (lock, cvar) = &*ready; - let mut started = lock.lock().unwrap(); - *started = true; - cvar.notify_all(); - } - let mut buffer = vec![0; BUFFER_SIZE]; 'exit: loop { let data_socket = accept(&connection_socket).unwrap(); @@ -208,21 +190,10 @@ fn test_unix_msg() { ) .unwrap(); } - - unlinkat(cwd(), path, AtFlags::empty()).unwrap(); } }; let client = move || { - { - let (lock, cvar) = &*ready; - let mut started = lock.lock().unwrap(); - while !*started { - started = cvar.wait(started).unwrap(); - } - } - - let addr = SocketAddrUnix::new(path).unwrap(); let mut buffer = vec![0; BUFFER_SIZE]; let runs: &[(&[&str], i32)] = &[ (&["1", "2"], 3), @@ -257,18 +228,25 @@ fn test_unix_msg() { ) .unwrap(); - let nread = recvmsg( + let result = recvmsg( &data_socket, &mut [IoSliceMut::new(&mut buffer)], &mut Default::default(), RecvFlags::empty(), ) - .unwrap() - .bytes; + .unwrap(); + let nread = result.bytes; assert_eq!( i32::from_str(&String::from_utf8_lossy(&buffer[..nread])).unwrap(), *sum ); + // Don't ask me why, but this was seen to fail on FreeBSD. SocketAddrUnix::path() + // returned None for some reason. + #[cfg(not(target_os = "freebsd"))] + assert_eq!( + Some(rustix::net::SocketAddrAny::Unix(addr.clone())), + result.address + ); } let data_socket = socket( @@ -305,6 +283,30 @@ fn test_unix_msg() { server.join().unwrap(); } +#[cfg(not(any(target_os = "redox", target_os = "wasi")))] +#[test] +fn test_unix_msg() { + let tmpdir = tempfile::tempdir().unwrap(); + let path = tmpdir.path().join("scp_4804"); + + let name = SocketAddrUnix::new(&path).unwrap(); + do_test_unix_msg(name); + + unlinkat(cwd(), path, AtFlags::empty()).unwrap(); +} + +#[cfg(any(target_os = "android", target_os = "linux"))] +#[test] +fn test_abstract_unix_msg() { + use std::os::unix::ffi::OsStrExt; + + let tmpdir = tempfile::tempdir().unwrap(); + let path = tmpdir.path().join("scp_4804"); + + let name = SocketAddrUnix::new_abstract_name(path.as_os_str().as_bytes()).unwrap(); + do_test_unix_msg(name); +} + #[cfg(not(any(target_os = "redox", target_os = "wasi")))] #[test] fn test_unix_msg_with_scm_rights() { @@ -318,31 +320,23 @@ fn test_unix_msg_with_scm_rights() { let tmpdir = tempfile::tempdir().unwrap(); let path = tmpdir.path().join("scp_4804"); - let ready = Arc::new((Mutex::new(false), Condvar::new())); let server = { - let ready = ready.clone(); let path = path.clone(); - move || { - let connection_socket = socket( - AddressFamily::UNIX, - SocketType::SEQPACKET, - Protocol::default(), - ) - .unwrap(); - let mut pipe_end = None; + let connection_socket = socket( + AddressFamily::UNIX, + SocketType::SEQPACKET, + Protocol::default(), + ) + .unwrap(); - let name = SocketAddrUnix::new(&path).unwrap(); - bind_unix(&connection_socket, &name).unwrap(); - listen(&connection_socket, 1).unwrap(); + let name = SocketAddrUnix::new(&path).unwrap(); + bind_unix(&connection_socket, &name).unwrap(); + listen(&connection_socket, 1).unwrap(); - { - let (lock, cvar) = &*ready; - let mut started = lock.lock().unwrap(); - *started = true; - cvar.notify_all(); - } + move || { + let mut pipe_end = None; let mut buffer = vec![0; BUFFER_SIZE]; let mut cmsg_space = vec![0; rustix::cmsg_space!(ScmRights(1))]; @@ -403,14 +397,6 @@ fn test_unix_msg_with_scm_rights() { }; let client = move || { - { - let (lock, cvar) = &*ready; - let mut started = lock.lock().unwrap(); - while !*started { - started = cvar.wait(started).unwrap(); - } - } - let addr = SocketAddrUnix::new(path).unwrap(); let (read_end, write_end) = pipe().unwrap(); let mut buffer = vec![0; BUFFER_SIZE];