Skip to content

Commit

Permalink
refactor: use bounded channel
Browse files Browse the repository at this point in the history
  • Loading branch information
WenyXu committed May 30, 2024
1 parent 6d1f8c9 commit d6494f5
Showing 1 changed file with 16 additions and 10 deletions.
26 changes: 16 additions & 10 deletions src/mito2/src/wal/entry_distributor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use snafu::ensure;
use store_api::logstore::entry::Entry;
use store_api::logstore::provider::Provider;
use store_api::storage::RegionId;
use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender};
use tokio::sync::mpsc::{self, Receiver, Sender, UnboundedReceiver, UnboundedSender};
use tokio::sync::oneshot;
use tokio_stream::StreamExt;

Expand All @@ -38,7 +38,7 @@ pub(crate) struct WalEntryDistributor {
raw_wal_reader: Arc<dyn RawEntryReader>,
provider: Provider,
/// Sends [Entry] to receivers based on [RegionId]
senders: HashMap<RegionId, UnboundedSender<Entry>>,
senders: HashMap<RegionId, Sender<Entry>>,
/// Waits for the arg from the [WalEntryReader].
arg_receivers: Vec<(RegionId, oneshot::Receiver<EntryId>)>,
}
Expand Down Expand Up @@ -75,7 +75,7 @@ impl WalEntryDistributor {
for (region_id, start_id) in args {
subscribers.insert(
region_id,
Receiver {
EntryReceiver {
start_id,
sender: self.senders[&region_id].clone(),
},
Expand All @@ -89,9 +89,9 @@ impl WalEntryDistributor {
let entry_id = entry.entry_id();
let region_id = entry.region_id();

if let Some(Receiver { sender, start_id }) = subscribers.get(&region_id) {
if let Some(EntryReceiver { sender, start_id }) = subscribers.get(&region_id) {
if entry_id >= *start_id {
if let Err(err) = sender.send(entry) {
if let Err(err) = sender.send(entry).await {
error!(err; "Failed to distribute raw entry, entry_id:{}, region_id: {}", entry_id, region_id);
}
}
Expand All @@ -109,15 +109,15 @@ impl WalEntryDistributor {
pub(crate) struct WalEntryReceiver {
region_id: RegionId,
/// Receives the [Entry] from the [WalEntryDistributor].
entry_receiver: UnboundedReceiver<Entry>,
entry_receiver: Receiver<Entry>,
/// Sends the `start_id` to the [WalEntryDistributor].
arg_sender: oneshot::Sender<EntryId>,
}

impl WalEntryReceiver {
pub fn new(
region_id: RegionId,
entry_receiver: UnboundedReceiver<Entry>,
entry_receiver: Receiver<Entry>,
arg_sender: oneshot::Sender<EntryId>,
) -> Self {
Self {
Expand Down Expand Up @@ -170,9 +170,9 @@ impl WalEntryReader for WalEntryReceiver {
}
}

struct Receiver {
struct EntryReceiver {
start_id: EntryId,
sender: UnboundedSender<Entry>,
sender: Sender<Entry>,
}

/// Returns [WalEntryDistributor] and batch [WalEntryReceiver]s.
Expand All @@ -195,13 +195,14 @@ pub fn build_wal_entry_distributor_and_receivers(
provider: Provider,
raw_wal_reader: Arc<dyn RawEntryReader>,
region_ids: Vec<RegionId>,
buffer_size: usize,
) -> (WalEntryDistributor, Vec<WalEntryReceiver>) {
let mut senders = HashMap::with_capacity(region_ids.len());
let mut readers = Vec::with_capacity(region_ids.len());
let mut arg_receivers = Vec::with_capacity(region_ids.len());

for region_id in region_ids {
let (entry_sender, entry_receiver) = mpsc::unbounded_channel();
let (entry_sender, entry_receiver) = mpsc::channel(buffer_size);
let (arg_sender, arg_receiver) = oneshot::channel();

senders.insert(region_id, entry_sender);
Expand Down Expand Up @@ -266,6 +267,7 @@ mod tests {
provider,
reader,
vec![RegionId::new(1024, 1), RegionId::new(1025, 1)],
128,
);

// Drops all receivers
Expand Down Expand Up @@ -329,6 +331,7 @@ mod tests {
RegionId::new(1024, 2),
RegionId::new(1024, 3),
],
128,
);
assert_eq!(receivers.len(), 3);

Expand Down Expand Up @@ -434,6 +437,7 @@ mod tests {
provider.clone(),
Arc::new(corrupted_stream),
vec![region1, region2, region3],
128,
);
assert_eq!(receivers.len(), 3);
let mut streams = receivers
Expand Down Expand Up @@ -516,6 +520,7 @@ mod tests {
provider.clone(),
Arc::new(corrupted_stream),
vec![region1, region2],
128,
);
assert_eq!(receivers.len(), 2);
let mut streams = receivers
Expand Down Expand Up @@ -607,6 +612,7 @@ mod tests {
provider.clone(),
reader,
vec![RegionId::new(1024, 1), RegionId::new(1024, 2)],
128,
);
assert_eq!(receivers.len(), 2);
let mut streams = receivers
Expand Down

0 comments on commit d6494f5

Please sign in to comment.