Skip to content

Commit

Permalink
Merge pull request #261 from piodul/better-connection-pooling
Browse files Browse the repository at this point in the history
Better connection management
  • Loading branch information
psarna authored Nov 3, 2021
2 parents 6590c11 + 427a935 commit 15f8faf
Show file tree
Hide file tree
Showing 10 changed files with 1,294 additions and 764 deletions.
70 changes: 45 additions & 25 deletions scylla/src/routing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use rand::Rng;
use std::collections::HashMap;
use std::convert::TryFrom;
use std::net::SocketAddr;
use std::num::Wrapping;
use std::num::{NonZeroU16, Wrapping};
use thiserror::Error;

#[derive(PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)]
Expand All @@ -19,11 +19,18 @@ pub struct Token {
}

pub type Shard = u32;
pub type ShardCount = NonZeroU16;

#[derive(PartialEq, Eq, Clone, Debug)]
pub struct ShardInfo {
pub shard: u16,
pub nr_shards: u16,
pub nr_shards: ShardCount,
pub msb_ignore: u8,
}

#[derive(PartialEq, Eq, Clone, Debug)]
pub struct Sharder {
pub nr_shards: ShardCount,
pub msb_ignore: u8,
}

Expand All @@ -41,59 +48,69 @@ pub fn murmur3_token(pk: Bytes) -> Token {
}

