Skip to content

Commit

Permalink
feat: add support for DSCP and TTL / Hop Limit
Browse files Browse the repository at this point in the history
* Support IP_RECVTTL and IPV6_RECVHOPLIMIT socket options
and related control messages for recvmsg.
* Support setting DSCP in control messages for both sendmsg
and recvmsg.

Signed-off-by: Bigo <[email protected]>
  • Loading branch information
crisidev committed Jun 3, 2024
1 parent 1939f92 commit 8a030dc
Show file tree
Hide file tree
Showing 4 changed files with 415 additions and 3 deletions.
89 changes: 89 additions & 0 deletions src/sys/socket/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,32 @@ pub enum ControlMessageOwned {
#[cfg_attr(docsrs, doc(cfg(feature = "net")))]
Ipv6OrigDstAddr(libc::sockaddr_in6),

/// Time-to-Live (TTL) header field of the incoming IPv4 packet.
///
/// [Further reading](https://www.man7.org/linux/man-pages/man7/ip.7.html)
#[cfg(any(linux_android, target_os = "freebsd"))]
#[cfg(feature = "net")]
#[cfg_attr(docsrs, doc(cfg(feature = "net")))]
Ipv4RecvTtl(u8),

/// Hop Limit header field of the incoming IPv6 packet.
#[cfg(any(linux_android, target_os = "freebsd"))]
#[cfg(feature = "net")]
#[cfg_attr(docsrs, doc(cfg(feature = "net")))]
Ipv6RecvHopLimit(u8),

/// Retrieve the DSCP (ToS) header field of the incoming IPv4 packet.
#[cfg(any(linux_android, target_os = "freebsd"))]
#[cfg(feature = "net")]
#[cfg_attr(docsrs, doc(cfg(feature = "net")))]
Ipv4Tos(i32),

/// Retrieve the DSCP (Traffic Class) header field of the incoming IPv6 packet.
#[cfg(any(linux_android, target_os = "freebsd"))]
#[cfg(feature = "net")]
#[cfg_attr(docsrs, doc(cfg(feature = "net")))]
Ipv6TClass(i32),

/// UDP Generic Receive Offload (GRO) allows receiving multiple UDP
/// packets from a single sender.
/// Fixed-size payloads are following one by one in a receive buffer.
Expand Down Expand Up @@ -987,6 +1013,30 @@ impl ControlMessageOwned {
let content_type = unsafe { ptr::read_unaligned(p as *const u8) };
ControlMessageOwned::TlsGetRecordType(content_type.into())
},
#[cfg(any(linux_android, target_os = "freebsd"))]
#[cfg(feature = "net")]
(libc::IPPROTO_IP, libc::IP_TTL) => {
let ttl: u8 = unsafe { ptr::read_unaligned(p as *const u8) };
ControlMessageOwned::Ipv4RecvTtl(ttl)
},
#[cfg(any(linux_android, target_os = "freebsd"))]
#[cfg(feature = "net")]
(libc::IPPROTO_IPV6, libc::IPV6_HOPLIMIT) => {
let ttl: u8 = unsafe { ptr::read_unaligned(p as *const u8) };
ControlMessageOwned::Ipv6RecvHopLimit(ttl)
},
#[cfg(any(linux_android, target_os = "freebsd"))]
#[cfg(feature = "net")]
(libc::IPPROTO_IP, libc::IP_TOS) => {
let tos = unsafe { ptr::read_unaligned(p as *const i32) };
ControlMessageOwned::Ipv4Tos(tos)
},
#[cfg(any(linux_android, target_os = "freebsd"))]
#[cfg(feature = "net")]
(libc::IPPROTO_IPV6, libc::IPV6_TCLASS) => {
let tc = unsafe { ptr::read_unaligned(p as *const i32) };
ControlMessageOwned::Ipv6TClass(tc)
},
(_, _) => {
let sl = unsafe { std::slice::from_raw_parts(p, len) };
let ucmsg = UnknownCmsg(*header, Vec::<u8>::from(sl));
Expand Down Expand Up @@ -1152,6 +1202,17 @@ pub enum ControlMessage<'a> {
/// page.
#[cfg(target_os = "linux")]
TxTime(&'a u64),

/// Configure DSCP / IP TOS for outgoing v4 packets.
///
/// Further information can be found [here](https://en.wikipedia.org/wiki/Differentiated_services).
#[cfg(target_os = "linux")]
IpTos(&'a i32),
/// Configure DSCP / IP TOS for outgoing v6 packets.
///
/// Further information can be found [here](https://en.wikipedia.org/wiki/Differentiated_services).
#[cfg(target_os = "linux")]
Ipv6TClass(&'a i32),
}

// An opaque structure used to prevent cmsghdr from being a public type
Expand Down Expand Up @@ -1256,6 +1317,14 @@ impl<'a> ControlMessage<'a> {
ControlMessage::TxTime(tx_time) => {
tx_time as *const _ as *const u8
},
#[cfg(target_os = "linux")]
ControlMessage::IpTos(tos) => {
tos as *const _ as *const u8
},
#[cfg(target_os = "linux")]
ControlMessage::Ipv6TClass(tclass) => {
tclass as *const _ as *const u8
},
};
unsafe {
ptr::copy_nonoverlapping(
Expand Down Expand Up @@ -1320,6 +1389,14 @@ impl<'a> ControlMessage<'a> {
ControlMessage::TxTime(tx_time) => {
mem::size_of_val(tx_time)
},
#[cfg(target_os = "linux")]
ControlMessage::IpTos(tos) => {
mem::size_of_val(tos)
},
#[cfg(target_os = "linux")]
ControlMessage::Ipv6TClass(tclass) => {
mem::size_of_val(tclass)
},
}
}

Expand Down Expand Up @@ -1354,6 +1431,10 @@ impl<'a> ControlMessage<'a> {
ControlMessage::RxqOvfl(_) => libc::SOL_SOCKET,
#[cfg(target_os = "linux")]
ControlMessage::TxTime(_) => libc::SOL_SOCKET,
#[cfg(target_os = "linux")]
ControlMessage::IpTos(_) => libc::IPPROTO_IP,
#[cfg(target_os = "linux")]
ControlMessage::Ipv6TClass(_) => libc::IPPROTO_IPV6,
}
}

Expand Down Expand Up @@ -1403,6 +1484,14 @@ impl<'a> ControlMessage<'a> {
ControlMessage::TxTime(_) => {
libc::SCM_TXTIME
},
#[cfg(target_os = "linux")]
ControlMessage::IpTos(_) => {
libc::IP_TOS
},
#[cfg(target_os = "linux")]
ControlMessage::Ipv6TClass(_) => {
libc::IPV6_TCLASS
},
}
}

Expand Down
51 changes: 48 additions & 3 deletions src/sys/socket/sockopt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ sockopt_impl!(
#[cfg(feature = "net")]
sockopt_impl!(
#[cfg_attr(docsrs, doc(cfg(feature = "net")))]
/// Set or receive the Type-Of-Service (TOS) field that is
/// Set the Type-Of-Service (TOS) field that is
/// sent with every IP packet originating from this socket
IpTos,
Both,
Expand All @@ -418,13 +418,35 @@ sockopt_impl!(
#[cfg(feature = "net")]
sockopt_impl!(
#[cfg_attr(docsrs, doc(cfg(feature = "net")))]
/// Traffic class associated with outgoing packets
/// Receive the Type-Of-Service (TOS) associated with incoming packets.
IpRecvTos,
Both,
libc::IPPROTO_IP,
libc::IP_RECVTOS,
bool
);
#[cfg(target_os = "linux")]
#[cfg(feature = "net")]
sockopt_impl!(
#[cfg_attr(docsrs, doc(cfg(feature = "net")))]
/// Set the traffic class associated with outgoing packets.
Ipv6TClass,
Both,
libc::IPPROTO_IPV6,
libc::IPV6_TCLASS,
libc::c_int
);
#[cfg(target_os = "linux")]
#[cfg(feature = "net")]
sockopt_impl!(
#[cfg_attr(docsrs, doc(cfg(feature = "net")))]
/// Receive the traffic class associated with incoming packets.
Ipv6TRecvTClass,
Both,
libc::IPPROTO_IPV6,
libc::IPV6_RECVTCLASS,
bool
);
#[cfg(any(linux_android, target_os = "fuchsia"))]
#[cfg(feature = "net")]
sockopt_impl!(
Expand Down Expand Up @@ -1045,7 +1067,19 @@ sockopt_impl!(
libc::IP_TTL,
libc::c_int
);
#[cfg(any(apple_targets, linux_android, target_os = "freebsd"))]
#[cfg(any(linux_android, target_os = "freebsd"))]
#[cfg(feature = "net")]
sockopt_impl!(
/// Enables a receiving socket to retrieve the Time-to-Live (TTL) field
/// from incoming IPv4 packets.
Ipv4RecvTtl,
Both,
libc::IPPROTO_IP,
libc::IP_RECVTTL,
bool
);
#[cfg(any(linux_android, target_os = "freebsd"))]
#[cfg(feature = "net")]
sockopt_impl!(
/// Set the unicast hop limit for the socket.
Ipv6Ttl,
Expand All @@ -1056,6 +1090,17 @@ sockopt_impl!(
);
#[cfg(any(linux_android, target_os = "freebsd"))]
#[cfg(feature = "net")]
sockopt_impl!(
/// Enables a receiving socket to retrieve the Hop Limit field
/// (similar to TTL in IPv4) from incoming IPv6 packets.
Ipv6RecvHopLimit,
Both,
libc::IPPROTO_IPV6,
libc::IPV6_RECVHOPLIMIT,
bool
);
#[cfg(any(linux_android, target_os = "freebsd"))]
#[cfg(feature = "net")]
sockopt_impl!(
#[cfg_attr(docsrs, doc(cfg(feature = "net")))]
/// The `recvmsg(2)` call will return the destination IP address for a UDP
Expand Down
160 changes: 160 additions & 0 deletions test/sys/test_socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2544,6 +2544,166 @@ fn test_recvmsg_rxq_ovfl() {
assert_eq!(drop_counter, 1);
}

#[cfg(target_os = "linux")]
#[cfg(feature = "net")]
#[cfg_attr(qemu, ignore)]
#[test]
pub fn test_ip_tos_udp() {
use nix::sys::socket::ControlMessageOwned;
use nix::sys::socket::{
bind, recvmsg, sendmsg, setsockopt, socket, sockopt, ControlMessage,
MsgFlags, SockFlag, SockType, SockaddrIn,
};

let sock_addr = SockaddrIn::from_str("127.0.0.1:6909").unwrap();
let rsock = socket(
AddressFamily::Inet,
SockType::Datagram,
SockFlag::empty(),
None,
)
.unwrap();
setsockopt(&rsock, sockopt::IpRecvTos, &true).unwrap();
bind(rsock.as_raw_fd(), &sock_addr).unwrap();

let sbuf = [0u8; 2048];
let iov1 = [std::io::IoSlice::new(&sbuf)];

let mut rbuf = [0u8; 2048];
let mut iov2 = [std::io::IoSliceMut::new(&mut rbuf)];
let mut rcmsg = cmsg_space!(libc::c_int);

let ssock = socket(
AddressFamily::Inet,
SockType::Datagram,
SockFlag::empty(),
None,
)
.expect("send socket failed");
setsockopt(&ssock, sockopt::IpTos, &20).unwrap();

// Test the sendmsg control message and check the received packet has the same TOS.
let scmsg = ControlMessage::IpTos(&20);
sendmsg(
ssock.as_raw_fd(),
&iov1,
&[scmsg],
MsgFlags::empty(),
Some(&sock_addr),
)
.unwrap();

// TODO: this test is weak, but testing for the actual ToS value results in sporadic
// failures in CI where the ToS in the message header is not the one set by the
// sender, so for now the test only checks for the presence of the ToS in the message
// header.
let mut tc = None;
let recv = recvmsg::<()>(
rsock.as_raw_fd(),
&mut iov2,
Some(&mut rcmsg),
MsgFlags::empty(),
)
.unwrap();
for c in recv.cmsgs().unwrap() {
println!("CMSG: {c:?}");
if let ControlMessageOwned::Ipv4Tos(t) = c {
tc = Some(t);
}
}
assert!(tc.is_some());
}

#[cfg(target_os = "linux")]
#[cfg(feature = "net")]
#[cfg_attr(qemu, ignore)]
#[test]
pub fn test_ipv6_tclass_udp() {
use nix::sys::socket::ControlMessageOwned;
use nix::sys::socket::{
bind, recvmsg, sendmsg, setsockopt, socket, sockopt, ControlMessage,
MsgFlags, SockFlag, SockType, SockaddrIn6,
};

let std_sa = SocketAddrV6::from_str("[::1]:6902").unwrap();
let sock_addr: SockaddrIn6 = SockaddrIn6::from(std_sa);
let rsock = socket(
AddressFamily::Inet6,
SockType::Datagram,
SockFlag::empty(),
None,
)
.unwrap();
setsockopt(&rsock, sockopt::Ipv6TRecvTClass, &true).unwrap();
bind(rsock.as_raw_fd(), &sock_addr).unwrap();

let sbuf = [0u8; 2048];
let iov1 = [std::io::IoSlice::new(&sbuf)];

let mut rbuf = [0u8; 2048];
let mut iov2 = [std::io::IoSliceMut::new(&mut rbuf)];
let mut rcmsg = cmsg_space!(libc::c_int);

let ssock = socket(
AddressFamily::Inet6,
SockType::Datagram,
SockFlag::empty(),
None,
)
.expect("send socket failed");
setsockopt(&ssock, sockopt::Ipv6TClass, &10).unwrap();

sendmsg(
ssock.as_raw_fd(),
&iov1,
&[],
MsgFlags::empty(),
Some(&sock_addr),
)
.unwrap();

let mut tc = None;
let recv = recvmsg::<()>(
rsock.as_raw_fd(),
&mut iov2,
Some(&mut rcmsg),
MsgFlags::empty(),
)
.unwrap();
for c in recv.cmsgs().unwrap() {
if let ControlMessageOwned::Ipv6TClass(t) = c {
tc = Some(t);
}
}
assert_eq!(tc, Some(10));

let scmsg = ControlMessage::Ipv6TClass(&20);
sendmsg(
ssock.as_raw_fd(),
&iov1,
&[scmsg],
MsgFlags::empty(),
Some(&sock_addr),
)
.unwrap();

let mut tc = None;
let recv = recvmsg::<()>(
rsock.as_raw_fd(),
&mut iov2,
Some(&mut rcmsg),
MsgFlags::empty(),
)
.unwrap();
for c in recv.cmsgs().unwrap() {
if let ControlMessageOwned::Ipv6TClass(t) = c {
tc = Some(t);
}
}

assert_eq!(tc, Some(20));
}

#[cfg(linux_android)]
mod linux_errqueue {
use super::FromStr;
Expand Down
Loading

0 comments on commit 8a030dc

Please sign in to comment.