From aefff651b36e49a99a9f32f844939856b12fd68c Mon Sep 17 00:00:00 2001 From: Andy Grover Date: Thu, 9 May 2024 17:35:29 -0700 Subject: [PATCH] recvmsg: Check if CMSG buffer was too small and return an error If MSG_CTRUNC is set, it is not safe to iterate the cmsgs, since they could have been truncated. Change RecvMsg::cmsgs() to return a Result, and to check for this flag (an API change). Update tests for API change. Add test for too-small buffer. --- changelog/2413.changed.md | 1 + src/sys/socket/mod.rs | 12 ++++++--- test/sys/test_socket.rs | 57 +++++++++++++++++++++++---------------- 3 files changed, 44 insertions(+), 26 deletions(-) create mode 100644 changelog/2413.changed.md diff --git a/changelog/2413.changed.md b/changelog/2413.changed.md new file mode 100644 index 0000000000..7bae72f7d8 --- /dev/null +++ b/changelog/2413.changed.md @@ -0,0 +1 @@ +`RecvMsg::cmsgs()` now returns a `Result`, and checks that cmsgs were not truncated. diff --git a/src/sys/socket/mod.rs b/src/sys/socket/mod.rs index 3d1651bd3f..cba2938c87 100644 --- a/src/sys/socket/mod.rs +++ b/src/sys/socket/mod.rs @@ -13,6 +13,7 @@ use libc::{self, c_int, size_t, socklen_t}; #[cfg(all(feature = "uio", not(target_os = "redox")))] use libc::{ c_void, iovec, CMSG_DATA, CMSG_FIRSTHDR, CMSG_LEN, CMSG_NXTHDR, CMSG_SPACE, + MSG_CTRUNC, }; #[cfg(not(target_os = "redox"))] use std::io::{IoSlice, IoSliceMut}; @@ -601,11 +602,16 @@ pub struct RecvMsg<'a, 's, S> { impl<'a, S> RecvMsg<'a, '_, S> { /// Iterate over the valid control messages pointed to by this /// msghdr. - pub fn cmsgs(&self) -> CmsgIterator { - CmsgIterator { + pub fn cmsgs(&self) -> Result { + + if self.mhdr.msg_flags & MSG_CTRUNC == MSG_CTRUNC { + return Err(Errno::ENOBUFS); + } + + Ok(CmsgIterator { cmsghdr: self.cmsghdr, mhdr: &self.mhdr - } + }) } } diff --git a/test/sys/test_socket.rs b/test/sys/test_socket.rs index ee60e62b45..2eeb667eb8 100644 --- a/test/sys/test_socket.rs +++ b/test/sys/test_socket.rs @@ -4,6 +4,7 @@ use libc::c_char; use nix::sys::socket::{getsockname, AddressFamily, UnixAddr}; use std::collections::hash_map::DefaultHasher; use std::hash::{Hash, Hasher}; +use std::io; use std::net::{SocketAddrV4, SocketAddrV6}; use std::os::unix::io::{AsRawFd, RawFd}; use std::path::Path; @@ -55,7 +56,7 @@ pub fn test_timestamping() { .unwrap(); let mut ts = None; - for c in recv.cmsgs() { + for c in recv.cmsgs().unwrap() { if let ControlMessageOwned::ScmTimestampsns(timestamps) = c { ts = Some(timestamps.system); } @@ -889,7 +890,7 @@ pub fn test_scm_rights() { ) .unwrap(); - for cmsg in msg.cmsgs() { + for cmsg in msg.cmsgs().unwrap() { if let ControlMessageOwned::ScmRights(fd) = cmsg { assert_eq!(received_r, None); assert_eq!(fd.len(), 1); @@ -1330,7 +1331,7 @@ fn test_scm_rights_single_cmsg_multiple_fds() { .flags .intersects(MsgFlags::MSG_TRUNC | MsgFlags::MSG_CTRUNC)); - let mut cmsgs = msg.cmsgs(); + let mut cmsgs = msg.cmsgs().unwrap(); match cmsgs.next() { Some(ControlMessageOwned::ScmRights(fds)) => { assert_eq!( @@ -1399,7 +1400,7 @@ pub fn test_sendmsg_empty_cmsgs() { ) .unwrap(); - if msg.cmsgs().next().is_some() { + if msg.cmsgs().unwrap().next().is_some() { panic!("unexpected cmsg"); } assert!(!msg @@ -1466,7 +1467,7 @@ fn test_scm_credentials() { .unwrap(); let mut received_cred = None; - for cmsg in msg.cmsgs() { + for cmsg in msg.cmsgs().unwrap() { let cred = match cmsg { #[cfg(linux_android)] ControlMessageOwned::ScmCredentials(cred) => cred, @@ -1497,7 +1498,7 @@ fn test_scm_credentials() { #[test] fn test_scm_credentials_and_rights() { let space = cmsg_space!(libc::ucred, RawFd); - test_impl_scm_credentials_and_rights(space); + test_impl_scm_credentials_and_rights(space).unwrap(); } /// Ensure that passing a an oversized control message buffer to recvmsg @@ -1509,11 +1510,20 @@ fn test_scm_credentials_and_rights() { #[test] fn test_too_large_cmsgspace() { let space = vec![0u8; 1024]; - test_impl_scm_credentials_and_rights(space); + test_impl_scm_credentials_and_rights(space).unwrap(); +} + +#[cfg(linux_android)] +#[test] +fn test_too_small_cmsgspace() { + let space = vec![0u8; 4]; + assert!(test_impl_scm_credentials_and_rights(space).is_err()); } #[cfg(linux_android)] -fn test_impl_scm_credentials_and_rights(mut space: Vec) { +fn test_impl_scm_credentials_and_rights( + mut space: Vec, +) -> Result<(), io::Error> { use libc::ucred; use nix::sys::socket::sockopt::PassCred; use nix::sys::socket::{ @@ -1573,9 +1583,9 @@ fn test_impl_scm_credentials_and_rights(mut space: Vec) { .unwrap(); let mut received_cred = None; - assert_eq!(msg.cmsgs().count(), 2, "expected 2 cmsgs"); + assert_eq!(msg.cmsgs()?.count(), 2, "expected 2 cmsgs"); - for cmsg in msg.cmsgs() { + for cmsg in msg.cmsgs()? { match cmsg { ControlMessageOwned::ScmRights(fds) => { assert_eq!(received_r, None, "already received fd"); @@ -1606,6 +1616,8 @@ fn test_impl_scm_credentials_and_rights(mut space: Vec) { read(received_r.as_raw_fd(), &mut buf).unwrap(); assert_eq!(&buf[..], b"world"); close(received_r).unwrap(); + + Ok(()) } // Test creating and using named unix domain sockets @@ -1742,7 +1754,6 @@ fn loopback_address( use nix::ifaddrs::getifaddrs; use nix::net::if_::*; use nix::sys::socket::SockaddrLike; - use std::io; use std::io::Write; let mut addrs = match getifaddrs() { @@ -1837,7 +1848,7 @@ pub fn test_recv_ipv4pktinfo() { .flags .intersects(MsgFlags::MSG_TRUNC | MsgFlags::MSG_CTRUNC)); - let mut cmsgs = msg.cmsgs(); + let mut cmsgs = msg.cmsgs().unwrap(); if let Some(ControlMessageOwned::Ipv4PacketInfo(pktinfo)) = cmsgs.next() { let i = if_nametoindex(lo_name.as_bytes()).expect("if_nametoindex"); @@ -1929,7 +1940,7 @@ pub fn test_recvif() { assert!(!msg .flags .intersects(MsgFlags::MSG_TRUNC | MsgFlags::MSG_CTRUNC)); - assert_eq!(msg.cmsgs().count(), 2, "expected 2 cmsgs"); + assert_eq!(msg.cmsgs().unwrap().count(), 2, "expected 2 cmsgs"); let mut rx_recvif = false; let mut rx_recvdstaddr = false; @@ -2027,10 +2038,10 @@ pub fn test_recvif_ipv4() { assert!(!msg .flags .intersects(MsgFlags::MSG_TRUNC | MsgFlags::MSG_CTRUNC)); - assert_eq!(msg.cmsgs().count(), 1, "expected 1 cmsgs"); + assert_eq!(msg.cmsgs().unwrap().count(), 1, "expected 1 cmsgs"); let mut rx_recvorigdstaddr = false; - for cmsg in msg.cmsgs() { + for cmsg in msg.cmsgs().unwrap() { match cmsg { ControlMessageOwned::Ipv4OrigDstAddr(addr) => { rx_recvorigdstaddr = true; @@ -2113,10 +2124,10 @@ pub fn test_recvif_ipv6() { assert!(!msg .flags .intersects(MsgFlags::MSG_TRUNC | MsgFlags::MSG_CTRUNC)); - assert_eq!(msg.cmsgs().count(), 1, "expected 1 cmsgs"); + assert_eq!(msg.cmsgs().unwrap().count(), 1, "expected 1 cmsgs"); let mut rx_recvorigdstaddr = false; - for cmsg in msg.cmsgs() { + for cmsg in msg.cmsgs().unwrap() { match cmsg { ControlMessageOwned::Ipv6OrigDstAddr(addr) => { rx_recvorigdstaddr = true; @@ -2214,7 +2225,7 @@ pub fn test_recv_ipv6pktinfo() { .flags .intersects(MsgFlags::MSG_TRUNC | MsgFlags::MSG_CTRUNC)); - let mut cmsgs = msg.cmsgs(); + let mut cmsgs = msg.cmsgs().unwrap(); if let Some(ControlMessageOwned::Ipv6PacketInfo(pktinfo)) = cmsgs.next() { let i = if_nametoindex(lo_name.as_bytes()).expect("if_nametoindex"); @@ -2357,7 +2368,7 @@ fn test_recvmsg_timestampns() { flags, ) .unwrap(); - let rtime = match r.cmsgs().next() { + let rtime = match r.cmsgs().unwrap().next() { Some(ControlMessageOwned::ScmTimestampns(rtime)) => rtime, Some(_) => panic!("Unexpected control message"), None => panic!("No control message"), @@ -2418,7 +2429,7 @@ fn test_recvmmsg_timestampns() { ) .unwrap() .collect(); - let rtime = match r[0].cmsgs().next() { + let rtime = match r[0].cmsgs().unwrap().next() { Some(ControlMessageOwned::ScmTimestampns(rtime)) => rtime, Some(_) => panic!("Unexpected control message"), None => panic!("No control message"), @@ -2508,7 +2519,7 @@ fn test_recvmsg_rxq_ovfl() { MsgFlags::MSG_DONTWAIT, ) { Ok(r) => { - drop_counter = match r.cmsgs().next() { + drop_counter = match r.cmsgs().unwrap().next() { Some(ControlMessageOwned::RxqOvfl(drop_counter)) => { drop_counter } @@ -2687,7 +2698,7 @@ mod linux_errqueue { assert_eq!(msg.address, Some(sock_addr)); // Check for expected control message. - let ext_err = match msg.cmsgs().next() { + let ext_err = match msg.cmsgs().unwrap().next() { Some(cmsg) => testf(&cmsg), None => panic!("No control message"), }; @@ -2878,7 +2889,7 @@ fn test_recvmm2() -> nix::Result<()> { #[cfg(not(any(qemu, target_arch = "aarch64")))] let mut saw_time = false; let mut recvd = 0; - for cmsg in rmsg.cmsgs() { + for cmsg in rmsg.cmsgs().unwrap() { if let ControlMessageOwned::ScmTimestampsns(timestamps) = cmsg { let ts = timestamps.system;