diff --git a/scylla/src/routing.rs b/scylla/src/routing.rs index 581c83a080..d1e8273efc 100644 --- a/scylla/src/routing.rs +++ b/scylla/src/routing.rs @@ -3,7 +3,7 @@ use rand::Rng; use std::collections::HashMap; use std::convert::TryFrom; use std::net::SocketAddr; -use std::num::Wrapping; +use std::num::{NonZeroU16, Wrapping}; use thiserror::Error; #[derive(PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)] @@ -19,11 +19,18 @@ pub struct Token { } pub type Shard = u32; +pub type ShardCount = NonZeroU16; #[derive(PartialEq, Eq, Clone, Debug)] pub struct ShardInfo { pub shard: u16, - pub nr_shards: u16, + pub nr_shards: ShardCount, + pub msb_ignore: u8, +} + +#[derive(PartialEq, Eq, Clone, Debug)] +pub struct Sharder { + pub nr_shards: ShardCount, pub msb_ignore: u8, } @@ -41,8 +48,7 @@ pub fn murmur3_token(pk: Bytes) -> Token { } impl ShardInfo { - pub fn new(shard: u16, nr_shards: u16, msb_ignore: u8) -> Self { - assert!(nr_shards > 0); + pub fn new(shard: u16, nr_shards: ShardCount, msb_ignore: u8) -> Self { ShardInfo { shard, nr_shards, @@ -50,24 +56,38 @@ impl ShardInfo { } } + pub fn get_sharder(&self) -> Sharder { + Sharder::new(self.nr_shards, self.msb_ignore) + } +} + +impl Sharder { + pub fn new(nr_shards: ShardCount, msb_ignore: u8) -> Self { + Sharder { + nr_shards, + msb_ignore, + } + } + pub fn shard_of(&self, token: Token) -> Shard { let mut biased_token = (token.value as u64).wrapping_add(1u64 << 63); biased_token <<= self.msb_ignore; - (((biased_token as u128) * (self.nr_shards as u128)) >> 64) as Shard + (((biased_token as u128) * (self.nr_shards.get() as u128)) >> 64) as Shard } /// If we connect to Scylla using Scylla's shard aware port, then Scylla assigns a shard to the /// connection based on the source port. This calculates the assigned shard. pub fn shard_of_source_port(&self, source_port: u16) -> Shard { - (source_port % self.nr_shards) as Shard + (source_port % self.nr_shards.get()) as Shard } /// Randomly choose a source port `p` such that `shard == shard_of_source_port(p)`. pub fn draw_source_port_for_shard(&self, shard: Shard) -> u16 { - assert!(shard < self.nr_shards as u32); - rand::thread_rng().gen_range((49152 + self.nr_shards - 1)..(65535 - self.nr_shards + 1)) - / self.nr_shards - * self.nr_shards + assert!(shard < self.nr_shards.get() as u32); + rand::thread_rng() + .gen_range((49152 + self.nr_shards.get() - 1)..(65535 - self.nr_shards.get() + 1)) + / self.nr_shards.get() + * self.nr_shards.get() + shard as u16 } @@ -75,25 +95,22 @@ impl ShardInfo { /// Starts at a random port and goes forward by `nr_shards`. After reaching maximum wraps back around. /// Stops once all possibile ports have been returned pub fn iter_source_ports_for_shard(&self, shard: Shard) -> impl Iterator { - assert!(shard < self.nr_shards as u32); + assert!(shard < self.nr_shards.get() as u32); // Randomly choose a port to start at let starting_port = self.draw_source_port_for_shard(shard); // Choose smallest available port number to begin at after wrapping // apply the formula from draw_source_port_for_shard for lowest possible gen_range result - let first_valid_port = - (49152 + self.nr_shards - 1) / self.nr_shards * self.nr_shards + shard as u16; + let first_valid_port = (49152 + self.nr_shards.get() - 1) / self.nr_shards.get() + * self.nr_shards.get() + + shard as u16; - let before_wrap = (starting_port..=65535).step_by(self.nr_shards.into()); - let after_wrap = (first_valid_port..starting_port).step_by(self.nr_shards.into()); + let before_wrap = (starting_port..=65535).step_by(self.nr_shards.get().into()); + let after_wrap = (first_valid_port..starting_port).step_by(self.nr_shards.get().into()); before_wrap.chain(after_wrap) } - - pub fn get_nr_shards(&self) -> u16 { - self.nr_shards - } } #[derive(Error, Debug)] @@ -102,6 +119,8 @@ pub enum ShardingError { MissingShardInfoParameter, #[error("ShardInfo parameters missing after unwraping")] MissingUnwrapedShardInfoParameter, + #[error("ShardInfo contains an invalid number of shards (0)")] + ZeroShards, #[error("ParseIntError encountered while getting ShardInfo")] ParseIntError(#[from] std::num::ParseIntError), } @@ -123,6 +142,7 @@ impl<'a> TryFrom<&'a HashMap>> for ShardInfo { } let shard = shard_entry.unwrap().first().unwrap().parse::()?; let nr_shards = nr_shards_entry.unwrap().first().unwrap().parse::()?; + let nr_shards = ShardCount::new(nr_shards).ok_or(ShardingError::ZeroShards)?; let msb_ignore = msb_ignore_entry.unwrap().first().unwrap().parse::()?; Ok(ShardInfo::new(shard, nr_shards, msb_ignore)) } @@ -223,22 +243,22 @@ fn fmix(mut k: Wrapping) -> Wrapping { #[cfg(test)] mod tests { - use super::ShardInfo; use super::Token; + use super::{ShardCount, Sharder}; use std::collections::HashSet; #[test] fn test_shard_of() { /* Test values taken from the gocql driver. */ - let shard_info = ShardInfo::new(0, 4, 12); + let sharder = Sharder::new(ShardCount::new(4).unwrap(), 12); assert_eq!( - shard_info.shard_of(Token { + sharder.shard_of(Token { value: -9219783007514621794 }), 3 ); assert_eq!( - shard_info.shard_of(Token { + sharder.shard_of(Token { value: 9222582454147032830 }), 3 @@ -251,7 +271,7 @@ mod tests { let max_port_num = 65535; let min_port_num = (49152 + nr_shards - 1) / nr_shards * nr_shards; - let shard_info = ShardInfo::new(0, nr_shards, 12); + let sharder = Sharder::new(ShardCount::new(nr_shards).unwrap(), 12); // Test for each shard for shard in 0..nr_shards { @@ -265,7 +285,7 @@ mod tests { let possible_ports_number: usize = ((max_port_num - lowest_port) / nr_shards + 1).into(); - let port_iter = shard_info.iter_source_ports_for_shard(shard.into()); + let port_iter = sharder.iter_source_ports_for_shard(shard.into()); let mut returned_ports: HashSet = HashSet::new(); for port in port_iter { diff --git a/scylla/src/transport/cluster.rs b/scylla/src/transport/cluster.rs index 726033fc07..985e59cff7 100644 --- a/scylla/src/transport/cluster.rs +++ b/scylla/src/transport/cluster.rs @@ -1,9 +1,10 @@ use crate::frame::response::event::{Event, StatusChangeEvent}; /// Cluster manages up to date information and connections to database nodes use crate::routing::Token; -use crate::transport::connection::{Connection, ConnectionConfig, VerifiedKeyspaceName}; +use crate::transport::connection::{Connection, VerifiedKeyspaceName}; +use crate::transport::connection_pool::PoolConfig; use crate::transport::errors::QueryError; -use crate::transport::node::{Node, NodeConnections}; +use crate::transport::node::Node; use crate::transport::topology::{Keyspace, TopologyInfo, TopologyReader}; use arc_swap::ArcSwap; @@ -50,7 +51,7 @@ struct ClusterWorker { // Cluster connections topology_reader: TopologyReader, - connection_config: ConnectionConfig, + pool_config: PoolConfig, // To listen for refresh requests refresh_channel: tokio::sync::mpsc::Receiver, @@ -79,7 +80,7 @@ struct UseKeyspaceRequest { impl Cluster { pub async fn new( initial_peers: &[SocketAddr], - connection_config: ConnectionConfig, + pool_config: PoolConfig, ) -> Result { let cluster_data = Arc::new(ArcSwap::from(Arc::new(ClusterData { known_peers: HashMap::new(), @@ -98,10 +99,10 @@ impl Cluster { topology_reader: TopologyReader::new( initial_peers, - connection_config.clone(), + pool_config.connection_config.clone(), server_events_sender, ), - connection_config, + pool_config, refresh_channel: refresh_receiver, server_events_channel: server_events_receiver, @@ -173,26 +174,10 @@ impl Cluster { let mut last_error: Option = None; - // Takes result of ConnectionKeeper::get_connection() and pushes it onto result list or sets last_error - let mut push_to_result = |get_conn_res: Result, QueryError>| { - match get_conn_res { - Ok(conn) => result.push(conn), - Err(e) => last_error = Some(e), - }; - }; - for node in peers.values() { - let connections: Arc = node.connections.read().unwrap().clone(); - - match &*connections { - NodeConnections::Single(conn_keeper) => { - push_to_result(conn_keeper.get_connection().await) - } - NodeConnections::Sharded { shard_conns, .. } => { - for conn_keeper in shard_conns { - push_to_result(conn_keeper.get_connection().await); - } - } + match node.get_working_connections() { + Ok(conns) => result.extend(conns), + Err(e) => last_error = Some(e), } } @@ -226,11 +211,17 @@ impl ClusterData { } } + pub async fn wait_until_all_pools_are_initialized(&self) { + for node in self.all_nodes.iter() { + node.wait_until_pool_initialized().await; + } + } + /// Creates new ClusterData using information about topology held in `info`. /// Uses provided `known_peers` hashmap to recycle nodes if possible. pub fn new( info: TopologyInfo, - connection_config: &ConnectionConfig, + pool_config: &PoolConfig, known_peers: &HashMap>, used_keyspace: &Option, ) -> Self { @@ -251,7 +242,7 @@ impl ClusterData { } _ => Arc::new(Node::new( peer.address, - connection_config.clone(), + pool_config.clone(), peer.datacenter, peer.rack, used_keyspace.clone(), @@ -442,11 +433,15 @@ impl ClusterWorker { let new_cluster_data = Arc::new(ClusterData::new( topo_info, - &self.connection_config, + &self.pool_config, &cluster_data.known_peers, &self.used_keyspace, )); + new_cluster_data + .wait_until_all_pools_are_initialized() + .await; + self.update_cluster_data(new_cluster_data); Ok(()) diff --git a/scylla/src/transport/connection.rs b/scylla/src/transport/connection.rs index c337bafb5a..2264074442 100644 --- a/scylla/src/transport/connection.rs +++ b/scylla/src/transport/connection.rs @@ -56,7 +56,7 @@ pub struct Connection { source_port: u16, shard_info: Option, config: ConnectionConfig, - is_shard_aware: bool, + shard_aware_port: Option, } type ResponseHandler = oneshot::Sender>; @@ -184,6 +184,18 @@ impl Default for ConnectionConfig { } } +impl ConnectionConfig { + #[cfg(feature = "ssl")] + pub fn is_ssl(&self) -> bool { + self.ssl_context.is_some() + } + + #[cfg(not(feature = "ssl"))] + pub fn is_ssl(&self) -> bool { + false + } +} + // Used to listen for fatal error in connection pub type ErrorReceiver = tokio::sync::oneshot::Receiver; @@ -225,7 +237,7 @@ impl Connection { connect_address: addr, shard_info: None, config, - is_shard_aware: false, + shard_aware_port: None, }; Ok((connection, error_receiver)) @@ -772,7 +784,11 @@ impl Connection { /// Are we connected to Scylla's shard aware port? // TODO: couple this with shard_info? pub fn get_is_shard_aware(&self) -> bool { - self.is_shard_aware + Some(self.connect_address.port()) == self.shard_aware_port + } + + pub fn get_shard_aware_port(&self) -> Option { + self.shard_aware_port } pub fn get_source_port(&self) -> u16 { @@ -783,8 +799,8 @@ impl Connection { self.shard_info = shard_info } - fn set_is_shard_aware(&mut self, is_shard_aware: bool) { - self.is_shard_aware = is_shard_aware; + fn set_shard_aware_port(&mut self, shard_aware_port: Option) { + self.shard_aware_port = shard_aware_port; } pub fn get_connect_address(&self) -> SocketAddr { @@ -818,6 +834,11 @@ pub async fn open_named_connection( let options_result = connection.get_options().await?; + let shard_aware_port_key = match config.is_ssl() { + true => "SCYLLA_SHARD_AWARE_PORT_SSL", + false => "SCYLLA_SHARD_AWARE_PORT", + }; + let (shard_info, supported_compression, shard_aware_port) = match options_result { Response::Supported(mut supported) => { let shard_info = ShardInfo::try_from(&supported.options).ok(); @@ -827,7 +848,7 @@ pub async fn open_named_connection( .unwrap_or_else(Vec::new); let shard_aware_port = supported .options - .remove("SCYLLA_SHARD_AWARE_PORT") + .remove(shard_aware_port_key) .unwrap_or_else(Vec::new) .into_iter() .next() @@ -837,7 +858,7 @@ pub async fn open_named_connection( _ => (None, Vec::new(), None), }; connection.set_shard_info(shard_info); - connection.set_is_shard_aware(Some(addr.port()) == shard_aware_port); + connection.set_shard_aware_port(shard_aware_port); let mut options = HashMap::new(); options.insert("CQL_VERSION".to_string(), "4.0.0".to_string()); // FIXME: hardcoded values diff --git a/scylla/src/transport/connection_keeper.rs b/scylla/src/transport/connection_keeper.rs deleted file mode 100644 index 692eec9d3f..0000000000 --- a/scylla/src/transport/connection_keeper.rs +++ /dev/null @@ -1,345 +0,0 @@ -/// ConnectionKeeper keeps a Connection to some address and works to keep it open -use crate::routing::ShardInfo; -use crate::transport::errors::QueryError; -use crate::transport::{ - connection, - connection::{Connection, ConnectionConfig, ErrorReceiver, VerifiedKeyspaceName}, -}; - -use futures::{future::RemoteHandle, FutureExt}; -use std::io::ErrorKind; -use std::net::SocketAddr; -use std::sync::Arc; - -/// ConnectionKeeper keeps a Connection to some address and works to keep it open -pub struct ConnectionKeeper { - conn_state_receiver: tokio::sync::watch::Receiver, - use_keyspace_channel: tokio::sync::mpsc::Sender, - _worker_handle: RemoteHandle<()>, -} - -#[derive(Clone)] -pub enum ConnectionState { - Initializing, // First connect attempt ongoing - Connected(Arc), - Broken(QueryError), -} - -/// Works in the background to keep the connection open -struct ConnectionKeeperWorker { - address: SocketAddr, - config: ConnectionConfig, - shard_info: Option, - - shard_info_sender: Option, - conn_state_sender: tokio::sync::watch::Sender, - - // Channel used to receive use keyspace requests - use_keyspace_channel: tokio::sync::mpsc::Receiver, - - // Keyspace send in "USE " when opening each connection - used_keyspace: Option, -} - -pub type ShardInfoSender = Arc>>>; - -#[derive(Debug)] -struct UseKeyspaceRequest { - keyspace_name: VerifiedKeyspaceName, - response_chan: tokio::sync::oneshot::Sender>, -} - -impl ConnectionKeeper { - /// Creates new ConnectionKeeper that starts a connection in the background - /// # Arguments - /// - /// * `address` - IP address to connect to - /// * `compression` - preferred compression method to use - /// * `shard_info` - ShardInfo to use, will connect to shard number `shard_info.shard` - /// * `shard_info_sender` - channel to send new ShardInfo after each connection creation - pub fn new( - address: SocketAddr, - config: ConnectionConfig, - shard_info: Option, - shard_info_sender: Option, - keyspace_name: Option, - ) -> Self { - let (conn_state_sender, conn_state_receiver) = - tokio::sync::watch::channel(ConnectionState::Initializing); - - let (use_keyspace_sender, use_keyspace_receiver) = tokio::sync::mpsc::channel(1); - - let worker = ConnectionKeeperWorker { - address, - config, - shard_info, - shard_info_sender, - conn_state_sender, - use_keyspace_channel: use_keyspace_receiver, - used_keyspace: keyspace_name, - }; - - let (fut, worker_handle) = worker.work().remote_handle(); - tokio::spawn(fut); - - ConnectionKeeper { - conn_state_receiver, - use_keyspace_channel: use_keyspace_sender, - _worker_handle: worker_handle, - } - } - - /// Get current connection state, returns immediately - pub fn connection_state(&self) -> ConnectionState { - self.conn_state_receiver.borrow().clone() - } - - pub async fn wait_until_initialized(&self) { - match &*self.conn_state_receiver.borrow() { - ConnectionState::Initializing => {} - _ => return, - }; - - let mut my_receiver = self.conn_state_receiver.clone(); - - my_receiver - .changed() - .await - .expect("Bug in ConnectionKeeper::wait_until_initialized"); - // Worker can't stop while we have &self to struct with worker_handle - - // Now state must be != Initializing - debug_assert!(!matches!( - &*self.conn_state_receiver.borrow(), - ConnectionState::Initializing - )); - } - - /// Wait for the connection to initialize and get it if succesfylly connected - pub async fn get_connection(&self) -> Result, QueryError> { - self.wait_until_initialized().await; - - match self.connection_state() { - ConnectionState::Connected(conn) => Ok(conn), - ConnectionState::Broken(e) => Err(e), - _ => unreachable!(), - } - } - - pub async fn use_keyspace( - &self, - keyspace_name: VerifiedKeyspaceName, - ) -> Result<(), QueryError> { - let (response_sender, response_receiver) = tokio::sync::oneshot::channel(); - - self.use_keyspace_channel - .send(UseKeyspaceRequest { - keyspace_name, - response_chan: response_sender, - }) - .await - .expect("Bug in ConnectionKeeper::use_keyspace sending"); - // Other end of this channel is in the Worker, can't be dropped while we have &self to _worker_handle - - response_receiver.await.unwrap() // ConnectionKeeperWorker always responds - } -} - -enum RunConnectionRes { - // An error occured during the connection - Error(QueryError), - // ConnectionKeeper was dropped and channels closed, we should stop - ShouldStop, -} - -impl ConnectionKeeperWorker { - pub async fn work(mut self) { - // Reconnect at most every 8 seconds - let reconnect_cooldown = tokio::time::Duration::from_secs(8); - let mut last_reconnect_time; - - loop { - last_reconnect_time = tokio::time::Instant::now(); - - // Connect and wait for error - let current_error: QueryError = match self.run_connection().await { - RunConnectionRes::Error(e) => e, - RunConnectionRes::ShouldStop => return, - }; - - // Mark the connection as broken, wait cooldown and reconnect - if self - .conn_state_sender - .send(ConnectionState::Broken(current_error)) - .is_err() - { - // ConnectionKeeper was dropped, we should stop - return; - } - - let next_reconnect_time = last_reconnect_time - .checked_add(reconnect_cooldown) - .unwrap_or_else(tokio::time::Instant::now); - - tokio::time::sleep_until(next_reconnect_time).await; - } - } - - // Opens a new connection and waits until some fatal error occurs - async fn run_connection(&mut self) -> RunConnectionRes { - // Connect to the node - let (connection, mut error_receiver) = match self.open_new_connection().await { - Ok(opened) => opened, - Err(e) => return RunConnectionRes::Error(e), - }; - - // Mark connection as Connected - if self - .conn_state_sender - .send(ConnectionState::Connected(connection.clone())) - .is_err() - { - // If the channel was dropped we should stop - return RunConnectionRes::ShouldStop; - } - - // Notify about new shard info - if let Some(sender) = &self.shard_info_sender { - let new_shard_info: Option = connection.get_shard_info().clone(); - - // Ignore sending error - // If no one wants to get shard_info that's OK - // If lock is poisoned do nothing - if let Ok(sender_locked) = sender.lock() { - let _ = sender_locked.send(new_shard_info); - } - } - - // Use the specified keyspace - if let Some(keyspace_name) = &self.used_keyspace { - let _ = connection.use_keyspace(keyspace_name).await; - // Ignore the error, used_keyspace could be set a long time ago and then deleted - // user gets all errors from session.use_keyspace() - } - - let connection_closed_error = QueryError::IoError(Arc::new(std::io::Error::new( - ErrorKind::Other, - "Connection closed", - ))); - - // Wait for events - a use keyspace request or a fatal error - loop { - tokio::select! { - recv_res = self.use_keyspace_channel.recv() => { - match recv_res { - Some(request) => { - self.used_keyspace = Some(request.keyspace_name.clone()); - - // Send USE KEYSPACE request, send result if channel wasn't closed - let res = connection.use_keyspace(&request.keyspace_name).await; - let _ = request.response_chan.send(res); - }, - None => return RunConnectionRes::ShouldStop, // If the channel was dropped we should stop - } - }, - connection_error = &mut error_receiver => { - let error = connection_error.unwrap_or(connection_closed_error); - return RunConnectionRes::Error(error); - } - } - } - } - - async fn open_new_connection(&self) -> Result<(Arc, ErrorReceiver), QueryError> { - let (connection, error_receiver) = match &self.shard_info { - Some(info) => self.open_new_connection_to_shard(info).await?, - None => connection::open_connection(self.address, None, self.config.clone()).await?, - }; - - Ok((Arc::new(connection), error_receiver)) - } - - async fn open_new_connection_to_shard( - &self, - shard_info: &ShardInfo, - ) -> Result<(Connection, ErrorReceiver), QueryError> { - // Create iterator over all possible source ports for this shard - let source_port_iter = shard_info.iter_source_ports_for_shard(shard_info.shard.into()); - - for port in source_port_iter { - let connect_result = - connection::open_connection(self.address, Some(port), self.config.clone()).await; - - match connect_result { - Err(err) if err.is_address_unavailable_for_use() => continue, // If we can't use this port, try the next one - result => return result, - } - } - - // Tried all source ports for that shard, give up - Err(QueryError::IoError(Arc::new(std::io::Error::new( - std::io::ErrorKind::AddrInUse, - "Could not find free source port for shard", - )))) - } -} - -#[cfg(test)] -mod tests { - use super::ConnectionKeeper; - use crate::transport::connection::ConnectionConfig; - use std::net::{SocketAddr, ToSocketAddrs}; - - // Open many connections to a node - // Port collision should occur - // If they are not handled this test will most likely fail - #[tokio::test] - async fn many_connections() { - let connections_number = 512; - - let connect_address: SocketAddr = std::env::var("SCYLLA_URI") - .unwrap_or_else(|_| "127.0.0.1:9042".to_string()) - .to_socket_addrs() - .unwrap() - .next() - .unwrap(); - - let connection_config = ConnectionConfig { - compression: None, - tcp_nodelay: true, - #[cfg(feature = "ssl")] - ssl_context: None, - ..Default::default() - }; - - // Get shard info from a single connection, all connections will open to this shard - let conn_keeper = - ConnectionKeeper::new(connect_address, connection_config.clone(), None, None, None); - let shard_info = conn_keeper - .get_connection() - .await - .unwrap() - .get_shard_info() - .clone(); - - // Open the connections - let mut conn_keepers: Vec = Vec::new(); - - for _ in 0..connections_number { - let conn_keeper = ConnectionKeeper::new( - connect_address, - connection_config.clone(), - shard_info.clone(), - None, - None, - ); - - conn_keepers.push(conn_keeper); - } - - // Check that each connection keeper connected succesfully - for conn_keeper in conn_keepers { - conn_keeper.get_connection().await.unwrap(); - } - } -} diff --git a/scylla/src/transport/connection_pool.rs b/scylla/src/transport/connection_pool.rs new file mode 100644 index 0000000000..88a0ca42ff --- /dev/null +++ b/scylla/src/transport/connection_pool.rs @@ -0,0 +1,1063 @@ +use crate::routing::{Shard, ShardCount, Sharder, Token}; +use crate::transport::errors::QueryError; +use crate::transport::{ + connection, + connection::{Connection, ConnectionConfig, ErrorReceiver, VerifiedKeyspaceName}, +}; + +use arc_swap::ArcSwap; +use futures::{future::RemoteHandle, stream::FuturesUnordered, Future, FutureExt, StreamExt}; +use rand::Rng; +use std::convert::TryInto; +use std::io::ErrorKind; +use std::net::{IpAddr, SocketAddr}; +use std::num::NonZeroUsize; +use std::pin::Pin; +use std::sync::{Arc, Weak}; +use std::time::Duration; +use tokio::sync::{mpsc, Notify}; +use tracing::{debug, trace, warn}; + +/// The target size of a per-node connection pool. +#[derive(Debug, Clone)] +pub enum PoolSize { + /// Indicates that the pool should establish given number of connections to the node. + /// + /// If this option is used with a Scylla cluster, it is not guaranteed that connections will be + /// distributed evenly across shards. Use this option if you cannot use the shard-aware port + /// and you suffer from the "connection storm" problems. + PerHost(NonZeroUsize), + + /// Indicates that the pool should establish given number of connections to each shard on the node. + /// + /// Cassandra nodes will be treated as if they have only one shard. + /// + /// The recommended setting for Scylla is one connection per shard - `PerShard(1)`. + PerShard(NonZeroUsize), +} + +impl Default for PoolSize { + fn default() -> Self { + PoolSize::PerShard(NonZeroUsize::new(1).unwrap()) + } +} + +#[derive(Clone)] +pub struct PoolConfig { + pub connection_config: ConnectionConfig, + pub pool_size: PoolSize, + pub can_use_shard_aware_port: bool, +} + +impl Default for PoolConfig { + fn default() -> Self { + Self { + connection_config: Default::default(), + pool_size: Default::default(), + can_use_shard_aware_port: true, + } + } +} + +enum MaybePoolConnections { + // The pool is being filled for the first time + Initializing, + + // The pool is empty because either initial filling failed or all connections + // became broken; will be asynchronously refilled + Broken, + + // The pool has some connections which are usable (or will be removed soon) + Ready(PoolConnections), +} + +#[derive(Clone)] +enum PoolConnections { + NotSharded(Vec>), + Sharded { + sharder: Sharder, + connections: Vec>>, + }, +} + +pub struct NodeConnectionPool { + conns: Arc>, + use_keyspace_request_sender: mpsc::Sender, + _refiller_handle: RemoteHandle<()>, + pool_updated_notify: Arc, +} + +impl NodeConnectionPool { + pub fn new( + address: IpAddr, + port: u16, + pool_config: PoolConfig, + current_keyspace: Option, + ) -> Self { + let (use_keyspace_request_sender, use_keyspace_request_receiver) = mpsc::channel(1); + let pool_updated_notify = Arc::new(Notify::new()); + + let refiller = PoolRefiller::new( + address, + port, + pool_config, + current_keyspace, + pool_updated_notify.clone(), + ); + + let conns = refiller.get_shared_connections(); + let (fut, handle) = refiller.run(use_keyspace_request_receiver).remote_handle(); + tokio::spawn(fut); + + Self { + conns, + use_keyspace_request_sender, + _refiller_handle: handle, + pool_updated_notify, + } + } + + pub fn connection_for_token(&self, token: Token) -> Result, QueryError> { + self.with_connections(|pool_conns| match pool_conns { + PoolConnections::NotSharded(conns) => { + Self::choose_random_connection_from_slice(conns).unwrap() + } + PoolConnections::Sharded { + sharder, + connections, + } => { + let shard: u16 = sharder + .shard_of(token) + .try_into() + .expect("Shard number doesn't fit in u16"); + Self::connection_for_shard(shard, sharder.nr_shards, connections.as_slice()) + } + }) + } + + pub fn random_connection(&self) -> Result, QueryError> { + self.with_connections(|pool_conns| match pool_conns { + PoolConnections::NotSharded(conns) => { + Self::choose_random_connection_from_slice(conns).unwrap() + } + PoolConnections::Sharded { + sharder, + connections, + } => { + let shard: u16 = rand::thread_rng().gen_range(0..sharder.nr_shards.get()); + Self::connection_for_shard(shard, sharder.nr_shards, connections.as_slice()) + } + }) + } + + // Tries to get a connection to given shard, if it's broken returns any working connection + fn connection_for_shard( + shard: u16, + nr_shards: ShardCount, + shard_conns: &[Vec>], + ) -> Arc { + // Try getting the desired connection + if let Some(conn) = Self::choose_random_connection_from_slice(&shard_conns[shard as usize]) + { + return conn; + } + + // If this fails try getting any other in random order + let mut shards_to_try: Vec = (0..shard).chain(shard + 1..nr_shards.get()).collect(); + + while !shards_to_try.is_empty() { + let idx = rand::thread_rng().gen_range(0..shards_to_try.len()); + let shard = shards_to_try.swap_remove(idx); + + if let Some(conn) = + Self::choose_random_connection_from_slice(&shard_conns[shard as usize]) + { + return conn; + } + } + + unreachable!("could not find any connection in supposedly non-empty pool") + } + + pub async fn use_keyspace( + &self, + keyspace_name: VerifiedKeyspaceName, + ) -> Result<(), QueryError> { + let (response_sender, response_receiver) = tokio::sync::oneshot::channel(); + + self.use_keyspace_request_sender + .send(UseKeyspaceRequest { + keyspace_name, + response_sender, + }) + .await + .expect("Bug in ConnectionKeeper::use_keyspace sending"); + // Other end of this channel is in the Refiller, can't be dropped while we have &self to _refiller_handle + + response_receiver.await.unwrap() // NodePoolRefiller always responds + } + + // Waits until the pool becomes initialized. + // The pool is considered initialized either if the first connection has been + // established or after first filling ends, whichever comes first. + pub async fn wait_until_initialized(&self) { + // First, register for the notification + // so that we don't miss it + let notified = self.pool_updated_notify.notified(); + + if let MaybePoolConnections::Initializing = **self.conns.load() { + // If the pool is not initialized yet, wait until we get a notification + notified.await; + } + } + + pub fn get_working_connections(&self) -> Result>, QueryError> { + self.with_connections(|pool_conns| match pool_conns { + PoolConnections::NotSharded(conns) => conns.clone(), + PoolConnections::Sharded { connections, .. } => { + connections.iter().flatten().cloned().collect() + } + }) + } + + fn choose_random_connection_from_slice(v: &[Arc]) -> Option> { + if v.is_empty() { + None + } else if v.len() == 1 { + Some(v[0].clone()) + } else { + let idx = rand::thread_rng().gen_range(0..v.len()); + Some(v[idx].clone()) + } + } + + fn with_connections(&self, f: impl FnOnce(&PoolConnections) -> T) -> Result { + let conns = self.conns.load_full(); + match &*conns { + MaybePoolConnections::Ready(pool_connections) => Ok(f(pool_connections)), + _ => Err(QueryError::IoError(Arc::new(std::io::Error::new( + ErrorKind::Other, + "No connections in the pool", + )))), + } + } +} + +const EXCESS_CONNECTION_BOUND_PER_SHARD_MULTIPLIER: usize = 10; + +// TODO: Make it configurable through a policy (issue #184) +const MIN_FILL_BACKOFF: Duration = Duration::from_millis(50); +const MAX_FILL_BACKOFF: Duration = Duration::from_secs(10); +const FILL_BACKOFF_MULTIPLIER: u32 = 2; + +// A simple exponential strategy for pool fill backoffs. +struct RefillDelayStrategy { + current_delay: Duration, +} + +impl RefillDelayStrategy { + fn new() -> Self { + Self { + current_delay: MIN_FILL_BACKOFF, + } + } + + fn get_delay(&self) -> Duration { + self.current_delay + } + + fn on_successful_fill(&mut self) { + self.current_delay = MIN_FILL_BACKOFF; + } + + fn on_fill_error(&mut self) { + self.current_delay = std::cmp::min( + MAX_FILL_BACKOFF, + self.current_delay * FILL_BACKOFF_MULTIPLIER, + ); + } +} + +struct PoolRefiller { + // Following information identify the pool and do not change + address: IpAddr, + regular_port: u16, + pool_config: PoolConfig, + + // Following fields are updated with information from OPTIONS + shard_aware_port: Option, + sharder: Option, + + // `shared_conns` is updated only after `conns` change + shared_conns: Arc>, + conns: Vec>>, + + // Set to true if there was an error since the last refill, + // set to false when refilling starts. + had_error_since_last_refill: bool, + + refill_delay_strategy: RefillDelayStrategy, + + // Receives information about connections becoming ready, i.e. newly connected + // or after its keyspace was correctly set. + // TODO: This should probably be a channel + ready_connections: + FuturesUnordered + Send + 'static>>>, + + // Receives information about breaking connections + connection_errors: + FuturesUnordered + Send + 'static>>>, + + // When connecting, Scylla always assigns the shard which handles the least + // number of connections. If there are some non-shard-aware clients + // connected to the same node, they might cause the shard distribution + // to be heavily biased and Scylla will be very reluctant to assign some shards. + // + // In order to combat this, if the pool is not full and we get a connection + // for a shard which was already filled, we keep those additional connections + // in order to affect how Scylla assigns shards. A similar method is used + // in Scylla's forks of the java and gocql drivers. + // + // The number of those connections is bounded by the number of shards multiplied + // by a constant factor, and are all closed when they exceed this number. + excess_connections: Vec>, + + current_keyspace: Option, + + // Signaled when the connection pool is updated + pool_updated_notify: Arc, +} + +#[derive(Debug)] +struct UseKeyspaceRequest { + keyspace_name: VerifiedKeyspaceName, + response_sender: tokio::sync::oneshot::Sender>, +} + +impl PoolRefiller { + pub fn new( + address: IpAddr, + port: u16, + pool_config: PoolConfig, + current_keyspace: Option, + pool_updated_notify: Arc, + ) -> Self { + // At the beginning, we assume the node does not have any shards + // and assume that the node is a Cassandra node + let conns = vec![Vec::new()]; + let shared_conns = Arc::new(ArcSwap::new(Arc::new(MaybePoolConnections::Initializing))); + + Self { + address, + regular_port: port, + pool_config, + + shard_aware_port: None, + sharder: None, + + shared_conns, + conns, + + had_error_since_last_refill: false, + refill_delay_strategy: RefillDelayStrategy::new(), + + ready_connections: FuturesUnordered::new(), + connection_errors: FuturesUnordered::new(), + + excess_connections: Vec::new(), + + current_keyspace, + + pool_updated_notify, + } + } + + pub fn get_shared_connections(&self) -> Arc> { + self.shared_conns.clone() + } + + // The main loop of the pool refiller + pub async fn run( + mut self, + mut use_keyspace_request_receiver: mpsc::Receiver, + ) { + debug!("[{}] Started asynchronous pool worker", self.address); + + let mut next_refill_time = tokio::time::Instant::now(); + let mut refill_scheduled = true; + + loop { + tokio::select! { + _ = tokio::time::sleep_until(next_refill_time), if refill_scheduled => { + self.had_error_since_last_refill = false; + self.start_filling(); + refill_scheduled = false; + } + + evt = self.ready_connections.select_next_some(), if !self.ready_connections.is_empty() => { + self.handle_ready_connection(evt); + + if self.is_full() { + debug!( + "[{}] Pool is full, clearing {} excess connections", + self.address, + self.excess_connections.len() + ); + self.excess_connections.clear(); + } + } + + evt = self.connection_errors.select_next_some(), if !self.connection_errors.is_empty() => { + if let Some(conn) = evt.connection.upgrade() { + debug!("[{}] Got error for connection {:p}: {:?}", self.address, Arc::as_ptr(&conn), evt.error); + self.remove_connection(conn); + } + } + + req = use_keyspace_request_receiver.recv() => { + if let Some(req) = req { + debug!("[{}] Requested keyspace change: {}", self.address, req.keyspace_name.as_str()); + self.use_keyspace(&req.keyspace_name, req.response_sender); + } else { + // The keyspace request channel is dropped. + // This means that the corresponding pool is dropped. + // We can stop here. + trace!("[{}] Keyspace request channel dropped, stopping asynchronous pool worker", self.address); + return; + } + } + } + + // Schedule refilling here + if !refill_scheduled && self.need_filling() { + // Update shared_conns here even if there are no connections. + // This will signal the waiters in `wait_until_initialized`. + self.update_shared_conns(); + if self.had_error_since_last_refill { + self.refill_delay_strategy.on_fill_error(); + } else { + self.refill_delay_strategy.on_successful_fill(); + } + let delay = self.refill_delay_strategy.get_delay(); + debug!( + "[{}] Scheduling next refill in {} ms", + self.address, + delay.as_millis(), + ); + + next_refill_time = tokio::time::Instant::now() + delay; + refill_scheduled = true; + } + } + } + + fn is_filling(&self) -> bool { + !self.ready_connections.is_empty() + } + + fn is_full(&self) -> bool { + match self.pool_config.pool_size { + PoolSize::PerHost(target) => self.active_connection_count() >= target.get(), + PoolSize::PerShard(target) => { + self.conns.iter().all(|conns| conns.len() >= target.get()) + } + } + } + + fn is_empty(&self) -> bool { + self.conns.iter().all(|conns| conns.is_empty()) + } + + fn need_filling(&self) -> bool { + !self.is_filling() && !self.is_full() + } + + fn can_use_shard_aware_port(&self) -> bool { + self.sharder.is_some() + && self.shard_aware_port.is_some() + && self.pool_config.can_use_shard_aware_port + } + + // Begins opening a number of connections in order to fill the connection pool. + // Futures which open the connections are pushed to the `ready_connections` + // FuturesUnordered structure, and their results are processed in the main loop. + fn start_filling(&mut self) { + if self.is_empty() { + // If the pool is empty, it might mean that the node is not alive. + // It is more likely than not that the next connection attempt will + // fail, so there is no use in opening more than one connection now. + trace!( + "[{}] Will open the first connection to the node", + self.address + ); + self.start_opening_connection(None); + return; + } + + if self.can_use_shard_aware_port() { + // Only use the shard-aware port if we have a PerShard strategy + if let PoolSize::PerShard(target) = self.pool_config.pool_size { + // Try to fill up each shard up to `target` connections + for (shard_id, shard_conns) in self.conns.iter().enumerate() { + let to_open_count = target.get().saturating_sub(shard_conns.len()); + if to_open_count == 0 { + continue; + } + trace!( + "[{}] Will open {} connections to shard {}", + self.address, + to_open_count, + shard_id, + ); + for _ in 0..to_open_count { + self.start_opening_connection(Some(shard_id as Shard)); + } + } + return; + } + } + // Calculate how many more connections we need to open in order + // to achieve the target connection count. + let to_open_count = match self.pool_config.pool_size { + PoolSize::PerHost(target) => { + target.get().saturating_sub(self.active_connection_count()) + } + PoolSize::PerShard(target) => self + .conns + .iter() + .map(|conns| target.get().saturating_sub(conns.len())) + .sum::(), + }; + // When connecting to Scylla through non-shard-aware port, + // Scylla alone will choose shards for us. We hope that + // they will distribute across shards in the way we want, + // but we have no guarantee, so we might have to retry + // connecting later. + trace!( + "[{}] Will open {} non-shard-aware connections", + self.address, + to_open_count, + ); + for _ in 0..to_open_count { + self.start_opening_connection(None); + } + } + + // Handles a newly opened connection and decides what to do with it. + fn handle_ready_connection(&mut self, evt: OpenedConnectionEvent) { + match evt.result { + Err(err) => { + if evt.requested_shard.is_some() { + // If we failed to connect to a shard-aware port, + // fall back to the non-shard-aware port. + // Don't set `had_error_since_last_refill` here; + // the shard-aware port might be unreachable, but + // the regular port might be reachable. If we set + // `had_error_since_last_refill` here, it would cause + // the backoff to increase on each refill. With + // the non-shard aware port, multiple refills are sometimes + // necessary, so increasing the backoff would delay + // filling the pool even if the non-shard-aware port works + // and does not cause any errors. + debug!( + "[{}] Failed to open connection to the shard-aware port: {:?}, will retry with regular port", + self.address, + err, + ); + self.start_opening_connection(None); + } else { + // Encountered an error while connecting to the non-shard-aware + // port. Set the `had_error_since_last_refill` flag so that + // the next refill will be delayed more than this one. + self.had_error_since_last_refill = true; + debug!( + "[{}] Failed to open connection to the non-shard-aware port: {:?}", + self.address, err, + ); + } + } + Ok((connection, error_receiver)) => { + // Update sharding and optionally reshard + let shard_info = connection.get_shard_info().as_ref(); + let sharder = shard_info.map(|s| s.get_sharder()); + let shard_id = shard_info.map_or(0, |s| s.shard as usize); + self.maybe_reshard(sharder); + + // Update the shard-aware port + if self.shard_aware_port != connection.get_shard_aware_port() { + debug!( + "[{}] Updating shard aware port: {:?}", + self.address, + connection.get_shard_aware_port(), + ); + self.shard_aware_port = connection.get_shard_aware_port(); + } + + // Before the connection can be put to the pool, we need + // to make sure that it uses appropriate keyspace + if let Some(keyspace) = &self.current_keyspace { + if evt.keyspace_name.as_ref() != Some(keyspace) { + // Asynchronously start setting keyspace for this + // connection. It will be received on the ready + // connections channel and will travel through + // this logic again, to be finally put into + // the conns. + self.start_setting_keyspace_for_connection( + connection, + error_receiver, + evt.requested_shard, + ); + return; + } + } + + // Decide if the connection can be accepted, according to + // the pool filling strategy + let can_be_accepted = match self.pool_config.pool_size { + PoolSize::PerHost(target) => self.active_connection_count() < target.get(), + PoolSize::PerShard(target) => self.conns[shard_id].len() < target.get(), + }; + + if can_be_accepted { + // Don't complain and just put the connection to the pool. + // If this was a shard-aware port connection which missed + // the right shard, we still want to accept it + // because it fills our pool. + let conn = Arc::new(connection); + trace!( + "[{}] Adding connection {:p} to shard {} pool, now there are {} for the shard, total {}", + self.address, + Arc::as_ptr(&conn), + shard_id, + self.conns[shard_id].len() + 1, + self.active_connection_count() + 1, + ); + + self.connection_errors + .push(wait_for_error(Arc::downgrade(&conn), error_receiver).boxed()); + self.conns[shard_id].push(conn); + + self.update_shared_conns(); + } else if evt.requested_shard.is_some() { + // This indicates that some shard-aware connections + // missed the target shard (probably due to NAT). + // Because we don't know how address translation + // works here, it's better to leave the task + // of choosing the shard to Scylla. We will retry + // immediately with a non-shard-aware port here. + debug!( + "[{}] Excess shard-aware port connection for shard {}; will retry with non-shard-aware port", + self.address, + shard_id, + ); + + self.start_opening_connection(None); + } else { + // We got unlucky and Scylla didn't distribute + // shards across connections evenly. + // We will retry in the next iteration, + // for now put it into the excess connection + // pool. + let conn = Arc::new(connection); + trace!( + "[{}] Storing excess connection {:p} for shard {}", + self.address, + Arc::as_ptr(&conn), + shard_id, + ); + + self.connection_errors + .push(wait_for_error(Arc::downgrade(&conn), error_receiver).boxed()); + self.excess_connections.push(conn); + + let excess_connection_limit = self.excess_connection_limit(); + if self.excess_connections.len() > excess_connection_limit { + debug!( + "[{}] Excess connection pool exceeded limit of {} connections - clearing", + self.address, + excess_connection_limit, + ); + self.excess_connections.clear(); + } + } + } + } + } + + // Starts opening a new connection in the background. The result of connecting + // will be available on `ready_connections`. If the shard is specified and + // the shard aware port is available, it will attempt to connect directly + // to the shard using the port. + fn start_opening_connection(&self, shard: Option) { + let cfg = self.pool_config.connection_config.clone(); + let fut = match (self.sharder.clone(), self.shard_aware_port, shard) { + (Some(sharder), Some(port), Some(shard)) => { + let shard_aware_address = (self.address, port).into(); + async move { + let result = open_connection_to_shard_aware_port( + shard_aware_address, + shard, + sharder.clone(), + &cfg, + ) + .await; + OpenedConnectionEvent { + result, + requested_shard: Some(shard), + keyspace_name: None, + } + } + .boxed() + } + _ => { + let non_shard_aware_address = (self.address, self.regular_port).into(); + async move { + let result = + connection::open_connection(non_shard_aware_address, None, cfg).await; + OpenedConnectionEvent { + result, + requested_shard: None, + keyspace_name: None, + } + } + .boxed() + } + }; + self.ready_connections.push(fut); + } + + fn maybe_reshard(&mut self, new_sharder: Option) { + if self.sharder == new_sharder { + return; + } + + debug!( + "[{}] New sharder: {:?}, clearing all connections", + self.address, new_sharder, + ); + + self.sharder = new_sharder.clone(); + + // If the sharder has changed, we can throw away all previous connections. + // All connections to the same live node will have the same sharder, + // so the old ones will become dead very soon anyway. + self.conns.clear(); + + let shard_count = new_sharder.map_or(1, |s| s.nr_shards.get() as usize); + self.conns.resize_with(shard_count, Vec::new); + + self.excess_connections.clear(); + } + + // Updates `shared_conns` based on `conns`. + fn update_shared_conns(&mut self) { + let new_conns = if !self.has_connections() { + Arc::new(MaybePoolConnections::Broken) + } else { + let new_conns = if let Some(sharder) = self.sharder.as_ref() { + debug_assert_eq!(self.conns.len(), sharder.nr_shards.get() as usize); + PoolConnections::Sharded { + sharder: sharder.clone(), + connections: self.conns.clone(), + } + } else { + debug_assert_eq!(self.conns.len(), 1); + PoolConnections::NotSharded(self.conns[0].clone()) + }; + Arc::new(MaybePoolConnections::Ready(new_conns)) + }; + + // Make the connection list available + self.shared_conns.store(new_conns); + + // Notify potential waiters + self.pool_updated_notify.notify_waiters(); + } + + // Removes given connection from the pool. It looks both into active + // connections and excess connections. + fn remove_connection(&mut self, connection: Arc) { + let ptr = Arc::as_ptr(&connection); + + let maybe_remove_in_vec = |v: &mut Vec>| -> bool { + let maybe_idx = v + .iter() + .enumerate() + .find(|(_, other_conn)| Arc::ptr_eq(&connection, other_conn)) + .map(|(idx, _)| idx); + match maybe_idx { + Some(idx) => { + v.swap_remove(idx); + true + } + None => false, + } + }; + + // First, look it up in the shard bucket + // We might have resharded, so the bucket might not exist anymore + let shard_id = connection + .get_shard_info() + .as_ref() + .map_or(0, |s| s.shard as usize); + if shard_id < self.conns.len() && maybe_remove_in_vec(&mut self.conns[shard_id]) { + trace!( + "[{}] Connection {:p} removed from shard {} pool, now there is {} for the shard, total {}", + self.address, + ptr, + shard_id, + self.conns[shard_id].len(), + self.active_connection_count(), + ); + self.update_shared_conns(); + return; + } + + // If we didn't find it, it might sit in the excess_connections bucket + if maybe_remove_in_vec(&mut self.excess_connections) { + trace!( + "[{}] Connection {:p} removed from excess connection pool", + self.address, + ptr, + ); + return; + } + + trace!( + "[{}] Connection {:p} was already removed", + self.address, + ptr, + ); + } + + // Sets current keyspace for available connections. + // Connections which are being currently opened and future connections + // will have this keyspace set when they appear on `ready_connections`. + // Sends response to the `response_sender` when all current connections + // have their keyspace set. + fn use_keyspace( + &mut self, + keyspace_name: &VerifiedKeyspaceName, + response_sender: tokio::sync::oneshot::Sender>, + ) { + self.current_keyspace = Some(keyspace_name.clone()); + + let mut conns = self.conns.clone(); + let keyspace_name = keyspace_name.clone(); + let address = self.address; + + let fut = async move { + let mut use_keyspace_futures = Vec::new(); + + for shard_conns in conns.iter_mut() { + for conn in shard_conns.iter_mut() { + let fut = conn.use_keyspace(&keyspace_name); + use_keyspace_futures.push(fut); + } + } + + if use_keyspace_futures.is_empty() { + return Ok(()); + } + + let use_keyspace_results: Vec> = + futures::future::join_all(use_keyspace_futures).await; + + // If there was at least one Ok and the rest were IoErrors we can return Ok + // keyspace name is correct and will be used on broken connection on the next reconnect + + // If there were only IoErrors then return IoError + // If there was an error different than IoError return this error - something is wrong + + let mut was_ok: bool = false; + let mut io_error: Option> = None; + + for result in use_keyspace_results { + match result { + Ok(()) => was_ok = true, + Err(err) => match err { + QueryError::IoError(io_err) => io_error = Some(io_err), + _ => return Err(err), + }, + } + } + + if was_ok { + return Ok(()); + } + + // We can unwrap io_error because use_keyspace_futures must be nonempty + Err(QueryError::IoError(io_error.unwrap())) + }; + + tokio::task::spawn(async move { + let res = fut.await; + match &res { + Ok(()) => debug!("[{}] Successfully changed current keyspace", address), + Err(err) => warn!("[{}] Failed to change keyspace: {:?}", address, err), + } + let _ = response_sender.send(res); + }); + } + + // Requires the keyspace to be set + // Requires that the event is for a successful connection + fn start_setting_keyspace_for_connection( + &mut self, + connection: Connection, + error_receiver: ErrorReceiver, + requested_shard: Option, + ) { + // TODO: There should be a timeout for this + + let keyspace_name = self.current_keyspace.as_ref().cloned().unwrap(); + self.ready_connections.push( + async move { + let result = connection.use_keyspace(&keyspace_name).await; + if let Err(err) = result { + warn!( + "[{}] Failed to set keyspace for new connection: {}", + connection.get_connect_address().ip(), + err, + ); + } + OpenedConnectionEvent { + result: Ok((connection, error_receiver)), + requested_shard, + keyspace_name: Some(keyspace_name), + } + } + .boxed(), + ); + } + + fn has_connections(&self) -> bool { + self.conns.iter().any(|v| !v.is_empty()) + } + + fn active_connection_count(&self) -> usize { + self.conns.iter().map(Vec::len).sum::() + } + + fn excess_connection_limit(&self) -> usize { + match self.pool_config.pool_size { + PoolSize::PerShard(_) => { + EXCESS_CONNECTION_BOUND_PER_SHARD_MULTIPLIER + * self + .sharder + .as_ref() + .map_or(1, |s| s.nr_shards.get() as usize) + } + + // In PerHost mode we do not need to keep excess connections + PoolSize::PerHost(_) => 0, + } + } +} + +struct BrokenConnectionEvent { + connection: Weak, + error: QueryError, +} + +async fn wait_for_error( + connection: Weak, + error_receiver: ErrorReceiver, +) -> BrokenConnectionEvent { + BrokenConnectionEvent { + connection, + error: error_receiver.await.unwrap_or_else(|_| { + QueryError::IoError(Arc::new(std::io::Error::new( + ErrorKind::Other, + "Connection broken", + ))) + }), + } +} + +struct OpenedConnectionEvent { + result: Result<(Connection, ErrorReceiver), QueryError>, + requested_shard: Option, + keyspace_name: Option, +} + +async fn open_connection_to_shard_aware_port( + address: SocketAddr, + shard: Shard, + sharder: Sharder, + connection_config: &ConnectionConfig, +) -> Result<(Connection, ErrorReceiver), QueryError> { + // Create iterator over all possible source ports for this shard + let source_port_iter = sharder.iter_source_ports_for_shard(shard); + + for port in source_port_iter { + let connect_result = + connection::open_connection(address, Some(port), connection_config.clone()).await; + + match connect_result { + Err(err) if err.is_address_unavailable_for_use() => continue, // If we can't use this port, try the next one + result => return result, + } + } + + // Tried all source ports for that shard, give up + Err(QueryError::IoError(Arc::new(std::io::Error::new( + std::io::ErrorKind::AddrInUse, + "Could not find free source port for shard", + )))) +} + +#[cfg(test)] +mod tests { + use super::open_connection_to_shard_aware_port; + use crate::routing::{ShardCount, Sharder}; + use crate::transport::connection::ConnectionConfig; + use std::net::{SocketAddr, ToSocketAddrs}; + + // Open many connections to a node + // Port collision should occur + // If they are not handled this test will most likely fail + #[tokio::test] + async fn many_connections() { + let connections_number = 512; + + let connect_address: SocketAddr = std::env::var("SCYLLA_URI") + .unwrap_or_else(|_| "127.0.0.1:9042".to_string()) + .to_socket_addrs() + .unwrap() + .next() + .unwrap(); + + let connection_config = ConnectionConfig { + compression: None, + tcp_nodelay: true, + #[cfg(feature = "ssl")] + ssl_context: None, + ..Default::default() + }; + + // This does not have to be the real sharder, + // the test is only about port collisions, not connecting + // to the right shard + let sharder = Sharder::new(ShardCount::new(3).unwrap(), 12); + + // Open the connections + let mut conns = Vec::new(); + + for _ in 0..connections_number { + conns.push(open_connection_to_shard_aware_port( + connect_address, + 0, + sharder.clone(), + &connection_config, + )); + } + + let joined = futures::future::join_all(conns).await; + + // Check that each connection managed to connect successfully + for res in joined { + res.unwrap(); + } + } +} diff --git a/scylla/src/transport/mod.rs b/scylla/src/transport/mod.rs index 0eff9c9aa5..a45f78fca8 100644 --- a/scylla/src/transport/mod.rs +++ b/scylla/src/transport/mod.rs @@ -1,6 +1,6 @@ mod cluster; pub(crate) mod connection; -mod connection_keeper; +mod connection_pool; pub mod load_balancing; mod node; pub mod retry_policy; diff --git a/scylla/src/transport/node.rs b/scylla/src/transport/node.rs index acda5d9217..57c8cb9760 100644 --- a/scylla/src/transport/node.rs +++ b/scylla/src/transport/node.rs @@ -1,20 +1,16 @@ /// Node represents a cluster node along with it's data and connections -use crate::routing::{ShardInfo, Token}; +use crate::routing::Token; +use crate::transport::connection::Connection; use crate::transport::connection::VerifiedKeyspaceName; -use crate::transport::connection::{Connection, ConnectionConfig}; -use crate::transport::connection_keeper::{ConnectionKeeper, ShardInfoSender}; +use crate::transport::connection_pool::{NodeConnectionPool, PoolConfig}; use crate::transport::errors::QueryError; -use futures::future::join_all; -use futures::{future::RemoteHandle, FutureExt}; -use rand::Rng; use std::{ - convert::TryInto, hash::{Hash, Hasher}, net::SocketAddr, sync::{ atomic::{AtomicBool, Ordering}, - Arc, RwLock, + Arc, }, }; @@ -24,40 +20,9 @@ pub struct Node { pub datacenter: Option, pub rack: Option, - pub connections: Arc>>, + pool: NodeConnectionPool, down_marker: AtomicBool, - - use_keyspace_channel: tokio::sync::mpsc::Sender, - - _worker_handle: RemoteHandle<()>, -} - -pub enum NodeConnections { - /// Non shard-aware ex. a Cassandra node connection - Single(ConnectionKeeper), - /// Shard aware Scylla node connections - Sharded { - shard_info: ShardInfo, - /// shard_conns always contains shard_info.nr_shards ConnectionKeepers - shard_conns: Vec, - }, -} - -// Works in the background to detect ShardInfo changes and keep node connections updated -struct NodeWorker { - node_conns: Arc>>, - node_addr: SocketAddr, - connection_config: ConnectionConfig, - - shard_info_sender: ShardInfoSender, - shard_info_receiver: tokio::sync::watch::Receiver>, - - // Channel used to receive use keyspace requests - use_keyspace_channel: tokio::sync::mpsc::Receiver, - - // Keyspace send in "USE " when opening each connection - used_keyspace: Option, } #[derive(Debug)] @@ -76,85 +41,32 @@ impl Node { /// `rack` - optional rack name pub fn new( address: SocketAddr, - connection_config: ConnectionConfig, + pool_config: PoolConfig, datacenter: Option, rack: Option, keyspace_name: Option, ) -> Self { - let (shard_info_sender, shard_info_receiver) = tokio::sync::watch::channel(None); - - let shard_info_sender = Arc::new(std::sync::Mutex::new(shard_info_sender)); - - let (use_keyspace_sender, use_keyspace_receiver) = tokio::sync::mpsc::channel(32); - - let connections = Arc::new(RwLock::new(Arc::new(NodeConnections::Single( - ConnectionKeeper::new( - address, - connection_config.clone(), - None, - Some(shard_info_sender.clone()), - keyspace_name.clone(), - ), - )))); - - let worker = NodeWorker { - node_conns: connections.clone(), - node_addr: address, - connection_config, - shard_info_sender, - shard_info_receiver, - use_keyspace_channel: use_keyspace_receiver, - used_keyspace: keyspace_name, - }; - - let (fut, worker_handle) = worker.work().remote_handle(); - tokio::spawn(fut); + let pool = + NodeConnectionPool::new(address.ip(), address.port(), pool_config, keyspace_name); Node { address, datacenter, rack, - connections, + pool, down_marker: false.into(), - use_keyspace_channel: use_keyspace_sender, - _worker_handle: worker_handle, } } /// Get connection which should be used to connect using given token /// If this connection is broken get any random connection to this Node pub async fn connection_for_token(&self, token: Token) -> Result, QueryError> { - let connections: Arc = self.connections.read().unwrap().clone(); - - match &*connections { - NodeConnections::Single(conn_keeper) => conn_keeper.get_connection().await, - NodeConnections::Sharded { - shard_info, - shard_conns, - } => { - let shard: u16 = shard_info - .shard_of(token) - .try_into() - .expect("Shard number doesn't fit in u16"); - Self::connection_for_shard(shard, shard_info.nr_shards, shard_conns).await - } - } + self.pool.connection_for_token(token) } /// Get random connection pub async fn random_connection(&self) -> Result, QueryError> { - let connections: Arc = self.connections.read().unwrap().clone(); - - match &*connections { - NodeConnections::Single(conn_keeper) => conn_keeper.get_connection().await, - NodeConnections::Sharded { - shard_info, - shard_conns, - } => { - let shard: u16 = rand::thread_rng().gen_range(0..shard_info.nr_shards); - Self::connection_for_shard(shard, shard_info.nr_shards, shard_conns).await - } - } + self.pool.random_connection() } pub fn is_down(&self) -> bool { @@ -165,50 +77,19 @@ impl Node { self.down_marker.store(is_down, Ordering::Relaxed); } - // Tries to get a connection to given shard, if it's broken returns any working connection - async fn connection_for_shard( - shard: u16, - nr_shards: u16, - shard_conns: &[ConnectionKeeper], - ) -> Result, QueryError> { - // Try getting the desired connection - let mut last_error: QueryError = match shard_conns[shard as usize].get_connection().await { - Ok(connection) => return Ok(connection), - Err(e) => e, - }; - - // If this fails try getting any other in random order - let mut shards_to_try: Vec = (shard..nr_shards).chain(0..shard).skip(1).collect(); - - while !shards_to_try.is_empty() { - let idx = rand::thread_rng().gen_range(0..shards_to_try.len()); - let shard = shards_to_try.swap_remove(idx); - - match shard_conns[shard as usize].get_connection().await { - Ok(conn) => return Ok(conn), - Err(e) => last_error = e, - } - } - - Err(last_error) - } - pub async fn use_keyspace( &self, keyspace_name: VerifiedKeyspaceName, ) -> Result<(), QueryError> { - let (response_sender, response_receiver) = tokio::sync::oneshot::channel(); + self.pool.use_keyspace(keyspace_name).await + } - self.use_keyspace_channel - .send(UseKeyspaceRequest { - keyspace_name, - response_chan: response_sender, - }) - .await - .expect("Bug in Node::use_keyspace sending"); - // Other end of this channel is in NodeWorker, can't be dropped while we have &self to Node with _worker_handle + pub fn get_working_connections(&self) -> Result>, QueryError> { + self.pool.get_working_connections() + } - response_receiver.await.unwrap() // NodeWorker always responds + pub async fn wait_until_pool_initialized(&self) { + self.pool.wait_until_initialized().await } } @@ -225,154 +106,3 @@ impl Hash for Node { self.address.hash(state); } } - -impl NodeWorker { - pub async fn work(mut self) { - let mut cur_shard_info: Option = self.shard_info_receiver.borrow().clone(); - - loop { - tokio::select! { - // Wait for current shard_info to change - changed_res = self.shard_info_receiver.changed() => { - // We own one sending end of this channel so it can't return None - changed_res.expect("Bug in NodeWorker::work") - // Then go to resharding update - }, - // Wait for a use_keyspace request - recv_res = self.use_keyspace_channel.recv() => { - match recv_res { - Some(request) => { - self.used_keyspace = Some(request.keyspace_name.clone()); - - let node_conns = self.node_conns.read().unwrap().clone(); - let use_keyspace_future = Self::handle_use_keyspace_request(node_conns, request); - tokio::spawn(use_keyspace_future); - }, - None => return, - } - - continue; // Don't go to resharding update, wait for the next event - }, - } - - let new_shard_info: Option = self.shard_info_receiver.borrow().clone(); - - // See if the node has resharded - match (&cur_shard_info, &new_shard_info) { - (Some(cur), Some(new)) => { - if cur.nr_shards == new.nr_shards && cur.msb_ignore == new.msb_ignore { - // Nothing chaged, go back to waiting for a change - continue; - } - } - (None, None) => continue, // Nothing chaged, go back to waiting for a change - _ => {} - } - - cur_shard_info = new_shard_info; - - // We received updated node ShardInfo - // Create new node connections. It will happen rarely so we can probably afford it - // TODO: Maybe save some connections instead of recreating? - let new_connections: NodeConnections = match &cur_shard_info { - None => NodeConnections::Single(ConnectionKeeper::new( - self.node_addr, - self.connection_config.clone(), - None, - Some(self.shard_info_sender.clone()), - self.used_keyspace.clone(), - )), - Some(shard_info) => { - let mut connections: Vec = - Vec::with_capacity(shard_info.nr_shards as usize); - - for shard in 0..shard_info.nr_shards { - let mut cur_conn_shard_info = shard_info.clone(); - cur_conn_shard_info.shard = shard; - let cur_conn = ConnectionKeeper::new( - self.node_addr, - self.connection_config.clone(), - Some(cur_conn_shard_info), - Some(self.shard_info_sender.clone()), - self.used_keyspace.clone(), - ); - - connections.push(cur_conn); - } - - NodeConnections::Sharded { - shard_info: shard_info.clone(), - shard_conns: connections, - } - } - }; - - let mut new_connections_to_swap = Arc::new(new_connections); - - // Update node.connections - // Use std::mem::swap to minimalize time spent holding write lock - let mut node_conns_lock = self.node_conns.write().unwrap(); - std::mem::swap(&mut *node_conns_lock, &mut new_connections_to_swap); - drop(node_conns_lock); - } - } - - async fn handle_use_keyspace_request( - node_conns: Arc, - request: UseKeyspaceRequest, - ) { - let result = Self::send_use_keyspace(node_conns, &request.keyspace_name).await; - - // Don't care if nobody wants request result - let _ = request.response_chan.send(result); - } - - async fn send_use_keyspace( - node_conns: Arc, - keyspace_name: &VerifiedKeyspaceName, - ) -> Result<(), QueryError> { - let mut use_keyspace_futures = Vec::new(); - - match &*node_conns { - NodeConnections::Single(conn_keeper) => { - let fut = conn_keeper.use_keyspace(keyspace_name.clone()); - use_keyspace_futures.push(fut); - } - NodeConnections::Sharded { shard_conns, .. } => { - for conn_keeper in shard_conns { - let fut = conn_keeper.use_keyspace(keyspace_name.clone()); - use_keyspace_futures.push(fut); - } - } - } - - let use_keyspace_results: Vec> = - join_all(use_keyspace_futures).await; - - // If there was at least one Ok and the rest were IoErrors we can return Ok - // keyspace name is correct and will be used on broken connection on the next reconnect - - // If there were only IoErrors then return IoError - // If there was an error different than IoError return this error - something is wrong - - let mut was_ok: bool = false; - let mut io_error: Option> = None; - - for result in use_keyspace_results { - match result { - Ok(()) => was_ok = true, - Err(err) => match err { - QueryError::IoError(io_err) => io_error = Some(io_err), - _ => return Err(err), - }, - } - } - - if was_ok { - return Ok(()); - } - - // We can unwrap io_error because use_keyspace_futures must be nonempty - Err(QueryError::IoError(io_error.unwrap())) - } -} diff --git a/scylla/src/transport/session.rs b/scylla/src/transport/session.rs index 0aa52dcc3d..8954f93fda 100644 --- a/scylla/src/transport/session.rs +++ b/scylla/src/transport/session.rs @@ -9,19 +9,20 @@ use std::sync::Arc; use std::time::Duration; use tokio::net::lookup_host; use tokio::time::timeout; -use tracing::{debug, info, warn}; +use tracing::debug; use uuid::Uuid; use super::connection::QueryResponse; use super::errors::{BadQuery, NewSessionError, QueryError}; use crate::frame::response::cql_to_rust::FromRowError; -use crate::frame::response::{result, Response}; +use crate::frame::response::result; use crate::frame::value::{BatchValues, SerializedValues, ValueList}; use crate::prepared_statement::{PartitionKeyError, PreparedStatement}; use crate::query::Query; use crate::routing::{murmur3_token, Token}; use crate::statement::{Consistency, SerialConsistency}; use crate::tracing::{GetTracingConfig, TracingEvent, TracingInfo}; +use crate::transport::connection_pool::PoolConfig; use crate::transport::{ cluster::Cluster, connection::{BatchResult, Connection, ConnectionConfig, QueryResult, VerifiedKeyspaceName}, @@ -36,6 +37,8 @@ use crate::transport::{ use crate::{batch::Batch, statement::StatementConfig}; use crate::{cql_to_rust::FromRow, transport::speculative_execution}; +pub use crate::transport::connection_pool::PoolSize; + #[cfg(feature = "ssl")] use openssl::ssl::SslContext; @@ -83,6 +86,14 @@ pub struct SessionConfig { pub schema_agreement_interval: Duration, pub connect_timeout: std::time::Duration, + + /// Size of the per-node connection pool, i.e. how many connections the driver should keep to each node. + /// The default is `PerShard(1)`, which is the recommended setting for Scylla clusters. + pub connection_pool_size: PoolSize, + + /// If true, prevents the driver from connecting to the shard-aware port, even if the node supports it. + /// Generally, this options is best left as default (false). + pub disallow_shard_aware_port: bool, /* These configuration options will be added in the future: @@ -127,6 +138,8 @@ impl SessionConfig { auth_username: None, auth_password: None, connect_timeout: std::time::Duration::from_secs(5), + connection_pool_size: Default::default(), + disallow_shard_aware_port: false, } } @@ -188,6 +201,15 @@ impl SessionConfig { } } + /// Creates a PoolConfig which can be used to create NodeConnectionPools + fn get_pool_config(&self) -> PoolConfig { + PoolConfig { + connection_config: self.get_connection_config(), + pool_size: self.connection_pool_size.clone(), + can_use_shard_aware_port: !self.disallow_shard_aware_port, + } + } + /// Makes a config that should be used in Connection fn get_connection_config(&self) -> ConnectionConfig { ConnectionConfig { @@ -291,38 +313,7 @@ impl Session { node_addresses.extend(resolved); - let use_ssl = match () { - #[cfg(not(feature = "ssl"))] - () => false, - #[cfg(feature = "ssl")] - () => config.ssl_context.is_some(), - }; - - let mut shard_aware_addresses: Vec = vec![]; - if let Some(shard_aware_port) = - Self::get_shard_aware_port(node_addresses[0], config.get_connection_config(), use_ssl) - .await - { - info!("Shard-aware port detected: {}", shard_aware_port); - shard_aware_addresses = (&node_addresses) - .iter() - .map(|addr| SocketAddr::new(addr.ip(), shard_aware_port)) - .collect(); - } - - // Start the session - let cluster = if !shard_aware_addresses.is_empty() { - match Cluster::new(&shard_aware_addresses, config.get_connection_config()).await { - Ok(clust) => clust, - Err(e) => { - warn!("Unable to establish connections at detected shard-aware port, falling back to default ports: {}", e); - Cluster::new(&node_addresses, config.get_connection_config()).await? - } - } - } else { - info!("Shard-aware ports not available, falling back to default ports"); - Cluster::new(&node_addresses, config.get_connection_config()).await? - }; + let cluster = Cluster::new(&node_addresses, config.get_pool_config()).await?; let session = Session { cluster, @@ -342,29 +333,6 @@ impl Session { Ok(session) } - async fn get_shard_aware_port( - addr: SocketAddr, - config: ConnectionConfig, - use_ssl: bool, - ) -> Option { - let (probe, _) = Connection::new(addr, None, config).await.ok()?; - let options_result = probe.get_options().await.ok()?; - let options_key = if use_ssl { - "SCYLLA_SHARD_AWARE_PORT_SSL" - } else { - "SCYLLA_SHARD_AWARE_PORT" - }; - match options_result { - Response::Supported(mut supported) => supported - .options - .remove(options_key) - .unwrap_or_else(Vec::new) - .get(0) - .and_then(|p| p.parse::().ok()), - _ => None, - } - } - /// Sends a query to the database and receives a response. /// Returns only a single page of results, to receive multiple pages use [query_iter](Session::query_iter) /// diff --git a/scylla/src/transport/session_builder.rs b/scylla/src/transport/session_builder.rs index b3b176f344..71f0696642 100644 --- a/scylla/src/transport/session_builder.rs +++ b/scylla/src/transport/session_builder.rs @@ -5,7 +5,7 @@ use super::load_balancing::LoadBalancingPolicy; use super::session::{Session, SessionConfig}; use super::speculative_execution::SpeculativeExecutionPolicy; use super::Compression; -use crate::transport::retry_policy::RetryPolicy; +use crate::transport::{connection_pool::PoolSize, retry_policy::RetryPolicy}; use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; @@ -380,6 +380,70 @@ impl SessionBuilder { self.config.connect_timeout = duration; self } + + /// Sets the per-node connection pool size. + /// The default is one connection per shard, which is the recommended setting for Scylla. + /// + /// # Example + /// ``` + /// # use scylla::{Session, SessionBuilder}; + /// # async fn example() -> Result<(), Box> { + /// use std::num::NonZeroUsize; + /// use scylla::transport::session::PoolSize; + /// + /// // This session will establish 4 connections to each node. + /// // For Scylla clusters, this number will be divided across shards + /// let session: Session = SessionBuilder::new() + /// .known_node("127.0.0.1:9042") + /// .pool_size(PoolSize::PerHost(NonZeroUsize::new(4).unwrap())) + /// .build() + /// .await?; + /// # Ok(()) + /// # } + /// ``` + pub fn pool_size(mut self, size: PoolSize) -> Self { + self.config.connection_pool_size = size; + self + } + + /// If true, prevents the driver from connecting to the shard-aware port, even if the node supports it. + /// + /// _This is a Scylla-specific option_. It has no effect on Cassandra clusters. + /// + /// By default, connecting to the shard-aware port is __allowed__ and, in general, this setting + /// _should not be changed_. The shard-aware port (19042 or 19142) makes the process of + /// establishing connection per shard more robust compared to the regular transport port + /// (9042 or 9142). With the shard-aware port, the driver is able to choose which shard + /// will be assigned to the connection. + /// + /// In order to be able to use the shard-aware port effectively, the port needs to be + /// reachable and not behind a NAT which changes source ports (the driver uses the source port + /// to tell Scylla which shard to assign). However, the driver is designed to behave in a robust + /// way if those conditions are not met - if the driver fails to connect to the port or gets + /// a connection to the wrong shard, it will re-attempt the connection to the regular transport port. + /// + /// The only cost of misconfigured shard-aware port should be a slightly longer reconnection time. + /// If it is unacceptable to you or suspect that it causes you some other problems, + /// you can use this option to disable the shard-aware port feature completely. + /// However, __you should use it as a last resort__. Before you do that, we strongly recommend + /// that you consider fixing the network issues. + /// + /// # Example + /// ``` + /// # use scylla::{Session, SessionBuilder}; + /// # async fn example() -> Result<(), Box> { + /// let session: Session = SessionBuilder::new() + /// .known_node("127.0.0.1:9042") + /// .disallow_shard_aware_port(true) + /// .build() + /// .await?; + /// # Ok(()) + /// # } + /// ``` + pub fn disallow_shard_aware_port(mut self, disallow: bool) -> Self { + self.config.disallow_shard_aware_port = disallow; + self + } } /// Creates a [`SessionBuilder`] with default configuration, same as [`SessionBuilder::new`] diff --git a/scylla/src/transport/topology.rs b/scylla/src/transport/topology.rs index 756f3eac1c..8de171383c 100644 --- a/scylla/src/transport/topology.rs +++ b/scylla/src/transport/topology.rs @@ -1,7 +1,7 @@ use crate::frame::response::event::Event; use crate::routing::Token; use crate::transport::connection::{Connection, ConnectionConfig}; -use crate::transport::connection_keeper::ConnectionKeeper; +use crate::transport::connection_pool::{NodeConnectionPool, PoolConfig, PoolSize}; use crate::transport::errors::QueryError; use crate::transport::session::IntoTypedRows; @@ -9,6 +9,7 @@ use rand::seq::SliceRandom; use rand::thread_rng; use std::collections::HashMap; use std::net::{IpAddr, SocketAddr}; +use std::num::NonZeroUsize; use std::str::FromStr; use tokio::sync::mpsc; use tracing::{debug, error, warn}; @@ -17,7 +18,7 @@ use tracing::{debug, error, warn}; pub struct TopologyReader { connection_config: ConnectionConfig, control_connection_address: SocketAddr, - control_connection: ConnectionKeeper, + control_connection: NodeConnectionPool, // when control connection fails, TopologyReader tries to connect to one of known_peers known_peers: Vec, @@ -74,12 +75,9 @@ impl TopologyReader { // - send received events via server_event_sender connection_config.event_sender = Some(server_event_sender); - let control_connection = ConnectionKeeper::new( + let control_connection = Self::make_control_connection_pool( control_connection_address, connection_config.clone(), - None, - None, - None, ); TopologyReader { @@ -122,12 +120,9 @@ impl TopologyReader { ); self.control_connection_address = *peer; - self.control_connection = ConnectionKeeper::new( + self.control_connection = Self::make_control_connection_pool( self.control_connection_address, self.connection_config.clone(), - None, - None, - None, ); result = self.fetch_topology_info().await; @@ -164,13 +159,32 @@ impl TopologyReader { .map(|peer| peer.address) .collect(); } + + fn make_control_connection_pool( + addr: SocketAddr, + connection_config: ConnectionConfig, + ) -> NodeConnectionPool { + let pool_config = PoolConfig { + connection_config, + + // We want to have only one connection to receive events from + pool_size: PoolSize::PerHost(NonZeroUsize::new(1).unwrap()), + + // The shard-aware port won't be used with PerHost pool size anyway, + // so explicitly disable it here + can_use_shard_aware_port: false, + }; + + NodeConnectionPool::new(addr.ip(), addr.port(), pool_config, None) + } } async fn query_topology_info( - conn_keeper: &ConnectionKeeper, + pool: &NodeConnectionPool, connect_port: u16, ) -> Result { - let conn: &Connection = &*conn_keeper.get_connection().await?; + pool.wait_until_initialized().await; + let conn: &Connection = &*pool.random_connection()?; let peers_query = query_peers(conn, connect_port); let keyspaces_query = query_keyspaces(conn);