From 80d648f779faf77b97509c035dfcb5a82bf9beaf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karol=20Bary=C5=82a?= Date: Fri, 22 Mar 2024 11:35:14 +0100 Subject: [PATCH] LBP: Return Option instead of Shard This was already documented as such, but due to an oversight the code was in disagreement with documentation. Approach from the documentation is better, because the currently implemented approach prevented deduplication in Plan from working correctly. --- examples/custom_load_balancing_policy.rs | 6 +- .../src/transport/load_balancing/default.rs | 208 +++++++++--------- scylla/src/transport/load_balancing/mod.rs | 5 +- scylla/src/transport/load_balancing/plan.rs | 77 ++++++- scylla/tests/integration/consistency.rs | 2 +- .../tests/integration/execution_profiles.rs | 8 +- scylla/tests/integration/utils.rs | 6 +- 7 files changed, 189 insertions(+), 123 deletions(-) diff --git a/examples/custom_load_balancing_policy.rs b/examples/custom_load_balancing_policy.rs index fb1ae0cb7c..5c279f2331 100644 --- a/examples/custom_load_balancing_policy.rs +++ b/examples/custom_load_balancing_policy.rs @@ -18,12 +18,12 @@ struct CustomLoadBalancingPolicy { fav_datacenter_name: String, } -fn with_random_shard(node: NodeRef) -> (NodeRef, Shard) { +fn with_random_shard(node: NodeRef) -> (NodeRef, Option) { let nr_shards = node .sharder() .map(|sharder| sharder.nr_shards.get()) .unwrap_or(1); - (node, thread_rng().gen_range(0..nr_shards) as Shard) + (node, Some(thread_rng().gen_range(0..nr_shards) as Shard)) } impl LoadBalancingPolicy for CustomLoadBalancingPolicy { @@ -31,7 +31,7 @@ impl LoadBalancingPolicy for CustomLoadBalancingPolicy { &'a self, _info: &'a RoutingInfo, cluster: &'a ClusterData, - ) -> Option<(NodeRef<'a>, Shard)> { + ) -> Option<(NodeRef<'a>, Option)> { self.fallback(_info, cluster).next() } diff --git a/scylla/src/transport/load_balancing/default.rs b/scylla/src/transport/load_balancing/default.rs index e3c1f97377..60b8cbea4f 100644 --- a/scylla/src/transport/load_balancing/default.rs +++ b/scylla/src/transport/load_balancing/default.rs @@ -75,7 +75,7 @@ pub struct DefaultPolicy { preferences: NodeLocationPreference, is_token_aware: bool, permit_dc_failover: bool, - pick_predicate: Box, Shard)) -> bool + Send + Sync>, + pick_predicate: Box, Option) -> bool + Send + Sync>, latency_awareness: Option, fixed_seed: Option, } @@ -97,7 +97,7 @@ impl LoadBalancingPolicy for DefaultPolicy { &'a self, query: &'a RoutingInfo, cluster: &'a ClusterData, - ) -> Option<(NodeRef<'a>, Shard)> { + ) -> Option<(NodeRef<'a>, Option)> { let routing_info = self.routing_info(query, cluster); if let Some(ref token_with_strategy) = routing_info.token_with_strategy { if self.preferences.datacenter().is_some() @@ -126,13 +126,13 @@ or refrain from preferring datacenters (which may ban all other datacenters, if let local_rack_picked = self.pick_replica( ts, NodeLocationCriteria::DatacenterAndRack(dc, rack), - &self.pick_predicate, + |node, shard| (self.pick_predicate)(node, Some(shard)), cluster, statement_type, ); - if let Some(alive_local_rack_replica) = local_rack_picked { - return Some(alive_local_rack_replica); + if let Some((alive_local_rack_replica, shard)) = local_rack_picked { + return Some((alive_local_rack_replica, Some(shard))); } } @@ -143,13 +143,13 @@ or refrain from preferring datacenters (which may ban all other datacenters, if let picked = self.pick_replica( ts, NodeLocationCriteria::Datacenter(dc), - &self.pick_predicate, + |node, shard| (self.pick_predicate)(node, Some(shard)), cluster, statement_type, ); - if let Some(alive_local_replica) = picked { - return Some(alive_local_replica); + if let Some((alive_local_replica, shard)) = picked { + return Some((alive_local_replica, Some(shard))); } } @@ -161,12 +161,12 @@ or refrain from preferring datacenters (which may ban all other datacenters, if let picked = self.pick_replica( ts, NodeLocationCriteria::Any, - &self.pick_predicate, + |node, shard| (self.pick_predicate)(node, Some(shard)), cluster, statement_type, ); - if let Some(alive_remote_replica) = picked { - return Some(alive_remote_replica); + if let Some((alive_remote_replica, shard)) = picked { + return Some((alive_remote_replica, Some(shard))); } } }; @@ -179,47 +179,47 @@ or refrain from preferring datacenters (which may ban all other datacenters, if if let NodeLocationPreference::DatacenterAndRack(dc, rack) = &self.preferences { // Try to pick some alive local rack random node. let rack_predicate = Self::make_rack_predicate( - &self.pick_predicate, + |node| (self.pick_predicate)(node, None), NodeLocationCriteria::DatacenterAndRack(dc, rack), ); - let local_rack_picked = self.pick_node(nodes, rack_predicate); + let local_rack_picked = self.pick_node(nodes, |node| rack_predicate(&node)); if let Some(alive_local_rack) = local_rack_picked { - return Some(alive_local_rack); + return Some((alive_local_rack, None)); } } // Try to pick some alive local random node. - if let Some(alive_local) = self.pick_node(nodes, &self.pick_predicate) { - return Some(alive_local); + if let Some(alive_local) = self.pick_node(nodes, |node| (self.pick_predicate)(node, None)) { + return Some((alive_local, None)); } let all_nodes = cluster.replica_locator().unique_nodes_in_global_ring(); // If a datacenter failover is possible, loosen restriction about locality. if self.is_datacenter_failover_possible(&routing_info) { - let picked = self.pick_node(all_nodes, &self.pick_predicate); + let picked = self.pick_node(all_nodes, |node| (self.pick_predicate)(node, None)); if let Some(alive_maybe_remote) = picked { - return Some(alive_maybe_remote); + return Some((alive_maybe_remote, None)); } } // Previous checks imply that every node we could have selected is down. // Let's try to return a down node that wasn't disabled. - let picked = self.pick_node(nodes, |(node, _shard)| node.is_enabled()); + let picked = self.pick_node(nodes, |node| node.is_enabled()); if let Some(down_but_enabled_local_node) = picked { - return Some(down_but_enabled_local_node); + return Some((down_but_enabled_local_node, None)); } // If a datacenter failover is possible, loosen restriction about locality. if self.is_datacenter_failover_possible(&routing_info) { - let picked = self.pick_node(all_nodes, |(node, _shard)| node.is_enabled()); + let picked = self.pick_node(all_nodes, |node| node.is_enabled()); if let Some(down_but_enabled_maybe_remote_node) = picked { - return Some(down_but_enabled_maybe_remote_node); + return Some((down_but_enabled_maybe_remote_node, None)); } } // Every node is disabled. This could be due to a bad host filter - configuration error. - nodes.first().map(|node| self.with_random_shard(node)) + nodes.first().map(|node| (node, None)) } fn fallback<'a>( @@ -241,7 +241,7 @@ or refrain from preferring datacenters (which may ban all other datacenters, if let local_rack_replicas = self.fallback_replicas( ts, NodeLocationCriteria::DatacenterAndRack(dc, rack), - Self::is_alive, + |node, shard| Self::is_alive(node, Some(shard)), cluster, statement_type, ); @@ -257,7 +257,7 @@ or refrain from preferring datacenters (which may ban all other datacenters, if let local_replicas = self.fallback_replicas( ts, NodeLocationCriteria::Datacenter(dc), - Self::is_alive, + |node, shard| Self::is_alive(node, Some(shard)), cluster, statement_type, ); @@ -273,7 +273,7 @@ or refrain from preferring datacenters (which may ban all other datacenters, if let remote_replicas = self.fallback_replicas( ts, NodeLocationCriteria::Any, - Self::is_alive, + |node, shard| Self::is_alive(node, Some(shard)), cluster, statement_type, ); @@ -287,10 +287,11 @@ or refrain from preferring datacenters (which may ban all other datacenters, if Either::Left( maybe_local_rack_replicas .chain(maybe_local_replicas) - .chain(maybe_remote_replicas), + .chain(maybe_remote_replicas) + .map(|(node, shard)| (node, Some(shard))), ) } else { - Either::Right(std::iter::empty::<(NodeRef<'a>, Shard)>()) + Either::Right(std::iter::empty::<(NodeRef<'a>, Option)>()) }; // Get a list of all local alive nodes, and apply a round robin to it @@ -299,31 +300,37 @@ or refrain from preferring datacenters (which may ban all other datacenters, if let maybe_local_rack_nodes = if let NodeLocationPreference::DatacenterAndRack(dc, rack) = &self.preferences { let rack_predicate = Self::make_rack_predicate( - &self.pick_predicate, + |node| (self.pick_predicate)(node, None), NodeLocationCriteria::DatacenterAndRack(dc, rack), ); - Either::Left(self.round_robin_nodes_with_shards(local_nodes, rack_predicate)) + Either::Left( + self.round_robin_nodes(local_nodes, rack_predicate) + .map(|node| (node, None)), + ) } else { - Either::Right(std::iter::empty::<(NodeRef<'a>, Shard)>()) + Either::Right(std::iter::empty::<(NodeRef<'a>, Option)>()) }; - let robinned_local_nodes = self.round_robin_nodes_with_shards(local_nodes, Self::is_alive); + let robinned_local_nodes = self + .round_robin_nodes(local_nodes, |node| Self::is_alive(node, None)) + .map(|node| (node, None)); let all_nodes = cluster.replica_locator().unique_nodes_in_global_ring(); // If a datacenter failover is possible, loosen restriction about locality. let maybe_remote_nodes = if self.is_datacenter_failover_possible(&routing_info) { - let robinned_all_nodes = self.round_robin_nodes_with_shards(all_nodes, Self::is_alive); + let robinned_all_nodes = + self.round_robin_nodes(all_nodes, |node| Self::is_alive(node, None)); - Either::Left(robinned_all_nodes) + Either::Left(robinned_all_nodes.map(|node| (node, None))) } else { - Either::Right(std::iter::empty::<(NodeRef<'a>, Shard)>()) + Either::Right(std::iter::empty::<(NodeRef<'a>, Option)>()) }; // Even if we consider some enabled nodes to be down, we should try contacting them in the last resort. let maybe_down_local_nodes = local_nodes .iter() .filter(|node| node.is_enabled()) - .map(|node| self.with_random_shard(node)); + .map(|node| (node, None)); // If a datacenter failover is possible, loosen restriction about locality. let maybe_down_nodes = if self.is_datacenter_failover_possible(&routing_info) { @@ -331,7 +338,7 @@ or refrain from preferring datacenters (which may ban all other datacenters, if all_nodes .iter() .filter(|node| node.is_enabled()) - .map(|node| self.with_random_shard(node)), + .map(|node| (node, None)), ) } else { Either::Right(std::iter::empty()) @@ -433,15 +440,28 @@ impl DefaultPolicy { /// Wraps the provided predicate, adding the requirement for rack to match. fn make_rack_predicate<'a>( - predicate: impl Fn(&(NodeRef<'a>, Shard)) -> bool + 'a, + predicate: impl Fn(NodeRef<'a>) -> bool + 'a, + replica_location: NodeLocationCriteria<'a>, + ) -> impl Fn(&NodeRef<'a>) -> bool { + move |node| match replica_location { + NodeLocationCriteria::Any | NodeLocationCriteria::Datacenter(_) => predicate(node), + NodeLocationCriteria::DatacenterAndRack(_, rack) => { + predicate(node) && node.rack.as_deref() == Some(rack) + } + } + } + + /// Wraps the provided predicate, adding the requirement for rack to match. + fn make_sharded_rack_predicate<'a>( + predicate: impl Fn(NodeRef<'a>, Shard) -> bool + 'a, replica_location: NodeLocationCriteria<'a>, - ) -> impl Fn(&(NodeRef<'a>, Shard)) -> bool { - move |node_and_shard @ (node, _shard)| match replica_location { + ) -> impl Fn(NodeRef<'a>, Shard) -> bool { + move |node, shard| match replica_location { NodeLocationCriteria::Any | NodeLocationCriteria::Datacenter(_) => { - predicate(node_and_shard) + predicate(node, shard) } NodeLocationCriteria::DatacenterAndRack(_, rack) => { - predicate(node_and_shard) && node.rack.as_deref() == Some(rack) + predicate(node, shard) && node.rack.as_deref() == Some(rack) } } } @@ -450,11 +470,11 @@ impl DefaultPolicy { &'a self, ts: &TokenWithStrategy<'a>, replica_location: NodeLocationCriteria<'a>, - predicate: impl Fn(&(NodeRef<'a>, Shard)) -> bool + 'a, + predicate: impl Fn(NodeRef<'a>, Shard) -> bool + 'a, cluster: &'a ClusterData, order: ReplicaOrder, ) -> impl Iterator, Shard)> { - let predicate = Self::make_rack_predicate(predicate, replica_location); + let predicate = Self::make_sharded_rack_predicate(predicate, replica_location); let replica_iter = match order { ReplicaOrder::Arbitrary => Either::Left( @@ -467,14 +487,14 @@ impl DefaultPolicy { .into_iter(), ), }; - replica_iter.filter(move |node_and_shard: &(NodeRef<'a>, Shard)| predicate(node_and_shard)) + replica_iter.filter(move |(node, shard): &(NodeRef<'a>, Shard)| predicate(node, *shard)) } fn pick_replica<'a>( &'a self, ts: &TokenWithStrategy<'a>, replica_location: NodeLocationCriteria<'a>, - predicate: &'a impl Fn(&(NodeRef<'a>, Shard)) -> bool, + predicate: impl Fn(NodeRef<'a>, Shard) -> bool + 'a, cluster: &'a ClusterData, statement_type: StatementType, ) -> Option<(NodeRef<'a>, Shard)> { @@ -502,7 +522,7 @@ impl DefaultPolicy { &'a self, ts: &TokenWithStrategy<'a>, replica_location: NodeLocationCriteria<'a>, - predicate: &'a impl Fn(&(NodeRef<'a>, Shard)) -> bool, + predicate: impl Fn(NodeRef<'a>, Shard) -> bool + 'a, cluster: &'a ClusterData, ) -> Option<(NodeRef<'a>, Shard)> { match replica_location { @@ -521,8 +541,8 @@ impl DefaultPolicy { .into_replicas_ordered() .into_iter() .next() - .and_then(|primary_replica| { - predicate(&primary_replica).then_some(primary_replica) + .and_then(|(primary_replica, shard)| { + predicate(primary_replica, shard).then_some((primary_replica, shard)) }) } NodeLocationCriteria::Datacenter(_) | NodeLocationCriteria::DatacenterAndRack(_, _) => { @@ -534,7 +554,7 @@ impl DefaultPolicy { self.replicas( ts, replica_location, - move |node_and_shard| predicate(node_and_shard), + predicate, cluster, ReplicaOrder::RingOrder, ) @@ -547,18 +567,18 @@ impl DefaultPolicy { &'a self, ts: &TokenWithStrategy<'a>, replica_location: NodeLocationCriteria<'a>, - predicate: &'a impl Fn(&(NodeRef<'a>, Shard)) -> bool, + predicate: impl Fn(NodeRef<'a>, Shard) -> bool + 'a, cluster: &'a ClusterData, ) -> Option<(NodeRef<'a>, Shard)> { - let predicate = Self::make_rack_predicate(predicate, replica_location); + let predicate = Self::make_sharded_rack_predicate(predicate, replica_location); let replica_set = self.nonfiltered_replica_set(ts, replica_location, cluster); if let Some(fixed) = self.fixed_seed { let mut gen = Pcg32::new(fixed, 0); - replica_set.choose_filtered(&mut gen, predicate) + replica_set.choose_filtered(&mut gen, |(node, shard)| predicate(node, *shard)) } else { - replica_set.choose_filtered(&mut thread_rng(), predicate) + replica_set.choose_filtered(&mut thread_rng(), |(node, shard)| predicate(node, *shard)) } } @@ -566,7 +586,7 @@ impl DefaultPolicy { &'a self, ts: &TokenWithStrategy<'a>, replica_location: NodeLocationCriteria<'a>, - predicate: impl Fn(&(NodeRef<'_>, Shard)) -> bool + 'a, + predicate: impl Fn(NodeRef<'_>, Shard) -> bool + 'a, cluster: &'a ClusterData, statement_type: StatementType, ) -> impl Iterator, Shard)> { @@ -604,22 +624,21 @@ impl DefaultPolicy { fn pick_node<'a>( &'a self, nodes: &'a [Arc], - predicate: impl Fn(&(NodeRef<'a>, Shard)) -> bool, - ) -> Option<(NodeRef<'_>, Shard)> { + predicate: impl Fn(NodeRef<'a>) -> bool, + ) -> Option> { // Select the first node that matches the predicate - Self::randomly_rotated_nodes(nodes) - .map(|node| self.with_random_shard(node)) - .find(predicate) + Self::randomly_rotated_nodes(nodes).find(|&node| predicate(node)) } - fn round_robin_nodes_with_shards<'a>( + fn round_robin_nodes<'a>( &'a self, nodes: &'a [Arc], - predicate: impl Fn(&(NodeRef<'a>, Shard)) -> bool, - ) -> impl Iterator, Shard)> { - Self::randomly_rotated_nodes(nodes) - .map(|node| self.with_random_shard(node)) - .filter(predicate) + // I wanted this to be + // impl Fn(&NodeRef<'a>) -> bool + // but I have no idea how to make this work with borrow checker + predicate: impl Fn(&NodeRef<'a>) -> bool, + ) -> impl Iterator> { + Self::randomly_rotated_nodes(nodes).filter(predicate) } fn shuffle<'a>( @@ -638,23 +657,7 @@ impl DefaultPolicy { vec.into_iter() } - fn with_random_shard<'a>(&self, node: NodeRef<'a>) -> (NodeRef<'a>, Shard) { - let nr_shards = node - .sharder() - .map(|sharder| sharder.nr_shards.get()) - .unwrap_or(1); - ( - node, - (if let Some(fixed) = self.fixed_seed { - let mut gen = Pcg32::new(fixed, 0); - gen.gen_range(0..nr_shards) - } else { - thread_rng().gen_range(0..nr_shards) - }) as Shard, - ) - } - - fn is_alive(&(node, _shard): &(NodeRef<'_>, Shard)) -> bool { + fn is_alive(node: NodeRef, _shard: Option) -> bool { // For now, we leave this as stub, until we have time to improve node events. // node.is_enabled() && !node.is_down() node.is_enabled() @@ -720,11 +723,10 @@ impl DefaultPolicyBuilder { let latency_awareness = self.latency_awareness.map(|builder| builder.build()); let pick_predicate = if let Some(ref latency_awareness) = latency_awareness { let latency_predicate = latency_awareness.generate_predicate(); - Box::new( - move |node_and_shard @ (node, _shard): &(NodeRef<'_>, Shard)| { - DefaultPolicy::is_alive(node_and_shard) && latency_predicate(node) - }, - ) as Box, Shard)) -> bool + Send + Sync + 'static> + Box::new(move |node: NodeRef<'_>, shard| { + DefaultPolicy::is_alive(node, shard) && latency_predicate(node) + }) + as Box, Option) -> bool + Send + Sync + 'static> } else { Box::new(DefaultPolicy::is_alive) }; @@ -2401,8 +2403,8 @@ mod latency_awareness { pub(super) fn wrap<'a>( &self, - fallback: impl Iterator, Shard)>, - ) -> impl Iterator, Shard)> { + fallback: impl Iterator, Option)>, + ) -> impl Iterator, Option)> { let min_avg_latency = match self.last_min_latency.load() { Some(min_avg) => min_avg, None => return Either::Left(fallback), // noop, as no latency data has been collected yet @@ -2723,8 +2725,8 @@ mod latency_awareness { struct IteratorWithSkippedNodes<'a, Fast, Penalised> where - Fast: Iterator, Shard)>, - Penalised: Iterator, Shard)>, + Fast: Iterator, Option)>, + Penalised: Iterator, Option)>, { fast_nodes: Fast, penalised_nodes: Penalised, @@ -2733,13 +2735,13 @@ mod latency_awareness { impl<'a> IteratorWithSkippedNodes< 'a, - std::vec::IntoIter<(NodeRef<'a>, Shard)>, - std::vec::IntoIter<(NodeRef<'a>, Shard)>, + std::vec::IntoIter<(NodeRef<'a>, Option)>, + std::vec::IntoIter<(NodeRef<'a>, Option)>, > { fn new( average_latencies: &HashMap>>, - nodes: impl Iterator, Shard)>, + nodes: impl Iterator, Option)>, exclusion_threshold: f64, retry_period: Duration, minimum_measurements: usize, @@ -2775,10 +2777,10 @@ mod latency_awareness { impl<'a, Fast, Penalised> Iterator for IteratorWithSkippedNodes<'a, Fast, Penalised> where - Fast: Iterator, Shard)>, - Penalised: Iterator, Shard)>, + Fast: Iterator, Option)>, + Penalised: Iterator, Option)>, { - type Item = (NodeRef<'a>, Shard); + type Item = (NodeRef<'a>, Option); fn next(&mut self) -> Option { self.fast_nodes @@ -2860,12 +2862,10 @@ mod latency_awareness { ) -> DefaultPolicy { let pick_predicate = { let latency_predicate = latency_awareness.generate_predicate(); - Box::new( - move |node_and_shard @ (node, _shard): &(NodeRef<'_>, Shard)| { - DefaultPolicy::is_alive(node_and_shard) && latency_predicate(node) - }, - ) - as Box, Shard)) -> bool + Send + Sync + 'static> + Box::new(move |node: NodeRef<'_>, shard| { + DefaultPolicy::is_alive(node, shard) && latency_predicate(node) + }) + as Box, Option) -> bool + Send + Sync + 'static> }; DefaultPolicy { diff --git a/scylla/src/transport/load_balancing/mod.rs b/scylla/src/transport/load_balancing/mod.rs index 977e3d508f..f1cd5bdf27 100644 --- a/scylla/src/transport/load_balancing/mod.rs +++ b/scylla/src/transport/load_balancing/mod.rs @@ -39,7 +39,8 @@ pub struct RoutingInfo<'a> { /// /// It is computed on-demand, only if querying the most preferred node fails /// (or when speculative execution is triggered). -pub type FallbackPlan<'a> = Box, Shard)> + Send + Sync + 'a>; +pub type FallbackPlan<'a> = + Box, Option)> + Send + Sync + 'a>; /// Policy that decides which nodes and shards to contact for each query. /// @@ -67,7 +68,7 @@ pub trait LoadBalancingPolicy: Send + Sync + std::fmt::Debug { &'a self, query: &'a RoutingInfo, cluster: &'a ClusterData, - ) -> Option<(NodeRef<'a>, Shard)>; + ) -> Option<(NodeRef<'a>, Option)>; /// Returns all contact-appropriate nodes for a given query. fn fallback<'a>(&'a self, query: &'a RoutingInfo, cluster: &'a ClusterData) diff --git a/scylla/src/transport/load_balancing/plan.rs b/scylla/src/transport/load_balancing/plan.rs index 5fc6294467..c88748642f 100644 --- a/scylla/src/transport/load_balancing/plan.rs +++ b/scylla/src/transport/load_balancing/plan.rs @@ -1,3 +1,4 @@ +use rand::{thread_rng, Rng}; use tracing::error; use super::{FallbackPlan, LoadBalancingPolicy, NodeRef, RoutingInfo}; @@ -6,10 +7,10 @@ use crate::{routing::Shard, transport::ClusterData}; enum PlanState<'a> { Created, PickedNone, // This always means an abnormal situation: it means that no nodes satisfied locality/node filter requirements. - Picked((NodeRef<'a>, Shard)), + Picked((NodeRef<'a>, Option)), Fallback { iter: FallbackPlan<'a>, - node_to_filter_out: (NodeRef<'a>, Shard), + node_to_filter_out: (NodeRef<'a>, Option), }, } @@ -19,7 +20,52 @@ enum PlanState<'a> { /// eagerly in the first place and the remaining nodes computed on-demand /// (all at once). /// This significantly reduces the allocation overhead on "the happy path" -/// (when the first node successfully handles the request), +/// (when the first node successfully handles the request). +/// +/// `Plan` implements `Iterator, Shard)>` but LoadBalancingPolicy +/// returns `Option` instead of `Shard` both in `pick` and in `fallback`. +/// `Plan` handles the `None` case by using random shard for a given node. +/// There is currently no way to configure RNG used by `Plan`. +/// If you don't want `Plan` to do randomize shards or you want to control the RNG, +/// use custom LBP that will always return non-`None` shards. +/// Example of LBP that always uses shard 0, preventing `Plan` from using random numbers: +/// +/// ``` +/// # use std::sync::Arc; +/// # use scylla::load_balancing::LoadBalancingPolicy; +/// # use scylla::load_balancing::RoutingInfo; +/// # use scylla::transport::ClusterData; +/// # use scylla::transport::NodeRef; +/// # use scylla::routing::Shard; +/// # use scylla::load_balancing::FallbackPlan; +/// +/// #[derive(Debug)] +/// struct NonRandomLBP { +/// inner: Arc, +/// } +/// impl LoadBalancingPolicy for NonRandomLBP { +/// fn pick<'a>( +/// &'a self, +/// info: &'a RoutingInfo, +/// cluster: &'a ClusterData, +/// ) -> Option<(NodeRef<'a>, Option)> { +/// self.inner +/// .pick(info, cluster) +/// .map(|(node, shard)| (node, shard.or(Some(0)))) +/// } +/// +/// fn fallback<'a>(&'a self, info: &'a RoutingInfo, cluster: &'a ClusterData) -> FallbackPlan<'a> { +/// Box::new(self.inner +/// .fallback(info, cluster) +/// .map(|(node, shard)| (node, shard.or(Some(0))))) +/// } +/// +/// fn name(&self) -> String { +/// "NonRandomLBP".to_string() +/// } +/// } +/// ``` + pub struct Plan<'a> { policy: &'a dyn LoadBalancingPolicy, routing_info: &'a RoutingInfo<'a>, @@ -41,6 +87,21 @@ impl<'a> Plan<'a> { state: PlanState::Created, } } + + fn with_random_shard_if_unknown( + (node, shard): (NodeRef<'_>, Option), + ) -> (NodeRef<'_>, Shard) { + ( + node, + shard.unwrap_or_else(|| { + let nr_shards = node + .sharder() + .map(|sharder| sharder.nr_shards.get()) + .unwrap_or(1); + thread_rng().gen_range(0..nr_shards).into() + }), + ) + } } impl<'a> Iterator for Plan<'a> { @@ -52,7 +113,7 @@ impl<'a> Iterator for Plan<'a> { let picked = self.policy.pick(self.routing_info, self.cluster); if let Some(picked) = picked { self.state = PlanState::Picked(picked); - Some(picked) + Some(Self::with_random_shard_if_unknown(picked)) } else { // `pick()` returned None, which semantically means that a first node cannot be computed _cheaply_. // This, however, does not imply that fallback would return an empty plan, too. @@ -66,7 +127,7 @@ impl<'a> Iterator for Plan<'a> { iter, node_to_filter_out: node, }; - Some(node) + Some(Self::with_random_shard_if_unknown(node)) } else { error!("Load balancing policy returned an empty plan! The query cannot be executed. Routing info: {:?}", self.routing_info); self.state = PlanState::PickedNone; @@ -90,7 +151,7 @@ impl<'a> Iterator for Plan<'a> { if node == *node_to_filter_out { continue; } else { - return Some(node); + return Some(Self::with_random_shard_if_unknown(node)); } } @@ -135,7 +196,7 @@ mod tests { &'a self, _query: &'a RoutingInfo, _cluster: &'a ClusterData, - ) -> Option<(NodeRef<'a>, Shard)> { + ) -> Option<(NodeRef<'a>, Option)> { None } @@ -147,7 +208,7 @@ mod tests { Box::new( self.expected_nodes .iter() - .map(|(node_ref, shard)| (node_ref, *shard)), + .map(|(node_ref, shard)| (node_ref, Some(*shard))), ) } diff --git a/scylla/tests/integration/consistency.rs b/scylla/tests/integration/consistency.rs index 5f178a3bea..a96e4450bb 100644 --- a/scylla/tests/integration/consistency.rs +++ b/scylla/tests/integration/consistency.rs @@ -379,7 +379,7 @@ impl LoadBalancingPolicy for RoutingInfoReportingWrapper { &'a self, query: &'a RoutingInfo, cluster: &'a scylla::transport::ClusterData, - ) -> Option<(NodeRef<'a>, Shard)> { + ) -> Option<(NodeRef<'a>, Option)> { self.routing_info_tx .send(OwnedRoutingInfo::from(query.clone())) .unwrap(); diff --git a/scylla/tests/integration/execution_profiles.rs b/scylla/tests/integration/execution_profiles.rs index c0d1964f0a..59f95dfa88 100644 --- a/scylla/tests/integration/execution_profiles.rs +++ b/scylla/tests/integration/execution_profiles.rs @@ -51,9 +51,13 @@ impl LoadBalancingPolicy for BoundToPredefinedNodePolicy { &'a self, _info: &'a RoutingInfo, cluster: &'a ClusterData, - ) -> Option<(NodeRef<'a>, Shard)> { + ) -> Option<(NodeRef<'a>, Option)> { self.report_node(Report::LoadBalancing); - cluster.get_nodes_info().iter().next().map(|node| (node, 0)) + cluster + .get_nodes_info() + .iter() + .next() + .map(|node| (node, None)) } fn fallback<'a>( diff --git a/scylla/tests/integration/utils.rs b/scylla/tests/integration/utils.rs index b32be090af..f30021db33 100644 --- a/scylla/tests/integration/utils.rs +++ b/scylla/tests/integration/utils.rs @@ -19,12 +19,12 @@ pub(crate) fn setup_tracing() { .try_init(); } -fn with_pseudorandom_shard(node: NodeRef) -> (NodeRef, Shard) { +fn with_pseudorandom_shard(node: NodeRef) -> (NodeRef, Option) { let nr_shards = node .sharder() .map(|sharder| sharder.nr_shards.get()) .unwrap_or(1); - (node, ((nr_shards - 1) % 42) as Shard) + (node, Some(((nr_shards - 1) % 42) as Shard)) } #[derive(Debug)] @@ -34,7 +34,7 @@ impl LoadBalancingPolicy for FixedOrderLoadBalancer { &'a self, _info: &'a scylla::load_balancing::RoutingInfo, cluster: &'a scylla::transport::ClusterData, - ) -> Option<(NodeRef<'a>, Shard)> { + ) -> Option<(NodeRef<'a>, Option)> { cluster .get_nodes_info() .iter()