Skip to content

Commit

Permalink
use AsyncWaitGroup instead of JoinHandle
Browse files Browse the repository at this point in the history
  • Loading branch information
al8n committed Apr 27, 2024
1 parent a11ad62 commit 0b43401
Show file tree
Hide file tree
Showing 13 changed files with 112 additions and 104 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ rustdoc-args = ["--cfg", "docsrs"]

[workspace.dependencies]
auto_impl = "1"
atomic_refcell = "0.1"
agnostic-lite = { version = "0.3", features = ["time"] }
agnostic = "0.3.5"
async-lock = "3"
Expand Down Expand Up @@ -57,6 +56,7 @@ transformable = { version = "0.1.6", features = ["smol_str", "bytes"] }
thiserror = "1"
tracing = "0.1"
viewit = "0.1.5"
wg = { version = "0.9", default-features = false, features = ["future", "std", "triomphe"] }

memberlist-core = { version = "0.2", path = "core", default-features = false }
memberlist-net = { version = "0.2", path = "transports/net", default-features = false }
Expand Down
4 changes: 2 additions & 2 deletions core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,10 @@ once_cell = "1.17"
rustix = { version = "0.38", features = ["system"] }

[target.'cfg(windows)'.dependencies]
hostname = "0.3"
hostname = "0.4"

[dependencies]
auto_impl.workspace = true
atomic_refcell.workspace = true
agnostic-lite.workspace = true
async-channel.workspace = true
async-lock.workspace = true
Expand All @@ -68,6 +67,7 @@ memberlist-types.workspace = true
thiserror.workspace = true
tracing.workspace = true
viewit.workspace = true
wg.workspace = true

base64 = { version = "0.22", optional = true }

Expand Down
30 changes: 12 additions & 18 deletions core/src/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@ use std::{
},
};

use agnostic_lite::{AsyncSpawner, RuntimeLite};
use agnostic_lite::RuntimeLite;
use async_channel::{Receiver, Sender};
use async_lock::{Mutex, RwLock};

use atomic_refcell::AtomicRefCell;
use futures::stream::FuturesUnordered;
use nodecraft::{resolver::AddressResolver, CheapClone, Node};
use wg::AsyncWaitGroup;

use super::{
awareness::Awareness,
Expand Down Expand Up @@ -284,9 +283,7 @@ where
pub(crate) leave_broadcast_tx: Sender<()>,
pub(crate) leave_lock: Mutex<()>,
pub(crate) leave_broadcast_rx: Receiver<()>,
pub(crate) handles: AtomicRefCell<
FuturesUnordered<<<T::Runtime as RuntimeLite>::Spawner as AsyncSpawner>::JoinHandle<()>>,
>,
pub(crate) wg: AsyncWaitGroup,
pub(crate) probe_index: AtomicUsize,
pub(crate) handoff_tx: Sender<()>,
pub(crate) handoff_rx: Receiver<()>,
Expand Down Expand Up @@ -416,7 +413,7 @@ where
leave_lock: Mutex::new(()),
leave_broadcast_rx,
probe_index: AtomicUsize::new(0),
handles: AtomicRefCell::new(FuturesUnordered::new()),
wg: AsyncWaitGroup::new(),
handoff_tx,
handoff_rx,
queue: Mutex::new(MessageQueue::new()),
Expand All @@ -431,12 +428,11 @@ where
};

{
let handles = this.inner.handles.borrow();
handles.push(this.stream_listener(shutdown_rx.clone()));
handles.push(this.packet_handler(shutdown_rx.clone()));
handles.push(this.packet_listener(shutdown_rx.clone()));
this.stream_listener(shutdown_rx.clone());
this.packet_handler(shutdown_rx.clone());
this.packet_listener(shutdown_rx.clone());
#[cfg(feature = "metrics")]
handles.push(this.check_broadcast_queue_depth(shutdown_rx.clone()));
this.check_broadcast_queue_depth(shutdown_rx.clone());
}

Ok((shutdown_rx, this.inner.advertise.cheap_clone(), this))
Expand Down Expand Up @@ -468,16 +464,14 @@ where
}

