diff --git a/examples/custom_load_balancing_policy.rs b/examples/custom_load_balancing_policy.rs index fb1ae0cb7c..5c279f2331 100644 --- a/examples/custom_load_balancing_policy.rs +++ b/examples/custom_load_balancing_policy.rs @@ -18,12 +18,12 @@ struct CustomLoadBalancingPolicy { fav_datacenter_name: String, } -fn with_random_shard(node: NodeRef) -> (NodeRef, Shard) { +fn with_random_shard(node: NodeRef) -> (NodeRef, Option) { let nr_shards = node .sharder() .map(|sharder| sharder.nr_shards.get()) .unwrap_or(1); - (node, thread_rng().gen_range(0..nr_shards) as Shard) + (node, Some(thread_rng().gen_range(0..nr_shards) as Shard)) } impl LoadBalancingPolicy for CustomLoadBalancingPolicy { @@ -31,7 +31,7 @@ impl LoadBalancingPolicy for CustomLoadBalancingPolicy { &'a self, _info: &'a RoutingInfo, cluster: &'a ClusterData, - ) -> Option<(NodeRef<'a>, Shard)> { + ) -> Option<(NodeRef<'a>, Option)> { self.fallback(_info, cluster).next() } diff --git a/scylla/src/transport/load_balancing/default.rs b/scylla/src/transport/load_balancing/default.rs index e3c1f97377..60b8cbea4f 100644 --- a/scylla/src/transport/load_balancing/default.rs +++ b/scylla/src/transport/load_balancing/default.rs @@ -75,7 +75,7 @@ pub struct DefaultPolicy { preferences: NodeLocationPreference, is_token_aware: bool, permit_dc_failover: bool, - pick_predicate: Box, Shard)) -> bool + Send + Sync>, + pick_predicate: Box, Option) -> bool + Send + Sync>, latency_awareness: Option, fixed_seed: Option, } @@ -97,7 +97,7 @@ impl LoadBalancingPolicy for DefaultPolicy { &'a self, query: &'a RoutingInfo, cluster: &'a ClusterData, - ) -> Option<(NodeRef<'a>, Shard)> { + ) -> Option<(NodeRef<'a>, Option)> { let routing_info = self.routing_info(query, cluster); if let Some(ref token_with_strategy) = routing_info.token_with_strategy { if self.preferences.datacenter().is_some() @@ -126,13 +126,13 @@ or refrain from preferring datacenters (which may ban all other datacenters, if let local_rack_picked = self.pick_replica( ts, NodeLocationCriteria::DatacenterAndRack(dc, rack), - &self.pick_predicate, + |node, shard| (self.pick_predicate)(node, Some(shard)), cluster, statement_type, ); - if let Some(alive_local_rack_replica) = local_rack_picked { - return Some(alive_local_rack_replica); + if let Some((alive_local_rack_replica, shard)) = local_rack_picked { + return Some((alive_local_rack_replica, Some(shard))); } } @@ -143,13 +143,13 @@ or refrain from preferring datacenters (which may ban all other datacenters, if let picked = self.pick_replica( ts, NodeLocationCriteria::Datacenter(dc), - &self.pick_predicate, + |node, shard| (self.pick_predicate)(node, Some(shard)), cluster, statement_type, ); - if let Some(alive_local_replica) = picked { - return Some(alive_local_replica); + if let Some((alive_local_replica, shard)) = picked { + return Some((alive_local_replica, Some(shard))); } } @@ -161,12 +161,12 @@ or refrain from preferring datacenters (which may ban all other datacenters, if let picked = self.pick_replica( ts, NodeLocationCriteria::Any, - &self.pick_predicate, + |node, shard| (self.pick_predicate)(node, Some(shard)), cluster, statement_type, ); - if let Some(alive_remote_replica) = picked { - return Some(alive_remote_replica); + if let Some((alive_remote_replica, shard)) = picked { + return Some((alive_remote_replica, Some(shard))); } } }; @@ -179,47 +179,47 @@ or refrain from preferring datacenters (which may ban all other datacenters, if if let NodeLocationPreference::DatacenterAndRack(dc, rack) = &self.preferences { // Try to pick some alive local rack random node. let rack_predicate = Self::make_rack_predicate( - &self.pick_predicate, + |node| (self.pick_predicate)(node, None), NodeLocationCriteria::DatacenterAndRack(dc, rack), ); - let local_rack_picked = self.pick_node(nodes, rack_predicate); + let local_rack_picked = self.pick_node(nodes, |node| rack_predicate(&node)); if let Some(alive_local_rack) = local_rack_picked { - return Some(alive_local_rack); + return Some((alive_local_rack, None)); } } // Try to pick some alive local random node. - if let Some(alive_local) = self.pick_node(nodes, &self.pick_predicate) { - return Some(alive_local); + if let Some(alive_local) = self.pick_node(nodes, |node| (self.pick_predicate)(node, None)) { + return Some((alive_local, None)); } let all_nodes = cluster.replica_locator().unique_nodes_in_global_ring(); // If a datacenter failover is possible, loosen restriction about locality. if self.is_datacenter_failover_possible(&routing_info) { - let picked = self.pick_node(all_nodes, &self.pick_predicate); + let picked = self.pick_node(all_nodes, |node| (self.pick_predicate)(node, None)); if let Some(alive_maybe_remote) = picked { - return Some(alive_maybe_remote); + return Some((alive_maybe_remote, None)); } } // Previous checks imply that every node we could have selected is down. // Let's try to return a down node that wasn't disabled. - let picked = self.pick_node(nodes, |(node, _shard)| node.is_enabled()); + let picked = self.pick_node(nodes, |node| node.is_enabled()); if let Some(down_but_enabled_local_node) = picked { - return Some(down_but_enabled_local_node); + return Some((down_but_enabled_local_node, None)); } // If a datacenter failover is possible, loosen restriction about locality. if self.is_datacenter_failover_possible(&routing_info) { - let picked = self.pick_node(all_nodes, |(node, _shard)| node.is_enabled()); + let picked = self.pick_node(all_nodes, |node| node.is_enabled()); if let Some(down_but_enabled_maybe_remote_node) = picked { - return Some(down_but_enabled_maybe_remote_node); + return Some((down_but_enabled_maybe_remote_node, None)); } } // Every node is disabled. This could be due to a bad host filter - configuration error. - nodes.first().map(|node| self.with_random_shard(node)) + nodes.first().map(|node| (node, None)) } fn fallback<'a>( @@ -241,7 +241,7 @@ or refrain from preferring datacenters (which may ban all other datacenters, if let local_rack_replicas = self.fallback_replicas( ts, NodeLocationCriteria::DatacenterAndRack(dc, rack), - Self::is_alive, + |node, shard| Self::is_alive(node, Some(shard)), cluster, statement_type, ); @@ -257,7 +257,7 @@ or refrain from preferring datacenters (which may ban all other datacenters, if let local_replicas = self.fallback_replicas( ts, NodeLocationCriteria::Datacenter(dc), - Self::is_alive, + |node, shard| Self::is_alive(node, Some(shard)), cluster, statement_type, ); @@ -273,7 +273,7 @@ or refrain from preferring datacenters (which may ban all other datacenters, if let remote_replicas = self.fallback_replicas( ts, NodeLocationCriteria::Any, - Self::is_alive, + |node, shard| Self::is_alive(node, Some(shard)), cluster, statement_type, ); @@ -287,10 +287,11 @@ or refrain from preferring datacenters (which may ban all other datacenters, if Either::Left( maybe_local_rack_replicas .chain(maybe_local_replicas) - .chain(maybe_remote_replicas), + .chain(maybe_remote_replicas) + .map(|(node, shard)| (node, Some(shard))), ) } else { - Either::Right(std::iter::empty::<(NodeRef<'a>, Shard)>()) + Either::Right(std::iter::empty::<(NodeRef<'a>, Option)>()) }; // Get a list of all local alive nodes, and apply a round robin to it @@ -299,31 +300,37 @@ or refrain from preferring datacenters (which may ban all other datacenters, if let maybe_local_rack_nodes = if let NodeLocationPreference::DatacenterAndRack(dc, rack) = &self.preferences { let rack_predicate = Self::make_rack_predicate( - &self.pick_predicate, + |node| (self.pick_predicate)(node, None), NodeLocationCriteria::DatacenterAndRack(dc, rack), ); - Either::Left(self.round_robin_nodes_with_shards(local_nodes, rack_predicate)) + Either::Left( + self.round_robin_nodes(local_nodes, rack_predicate) + .map(|node| (node, None)), + ) } else { - Either::Right(std::iter::empty::<(NodeRef<'a>, Shard)>()) + Either::Right(std::iter::empty::<(NodeRef<'a>, Option)>()) }; - let robinned_local_nodes = self.round_robin_nodes_with_shards(local_nodes, Self::is_alive); + let robinned_local_nodes = self + .round_robin_nodes(local_nodes, |node| Self::is_alive(node, None)) + .map(|node| (node, None)); let all_nodes = cluster.replica_locator().unique_nodes_in_global_ring(); // If a datacenter failover is possible, loosen restriction about locality. let maybe_remote_nodes = if self.is_datacenter_failover_possible(&routing_info) { - let robinned_all_nodes = self.round_robin_nodes_with_shards(all_nodes, Self::is_alive); + let robinned_all_nodes = + self.round_robin_nodes(all_nodes, |node| Self::is_alive(node, None)); - Either::Left(robinned_all_nodes) + Either::Left(robinned_all_nodes.map(|node| (node, None))) } else { - Either::Right(std::iter::empty::<(NodeRef<'a>, Shard)>()) + Either::Right(std::iter::empty::<(NodeRef<'a>, Option)>()) }; // Even if we consider some enabled nodes to be down, we should try contacting them in the last resort. let maybe_down_local_nodes = local_nodes .iter() .filter(|node| node.is_enabled()) - .map(|node| self.with_random_shard(node)); + .map(|node| (node, None)); // If a datacenter failover is possible, loosen restriction about locality. let maybe_down_nodes = if self.is_datacenter_failover_possible(&routing_info) { @@ -331,7 +338,7 @@ or refrain from preferring datacenters (which may ban all other datacenters, if all_nodes .iter() .filter(|node| node.is_enabled()) - .map(|node| self.with_random_shard(node)), + .map(|node| (node, None)), ) } else { Either::Right(std::iter::empty()) @@ -433,15 +440,28 @@ impl DefaultPolicy { /// Wraps the provided predicate, adding the requirement for rack to match. fn make_rack_predicate<'a>( - predicate: impl Fn(&(NodeRef<'a>, Shard)) -> bool + 'a, + predicate: impl Fn(NodeRef<'a>) -> bool + 'a, + replica_location: NodeLocationCriteria<'a>, + ) -> impl Fn(&NodeRef<'a>) -> bool { + move |node| match replica_location { + NodeLocationCriteria::Any | NodeLocationCriteria::Datacenter(_) => predicate(node), + NodeLocationCriteria::DatacenterAndRack(_, rack) => { + predicate(node) && node.rack.as_deref() == Some(rack) + } + } + } + + /// Wraps the provided predicate, adding the requirement for rack to match. + fn make_sharded_rack_predicate<'a>( + predicate: impl Fn(NodeRef<'a>, Shard) -> bool + 'a, replica_location: NodeLocationCriteria<'a>, - ) -> impl Fn(&(NodeRef<'a>, Shard)) -> bool { - move |node_and_shard @ (node, _shard)| match replica_location { + ) -> impl Fn(NodeRef<'a>, Shard) -> bool { + move |node, shard| match replica_location { NodeLocationCriteria::Any | NodeLocationCriteria::Datacenter(_) => { - predicate(node_and_shard) + predicate(node, shard) } NodeLocationCriteria::DatacenterAndRack(_, rack) => { - predicate(node_and_shard) && node.rack.as_deref() == Some(rack) + predicate(node, shard) && node.rack.as_deref() == Some(rack) } } } @@ -450,11 +470,11 @@ impl DefaultPolicy { &'a self, ts: &TokenWithStrategy<'a>, replica_location: NodeLocationCriteria<'a>, - predicate: impl Fn(&(NodeRef<'a>, Shard)) -> bool + 'a, + predicate: impl Fn(NodeRef<'a>, Shard) -> bool + 'a, cluster: &'a ClusterData, order: ReplicaOrder, ) -> impl Iterator, Shard)> { - let predicate = Self::make_rack_predicate(predicate, replica_location); + let predicate = Self::make_sharded_rack_predicate(predicate, replica_location); let replica_iter = match order { ReplicaOrder::Arbitrary => Either::Left( @@ -467,14 +487,14 @@ impl DefaultPolicy { .into_iter(), ), }; - replica_iter.filter(move |node_and_shard: &(NodeRef<'a>, Shard)| predicate(node_and_shard)) + replica_iter.filter(move |(node, shard): &(NodeRef<'a>, Shard)| predicate(node, *shard)) } fn pick_replica<'a>( &'a self, ts: &TokenWithStrategy<'a>, replica_location: NodeLocationCriteria<'a>, - predicate: &'a impl Fn(&(NodeRef<'a>, Shard)) -> bool, + predicate: impl Fn(NodeRef<'a>, Shard) -> bool + 'a, cluster: &'a ClusterData, statement_type: StatementType, ) -> Option<(NodeRef<'a>, Shard)> { @@ -502,7 +522,7 @@ impl DefaultPolicy { &'a self, ts: &TokenWithStrategy<'a>, replica_location: NodeLocationCriteria<'a>, - predicate: &'a impl Fn(&(NodeRef<'a>, Shard)) -> bool, + predicate: impl Fn(NodeRef<'a>, Shard) -> bool + 'a, cluster: &'a ClusterData, ) -> Option<(NodeRef<'a>, Shard)> { match replica_location { @@ -521,8 +541,8 @@ impl DefaultPolicy { .into_replicas_ordered() .into_iter() .next() - .and_then(|primary_replica| { - predicate(&primary_replica).then_some(primary_replica) + .and_then(|(primary_replica, shard)| { + predicate(primary_replica, shard).then_some((primary_replica, shard)) }) } NodeLocationCriteria::Datacenter(_) | NodeLocationCriteria::DatacenterAndRack(_, _) => { @@ -534,7 +554,7 @@ impl DefaultPolicy { self.replicas( ts, replica_location, - move |node_and_shard| predicate(node_and_shard), + predicate, cluster, ReplicaOrder::RingOrder, ) @@ -547,18 +567,18 @@ impl DefaultPolicy { &'a self, ts: &TokenWithStrategy<'a>, replica_location: NodeLocationCriteria<'a>, - predicate: &'a impl Fn(&(NodeRef<'a>, Shard)) -> bool, + predicate: impl Fn(NodeRef<'a>, Shard) -> bool + 'a, cluster: &'a ClusterData, ) -> Option<(NodeRef<'a>, Shard)> { - let predicate = Self::make_rack_predicate(predicate, replica_location); + let predicate = Self::make_sharded_rack_predicate(predicate, replica_location); let replica_set = self.nonfiltered_replica_set(ts, replica_location, cluster); if let Some(fixed) = self.fixed_seed { let mut gen = Pcg32::new(fixed, 0); - replica_set.choose_filtered(&mut gen, predicate) + replica_set.choose_filtered(&mut gen, |(node, shard)| predicate(node, *shard)) } else { - replica_set.choose_filtered(&mut thread_rng(), predicate) + replica_set.choose_filtered(&mut thread_rng(), |(node, shard)| predicate(node, *shard)) } } @@ -566,7 +586,7 @@ impl DefaultPolicy { &'a self, ts: &TokenWithStrategy<'a>, replica_location: NodeLocationCriteria<'a>, - predicate: impl Fn(&(NodeRef<'_>, Shard)) -> bool + 'a, + predicate: impl Fn(NodeRef<'_>, Shard) -> bool + 'a, cluster: &'a ClusterData, statement_type: StatementType, ) -> impl Iterator, Shard)> { @@ -604,22 +624,21 @@ impl DefaultPolicy { fn pick_node<'a>( &'a self, nodes: &'a [Arc], - predicate: impl Fn(&(NodeRef<'a>, Shard)) -> bool, - ) -> Option<(NodeRef<'_>, Shard)> { + predicate: impl Fn(NodeRef<'a>) -> bool, + ) -> Option> { // Select the first node that matches the predicate - Self::randomly_rotated_nodes(nodes) - .map(|node| self.with_random_shard(node)) - .find(predicate) + Self::randomly_rotated_nodes(nodes).find(|&node| predicate(node)) } - fn round_robin_nodes_with_shards<'a>( + fn round_robin_nodes<'a>( &'a self, nodes: &'a [Arc], - predicate: impl Fn(&(NodeRef<'a>, Shard)) -> bool, - ) -> impl Iterator, Shard)> { - Self::randomly_rotated_nodes(nodes) - .map(|node| self.with_random_shard(node)) - .filter(predicate) + // I wanted this to be + // impl Fn(&NodeRef<'a>) -> bool + // but I have no idea how to make this work with borrow checker + predicate: impl Fn(&NodeRef<'a>) -> bool, + ) -> impl Iterator> { + Self::randomly_rotated_nodes(nodes).filter(predicate) } fn shuffle<'a>( @@ -638,23 +657,7 @@ impl DefaultPolicy { vec.into_iter() } - fn with_random_shard<'a>(&self, node: NodeRef<'a>) -> (NodeRef<'a>, Shard) { - let nr_shards = node - .sharder() - .map(|sharder| sharder.nr_shards.get()) - .unwrap_or(1); - ( - node, - (if let Some(fixed) = self.fixed_seed { - let mut gen = Pcg32::new(fixed, 0); - gen.gen_range(0..nr_shards) - } else { - thread_rng().gen_range(0..nr_shards) - }) as Shard, - ) - } - - fn is_alive(&(node, _shard): &(NodeRef<'_>, Shard)) -> bool { + fn is_alive(node: NodeRef, _shard: Option) -> bool { // For now, we leave this as stub, until we have time to improve node events. // node.is_enabled() && !node.is_down() node.is_enabled() @@ -720,11 +723,10 @@ impl DefaultPolicyBuilder { let latency_awareness = self.latency_awareness.map(|builder| builder.build()); let pick_predicate = if let Some(ref latency_awareness) = latency_awareness { let latency_predicate = latency_awareness.generate_predicate(); - Box::new( - move |node_and_shard @ (node, _shard): &(NodeRef<'_>, Shard)| { - DefaultPolicy::is_alive(node_and_shard) && latency_predicate(node) - }, - ) as Box, Shard)) -> bool + Send + Sync + 'static> + Box::new(move |node: NodeRef<'_>, shard| { + DefaultPolicy::is_alive(node, shard) && latency_predicate(node) + }) + as Box, Option) -> bool + Send + Sync + 'static> } else { Box::new(DefaultPolicy::is_alive) }; @@ -2401,8 +2403,8 @@ mod latency_awareness { pub(super) fn wrap<'a>( &self, - fallback: impl Iterator, Shard)>, - ) -> impl Iterator, Shard)> { + fallback: impl Iterator, Option)>, + ) -> impl Iterator, Option)> { let min_avg_latency = match self.last_min_latency.load() { Some(min_avg) => min_avg, None => return Either::Left(fallback), // noop, as no latency data has been collected yet @@ -2723,8 +2725,8 @@ mod latency_awareness { struct IteratorWithSkippedNodes<'a, Fast, Penalised> where - Fast: Iterator, Shard)>, - Penalised: Iterator, Shard)>, + Fast: Iterator, Option)>, + Penalised: Iterator, Option)>, { fast_nodes: Fast, penalised_nodes: Penalised, @@ -2733,13 +2735,13 @@ mod latency_awareness { impl<'a> IteratorWithSkippedNodes< 'a, - std::vec::IntoIter<(NodeRef<'a>, Shard)>, - std::vec::IntoIter<(NodeRef<'a>, Shard)>, + std::vec::IntoIter<(NodeRef<'a>, Option)>, + std::vec::IntoIter<(NodeRef<'a>, Option)>, > { fn new( average_latencies: &HashMap>>, - nodes: impl Iterator, Shard)>, + nodes: impl Iterator, Option)>, exclusion_threshold: f64, retry_period: Duration, minimum_measurements: usize, @@ -2775,10 +2777,10 @@ mod latency_awareness { impl<'a, Fast, Penalised> Iterator for IteratorWithSkippedNodes<'a, Fast, Penalised> where - Fast: Iterator, Shard)>, - Penalised: Iterator, Shard)>, + Fast: Iterator, Option)>, + Penalised: Iterator, Option)>, { - type Item = (NodeRef<'a>, Shard); + type Item = (NodeRef<'a>, Option); fn next(&mut self) -> Option { self.fast_nodes @@ -2860,12 +2862,10 @@ mod latency_awareness { ) -> DefaultPolicy { let pick_predicate = { let latency_predicate = latency_awareness.generate_predicate(); - Box::new( - move |node_and_shard @ (node, _shard): &(NodeRef<'_>, Shard)| { - DefaultPolicy::is_alive(node_and_shard) && latency_predicate(node) - }, - ) - as Box, Shard)) -> bool + Send + Sync + 'static> + Box::new(move |node: NodeRef<'_>, shard| { + DefaultPolicy::is_alive(node, shard) && latency_predicate(node) + }) + as Box, Option) -> bool + Send + Sync + 'static> }; DefaultPolicy { diff --git a/scylla/src/transport/load_balancing/mod.rs b/scylla/src/transport/load_balancing/mod.rs index 977e3d508f..f1cd5bdf27 100644 --- a/scylla/src/transport/load_balancing/mod.rs +++ b/scylla/src/transport/load_balancing/mod.rs @@ -39,7 +39,8 @@ pub struct RoutingInfo<'a> { /// /// It is computed on-demand, only if querying the most preferred node fails /// (or when speculative execution is triggered). -pub type FallbackPlan<'a> = Box, Shard)> + Send + Sync + 'a>; +pub type FallbackPlan<'a> = + Box, Option)> + Send + Sync + 'a>; /// Policy that decides which nodes and shards to contact for each query. /// @@ -67,7 +68,7 @@ pub trait LoadBalancingPolicy: Send + Sync + std::fmt::Debug { &'a self, query: &'a RoutingInfo, cluster: &'a ClusterData, - ) -> Option<(NodeRef<'a>, Shard)>; + ) -> Option<(NodeRef<'a>, Option)>; /// Returns all contact-appropriate nodes for a given query. fn fallback<'a>(&'a self, query: &'a RoutingInfo, cluster: &'a ClusterData) diff --git a/scylla/src/transport/load_balancing/plan.rs b/scylla/src/transport/load_balancing/plan.rs index 5fc6294467..c88748642f 100644 --- a/scylla/src/transport/load_balancing/plan.rs +++ b/scylla/src/transport/load_balancing/plan.rs @@ -1,3 +1,4 @@ +use rand::{thread_rng, Rng}; use tracing::error; use super::{FallbackPlan, LoadBalancingPolicy, NodeRef, RoutingInfo}; @@ -6,10 +7,10 @@ use crate::{routing::Shard, transport::ClusterData}; enum PlanState<'a> { Created, PickedNone, // This always means an abnormal situation: it means that no nodes satisfied locality/node filter requirements. - Picked((NodeRef<'a>, Shard)), + Picked((NodeRef<'a>, Option)), Fallback { iter: FallbackPlan<'a>, - node_to_filter_out: (NodeRef<'a>, Shard), + node_to_filter_out: (NodeRef<'a>, Option), }, } @@ -19,7 +20,52 @@ enum PlanState<'a> { /// eagerly in the first place and the remaining nodes computed on-demand /// (all at once). /// This significantly reduces the allocation overhead on "the happy path" -/// (when the first node successfully handles the request), +/// (when the first node successfully handles the request). +/// +/// `Plan` implements `Iterator, Shard)>` but LoadBalancingPolicy +/// returns `Option` instead of `Shard` both in `pick` and in `fallback`. +/// `Plan` handles the `None` case by using random shard for a given node. +/// There is currently no way to configure RNG used by `Plan`. +/// If you don't want `Plan` to do randomize shards or you want to control the RNG, +/// use custom LBP that will always return non-`None` shards. +/// Example of LBP that always uses shard 0, preventing `Plan` from using random numbers: +/// +/// ``` +/// # use std::sync::Arc; +/// # use scylla::load_balancing::LoadBalancingPolicy; +/// # use scylla::load_balancing::RoutingInfo; +/// # use scylla::transport::ClusterData; +/// # use scylla::transport::NodeRef; +/// # use scylla::routing::Shard; +/// # use scylla::load_balancing::FallbackPlan; +/// +/// #[derive(Debug)] +/// struct NonRandomLBP { +/// inner: Arc, +/// } +/// impl LoadBalancingPolicy for NonRandomLBP { +/// fn pick<'a>( +/// &'a self, +/// info: &'a RoutingInfo, +/// cluster: &'a ClusterData, +/// ) -> Option<(NodeRef<'a>, Option)> { +/// self.inner +/// .pick(info, cluster) +/// .map(|(node, shard)| (node, shard.or(Some(0)))) +/// } +/// +/// fn fallback<'a>(&'a self, info: &'a RoutingInfo, cluster: &'a ClusterData) -> FallbackPlan<'a> { +/// Box::new(self.inner +/// .fallback(info, cluster) +/// .map(|(node, shard)| (node, shard.or(Some(0))))) +/// } +/// +/// fn name(&self) -> String { +/// "NonRandomLBP".to_string() +/// } +/// } +/// ``` + pub struct Plan<'a> { policy: &'a dyn LoadBalancingPolicy, routing_info: &'a RoutingInfo<'a>, @@ -41,6 +87,21 @@ impl<'a> Plan<'a> { state: PlanState::Created, } } + + fn with_random_shard_if_unknown( + (node, shard): (NodeRef<'_>, Option), + ) -> (NodeRef<'_>, Shard) { + ( + node, + shard.unwrap_or_else(|| { + let nr_shards = node + .sharder() + .map(|sharder| sharder.nr_shards.get()) + .unwrap_or(1); + thread_rng().gen_range(0..nr_shards).into() + }), + ) + } } impl<'a> Iterator for Plan<'a> { @@ -52,7 +113,7 @@ impl<'a> Iterator for Plan<'a> { let picked = self.policy.pick(self.routing_info, self.cluster); if let Some(picked) = picked { self.state = PlanState::Picked(picked); - Some(picked) + Some(Self::with_random_shard_if_unknown(picked)) } else { // `pick()` returned None, which semantically means that a first node cannot be computed _cheaply_. // This, however, does not imply that fallback would return an empty plan, too. @@ -66,7 +127,7 @@ impl<'a> Iterator for Plan<'a> { iter, node_to_filter_out: node, }; - Some(node) + Some(Self::with_random_shard_if_unknown(node)) } else { error!("Load balancing policy returned an empty plan! The query cannot be executed. Routing info: {:?}", self.routing_info); self.state = PlanState::PickedNone; @@ -90,7 +151,7 @@ impl<'a> Iterator for Plan<'a> { if node == *node_to_filter_out { continue; } else { - return Some(node); + return Some(Self::with_random_shard_if_unknown(node)); } } @@ -135,7 +196,7 @@ mod tests { &'a self, _query: &'a RoutingInfo, _cluster: &'a ClusterData, - ) -> Option<(NodeRef<'a>, Shard)> { + ) -> Option<(NodeRef<'a>, Option)> { None } @@ -147,7 +208,7 @@ mod tests { Box::new( self.expected_nodes .iter() - .map(|(node_ref, shard)| (node_ref, *shard)), + .map(|(node_ref, shard)| (node_ref, Some(*shard))), ) } diff --git a/scylla/tests/integration/consistency.rs b/scylla/tests/integration/consistency.rs index 5f178a3bea..a96e4450bb 100644 --- a/scylla/tests/integration/consistency.rs +++ b/scylla/tests/integration/consistency.rs @@ -379,7 +379,7 @@ impl LoadBalancingPolicy for RoutingInfoReportingWrapper { &'a self, query: &'a RoutingInfo, cluster: &'a scylla::transport::ClusterData, - ) -> Option<(NodeRef<'a>, Shard)> { + ) -> Option<(NodeRef<'a>, Option)> { self.routing_info_tx .send(OwnedRoutingInfo::from(query.clone())) .unwrap(); diff --git a/scylla/tests/integration/execution_profiles.rs b/scylla/tests/integration/execution_profiles.rs index c0d1964f0a..59f95dfa88 100644 --- a/scylla/tests/integration/execution_profiles.rs +++ b/scylla/tests/integration/execution_profiles.rs @@ -51,9 +51,13 @@ impl LoadBalancingPolicy for BoundToPredefinedNodePolicy { &'a self, _info: &'a RoutingInfo, cluster: &'a ClusterData, - ) -> Option<(NodeRef<'a>, Shard)> { + ) -> Option<(NodeRef<'a>, Option)> { self.report_node(Report::LoadBalancing); - cluster.get_nodes_info().iter().next().map(|node| (node, 0)) + cluster + .get_nodes_info() + .iter() + .next() + .map(|node| (node, None)) } fn fallback<'a>( diff --git a/scylla/tests/integration/utils.rs b/scylla/tests/integration/utils.rs index b32be090af..f30021db33 100644 --- a/scylla/tests/integration/utils.rs +++ b/scylla/tests/integration/utils.rs @@ -19,12 +19,12 @@ pub(crate) fn setup_tracing() { .try_init(); } -fn with_pseudorandom_shard(node: NodeRef) -> (NodeRef, Shard) { +fn with_pseudorandom_shard(node: NodeRef) -> (NodeRef, Option) { let nr_shards = node .sharder() .map(|sharder| sharder.nr_shards.get()) .unwrap_or(1); - (node, ((nr_shards - 1) % 42) as Shard) + (node, Some(((nr_shards - 1) % 42) as Shard)) } #[derive(Debug)] @@ -34,7 +34,7 @@ impl LoadBalancingPolicy for FixedOrderLoadBalancer { &'a self, _info: &'a scylla::load_balancing::RoutingInfo, cluster: &'a scylla::transport::ClusterData, - ) -> Option<(NodeRef<'a>, Shard)> { + ) -> Option<(NodeRef<'a>, Option)> { cluster .get_nodes_info() .iter()