From dbc52fa730224b89038e3e1d4ef12973e88da8d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karol=20Bary=C5=82a?= Date: Fri, 1 Mar 2024 16:49:56 +0100 Subject: [PATCH] Make LoadBalancingPolicy shard-aware MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Wojciech Przytuła --- examples/Cargo.toml | 1 + examples/custom_load_balancing_policy.rs | 23 +- .../src/transport/load_balancing/default.rs | 212 +++++++++++------- scylla/src/transport/load_balancing/mod.rs | 10 +- scylla/src/transport/load_balancing/plan.rs | 47 ++-- scylla/tests/integration/consistency.rs | 4 +- .../tests/integration/execution_profiles.rs | 9 +- scylla/tests/integration/utils.rs | 15 +- 8 files changed, 206 insertions(+), 115 deletions(-) diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 467963f93f..b068ee9e3c 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -20,6 +20,7 @@ uuid = "1.0" tower = "0.4" stats_alloc = "0.1" clap = { version = "3.2.4", features = ["derive"] } +rand = "0.8.5" [[example]] name = "auth" diff --git a/examples/custom_load_balancing_policy.rs b/examples/custom_load_balancing_policy.rs index 523d093efe..fb1ae0cb7c 100644 --- a/examples/custom_load_balancing_policy.rs +++ b/examples/custom_load_balancing_policy.rs @@ -1,20 +1,37 @@ use anyhow::Result; +use rand::thread_rng; +use rand::Rng; use scylla::transport::NodeRef; use scylla::{ load_balancing::{LoadBalancingPolicy, RoutingInfo}, + routing::Shard, transport::{ClusterData, ExecutionProfile}, Session, SessionBuilder, }; use std::{env, sync::Arc}; /// Example load balancing policy that prefers nodes from favorite datacenter +/// This is, of course, very naive, as it is completely non token-aware. +/// For more realistic implementation, see [`DefaultPolicy`](scylla::load_balancing::DefaultPolicy). #[derive(Debug)] struct CustomLoadBalancingPolicy { fav_datacenter_name: String, } +fn with_random_shard(node: NodeRef) -> (NodeRef, Shard) { + let nr_shards = node + .sharder() + .map(|sharder| sharder.nr_shards.get()) + .unwrap_or(1); + (node, thread_rng().gen_range(0..nr_shards) as Shard) +} + impl LoadBalancingPolicy for CustomLoadBalancingPolicy { - fn pick<'a>(&'a self, _info: &'a RoutingInfo, cluster: &'a ClusterData) -> Option> { + fn pick<'a>( + &'a self, + _info: &'a RoutingInfo, + cluster: &'a ClusterData, + ) -> Option<(NodeRef<'a>, Shard)> { self.fallback(_info, cluster).next() } @@ -28,9 +45,9 @@ impl LoadBalancingPolicy for CustomLoadBalancingPolicy { .unique_nodes_in_datacenter_ring(&self.fav_datacenter_name); match fav_dc_nodes { - Some(nodes) => Box::new(nodes.iter()), + Some(nodes) => Box::new(nodes.iter().map(with_random_shard)), // If there is no dc with provided name, fallback to other datacenters - None => Box::new(cluster.get_nodes_info().iter()), + None => Box::new(cluster.get_nodes_info().iter().map(with_random_shard)), } } diff --git a/scylla/src/transport/load_balancing/default.rs b/scylla/src/transport/load_balancing/default.rs index a03b7efd56..25460ee636 100644 --- a/scylla/src/transport/load_balancing/default.rs +++ b/scylla/src/transport/load_balancing/default.rs @@ -3,7 +3,7 @@ pub use self::latency_awareness::LatencyAwarenessBuilder; use super::{FallbackPlan, LoadBalancingPolicy, NodeRef, RoutingInfo}; use crate::{ - routing::Token, + routing::{Shard, Token}, transport::{cluster::ClusterData, locator::ReplicaSet, node::Node, topology::Strategy}, }; use itertools::{Either, Itertools}; @@ -70,13 +70,14 @@ enum StatementType { /// It can be configured to be datacenter-aware and token-aware. /// Datacenter failover for queries with non local consistency mode is also supported. /// Latency awareness is available, althrough not recommended. +#[allow(clippy::type_complexity)] pub struct DefaultPolicy { preferences: NodeLocationPreference, is_token_aware: bool, permit_dc_failover: bool, - pick_predicate: Box bool + Send + Sync>, + pick_predicate: Box, Shard)) -> bool + Send + Sync>, latency_awareness: Option, - fixed_shuffle_seed: Option, + fixed_seed: Option, } impl fmt::Debug for DefaultPolicy { @@ -86,13 +87,17 @@ impl fmt::Debug for DefaultPolicy { .field("is_token_aware", &self.is_token_aware) .field("permit_dc_failover", &self.permit_dc_failover) .field("latency_awareness", &self.latency_awareness) - .field("fixed_shuffle_seed", &self.fixed_shuffle_seed) + .field("fixed_shuffle_seed", &self.fixed_seed) .finish_non_exhaustive() } } impl LoadBalancingPolicy for DefaultPolicy { - fn pick<'a>(&'a self, query: &'a RoutingInfo, cluster: &'a ClusterData) -> Option> { + fn pick<'a>( + &'a self, + query: &'a RoutingInfo, + cluster: &'a ClusterData, + ) -> Option<(NodeRef<'a>, Shard)> { let routing_info = self.routing_info(query, cluster); if let Some(ref token_with_strategy) = routing_info.token_with_strategy { if self.preferences.datacenter().is_some() @@ -177,7 +182,7 @@ or refrain from preferring datacenters (which may ban all other datacenters, if &self.pick_predicate, NodeLocationCriteria::DatacenterAndRack(dc, rack), ); - let local_rack_picked = Self::pick_node(nodes, rack_predicate); + let local_rack_picked = self.pick_node(nodes, rack_predicate); if let Some(alive_local_rack) = local_rack_picked { return Some(alive_local_rack); @@ -185,14 +190,14 @@ or refrain from preferring datacenters (which may ban all other datacenters, if } // Try to pick some alive local random node. - if let Some(alive_local) = Self::pick_node(nodes, &self.pick_predicate) { + if let Some(alive_local) = self.pick_node(nodes, &self.pick_predicate) { return Some(alive_local); } let all_nodes = cluster.replica_locator().unique_nodes_in_global_ring(); // If a datacenter failover is possible, loosen restriction about locality. if self.is_datacenter_failover_possible(&routing_info) { - let picked = Self::pick_node(all_nodes, &self.pick_predicate); + let picked = self.pick_node(all_nodes, &self.pick_predicate); if let Some(alive_maybe_remote) = picked { return Some(alive_maybe_remote); } @@ -200,21 +205,21 @@ or refrain from preferring datacenters (which may ban all other datacenters, if // Previous checks imply that every node we could have selected is down. // Let's try to return a down node that wasn't disabled. - let picked = Self::pick_node(nodes, |node| node.is_enabled()); + let picked = self.pick_node(nodes, |(node, _shard)| node.is_enabled()); if let Some(down_but_enabled_local_node) = picked { return Some(down_but_enabled_local_node); } // If a datacenter failover is possible, loosen restriction about locality. if self.is_datacenter_failover_possible(&routing_info) { - let picked = Self::pick_node(all_nodes, |node| node.is_enabled()); + let picked = self.pick_node(all_nodes, |(node, _shard)| node.is_enabled()); if let Some(down_but_enabled_maybe_remote_node) = picked { return Some(down_but_enabled_maybe_remote_node); } } // Every node is disabled. This could be due to a bad host filter - configuration error. - nodes.first() + nodes.first().map(|node| self.with_random_shard(node)) } fn fallback<'a>( @@ -285,7 +290,7 @@ or refrain from preferring datacenters (which may ban all other datacenters, if .chain(maybe_remote_replicas), ) } else { - Either::Right(std::iter::empty::>()) + Either::Right(std::iter::empty::<(NodeRef<'a>, Shard)>()) }; // Get a list of all local alive nodes, and apply a round robin to it @@ -297,29 +302,37 @@ or refrain from preferring datacenters (which may ban all other datacenters, if &self.pick_predicate, NodeLocationCriteria::DatacenterAndRack(dc, rack), ); - Either::Left(Self::round_robin_nodes(local_nodes, rack_predicate)) + Either::Left(self.round_robin_nodes_with_shards(local_nodes, rack_predicate)) } else { - Either::Right(std::iter::empty::>()) + Either::Right(std::iter::empty::<(NodeRef<'a>, Shard)>()) }; - let robined_local_nodes = Self::round_robin_nodes(local_nodes, Self::is_alive); + let robined_local_nodes = self.round_robin_nodes_with_shards(local_nodes, Self::is_alive); let all_nodes = cluster.replica_locator().unique_nodes_in_global_ring(); // If a datacenter failover is possible, loosen restriction about locality. let maybe_remote_nodes = if self.is_datacenter_failover_possible(&routing_info) { - let robined_all_nodes = Self::round_robin_nodes(all_nodes, Self::is_alive); + let robined_all_nodes = self.round_robin_nodes_with_shards(all_nodes, Self::is_alive); Either::Left(robined_all_nodes) } else { - Either::Right(std::iter::empty::>()) + Either::Right(std::iter::empty::<(NodeRef<'a>, Shard)>()) }; // Even if we consider some enabled nodes to be down, we should try contacting them in the last resort. - let maybe_down_local_nodes = local_nodes.iter().filter(|node| node.is_enabled()); + let maybe_down_local_nodes = local_nodes + .iter() + .filter(|node| node.is_enabled()) + .map(|node| self.with_random_shard(node)); // If a datacenter failover is possible, loosen restriction about locality. let maybe_down_nodes = if self.is_datacenter_failover_possible(&routing_info) { - Either::Left(all_nodes.iter().filter(|node| node.is_enabled())) + Either::Left( + all_nodes + .iter() + .filter(|node| node.is_enabled()) + .map(|node| self.with_random_shard(node)), + ) } else { Either::Right(std::iter::empty()) }; @@ -420,13 +433,15 @@ impl DefaultPolicy { /// Wraps the provided predicate, adding the requirement for rack to match. fn make_rack_predicate<'a>( - predicate: impl Fn(&NodeRef<'a>) -> bool + 'a, + predicate: impl Fn(&(NodeRef<'a>, Shard)) -> bool + 'a, replica_location: NodeLocationCriteria<'a>, - ) -> impl Fn(&NodeRef<'a>) -> bool { - move |node| match replica_location { - NodeLocationCriteria::Any | NodeLocationCriteria::Datacenter(_) => predicate(node), + ) -> impl Fn(&(NodeRef<'a>, Shard)) -> bool { + move |node_and_shard @ (node, _shard)| match replica_location { + NodeLocationCriteria::Any | NodeLocationCriteria::Datacenter(_) => { + predicate(node_and_shard) + } NodeLocationCriteria::DatacenterAndRack(_, rack) => { - predicate(node) && node.rack.as_deref() == Some(rack) + predicate(node_and_shard) && node.rack.as_deref() == Some(rack) } } } @@ -435,10 +450,10 @@ impl DefaultPolicy { &'a self, ts: &TokenWithStrategy<'a>, replica_location: NodeLocationCriteria<'a>, - predicate: impl Fn(&NodeRef<'a>) -> bool + 'a, + predicate: impl Fn(&(NodeRef<'a>, Shard)) -> bool + 'a, cluster: &'a ClusterData, order: ReplicaOrder, - ) -> impl Iterator> { + ) -> impl Iterator, Shard)> { let predicate = Self::make_rack_predicate(predicate, replica_location); let replica_iter = match order { @@ -452,19 +467,17 @@ impl DefaultPolicy { .into_iter(), ), }; - replica_iter - .filter(move |(node, _shard)| predicate(node)) - .map(|(node, _shard)| node) + replica_iter.filter(move |node_and_shard: &(NodeRef<'a>, Shard)| predicate(node_and_shard)) } fn pick_replica<'a>( &'a self, ts: &TokenWithStrategy<'a>, replica_location: NodeLocationCriteria<'a>, - predicate: &'a impl Fn(&NodeRef<'a>) -> bool, + predicate: &'a impl Fn(&(NodeRef<'a>, Shard)) -> bool, cluster: &'a ClusterData, statement_type: StatementType, - ) -> Option> { + ) -> Option<(NodeRef<'a>, Shard)> { match statement_type { StatementType::Lwt => self.pick_first_replica(ts, replica_location, predicate, cluster), StatementType::NonLwt => { @@ -489,9 +502,9 @@ impl DefaultPolicy { &'a self, ts: &TokenWithStrategy<'a>, replica_location: NodeLocationCriteria<'a>, - predicate: &'a impl Fn(&NodeRef<'a>) -> bool, + predicate: &'a impl Fn(&(NodeRef<'a>, Shard)) -> bool, cluster: &'a ClusterData, - ) -> Option> { + ) -> Option<(NodeRef<'a>, Shard)> { match replica_location { NodeLocationCriteria::Any => { // ReplicaSet returned by ReplicaLocator for this case: @@ -508,7 +521,7 @@ impl DefaultPolicy { .into_replicas_ordered() .into_iter() .next() - .and_then(|(primary_replica, _shard)| { + .and_then(|primary_replica| { predicate(&primary_replica).then_some(primary_replica) }) } @@ -521,7 +534,7 @@ impl DefaultPolicy { self.replicas( ts, replica_location, - move |node| predicate(node), + move |node_and_shard| predicate(node_and_shard), cluster, ReplicaOrder::RingOrder, ) @@ -534,22 +547,18 @@ impl DefaultPolicy { &'a self, ts: &TokenWithStrategy<'a>, replica_location: NodeLocationCriteria<'a>, - predicate: &'a impl Fn(&NodeRef<'a>) -> bool, + predicate: &'a impl Fn(&(NodeRef<'a>, Shard)) -> bool, cluster: &'a ClusterData, - ) -> Option> { + ) -> Option<(NodeRef<'a>, Shard)> { let predicate = Self::make_rack_predicate(predicate, replica_location); let replica_set = self.nonfiltered_replica_set(ts, replica_location, cluster); - if let Some(fixed) = self.fixed_shuffle_seed { + if let Some(fixed) = self.fixed_seed { let mut gen = Pcg32::new(fixed, 0); - replica_set - .choose_filtered(&mut gen, |(node, _shard)| predicate(node)) - .map(|(node, _shard)| node) + replica_set.choose_filtered(&mut gen, predicate) } else { - replica_set - .choose_filtered(&mut thread_rng(), |(node, _shard)| predicate(node)) - .map(|(node, _shard)| node) + replica_set.choose_filtered(&mut thread_rng(), predicate) } } @@ -557,10 +566,10 @@ impl DefaultPolicy { &'a self, ts: &TokenWithStrategy<'a>, replica_location: NodeLocationCriteria<'a>, - predicate: impl Fn(&NodeRef<'a>) -> bool + 'a, + predicate: impl Fn(&(NodeRef<'_>, Shard)) -> bool + 'a, cluster: &'a ClusterData, statement_type: StatementType, - ) -> impl Iterator> { + ) -> impl Iterator, Shard)> { let order = match statement_type { StatementType::Lwt => ReplicaOrder::RingOrder, StatementType::NonLwt => ReplicaOrder::Arbitrary, @@ -593,27 +602,33 @@ impl DefaultPolicy { } fn pick_node<'a>( + &'a self, nodes: &'a [Arc], - predicate: impl Fn(&NodeRef<'a>) -> bool, - ) -> Option> { + predicate: impl Fn(&(NodeRef<'a>, Shard)) -> bool, + ) -> Option<(NodeRef<'_>, Shard)> { // Select the first node that matches the predicate - Self::randomly_rotated_nodes(nodes).find(predicate) + Self::randomly_rotated_nodes(nodes) + .map(|node| self.with_random_shard(node)) + .find(predicate) } - fn round_robin_nodes<'a>( + fn round_robin_nodes_with_shards<'a>( + &'a self, nodes: &'a [Arc], - predicate: impl Fn(&NodeRef<'a>) -> bool, - ) -> impl Iterator> { - Self::randomly_rotated_nodes(nodes).filter(predicate) + predicate: impl Fn(&(NodeRef<'a>, Shard)) -> bool, + ) -> impl Iterator, Shard)> { + Self::randomly_rotated_nodes(nodes) + .map(|node| self.with_random_shard(node)) + .filter(predicate) } fn shuffle<'a>( &self, - iter: impl Iterator>, - ) -> impl Iterator> { - let mut vec: Vec> = iter.collect(); + iter: impl Iterator, Shard)>, + ) -> impl Iterator, Shard)> { + let mut vec: Vec<(NodeRef<'_>, Shard)> = iter.collect(); - if let Some(fixed) = self.fixed_shuffle_seed { + if let Some(fixed) = self.fixed_seed { let mut gen = Pcg32::new(fixed, 0); vec.shuffle(&mut gen); } else { @@ -623,7 +638,23 @@ impl DefaultPolicy { vec.into_iter() } - fn is_alive(node: &NodeRef<'_>) -> bool { + fn with_random_shard<'a>(&self, node: NodeRef<'a>) -> (NodeRef<'a>, Shard) { + let nr_shards = node + .sharder() + .map(|sharder| sharder.nr_shards.get()) + .unwrap_or(1); + ( + node, + (if let Some(fixed) = self.fixed_seed { + let mut gen = Pcg32::new(fixed, 0); + gen.gen_range(0..nr_shards) + } else { + thread_rng().gen_range(0..nr_shards) + }) as Shard, + ) + } + + fn is_alive(&(node, _shard): &(NodeRef<'_>, Shard)) -> bool { // For now, we leave this as stub, until we have time to improve node events. // node.is_enabled() && !node.is_down() node.is_enabled() @@ -644,7 +675,7 @@ impl Default for DefaultPolicy { permit_dc_failover: false, pick_predicate: Box::new(Self::is_alive), latency_awareness: None, - fixed_shuffle_seed: None, + fixed_seed: None, } } } @@ -689,8 +720,11 @@ impl DefaultPolicyBuilder { let latency_awareness = self.latency_awareness.map(|builder| builder.build()); let pick_predicate = if let Some(ref latency_awareness) = latency_awareness { let latency_predicate = latency_awareness.generate_predicate(); - Box::new(move |node: &NodeRef| DefaultPolicy::is_alive(node) && latency_predicate(node)) - as Box bool + Send + Sync + 'static> + Box::new( + move |node_and_shard @ (node, _shard): &(NodeRef<'_>, Shard)| { + DefaultPolicy::is_alive(node_and_shard) && latency_predicate(node) + }, + ) as Box, Shard)) -> bool + Send + Sync + 'static> } else { Box::new(DefaultPolicy::is_alive) }; @@ -701,7 +735,7 @@ impl DefaultPolicyBuilder { permit_dc_failover: self.permit_dc_failover, pick_predicate, latency_awareness, - fixed_shuffle_seed: (!self.enable_replica_shuffle).then(rand::random), + fixed_seed: (!self.enable_replica_shuffle).then(rand::random), }) } @@ -1245,7 +1279,7 @@ mod tests { preferences: NodeLocationPreference::Datacenter("eu".to_owned()), is_token_aware: true, permit_dc_failover: true, - fixed_shuffle_seed: Some(123), + fixed_seed: Some(123), ..Default::default() }, routing_info: RoutingInfo { @@ -1338,7 +1372,7 @@ mod tests { preferences: NodeLocationPreference::Datacenter("eu".to_owned()), is_token_aware: true, permit_dc_failover: true, - fixed_shuffle_seed: Some(123), + fixed_seed: Some(123), ..Default::default() }, routing_info: RoutingInfo { @@ -1584,7 +1618,7 @@ mod tests { ), is_token_aware: true, permit_dc_failover: false, - fixed_shuffle_seed: Some(123), + fixed_seed: Some(123), ..Default::default() }, routing_info: RoutingInfo { @@ -1716,7 +1750,7 @@ mod tests { preferences: NodeLocationPreference::Datacenter("eu".to_owned()), is_token_aware: true, permit_dc_failover: true, - fixed_shuffle_seed: Some(123), + fixed_seed: Some(123), ..Default::default() }, routing_info: RoutingInfo { @@ -1813,7 +1847,7 @@ mod tests { preferences: NodeLocationPreference::Datacenter("eu".to_owned()), is_token_aware: true, permit_dc_failover: true, - fixed_shuffle_seed: Some(123), + fixed_seed: Some(123), ..Default::default() }, routing_info: RoutingInfo { @@ -2070,7 +2104,7 @@ mod tests { ), is_token_aware: true, permit_dc_failover: false, - fixed_shuffle_seed: Some(123), + fixed_seed: Some(123), ..Default::default() }, routing_info: RoutingInfo { @@ -2142,7 +2176,7 @@ mod latency_awareness { use tracing::{instrument::WithSubscriber, trace, warn}; use uuid::Uuid; - use crate::{load_balancing::NodeRef, transport::node::Node}; + use crate::{load_balancing::NodeRef, routing::Shard, transport::node::Node}; use std::{ collections::HashMap, ops::Deref, @@ -2353,8 +2387,8 @@ mod latency_awareness { pub(super) fn wrap<'a>( &self, - fallback: impl Iterator>, - ) -> impl Iterator> { + fallback: impl Iterator, Shard)>, + ) -> impl Iterator, Shard)> { let min_avg_latency = match self.last_min_latency.load() { Some(min_avg) => min_avg, None => return Either::Left(fallback), // noop, as no latency data has been collected yet @@ -2675,8 +2709,8 @@ mod latency_awareness { struct IteratorWithSkippedNodes<'a, Fast, Penalised> where - Fast: Iterator>, - Penalised: Iterator>, + Fast: Iterator, Shard)>, + Penalised: Iterator, Shard)>, { fast_nodes: Fast, penalised_nodes: Penalised, @@ -2685,13 +2719,13 @@ mod latency_awareness { impl<'a> IteratorWithSkippedNodes< 'a, - std::vec::IntoIter>, - std::vec::IntoIter>, + std::vec::IntoIter<(NodeRef<'a>, Shard)>, + std::vec::IntoIter<(NodeRef<'a>, Shard)>, > { fn new( average_latencies: &HashMap>>, - nodes: impl Iterator>, + nodes: impl Iterator, Shard)>, exclusion_threshold: f64, retry_period: Duration, minimum_measurements: usize, @@ -2700,7 +2734,7 @@ mod latency_awareness { let mut fast_nodes = vec![]; let mut penalised_nodes = vec![]; - for node in nodes { + for node_and_shard @ (node, _shard) in nodes { match fast_enough( average_latencies, node.host_id, @@ -2709,11 +2743,11 @@ mod latency_awareness { minimum_measurements, min_avg, ) { - FastEnough::Yes => fast_nodes.push(node), + FastEnough::Yes => fast_nodes.push(node_and_shard), FastEnough::No { average } => { trace!("Latency awareness: Penalising node {{address={}, datacenter={:?}, rack={:?}}} for being on average at least {} times slower (latency: {}ms) than the fastest ({}ms).", node.address, node.datacenter, node.rack, exclusion_threshold, average.as_millis(), min_avg.as_millis()); - penalised_nodes.push(node); + penalised_nodes.push(node_and_shard); } } } @@ -2727,10 +2761,10 @@ mod latency_awareness { impl<'a, Fast, Penalised> Iterator for IteratorWithSkippedNodes<'a, Fast, Penalised> where - Fast: Iterator>, - Penalised: Iterator>, + Fast: Iterator, Shard)>, + Penalised: Iterator, Shard)>, { - type Item = &'a Arc; + type Item = (NodeRef<'a>, Shard); fn next(&mut self) -> Option { self.fast_nodes @@ -2749,7 +2783,8 @@ mod latency_awareness { }; use crate::{ - load_balancing::default::NodeLocationPreference, test_utils::create_new_session_builder, + load_balancing::default::NodeLocationPreference, routing::Shard, + test_utils::create_new_session_builder, }; use crate::{ load_balancing::{ @@ -2810,9 +2845,12 @@ mod latency_awareness { ) -> DefaultPolicy { let pick_predicate = { let latency_predicate = latency_awareness.generate_predicate(); - Box::new(move |node: &NodeRef| { - DefaultPolicy::is_alive(node) && latency_predicate(node) - }) as Box bool + Send + Sync + 'static> + Box::new( + move |node_and_shard @ (node, _shard): &(NodeRef<'_>, Shard)| { + DefaultPolicy::is_alive(node_and_shard) && latency_predicate(node) + }, + ) + as Box, Shard)) -> bool + Send + Sync + 'static> }; DefaultPolicy { @@ -2821,7 +2859,7 @@ mod latency_awareness { is_token_aware: true, pick_predicate, latency_awareness: Some(latency_awareness), - fixed_shuffle_seed: None, + fixed_seed: None, } } diff --git a/scylla/src/transport/load_balancing/mod.rs b/scylla/src/transport/load_balancing/mod.rs index d4095743c3..ad056a841f 100644 --- a/scylla/src/transport/load_balancing/mod.rs +++ b/scylla/src/transport/load_balancing/mod.rs @@ -3,7 +3,7 @@ //! See [the book](https://rust-driver.docs.scylladb.com/stable/load-balancing/load-balancing.html) for more information use super::{cluster::ClusterData, NodeRef}; -use crate::routing::Token; +use crate::routing::{Shard, Token}; use scylla_cql::{errors::QueryError, frame::types}; use std::time::Duration; @@ -39,7 +39,7 @@ pub struct RoutingInfo<'a> { /// /// It is computed on-demand, only if querying the most preferred node fails /// (or when speculative execution is triggered). -pub type FallbackPlan<'a> = Box> + Send + Sync + 'a>; +pub type FallbackPlan<'a> = Box, Shard)> + Send + Sync + 'a>; /// Policy that decides which nodes to contact for each query. /// @@ -62,7 +62,11 @@ pub type FallbackPlan<'a> = Box> + Send + Sync + /// This trait is used to produce an iterator of nodes to contact for a given query. pub trait LoadBalancingPolicy: Send + Sync + std::fmt::Debug { /// Returns the first node to contact for a given query. - fn pick<'a>(&'a self, query: &'a RoutingInfo, cluster: &'a ClusterData) -> Option>; + fn pick<'a>( + &'a self, + query: &'a RoutingInfo, + cluster: &'a ClusterData, + ) -> Option<(NodeRef<'a>, Shard)>; /// Returns all contact-appropriate nodes for a given query. fn fallback<'a>(&'a self, query: &'a RoutingInfo, cluster: &'a ClusterData) diff --git a/scylla/src/transport/load_balancing/plan.rs b/scylla/src/transport/load_balancing/plan.rs index e49d4cb012..83e95080f3 100644 --- a/scylla/src/transport/load_balancing/plan.rs +++ b/scylla/src/transport/load_balancing/plan.rs @@ -51,8 +51,8 @@ impl<'a> Iterator for Plan<'a> { PlanState::Created => { let picked = self.policy.pick(self.routing_info, self.cluster); if let Some(picked) = picked { - self.state = PlanState::Picked(picked); - Some(picked) + self.state = PlanState::Picked(picked.0); + Some(picked.0) } else { // `pick()` returned None, which semantically means that a first node cannot be computed _cheaply_. // This, however, does not imply that fallback would return an empty plan, too. @@ -64,9 +64,9 @@ impl<'a> Iterator for Plan<'a> { if let Some(node) = first_fallback_node { self.state = PlanState::Fallback { iter, - node_to_filter_out: node, + node_to_filter_out: node.0, }; - Some(node) + Some(node.0) } else { error!("Load balancing policy returned an empty plan! The query cannot be executed. Routing info: {:?}", self.routing_info); self.state = PlanState::PickedNone; @@ -87,10 +87,10 @@ impl<'a> Iterator for Plan<'a> { node_to_filter_out, } => { for node in iter { - if node == *node_to_filter_out { + if node.0 == *node_to_filter_out { continue; } else { - return Some(node); + return Some(node.0); } } @@ -105,6 +105,7 @@ impl<'a> Iterator for Plan<'a> { mod tests { use std::{net::SocketAddr, str::FromStr, sync::Arc}; + use crate::routing::Shard; use crate::transport::{ locator::test::{create_locator, mock_metadata_for_token_aware_tests}, Node, NodeAddr, @@ -112,24 +113,27 @@ mod tests { use super::*; - fn expected_nodes() -> Vec> { - vec![Arc::new(Node::new_for_test( - NodeAddr::Translatable(SocketAddr::from_str("127.0.0.1:9042").unwrap()), - None, - None, - ))] + fn expected_nodes() -> Vec<(Arc, Shard)> { + vec![( + Arc::new(Node::new_for_test( + NodeAddr::Translatable(SocketAddr::from_str("127.0.0.1:9042").unwrap()), + None, + None, + )), + 42, + )] } #[derive(Debug)] struct PickingNonePolicy { - expected_nodes: Vec>, + expected_nodes: Vec<(Arc, Shard)>, } impl LoadBalancingPolicy for PickingNonePolicy { fn pick<'a>( &'a self, _query: &'a RoutingInfo, _cluster: &'a ClusterData, - ) -> Option> { + ) -> Option<(NodeRef<'a>, Shard)> { None } @@ -138,7 +142,11 @@ mod tests { _query: &'a RoutingInfo, _cluster: &'a ClusterData, ) -> FallbackPlan<'a> { - Box::new(self.expected_nodes.iter()) + Box::new( + self.expected_nodes + .iter() + .map(|(node_ref, shard)| (node_ref, *shard)), + ) } fn name(&self) -> String { @@ -159,6 +167,13 @@ mod tests { }; let routing_info = RoutingInfo::default(); let plan = Plan::new(&policy, &routing_info, &cluster_data); - assert_eq!(Vec::from_iter(plan.cloned()), policy.expected_nodes); + assert_eq!( + Vec::from_iter(plan.cloned()), + policy + .expected_nodes + .into_iter() + .map(|(node, _shard)| node) + .collect::>() + ); } } diff --git a/scylla/tests/integration/consistency.rs b/scylla/tests/integration/consistency.rs index 54c81823e0..a85a4bb60d 100644 --- a/scylla/tests/integration/consistency.rs +++ b/scylla/tests/integration/consistency.rs @@ -4,7 +4,7 @@ use scylla::execution_profile::{ExecutionProfileBuilder, ExecutionProfileHandle} use scylla::load_balancing::{DefaultPolicy, LoadBalancingPolicy, RoutingInfo}; use scylla::prepared_statement::PreparedStatement; use scylla::retry_policy::FallthroughRetryPolicy; -use scylla::routing::Token; +use scylla::routing::{Shard, Token}; use scylla::test_utils::unique_keyspace_name; use scylla::transport::session::Session; use scylla::transport::NodeRef; @@ -378,7 +378,7 @@ impl LoadBalancingPolicy for RoutingInfoReportingWrapper { &'a self, query: &'a RoutingInfo, cluster: &'a scylla::transport::ClusterData, - ) -> Option> { + ) -> Option<(NodeRef<'a>, Shard)> { self.routing_info_tx .send(OwnedRoutingInfo::from(query.clone())) .unwrap(); diff --git a/scylla/tests/integration/execution_profiles.rs b/scylla/tests/integration/execution_profiles.rs index 119487a609..6e6ceb5f00 100644 --- a/scylla/tests/integration/execution_profiles.rs +++ b/scylla/tests/integration/execution_profiles.rs @@ -6,6 +6,7 @@ use assert_matches::assert_matches; use scylla::batch::BatchStatement; use scylla::batch::{Batch, BatchType}; use scylla::query::Query; +use scylla::routing::Shard; use scylla::statement::SerialConsistency; use scylla::transport::NodeRef; use scylla::{ @@ -46,9 +47,13 @@ impl BoundToPredefinedNodePolicy { } impl LoadBalancingPolicy for BoundToPredefinedNodePolicy { - fn pick<'a>(&'a self, _info: &'a RoutingInfo, cluster: &'a ClusterData) -> Option> { + fn pick<'a>( + &'a self, + _info: &'a RoutingInfo, + cluster: &'a ClusterData, + ) -> Option<(NodeRef<'a>, Shard)> { self.report_node(Report::LoadBalancing); - cluster.get_nodes_info().iter().next() + cluster.get_nodes_info().iter().next().map(|node| (node, 0)) // FIXME: hardcoded shard 0 } fn fallback<'a>( diff --git a/scylla/tests/integration/utils.rs b/scylla/tests/integration/utils.rs index 70abdd2c9a..52270c8942 100644 --- a/scylla/tests/integration/utils.rs +++ b/scylla/tests/integration/utils.rs @@ -1,6 +1,7 @@ use futures::Future; use itertools::Itertools; use scylla::load_balancing::LoadBalancingPolicy; +use scylla::routing::Shard; use scylla::transport::NodeRef; use std::collections::HashMap; use std::env; @@ -17,6 +18,14 @@ pub fn init_logger() { .try_init(); } +fn with_pseudorandom_shard(node: NodeRef) -> (NodeRef, Shard) { + let nr_shards = node + .sharder() + .map(|sharder| sharder.nr_shards.get()) + .unwrap_or(1); + (node, ((nr_shards - 1) % 42) as Shard) +} + #[derive(Debug)] pub struct FixedOrderLoadBalancer; impl LoadBalancingPolicy for FixedOrderLoadBalancer { @@ -24,12 +33,13 @@ impl LoadBalancingPolicy for FixedOrderLoadBalancer { &'a self, _info: &'a scylla::load_balancing::RoutingInfo, cluster: &'a scylla::transport::ClusterData, - ) -> Option> { + ) -> Option<(NodeRef<'a>, Shard)> { cluster .get_nodes_info() .iter() .sorted_by(|node1, node2| Ord::cmp(&node1.address, &node2.address)) .next() + .map(with_pseudorandom_shard) } fn fallback<'a>( @@ -41,7 +51,8 @@ impl LoadBalancingPolicy for FixedOrderLoadBalancer { cluster .get_nodes_info() .iter() - .sorted_by(|node1, node2| Ord::cmp(&node1.address, &node2.address)), + .sorted_by(|node1, node2| Ord::cmp(&node1.address, &node2.address)) + .map(with_pseudorandom_shard), ) }