diff --git a/scylla/src/transport/cluster.rs b/scylla/src/transport/cluster.rs index 7ff6fb51c..2313402bc 100644 --- a/scylla/src/transport/cluster.rs +++ b/scylla/src/transport/cluster.rs @@ -731,31 +731,7 @@ impl ClusterWorker { let use_keyspace_results: Vec> = join_all(use_keyspace_futures).await; - // If there was at least one Ok and the rest were IoErrors we can return Ok - // keyspace name is correct and will be used on broken connection on the next reconnect - - // If there were only IoErrors then return IoError - // If there was an error different than IoError return this error - something is wrong - - let mut was_ok: bool = false; - let mut io_error: Option> = None; - - for result in use_keyspace_results { - match result { - Ok(()) => was_ok = true, - Err(err) => match err { - QueryError::IoError(io_err) => io_error = Some(io_err), - _ => return Err(err), - }, - } - } - - if was_ok { - return Ok(()); - } - - // We can unwrap io_error because use_keyspace_futures must be nonempty - Err(QueryError::IoError(io_error.unwrap())) + use_keyspace_result(use_keyspace_results.into_iter()) } async fn perform_refresh(&mut self) -> Result<(), QueryError> { @@ -788,3 +764,39 @@ impl ClusterWorker { self.cluster_data.store(new_cluster_data); } } + +/// Returns a result of use_keyspace operation, based on the query results +/// returned from given node/connection. +/// +/// This function assumes that `use_keyspace_results` iterator is NON-EMPTY! +pub(crate) fn use_keyspace_result( + use_keyspace_results: impl Iterator>, +) -> Result<(), QueryError> { + // If there was at least one Ok and the rest were broken connection errors we can return Ok + // keyspace name is correct and will be used on broken connection on the next reconnect + + // If there were only broken connection errors then return broken connection error. + // If there was an error different than broken connection error return this error - something is wrong + + let mut was_ok: bool = false; + let mut broken_conn_error: Option = None; + + for result in use_keyspace_results { + match result { + Ok(()) => was_ok = true, + Err(err) => match err { + QueryError::BrokenConnection(_) | QueryError::ConnectionPoolError(_) => { + broken_conn_error = Some(err) + } + _ => return Err(err), + }, + } + } + + if was_ok { + return Ok(()); + } + + // We can unwrap conn_broken_error because use_keyspace_results must be nonempty + Err(broken_conn_error.unwrap()) +} diff --git a/scylla/src/transport/connection_pool.rs b/scylla/src/transport/connection_pool.rs index e97fa85b7..9e7a6b575 100644 --- a/scylla/src/transport/connection_pool.rs +++ b/scylla/src/transport/connection_pool.rs @@ -1108,31 +1108,7 @@ impl PoolRefiller { .await .map_err(|_| QueryError::TimeoutError)?; - // If there was at least one Ok and the rest were IoErrors we can return Ok - // keyspace name is correct and will be used on broken connection on the next reconnect - - // If there were only IoErrors then return IoError - // If there was an error different than IoError return this error - something is wrong - - let mut was_ok: bool = false; - let mut io_error: Option> = None; - - for result in use_keyspace_results { - match result { - Ok(()) => was_ok = true, - Err(err) => match err { - QueryError::IoError(io_err) => io_error = Some(io_err), - _ => return Err(err), - }, - } - } - - if was_ok { - return Ok(()); - } - - // We can unwrap io_error because use_keyspace_futures must be nonempty - Err(QueryError::IoError(io_error.unwrap())) + super::cluster::use_keyspace_result(use_keyspace_results.into_iter()) }; tokio::task::spawn(async move { diff --git a/scylla/src/transport/downgrading_consistency_retry_policy.rs b/scylla/src/transport/downgrading_consistency_retry_policy.rs index 7f2f326f1..abe55caed 100644 --- a/scylla/src/transport/downgrading_consistency_retry_policy.rs +++ b/scylla/src/transport/downgrading_consistency_retry_policy.rs @@ -94,7 +94,8 @@ impl RetrySession for DowngradingConsistencyRetrySession { match query_info.error { // Basic errors - there are some problems on this node // Retry on a different one if possible - QueryError::IoError(_) + QueryError::BrokenConnection(_) + | QueryError::ConnectionPoolError(_) | QueryError::DbError(DbError::Overloaded, _) | QueryError::DbError(DbError::ServerError, _) | QueryError::DbError(DbError::TruncateError, _) => { @@ -181,12 +182,10 @@ impl RetrySession for DowngradingConsistencyRetrySession { #[cfg(test)] mod tests { - use std::{io::ErrorKind, sync::Arc}; - use bytes::Bytes; use crate::test_utils::setup_tracing; - use crate::transport::errors::BadQuery; + use crate::transport::errors::{BadQuery, BrokenConnectionErrorKind, ConnectionPoolError}; use super::*; @@ -328,7 +327,10 @@ mod tests { QueryError::DbError(DbError::Overloaded, String::new()), QueryError::DbError(DbError::TruncateError, String::new()), QueryError::DbError(DbError::ServerError, String::new()), - QueryError::IoError(Arc::new(std::io::Error::new(ErrorKind::Other, "test"))), + QueryError::BrokenConnection( + BrokenConnectionErrorKind::TooManyOrphanedStreamIds(5).into(), + ), + QueryError::ConnectionPoolError(ConnectionPoolError::Initializing), ]; for &cl in CONSISTENCY_LEVELS { diff --git a/scylla/src/transport/errors.rs b/scylla/src/transport/errors.rs index b24c03757..2ab1314d3 100644 --- a/scylla/src/transport/errors.rs +++ b/scylla/src/transport/errors.rs @@ -52,10 +52,6 @@ pub enum QueryError { #[error("Failed to deserialize ERROR response: {0}")] CqlErrorParseError(#[from] CqlErrorParseError), - /// Input/Output error has occurred, connection broken etc. - #[error("IO Error: {0}")] - IoError(Arc), - /// Selected node's connection pool is in invalid state. #[error("No connections in the pool: {0}")] ConnectionPoolError(#[from] ConnectionPoolError), @@ -154,7 +150,6 @@ impl From for NewSessionError { QueryError::BadQuery(e) => NewSessionError::BadQuery(e), QueryError::CqlResultParseError(e) => NewSessionError::CqlResultParseError(e), QueryError::CqlErrorParseError(e) => NewSessionError::CqlErrorParseError(e), - QueryError::IoError(e) => NewSessionError::IoError(e), QueryError::ConnectionPoolError(e) => NewSessionError::ConnectionPoolError(e), QueryError::ProtocolError(m) => NewSessionError::ProtocolError(m), QueryError::InvalidMessage(m) => NewSessionError::InvalidMessage(m), @@ -207,10 +202,6 @@ pub enum NewSessionError { #[error("Failed to deserialize ERROR response: {0}")] CqlErrorParseError(#[from] CqlErrorParseError), - /// Input/Output error has occurred, connection broken etc. - #[error("IO Error: {0}")] - IoError(Arc), - /// Selected node's connection pool is in invalid state. #[error("No connections in the pool: {0}")] ConnectionPoolError(#[from] ConnectionPoolError), diff --git a/scylla/src/transport/load_balancing/default.rs b/scylla/src/transport/load_balancing/default.rs index 8f717a0df..51cc9fe6a 100644 --- a/scylla/src/transport/load_balancing/default.rs +++ b/scylla/src/transport/load_balancing/default.rs @@ -2855,7 +2855,6 @@ mod latency_awareness { | QueryError::CqlResultParseError(_) | QueryError::CqlErrorParseError(_) | QueryError::InvalidMessage(_) - | QueryError::IoError(_) | QueryError::ProtocolError(_) | QueryError::TimeoutError | QueryError::RequestTimeout(_) => true, diff --git a/scylla/src/transport/retry_policy.rs b/scylla/src/transport/retry_policy.rs index 75b3193c2..686f7643a 100644 --- a/scylla/src/transport/retry_policy.rs +++ b/scylla/src/transport/retry_policy.rs @@ -142,7 +142,8 @@ impl RetrySession for DefaultRetrySession { match query_info.error { // Basic errors - there are some problems on this node // Retry on a different one if possible - QueryError::IoError(_) + QueryError::BrokenConnection(_) + | QueryError::ConnectionPoolError(_) | QueryError::DbError(DbError::Overloaded, _) | QueryError::DbError(DbError::ServerError, _) | QueryError::DbError(DbError::TruncateError, _) => { @@ -221,11 +222,11 @@ mod tests { use super::{DefaultRetryPolicy, QueryInfo, RetryDecision, RetryPolicy}; use crate::statement::Consistency; use crate::test_utils::setup_tracing; - use crate::transport::errors::{BadQuery, QueryError}; + use crate::transport::errors::{ + BadQuery, BrokenConnectionErrorKind, ConnectionPoolError, QueryError, + }; use crate::transport::errors::{DbError, WriteType}; use bytes::Bytes; - use std::io::ErrorKind; - use std::sync::Arc; fn make_query_info(error: &QueryError, is_idempotent: bool) -> QueryInfo<'_> { QueryInfo { @@ -323,7 +324,10 @@ mod tests { QueryError::DbError(DbError::Overloaded, String::new()), QueryError::DbError(DbError::TruncateError, String::new()), QueryError::DbError(DbError::ServerError, String::new()), - QueryError::IoError(Arc::new(std::io::Error::new(ErrorKind::Other, "test"))), + QueryError::BrokenConnection( + BrokenConnectionErrorKind::TooManyOrphanedStreamIds(5).into(), + ), + QueryError::ConnectionPoolError(ConnectionPoolError::Initializing), ]; for error in idempotent_next_errors { diff --git a/scylla/src/transport/speculative_execution.rs b/scylla/src/transport/speculative_execution.rs index 7a09d860d..65389b5eb 100644 --- a/scylla/src/transport/speculative_execution.rs +++ b/scylla/src/transport/speculative_execution.rs @@ -85,7 +85,8 @@ impl SpeculativeExecutionPolicy for PercentileSpeculativeExecutionPolicy { fn can_be_ignored(result: &Result) -> bool { match result { Ok(_) => false, - Err(QueryError::IoError(_)) => true, + Err(QueryError::BrokenConnection(_)) => true, + Err(QueryError::ConnectionPoolError(_)) => true, Err(QueryError::TimeoutError) => true, _ => false, }