impl ShardInfo {
pub fn new(shard: u16, nr_shards: u16, msb_ignore: u8) -> Self {
assert!(nr_shards > 0);
pub fn new(shard: u16, nr_shards: ShardCount, msb_ignore: u8) -> Self {
ShardInfo {
shard,
nr_shards,
msb_ignore,
}
}

pub fn get_sharder(&self) -> Sharder {
Sharder::new(self.nr_shards, self.msb_ignore)
}
}

impl Sharder {
pub fn new(nr_shards: ShardCount, msb_ignore: u8) -> Self {
Sharder {
nr_shards,
msb_ignore,
}
}

pub fn shard_of(&self, token: Token) -> Shard {
let mut biased_token = (token.value as u64).wrapping_add(1u64 << 63);
biased_token <<= self.msb_ignore;
(((biased_token as u128) * (self.nr_shards as u128)) >> 64) as Shard
(((biased_token as u128) * (self.nr_shards.get() as u128)) >> 64) as Shard
}

/// If we connect to Scylla using Scylla's shard aware port, then Scylla assigns a shard to the
/// connection based on the source port. This calculates the assigned shard.
pub fn shard_of_source_port(&self, source_port: u16) -> Shard {
(source_port % self.nr_shards) as Shard
(source_port % self.nr_shards.get()) as Shard
}

/// Randomly choose a source port `p` such that `shard == shard_of_source_port(p)`.
pub fn draw_source_port_for_shard(&self, shard: Shard) -> u16 {
assert!(shard < self.nr_shards as u32);
rand::thread_rng().gen_range((49152 + self.nr_shards - 1)..(65535 - self.nr_shards + 1))
/ self.nr_shards
* self.nr_shards
assert!(shard < self.nr_shards.get() as u32);
rand::thread_rng()
.gen_range((49152 + self.nr_shards.get() - 1)..(65535 - self.nr_shards.get() + 1))
/ self.nr_shards.get()
* self.nr_shards.get()
+ shard as u16
}

/// Returns iterator over source ports `p` such that `shard == shard_of_source_port(p)`.
/// Starts at a random port and goes forward by `nr_shards`. After reaching maximum wraps back around.
/// Stops once all possibile ports have been returned
pub fn iter_source_ports_for_shard(&self, shard: Shard) -> impl Iterator<Item = u16> {
assert!(shard < self.nr_shards as u32);
assert!(shard < self.nr_shards.get() as u32);

// Randomly choose a port to start at
let starting_port = self.draw_source_port_for_shard(shard);

// Choose smallest available port number to begin at after wrapping
// apply the formula from draw_source_port_for_shard for lowest possible gen_range result
let first_valid_port =
(49152 + self.nr_shards - 1) / self.nr_shards * self.nr_shards + shard as u16;
let first_valid_port = (49152 + self.nr_shards.get() - 1) / self.nr_shards.get()
* self.nr_shards.get()
+ shard as u16;

let before_wrap = (starting_port..=65535).step_by(self.nr_shards.into());
let after_wrap = (first_valid_port..starting_port).step_by(self.nr_shards.into());
let before_wrap = (starting_port..=65535).step_by(self.nr_shards.get().into());
let after_wrap = (first_valid_port..starting_port).step_by(self.nr_shards.get().into());

before_wrap.chain(after_wrap)
}

pub fn get_nr_shards(&self) -> u16 {
self.nr_shards
}
}

#[derive(Error, Debug)]
Expand All @@ -102,6 +119,8 @@ pub enum ShardingError {
MissingShardInfoParameter,
#[error("ShardInfo parameters missing after unwraping")]
MissingUnwrapedShardInfoParameter,
#[error("ShardInfo contains an invalid number of shards (0)")]
ZeroShards,
#[error("ParseIntError encountered while getting ShardInfo")]
ParseIntError(#[from] std::num::ParseIntError),
}
Expand All @@ -123,6 +142,7 @@ impl<'a> TryFrom<&'a HashMap<String, Vec<String>>> for ShardInfo {
}
let shard = shard_entry.unwrap().first().unwrap().parse::<u16>()?;
let nr_shards = nr_shards_entry.unwrap().first().unwrap().parse::<u16>()?;
let nr_shards = ShardCount::new(nr_shards).ok_or(ShardingError::ZeroShards)?;
let msb_ignore = msb_ignore_entry.unwrap().first().unwrap().parse::<u8>()?;
Ok(ShardInfo::new(shard, nr_shards, msb_ignore))
}
Expand Down Expand Up @@ -223,22 +243,22 @@ fn fmix(mut k: Wrapping<i64>) -> Wrapping<i64> {

#[cfg(test)]
mod tests {
use super::ShardInfo;
use super::Token;
use super::{ShardCount, Sharder};
use std::collections::HashSet;

#[test]
fn test_shard_of() {
/* Test values taken from the gocql driver. */
let shard_info = ShardInfo::new(0, 4, 12);
let sharder = Sharder::new(ShardCount::new(4).unwrap(), 12);
assert_eq!(
shard_info.shard_of(Token {
sharder.shard_of(Token {
value: -9219783007514621794
}),
3
);
assert_eq!(
shard_info.shard_of(Token {
sharder.shard_of(Token {
value: 9222582454147032830
}),
3
Expand All @@ -251,7 +271,7 @@ mod tests {
let max_port_num = 65535;
let min_port_num = (49152 + nr_shards - 1) / nr_shards * nr_shards;

let shard_info = ShardInfo::new(0, nr_shards, 12);
let sharder = Sharder::new(ShardCount::new(nr_shards).unwrap(), 12);

// Test for each shard
for shard in 0..nr_shards {
Expand All @@ -265,7 +285,7 @@ mod tests {
let possible_ports_number: usize =
((max_port_num - lowest_port) / nr_shards + 1).into();

let port_iter = shard_info.iter_source_ports_for_shard(shard.into());
let port_iter = sharder.iter_source_ports_for_shard(shard.into());

let mut returned_ports: HashSet<u16> = HashSet::new();
for port in port_iter {
Expand Down
51 changes: 23 additions & 28 deletions scylla/src/transport/cluster.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use crate::frame::response::event::{Event, StatusChangeEvent};
/// Cluster manages up to date information and connections to database nodes
use crate::routing::Token;
use crate::transport::connection::{Connection, ConnectionConfig, VerifiedKeyspaceName};
use crate::transport::connection::{Connection, VerifiedKeyspaceName};
use crate::transport::connection_pool::PoolConfig;
use crate::transport::errors::QueryError;
use crate::transport::node::{Node, NodeConnections};
use crate::transport::node::Node;
use crate::transport::topology::{Keyspace, TopologyInfo, TopologyReader};

use arc_swap::ArcSwap;
Expand Down Expand Up @@ -50,7 +51,7 @@ struct ClusterWorker {

// Cluster connections
topology_reader: TopologyReader,
connection_config: ConnectionConfig,
pool_config: PoolConfig,

// To listen for refresh requests
refresh_channel: tokio::sync::mpsc::Receiver<RefreshRequest>,
Expand Down Expand Up @@ -79,7 +80,7 @@ struct UseKeyspaceRequest {
impl Cluster {
pub async fn new(
initial_peers: &[SocketAddr],
connection_config: ConnectionConfig,
pool_config: PoolConfig,
) -> Result<Cluster, QueryError> {
let cluster_data = Arc::new(ArcSwap::from(Arc::new(ClusterData {
known_peers: HashMap::new(),
Expand All @@ -98,10 +99,10 @@ impl Cluster {

topology_reader: TopologyReader::new(
initial_peers,
connection_config.clone(),
pool_config.connection_config.clone(),
server_events_sender,
),
connection_config,
pool_config,

refresh_channel: refresh_receiver,
server_events_channel: server_events_receiver,
Expand Down Expand Up @@ -173,26 +174,10 @@ impl Cluster {

let mut last_error: Option<QueryError> = None;

// Takes result of ConnectionKeeper::get_connection() and pushes it onto result list or sets last_error
let mut push_to_result = |get_conn_res: Result<Arc<Connection>, QueryError>| {
match get_conn_res {
Ok(conn) => result.push(conn),
Err(e) => last_error = Some(e),
};
};

for node in peers.values() {
let connections: Arc<NodeConnections> = node.connections.read().unwrap().clone();

match &*connections {
NodeConnections::Single(conn_keeper) => {
push_to_result(conn_keeper.get_connection().await)
}
NodeConnections::Sharded { shard_conns, .. } => {
for conn_keeper in shard_conns {
push_to_result(conn_keeper.get_connection().await);
}
}
match node.get_working_connections() {
Ok(conns) => result.extend(conns),
Err(e) => last_error = Some(e),
}
}

Expand Down Expand Up @@ -226,11 +211,17 @@ impl ClusterData {
}
}

pub async fn wait_until_all_pools_are_initialized(&self) {
for node in self.all_nodes.iter() {
node.wait_until_pool_initialized().await;
}
}

/// Creates new ClusterData using information about topology held in `info`.
/// Uses provided `known_peers` hashmap to recycle nodes if possible.
pub fn new(
info: TopologyInfo,
connection_config: &ConnectionConfig,
pool_config: &PoolConfig,
known_peers: &HashMap<SocketAddr, Arc<Node>>,
used_keyspace: &Option<VerifiedKeyspaceName>,
) -> Self {
Expand All @@ -251,7 +242,7 @@ impl ClusterData {
}
_ => Arc::new(Node::new(
peer.address,
connection_config.clone(),
pool_config.clone(),
peer.datacenter,
peer.rack,
used_keyspace.clone(),
Expand Down Expand Up @@ -442,11 +433,15 @@ impl ClusterWorker {

let new_cluster_data = Arc::new(ClusterData::new(
topo_info,
&self.connection_config,
&self.pool_config,
&cluster_data.known_peers,
&self.used_keyspace,
));

new_cluster_data
.wait_until_all_pools_are_initialized()
.await;

self.update_cluster_data(new_cluster_data);

Ok(())
Expand Down
35 changes: 28 additions & 7 deletions scylla/src/transport/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ pub struct Connection {
source_port: u16,
shard_info: Option<ShardInfo>,
config: ConnectionConfig,
is_shard_aware: bool,
shard_aware_port: Option<u16>,
}

type ResponseHandler = oneshot::Sender<Result<TaskResponse, QueryError>>;
Expand Down Expand Up @@ -184,6 +184,18 @@ impl Default for ConnectionConfig {
}
}

impl ConnectionConfig {
#[cfg(feature = "ssl")]
pub fn is_ssl(&self) -> bool {
self.ssl_context.is_some()
}

#[cfg(not(feature = "ssl"))]
pub fn is_ssl(&self) -> bool {
false
}
}

// Used to listen for fatal error in connection
pub type ErrorReceiver = tokio::sync::oneshot::Receiver<QueryError>;

Expand Down Expand Up @@ -225,7 +237,7 @@ impl Connection {
connect_address: addr,
shard_info: None,
config,
is_shard_aware: false,
shard_aware_port: None,
};

Ok((connection, error_receiver))
Expand Down Expand Up @@ -772,7 +784,11 @@ impl Connection {
/// Are we connected to Scylla's shard aware port?
// TODO: couple this with shard_info?
pub fn get_is_shard_aware(&self) -> bool {
self.is_shard_aware
Some(self.connect_address.port()) == self.shard_aware_port
}

pub fn get_shard_aware_port(&self) -> Option<u16> {
self.shard_aware_port
}

pub fn get_source_port(&self) -> u16 {
Expand All @@ -783,8 +799,8 @@ impl Connection {
self.shard_info = shard_info
}

fn set_is_shard_aware(&mut self, is_shard_aware: bool) {
self.is_shard_aware = is_shard_aware;
fn set_shard_aware_port(&mut self, shard_aware_port: Option<u16>) {
self.shard_aware_port = shard_aware_port;
}

pub fn get_connect_address(&self) -> SocketAddr {
Expand Down Expand Up @@ -818,6 +834,11 @@ pub async fn open_named_connection(

let options_result = connection.get_options().await?;

let shard_aware_port_key = match config.is_ssl() {
true => "SCYLLA_SHARD_AWARE_PORT_SSL",
false => "SCYLLA_SHARD_AWARE_PORT",
};

let (shard_info, supported_compression, shard_aware_port) = match options_result {
Response::Supported(mut supported) => {
let shard_info = ShardInfo::try_from(&supported.options).ok();
Expand All @@ -827,7 +848,7 @@ pub async fn open_named_connection(
.unwrap_or_else(Vec::new);
let shard_aware_port = supported
.options
.remove("SCYLLA_SHARD_AWARE_PORT")
.remove(shard_aware_port_key)
.unwrap_or_else(Vec::new)
.into_iter()
.next()
Expand All @@ -837,7 +858,7 @@ pub async fn open_named_connection(
_ => (None, Vec::new(), None),
};
connection.set_shard_info(shard_info);
connection.set_is_shard_aware(Some(addr.port()) == shard_aware_port);
connection.set_shard_aware_port(shard_aware_port);

let mut options = HashMap::new();
options.insert("CQL_VERSION".to_string(), "4.0.0".to_string()); // FIXME: hardcoded values
Expand Down
Loading

0 comments on commit 15f8faf

Please sign in to comment.