diff --git a/scylla/src/transport/load_balancing/default.rs b/scylla/src/transport/load_balancing/default.rs index e2474cdd7a..d8c6dff9cc 100644 --- a/scylla/src/transport/load_balancing/default.rs +++ b/scylla/src/transport/load_balancing/default.rs @@ -2455,14 +2455,38 @@ mod latency_awareness { None => return Either::Left(fallback), // noop, as no latency data has been collected yet }; - Either::Right(IteratorWithSkippedNodes::new( - self.node_avgs.read().unwrap().deref(), - fallback, - self.exclusion_threshold, - self.retry_period, - self.minimum_measurements, - min_avg_latency, - )) + let average_latencies = self.node_avgs.read().unwrap(); + let targets = fallback; + + let mut fast_targets = vec![]; + let mut penalised_targets = vec![]; + + for node_and_shard @ (node, _shard) in targets { + match fast_enough( + average_latencies.deref(), + node.host_id, + self.exclusion_threshold, + self.retry_period, + self.minimum_measurements, + min_avg_latency, + ) { + FastEnough::Yes => fast_targets.push(node_and_shard), + FastEnough::No { average } => { + trace!("Latency awareness: Penalising node {{address={}, datacenter={:?}, rack={:?}}} for being on average at least {} times slower (latency: {}ms) than the fastest ({}ms).", + node.address, node.datacenter, node.rack, self.exclusion_threshold, average.as_millis(), min_avg_latency.as_millis()); + penalised_targets.push(node_and_shard); + } + } + } + + let mut fast_targets = fast_targets.into_iter(); + let mut penalised_targets = penalised_targets.into_iter(); + + let skipping_penalised_targets_iterator = std::iter::from_fn(move || { + fast_targets.next().or_else(|| penalised_targets.next()) + }); + + Either::Right(skipping_penalised_targets_iterator) } pub(super) fn report_query(&self, node: &Node, latency: Duration) { @@ -2768,71 +2792,6 @@ mod latency_awareness { } } - struct IteratorWithSkippedNodes<'a, Fast, Penalised> - where - Fast: Iterator, Option)>, - Penalised: Iterator, Option)>, - { - fast_nodes: Fast, - penalised_nodes: Penalised, - } - - impl<'a> - IteratorWithSkippedNodes< - 'a, - std::vec::IntoIter<(NodeRef<'a>, Option)>, - std::vec::IntoIter<(NodeRef<'a>, Option)>, - > - { - fn new( - average_latencies: &HashMap>>, - nodes: impl Iterator, Option)>, - exclusion_threshold: f64, - retry_period: Duration, - minimum_measurements: usize, - min_avg: Duration, - ) -> Self { - let mut fast_nodes = vec![]; - let mut penalised_nodes = vec![]; - - for node_and_shard @ (node, _shard) in nodes { - match fast_enough( - average_latencies, - node.host_id, - exclusion_threshold, - retry_period, - minimum_measurements, - min_avg, - ) { - FastEnough::Yes => fast_nodes.push(node_and_shard), - FastEnough::No { average } => { - trace!("Latency awareness: Penalising node {{address={}, datacenter={:?}, rack={:?}}} for being on average at least {} times slower (latency: {}ms) than the fastest ({}ms).", - node.address, node.datacenter, node.rack, exclusion_threshold, average.as_millis(), min_avg.as_millis()); - penalised_nodes.push(node_and_shard); - } - } - } - - Self { - fast_nodes: fast_nodes.into_iter(), - penalised_nodes: penalised_nodes.into_iter(), - } - } - } - - impl<'a, Fast, Penalised> Iterator for IteratorWithSkippedNodes<'a, Fast, Penalised> - where - Fast: Iterator, Option)>, - Penalised: Iterator, Option)>, - { - type Item = (NodeRef<'a>, Option); - - fn next(&mut self) -> Option { - self.fast_nodes - .next() - .or_else(|| self.penalised_nodes.next()) - } - } #[cfg(test)] mod tests { use scylla_cql::Consistency;