#[cfg(feature = "metrics")]
fn check_broadcast_queue_depth(
&self,
shutdown_rx: Receiver<()>,
) -> <<T::Runtime as RuntimeLite>::Spawner as AsyncSpawner>::JoinHandle<()> {
fn check_broadcast_queue_depth(&self, shutdown_rx: Receiver<()>) {
use futures::{FutureExt, StreamExt};

let queue_check_interval = self.inner.opts.queue_check_interval;
let this = self.clone();

<T::Runtime as RuntimeLite>::spawn(async move {
let wg = this.inner.wg.add(1);
<T::Runtime as RuntimeLite>::spawn_detach(async move {
scopeguard::defer!(wg.done(););
let tick = <T::Runtime as RuntimeLite>::interval(queue_check_interval);
futures::pin_mut!(tick);
loop {
Expand Down
13 changes: 5 additions & 8 deletions core/src/network/packet/handler.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
use crate::base::MessageHandoff;

use agnostic_lite::AsyncSpawner;

use super::*;

impl<D, T> Memberlist<T, D>
Expand All @@ -12,13 +10,12 @@ where
/// a long running thread that processes messages received
/// over the packet interface, but is decoupled from the listener to avoid
/// blocking the listener which may cause ping/ack messages to be delayed.
pub(crate) fn packet_handler(
&self,
shutdown_rx: async_channel::Receiver<()>,
) -> <<T::Runtime as RuntimeLite>::Spawner as AsyncSpawner>::JoinHandle<()> {
pub(crate) fn packet_handler(&self, shutdown_rx: async_channel::Receiver<()>) {
let this = self.clone();
let handoff_rx = this.inner.handoff_rx.clone();
<T::Runtime as RuntimeLite>::spawn(async move {
let wg = this.inner.wg.add(1);
<T::Runtime as RuntimeLite>::spawn_detach(async move {
scopeguard::defer!(wg.done(););
loop {
futures::select! {
_ = shutdown_rx.recv().fuse() => {
Expand All @@ -38,7 +35,7 @@ where
}
}
}
})
});
}

/// Returns the next message to process in priority order, using LIFO
Expand Down
10 changes: 4 additions & 6 deletions core/src/network/packet/listener.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use crate::{base::MessageHandoff, transport::Wire};
use agnostic_lite::AsyncSpawner;
use either::Either;

use super::*;
Expand Down Expand Up @@ -35,13 +34,12 @@ where
D: Delegate<Id = T::Id, Address = <T::Resolver as AddressResolver>::ResolvedAddress>,
T: Transport,
{
pub(crate) fn packet_listener(
&self,
shutdown_rx: async_channel::Receiver<()>,
) -> <<T::Runtime as RuntimeLite>::Spawner as AsyncSpawner>::JoinHandle<()> {
pub(crate) fn packet_listener(&self, shutdown_rx: async_channel::Receiver<()>) {
let this = self.clone();
let packet_rx = this.inner.transport.packet();
<T::Runtime as RuntimeLite>::spawn(async move {
let wg = this.inner.wg.add(1);
<T::Runtime as RuntimeLite>::spawn_detach(async move {
scopeguard::defer!(wg.done(););
'outer: loop {
futures::select! {
_ = shutdown_rx.recv().fuse() => {
Expand Down
12 changes: 5 additions & 7 deletions core/src/network/stream.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use std::sync::Arc;

use agnostic_lite::AsyncSpawner;
use smol_str::SmolStr;

use crate::delegate::DelegateError;
Expand All @@ -15,13 +14,12 @@ where
{
/// A long running thread that pulls incoming streams from the
/// transport and hands them off for processing.
pub(crate) fn stream_listener(
&self,
shutdown_rx: async_channel::Receiver<()>,
) -> <<T::Runtime as RuntimeLite>::Spawner as AsyncSpawner>::JoinHandle<()> {
pub(crate) fn stream_listener(&self, shutdown_rx: async_channel::Receiver<()>) {
let this = self.clone();
let transport_rx = this.inner.transport.stream();
<T::Runtime as RuntimeLite>::spawn(async move {
let wg = this.inner.wg.add(1);
<T::Runtime as RuntimeLite>::spawn_detach(async move {
scopeguard::defer!(wg.done(););
tracing::debug!("memberlist: stream listener start");
loop {
futures::select! {
Expand Down Expand Up @@ -49,7 +47,7 @@ where
}
}
}
})
});
}

/// Used to merge the remote state with our local state
Expand Down
56 changes: 26 additions & 30 deletions core/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use super::{
Member, Members,
};

use agnostic_lite::{AsyncSpawner, RuntimeLite};
use agnostic_lite::RuntimeLite;

use futures::{stream::FuturesUnordered, FutureExt, StreamExt};
use nodecraft::{resolver::AddressResolver, CheapClone, Node};
Expand Down Expand Up @@ -635,12 +635,14 @@ where
macro_rules! bail_trigger {
($fn:ident) => {
paste::paste! {
async fn [<trigger _ $fn>](&self, stagger: Duration, interval: Duration, stop_rx: async_channel::Receiver<()>) -> <<T::Runtime as RuntimeLite>::Spawner as AsyncSpawner>::JoinHandle<()>
async fn [<trigger _ $fn>](&self, stagger: Duration, interval: Duration, stop_rx: async_channel::Receiver<()>)
{
let this = self.clone();
// Use a random stagger to avoid syncronizing
let rand_stagger = random_stagger(stagger);
<T::Runtime as RuntimeLite>::spawn(async move {
let wg = this.inner.wg.add(1);
<T::Runtime as RuntimeLite>::spawn_detach(async move {
scopeguard::defer!(wg.done(););
let delay = <T::Runtime as RuntimeLite>::sleep(rand_stagger);

futures::select! {
Expand All @@ -667,7 +669,7 @@ macro_rules! bail_trigger {
}

tracing::debug!(concat!("memberlist.state: ", stringify!($fn), " trigger exits"));
})
});
}
}
};
Expand All @@ -680,54 +682,48 @@ where
{
/// Used to ensure the Tick is performed periodically.
pub(crate) async fn schedule(&self, shutdown_rx: async_channel::Receiver<()>) {
let handles = self.inner.handles.borrow();
// Create a new probeTicker
if self.inner.opts.probe_interval > Duration::ZERO {
handles.push(
self
.trigger_probe(
self.inner.opts.probe_interval,
self.inner.opts.probe_interval,
shutdown_rx.clone(),
)
.await,
);
self
.trigger_probe(
self.inner.opts.probe_interval,
self.inner.opts.probe_interval,
shutdown_rx.clone(),
)
.await;
}

// Create a push pull ticker if needed
if self.inner.opts.push_pull_interval > Duration::ZERO {
handles.push(self.trigger_push_pull(shutdown_rx.clone()).await);
self.trigger_push_pull(shutdown_rx.clone()).await;
}

// Create a gossip ticker if needed
if self.inner.opts.gossip_interval > Duration::ZERO && self.inner.opts.gossip_nodes > 0 {
handles.push(
self
.trigger_gossip(
self.inner.opts.gossip_interval,
self.inner.opts.gossip_interval,
shutdown_rx.clone(),
)
.await,
);
self
.trigger_gossip(
self.inner.opts.gossip_interval,
self.inner.opts.gossip_interval,
shutdown_rx.clone(),
)
.await;
}
}

bail_trigger!(probe);

bail_trigger!(gossip);

async fn trigger_push_pull(
&self,
stop_rx: async_channel::Receiver<()>,
) -> <<T::Runtime as RuntimeLite>::Spawner as AsyncSpawner>::JoinHandle<()> {
async fn trigger_push_pull(&self, stop_rx: async_channel::Receiver<()>) {
let interval = self.inner.opts.push_pull_interval;
let this = self.clone();
let wg = this.inner.wg.add(1);
// Use a random stagger to avoid syncronizing
let mut rng = rand::thread_rng();
let rand_stagger = Duration::from_millis(rng.gen_range(0..interval.as_millis() as u64));

<T::Runtime as RuntimeLite>::spawn(async move {
<T::Runtime as RuntimeLite>::spawn_detach(async move {
scopeguard::defer!(wg.done(););
futures::select! {
_ = <T::Runtime as RuntimeLite>::sleep(rand_stagger).fuse() => {},
_ = stop_rx.recv().fuse() => {
Expand All @@ -750,7 +746,7 @@ where
},
}
}
})
});
}

// Used to perform a single round of failure detection and gossip
Expand Down
4 changes: 2 additions & 2 deletions transports/net/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ dnssec = ["dns", "nodecraft/dnssec"]
getrandom = { version = "0.2", features = ["js"] }

[dependencies]
atomic_refcell.workspace = true
agnostic.workspace = true
async-channel.workspace = true
async-lock.workspace = true
Expand All @@ -83,9 +82,10 @@ memberlist-core.workspace = true
thiserror.workspace = true
tracing.workspace = true
viewit.workspace = true
wg.workspace = true

# tls
futures-rustls = { version = "0.25", optional = true }
futures-rustls = { version = "0.26", optional = true }

# native-tls
async-native-tls = { version = "0.5", optional = true }
Expand Down
Loading

0 comments on commit 0b43401

Please sign in to comment.