Skip to content

Commit

Permalink
default_lbp: replace boilerplate with from_iter()
Browse files Browse the repository at this point in the history
As there is a brilliant `std::iter::from_iter()` function that creates a
new iterator based on a closure, it can be used instead of verbose
boilerplate incurred by introducing IteratorWithSkippedNodes.
  • Loading branch information
wprzytula authored and Lorak-mmk committed Mar 27, 2024
1 parent 9156191 commit 53801ae
Showing 1 changed file with 32 additions and 73 deletions.
105 changes: 32 additions & 73 deletions scylla/src/transport/load_balancing/default.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2455,14 +2455,38 @@ mod latency_awareness {
None => return Either::Left(fallback), // noop, as no latency data has been collected yet
};

Either::Right(IteratorWithSkippedNodes::new(
self.node_avgs.read().unwrap().deref(),
fallback,
self.exclusion_threshold,
self.retry_period,
self.minimum_measurements,
min_avg_latency,
))
let average_latencies = self.node_avgs.read().unwrap();
let targets = fallback;

let mut fast_targets = vec![];
let mut penalised_targets = vec![];

for node_and_shard @ (node, _shard) in targets {
match fast_enough(
average_latencies.deref(),
node.host_id,
self.exclusion_threshold,
self.retry_period,
self.minimum_measurements,
min_avg_latency,
) {
FastEnough::Yes => fast_targets.push(node_and_shard),
FastEnough::No { average } => {
trace!("Latency awareness: Penalising node {{address={}, datacenter={:?}, rack={:?}}} for being on average at least {} times slower (latency: {}ms) than the fastest ({}ms).",
node.address, node.datacenter, node.rack, self.exclusion_threshold, average.as_millis(), min_avg_latency.as_millis());
penalised_targets.push(node_and_shard);
}
}
}

let mut fast_targets = fast_targets.into_iter();
let mut penalised_targets = penalised_targets.into_iter();

let skipping_penalised_targets_iterator = std::iter::from_fn(move || {
fast_targets.next().or_else(|| penalised_targets.next())
});

Either::Right(skipping_penalised_targets_iterator)
}

pub(super) fn report_query(&self, node: &Node, latency: Duration) {
Expand Down Expand Up @@ -2768,71 +2792,6 @@ mod latency_awareness {
}
}

struct IteratorWithSkippedNodes<'a, Fast, Penalised>
where
Fast: Iterator<Item = (NodeRef<'a>, Option<Shard>)>,
Penalised: Iterator<Item = (NodeRef<'a>, Option<Shard>)>,
{
fast_nodes: Fast,
penalised_nodes: Penalised,
}

impl<'a>
IteratorWithSkippedNodes<
'a,
std::vec::IntoIter<(NodeRef<'a>, Option<Shard>)>,
std::vec::IntoIter<(NodeRef<'a>, Option<Shard>)>,
>
{
fn new(
average_latencies: &HashMap<Uuid, RwLock<Option<TimestampedAverage>>>,
nodes: impl Iterator<Item = (NodeRef<'a>, Option<Shard>)>,
exclusion_threshold: f64,
retry_period: Duration,
minimum_measurements: usize,
min_avg: Duration,
) -> Self {
let mut fast_nodes = vec![];
let mut penalised_nodes = vec![];

for node_and_shard @ (node, _shard) in nodes {
match fast_enough(
average_latencies,
node.host_id,
exclusion_threshold,
retry_period,
minimum_measurements,
min_avg,
) {
FastEnough::Yes => fast_nodes.push(node_and_shard),
FastEnough::No { average } => {
trace!("Latency awareness: Penalising node {{address={}, datacenter={:?}, rack={:?}}} for being on average at least {} times slower (latency: {}ms) than the fastest ({}ms).",
node.address, node.datacenter, node.rack, exclusion_threshold, average.as_millis(), min_avg.as_millis());
penalised_nodes.push(node_and_shard);
}
}
}

Self {
fast_nodes: fast_nodes.into_iter(),
penalised_nodes: penalised_nodes.into_iter(),
}
}
}

impl<'a, Fast, Penalised> Iterator for IteratorWithSkippedNodes<'a, Fast, Penalised>
where
Fast: Iterator<Item = (NodeRef<'a>, Option<Shard>)>,
Penalised: Iterator<Item = (NodeRef<'a>, Option<Shard>)>,
{
type Item = (NodeRef<'a>, Option<Shard>);

fn next(&mut self) -> Option<Self::Item> {
self.fast_nodes
.next()
.or_else(|| self.penalised_nodes.next())
}
}
#[cfg(test)]
mod tests {
use scylla_cql::Consistency;
Expand Down

0 comments on commit 53801ae

Please sign in to comment.