Skip to content

Commit

Permalink
Optimize probe shutdown logic implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
al8n committed Apr 14, 2024
1 parent eff155c commit 8986f1c
Show file tree
Hide file tree
Showing 9 changed files with 86 additions and 125 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ members = [
resolver = "2"

[workspace.package]
version = "0.2.0"
version = "0.2.1"
edition = "2021"
license = "MPL-2.0"
repository = "https://github.com/al8n/memberlist"
Expand Down
5 changes: 1 addition & 4 deletions core/src/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -318,9 +318,6 @@ where
return Err(e);
}

let mut futs = core::mem::take(&mut *self.handles.borrow_mut());
while futs.next().await.is_some() {}

Ok(())
}
}
Expand Down Expand Up @@ -486,7 +483,7 @@ where
loop {
futures::select! {
_ = shutdown_rx.recv().fuse() => {
tracing::info!("memberlist: broadcast queue checker exits");
tracing::debug!("memberlist: broadcast queue checker exits");
return;
},
_ = tick.next().fuse() => {
Expand Down
7 changes: 4 additions & 3 deletions core/src/network/packet/listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ where
loop {
futures::select! {
_ = shutdown_rx.recv().fuse() => {
tracing::debug!("memberlist: packet listener exits");
return;
break;
}
packet = packet_rx.recv().fuse() => {
match packet {
Expand All @@ -60,12 +59,14 @@ where
}
// If we got an error, which means on the other side the transport has been closed,
// so we need to return and shutdown the packet listener
return;
break;
},
}
}
}
}

tracing::debug!("memberlist: packet listener exits");
})
}

Expand Down
102 changes: 55 additions & 47 deletions core/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -652,17 +652,24 @@ macro_rules! bail_trigger {
}

let mut timer = <T::Runtime as RuntimeLite>::interval(interval);
loop {
futures::select! {
_ = futures::StreamExt::next(&mut timer).fuse() => {
this.$fn(&stop_rx).await;
}
'outer: loop {
futures::select_biased! {
_ = stop_rx.recv().fuse() => {
tracing::debug!(concat!("memberlist.state: ", stringify!($fn), " trigger exits"));
return;
break 'outer;
}
_ = futures::StreamExt::next(&mut timer).fuse() => {
if this.inner.shutdown_tx.is_closed() {
break 'outer;
}
let shutdown = this.$fn(&stop_rx).await;
if shutdown {
break 'outer;
}
}
}
}

tracing::debug!(concat!("memberlist.state: ", stringify!($fn), " trigger exits"));
})
}
}
Expand Down Expand Up @@ -750,52 +757,52 @@ where
}

