diff --git a/transports/net/Cargo.toml b/transports/net/Cargo.toml index e206d1c..8759b98 100644 --- a/transports/net/Cargo.toml +++ b/transports/net/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "memberlist-net" -version = "0.3.0" +version = "0.3.1" edition.workspace = true license.workspace = true repository.workspace = true diff --git a/transports/net/src/error.rs b/transports/net/src/error.rs index 688c97d..8cdc906 100644 --- a/transports/net/src/error.rs +++ b/transports/net/src/error.rs @@ -140,6 +140,7 @@ where /// Connection kind. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] #[repr(u8)] +#[non_exhaustive] pub enum ConnectionKind { /// Promised connection, e.g. TCP, QUIC. Promised, @@ -167,6 +168,7 @@ impl ConnectionKind { /// Connection error kind. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] #[repr(u8)] +#[non_exhaustive] pub enum ConnectionErrorKind { /// Failed to accept a connection. Accept, @@ -278,4 +280,12 @@ impl ConnectionError { error: err, } } + + pub(super) fn packet_write_on_transport_shutdown(err: std::io::Error) -> Self { + Self { + kind: ConnectionKind::Packet, + error_kind: ConnectionErrorKind::Close, + error: err, + } + } } diff --git a/transports/net/src/io/send_by_packet.rs b/transports/net/src/io/send_by_packet.rs index 6c8d148..6663236 100644 --- a/transports/net/src/io/send_by_packet.rs +++ b/transports/net/src/io/send_by_packet.rs @@ -536,13 +536,24 @@ where addr: &A::ResolvedAddress, buf: &[u8], ) -> Result> { - self - .next_socket(addr) - .send_to(buf, addr) - .await - .inspect(|num| { - tracing::trace!(remote=%addr, total_bytes = %num, sent=?buf, "memberlist_net.packet"); - }) - .map_err(|e| ConnectionError::packet_write(e).into()) + match self.next_socket(addr) { + Some(skt) => skt + .send_to(buf, addr) + .await + .inspect(|num| { + tracing::trace!(remote=%addr, total_bytes = %num, sent=?buf, "memberlist_net.packet"); + }) + .map_err(|e| ConnectionError::packet_write(e).into()), + None => { + tracing::error!("memberlist_net.packet: transport is being shutdown"); + Err( + ConnectionError::packet_write_on_transport_shutdown(std::io::Error::new( + std::io::ErrorKind::NotConnected, + "transport is being shutdown", + )) + .into(), + ) + } + } } } diff --git a/transports/net/src/lib.rs b/transports/net/src/lib.rs index dce667e..154e32c 100644 --- a/transports/net/src/lib.rs +++ b/transports/net/src/lib.rs @@ -141,10 +141,12 @@ where local_addr: A::Address, packet_rx: PacketSubscriber, stream_rx: StreamSubscriber, + num_v4_sockets: usize, v4_round_robin: AtomicUsize, - v4_sockets: SmallVec::Net as Net>::UdpSocket>>, + v4_sockets: AtomicRefCell::Net as Net>::UdpSocket>>>, + num_v6_sockets: usize, v6_round_robin: AtomicUsize, - v6_sockets: SmallVec::Net as Net>::UdpSocket>>, + v6_sockets: AtomicRefCell::Net as Net>::UdpSocket>>>, stream_layer: Arc, #[cfg(feature = "encryption")] encryptor: Option, @@ -367,9 +369,11 @@ where packet_rx, stream_rx, handles: AtomicRefCell::new(handles), - v4_sockets: v4_sockets.into_iter().map(|(ln, _)| ln).collect(), + num_v4_sockets: v4_sockets.len(), + v4_sockets: AtomicRefCell::new(v4_sockets.into_iter().map(|(ln, _)| ln).collect()), v4_round_robin: AtomicUsize::new(0), - v6_sockets: v6_sockets.into_iter().map(|(ln, _)| ln).collect(), + num_v6_sockets: v6_sockets.len(), + v6_sockets: AtomicRefCell::new(v6_sockets.into_iter().map(|(ln, _)| ln).collect()), v6_round_robin: AtomicUsize::new(0), stream_layer, #[cfg(feature = "encryption")] @@ -383,23 +387,47 @@ where fn next_socket( &self, addr: &A::ResolvedAddress, - ) -> &<::Net as Net>::UdpSocket { - if addr.is_ipv4() { + ) -> Option::Net as Net>::UdpSocket>> { + enum Kind { + V4(usize), + V6(usize), + } + + let kind = if addr.is_ipv4() { // if there's no v4 sockets, we assume remote addr can accept both v4 and v6 // give a try on v6 - if self.v4_sockets.is_empty() { - let idx = self.v6_round_robin.fetch_add(1, Ordering::AcqRel) % self.v6_sockets.len(); - &self.v6_sockets[idx] + if self.num_v4_sockets == 0 { + let idx = self.v6_round_robin.fetch_add(1, Ordering::AcqRel) % self.num_v6_sockets; + Kind::V6(idx) } else { - let idx = self.v4_round_robin.fetch_add(1, Ordering::AcqRel) % self.v4_sockets.len(); - &self.v4_sockets[idx] + let idx = self.v4_round_robin.fetch_add(1, Ordering::AcqRel) % self.num_v4_sockets; + Kind::V4(idx) } - } else if self.v6_sockets.is_empty() { - let idx = self.v4_round_robin.fetch_add(1, Ordering::AcqRel) % self.v4_sockets.len(); - &self.v4_sockets[idx] + } else if self.num_v6_sockets == 0 { + let idx = self.v4_round_robin.fetch_add(1, Ordering::AcqRel) % self.num_v4_sockets; + Kind::V4(idx) } else { - let idx = self.v6_round_robin.fetch_add(1, Ordering::AcqRel) % self.v6_sockets.len(); - &self.v6_sockets[idx] + let idx = self.v6_round_robin.fetch_add(1, Ordering::AcqRel) % self.num_v6_sockets; + Kind::V6(idx) + }; + + // if we failed to borrow, it means that this transport is being shut down. + + match kind { + Kind::V4(idx) => { + if let Ok(sockets) = self.v4_sockets.try_borrow() { + Some(sockets[idx].clone()) + } else { + None + } + } + Kind::V6(idx) => { + if let Ok(sockets) = self.v6_sockets.try_borrow() { + Some(sockets[idx].clone()) + } else { + None + } + } } } } @@ -666,6 +694,21 @@ where return Ok(()); } + // clear all udp sockets + loop { + if let Ok(mut s) = self.v4_sockets.try_borrow_mut() { + s.clear(); + break; + } + } + + loop { + if let Ok(mut s) = self.v6_sockets.try_borrow_mut() { + s.clear(); + break; + } + } + let mut handles = core::mem::take(&mut *self.handles.borrow_mut()); while handles.next().await.is_some() {} Ok(())