// Used to perform a single round of failure detection and gossip
// TODO(al8n): maybe an infinite loop happening here when graceful shutdown.
// use shutdown_rx for temporary fix.
async fn probe(&self, shutdown_rx: &async_channel::Receiver<()>) {
// FIX(al8n): maybe an infinite loop happening here when graceful shutdown.
async fn probe(&self, shutdown_rx: &async_channel::Receiver<()>) -> bool {
// Track the number of indexes we've considered probing
let mut num_check = 0;
loop {
futures::select_biased! {
_ = shutdown_rx.recv().fuse() => return,
default => {
let memberlist = self.inner.nodes.read().await;
let num_nodes = memberlist.nodes.len();
// Make sure we don't wrap around infinitely
if num_check >= num_nodes {
return;
}
match shutdown_rx.try_recv() {
Ok(_) => return true,
Err(async_channel::TryRecvError::Empty) => {}
Err(async_channel::TryRecvError::Closed) => return true,
}

// Handle the wrap around case
let probe_index = self.inner.probe_index.load(Ordering::Acquire);
if probe_index >= num_nodes {
drop(memberlist);
self.reset_nodes().await;
self.inner.probe_index.store(0, Ordering::Release);
num_check += 1;
continue;
}
let memberlist = self.inner.nodes.read().await;
let num_nodes = memberlist.nodes.len();
// Make sure we don't wrap around infinitely
if num_check >= num_nodes {
return false;
}

// Determine if we should probe this node
let mut skip = false;
let node = memberlist.nodes[probe_index].state.clone();
if node.dead_or_left() || node.id() == self.local_id() {
skip = true;
}
// Handle the wrap around case
let probe_index = self.inner.probe_index.load(Ordering::Acquire);
if probe_index >= num_nodes {
drop(memberlist);
self.reset_nodes().await;
self.inner.probe_index.store(0, Ordering::Release);
num_check += 1;
continue;
}

// Potentially skip
drop(memberlist);
self.inner.probe_index.store(probe_index + 1, Ordering::Release);
if skip {
num_check += 1;
continue;
}
// Determine if we should probe this node
let mut skip = false;
let node = memberlist.nodes[probe_index].state.clone();
if node.dead_or_left() || node.id() == self.local_id() {
skip = true;
}

// Probe the specific node
self.probe_node(&node).await;
return;
}
// Potentially skip
drop(memberlist);
self.inner.probe_index.fetch_add(1, Ordering::AcqRel);
if skip {
num_check += 1;
continue;
}

// Probe the specific node
self.probe_node(&node).await;
return false;
}
}

Expand Down Expand Up @@ -1166,9 +1173,9 @@ where

/// Invoked every GossipInterval period to broadcast our gossip
/// messages to a few random nodes.
async fn gossip(&self, shutdown_rx: &async_channel::Receiver<()>) {
async fn gossip(&self, shutdown_rx: &async_channel::Receiver<()>) -> bool {
futures::select_biased! {
_ = shutdown_rx.recv().fuse() => {},
_ = shutdown_rx.recv().fuse() => true,
default => {
#[cfg(feature = "metrics")]
let now = Instant::now();
Expand Down Expand Up @@ -1255,6 +1262,7 @@ where
fut.await
})
.await;
false
},
}
}
Expand Down
32 changes: 8 additions & 24 deletions memberlist/src/async_std.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use agnostic::async_std::AsyncStdRuntime;

/// Memberlist type alias for using [`NetTransport`](memberlist_net::NetTransport) and [`Tcp`](memberlist_net::stream_layer::tcp::Tcp) stream layer with `async_std` runtime.
#[cfg(all(
any(feature = "tcp", feature = "tls", feature = "native-tls"),
Expand All @@ -16,9 +14,9 @@ pub type AsyncStdTcpMemberlist<I, A, W, D> = memberlist_core::Memberlist<
memberlist_net::NetTransport<
I,
A,
memberlist_net::stream_layer::tcp::Tcp<AsyncStdRuntime>,
memberlist_net::stream_layer::tcp::Tcp<agnostic::async_std::AsyncStdRuntime>,
W,
AsyncStdRuntime,
agnostic::async_std::AsyncStdRuntime,
>,
D,
>;
Expand All @@ -30,9 +28,9 @@ pub type AsyncStdTlsMemberlist<I, A, W, D> = memberlist_core::Memberlist<
memberlist_net::NetTransport<
I,
A,
memberlist_net::stream_layer::tls::Tls<AsyncStdRuntime>,
memberlist_net::stream_layer::tls::Tls<agnostic::async_std::AsyncStdRuntime>,
W,
AsyncStdRuntime,
agnostic::async_std::AsyncStdRuntime,
>,
D,
>;
Expand All @@ -47,9 +45,9 @@ pub type AsyncStdNativeTlsMemberlist<I, A, W, D> = memberlist_core::Memberlist<
memberlist_net::NetTransport<
I,
A,
memberlist_net::stream_layer::native_tls::NativeTls<AsyncStdRuntime>,
memberlist_net::stream_layer::native_tls::NativeTls<agnostic::async_std::AsyncStdRuntime>,
W,
AsyncStdRuntime,
agnostic::async_std::AsyncStdRuntime,
>,
D,
>;
Expand All @@ -61,23 +59,9 @@ pub type AsyncStdQuicMemberlist<I, A, W, D> = memberlist_core::Memberlist<
memberlist_quic::QuicTransport<
I,
A,
memberlist_quic::stream_layer::quinn::Quinn<AsyncStdRuntime>,
W,
AsyncStdRuntime,
>,
D,
>;

/// Memberlist type alias for using [`QuicTransport`](memberlist_quic::QuicTransport) and [`S2n`](memberlist_quic::stream_layer::s2n::S2n) stream layer with `async_std` runtime.
#[cfg(all(feature = "s2n", not(target_family = "wasm")))]
#[cfg_attr(docsrs, doc(cfg(all(feature = "s2n", not(target_family = "wasm")))))]
pub type AsyncStdS2nMemberlist<I, A, W, D> = memberlist_core::Memberlist<
memberlist_quic::QuicTransport<
I,
A,
memberlist_quic::stream_layer::s2n::S2n<AsyncStdRuntime>,
memberlist_quic::stream_layer::quinn::Quinn<agnostic::async_std::AsyncStdRuntime>,
W,
AsyncStdRuntime,
agnostic::async_std::AsyncStdRuntime,
>,
D,
>;
32 changes: 8 additions & 24 deletions memberlist/src/smol.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use agnostic::smol::SmolRuntime;

/// Memberlist type alias for using [`NetTransport`](memberlist_net::NetTransport) and [`Tcp`](memberlist_net::stream_layer::tcp::Tcp) stream layer with `smol` runtime.
#[cfg(all(
any(feature = "tcp", feature = "tls", feature = "native-tls"),
Expand All @@ -16,9 +14,9 @@ pub type SmolTcpMemberlist<I, A, W, D> = memberlist_core::Memberlist<
memberlist_net::NetTransport<
I,
A,
memberlist_net::stream_layer::tcp::Tcp<SmolRuntime>,
memberlist_net::stream_layer::tcp::Tcp<agnostic::smol::SmolRuntime>,
W,
SmolRuntime,
agnostic::smol::SmolRuntime,
>,
D,
>;
Expand All @@ -30,9 +28,9 @@ pub type SmolTlsMemberlist<I, A, W, D> = memberlist_core::Memberlist<
memberlist_net::NetTransport<
I,
A,
memberlist_net::stream_layer::tls::Tls<SmolRuntime>,
memberlist_net::stream_layer::tls::Tls<agnostic::smol::SmolRuntime>,
W,
SmolRuntime,
agnostic::smol::SmolRuntime,
>,
D,
>;
Expand All @@ -47,9 +45,9 @@ pub type SmolNativeTlsMemberlist<I, A, W, D> = memberlist_core::Memberlist<
memberlist_net::NetTransport<
I,
A,
memberlist_net::stream_layer::native_tls::NativeTls<SmolRuntime>,
memberlist_net::stream_layer::native_tls::NativeTls<agnostic::smol::SmolRuntime>,
W,
SmolRuntime,
agnostic::smol::SmolRuntime,
>,
D,
>;
Expand All @@ -61,23 +59,9 @@ pub type SmolQuicMemberlist<I, A, W, D> = memberlist_core::Memberlist<
memberlist_quic::QuicTransport<
I,
A,
memberlist_quic::stream_layer::quinn::Quinn<SmolRuntime>,
W,
SmolRuntime,
>,
D,
>;

/// Memberlist type alias for using [`QuicTransport`](memberlist_quic::QuicTransport) and [`S2n`](memberlist_quic::stream_layer::s2n::S2n) stream layer with `smol` runtime.
#[cfg(all(feature = "s2n", not(target_family = "wasm")))]
#[cfg_attr(docsrs, doc(cfg(all(feature = "s2n", not(target_family = "wasm")))))]
pub type SmolS2nMemberlist<I, A, W, D> = memberlist_core::Memberlist<
memberlist_quic::QuicTransport<
I,
A,
memberlist_quic::stream_layer::s2n::S2n<SmolRuntime>,
memberlist_quic::stream_layer::quinn::Quinn<agnostic::smol::SmolRuntime>,
W,
SmolRuntime,
agnostic::smol::SmolRuntime,
>,
D,
>;
5 changes: 2 additions & 3 deletions transports/net/src/io/read_from_promised.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,17 +153,16 @@ where
// check if we should offload
let keys = enp.keys().await;
if encrypted_message_len <= self.opts.offload_size {
let buf = Self::decrypt(enp, encryption_algo, keys, stream_label.as_bytes(), buf)?;
let buf = Self::decrypt(encryption_algo, keys, stream_label.as_bytes(), buf)?;
let (_, msg) = W::decode_message(&buf).map_err(NetTransportError::Wire)?;
return Ok((readed, msg));
}

let (tx, rx) = futures::channel::oneshot::channel();
let enp = enp.clone();
rayon::spawn(move || {
if tx
.send(
Self::decrypt(&enp, encryption_algo, keys, stream_label.as_bytes(), buf)
Self::decrypt(encryption_algo, keys, stream_label.as_bytes(), buf)
.and_then(|b| W::decode_message(&b).map_err(NetTransportError::Wire)),
)
.is_err()
Expand Down
1 change: 1 addition & 0 deletions transports/net/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,7 @@ where

let mut handles = core::mem::take(&mut *self.handles.borrow_mut());
while handles.next().await.is_some() {}

Ok(())
}
}
Expand Down
25 changes: 6 additions & 19 deletions transports/net/src/packet_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -356,32 +356,19 @@ where
};
let keys = encryptor.keys().await;
if encrypted_message_size <= offload_size {
return Self::decrypt(
encryptor,
algo,
keys,
packet_label.as_bytes(),
&mut encrypted_message,
)
.and_then(|_| Self::read_from_packet_without_compression_and_encryption(encrypted_message));
return Self::decrypt(algo, keys, packet_label.as_bytes(), &mut encrypted_message).and_then(
|_| Self::read_from_packet_without_compression_and_encryption(encrypted_message),
);
}

let (tx, rx) = futures::channel::oneshot::channel();
let encryptor = encryptor.clone();

rayon::spawn(move || {
if tx
.send(
Self::decrypt(
&encryptor,
algo,
keys,
packet_label.as_bytes(),
&mut encrypted_message,
)
.and_then(|_| {
Self::read_from_packet_without_compression_and_encryption(encrypted_message)
}),
Self::decrypt(algo, keys, packet_label.as_bytes(), &mut encrypted_message).and_then(
|_| Self::read_from_packet_without_compression_and_encryption(encrypted_message),
),
)
.is_err()
{
Expand Down

0 comments on commit 8986f1c

Please sign in to comment.