diff --git a/scylla-cql/src/errors.rs b/scylla-cql/src/errors.rs index fa96b2880c..26195009a3 100644 --- a/scylla-cql/src/errors.rs +++ b/scylla-cql/src/errors.rs @@ -1,13 +1,19 @@ //! This module contains various errors which can be returned by `scylla::Session` -use crate::frame::frame_errors::{CqlResponseParseError, FrameError, ParseError}; +use crate::frame::frame_errors::{ + CqlAuthChallengeParseError, CqlAuthSuccessParseError, CqlAuthenticateParseError, + CqlErrorParseError, CqlEventParseError, CqlResponseParseError, CqlResultParseError, + CqlSupportedParseError, FrameError, ParseError, +}; use crate::frame::protocol_features::ProtocolFeatures; use crate::frame::value::SerializeValuesError; use crate::types::deserialize::{DeserializationError, TypeCheckError}; use crate::types::serialize::SerializationError; use crate::Consistency; use bytes::Bytes; +use std::error::Error; use std::io::ErrorKind; +use std::net::IpAddr; use std::sync::Arc; use thiserror::Error; @@ -22,14 +28,22 @@ pub enum QueryError { #[error(transparent)] BadQuery(#[from] BadQuery), - /// Failed to deserialize a CQL response from the server. + /// Received a RESULT server response, but failed to deserialize it. #[error(transparent)] - CqlResponseParseError(#[from] CqlResponseParseError), + CqlResultParseError(#[from] CqlResultParseError), + + /// Received an ERROR server response, but failed to deserialize it. + #[error("Failed to deserialize ERROR response: {0}")] + CqlErrorParseError(#[from] CqlErrorParseError), /// Input/Output error has occurred, connection broken etc. #[error("IO Error: {0}")] IoError(Arc), + /// Selected node's connection pool is in invalid state. + #[error("No connections in the pool: {0}")] + ConnectionPoolError(#[from] ConnectionPoolError), + /// Unexpected message received #[error("Protocol Error: {0}")] ProtocolError(&'static str), @@ -45,16 +59,44 @@ pub enum QueryError { #[error("Too many orphaned stream ids: {0}")] TooManyOrphanedStreamIds(u16), + #[error(transparent)] + BrokenConnection(#[from] BrokenConnectionError), + #[error("Unable to allocate stream id")] UnableToAllocStreamId, /// Client timeout occurred before any response arrived #[error("Request timeout: {0}")] RequestTimeout(String), +} - /// Address translation failed - #[error("Address translation failed: {0}")] - TranslationError(#[from] TranslationError), +/// An error type that occurred when executing one of: +/// - QUERY +/// - PREPARE +/// - EXECUTE +/// - BATCH +/// +/// requests. +#[derive(Error, Debug)] +pub enum UserRequestError { + #[error("Database returned an error: {0}, Error message: {1}")] + DbError(DbError, String), + #[error(transparent)] + CqlResultParseError(#[from] CqlResultParseError), + #[error("Failed to deserialize ERROR response: {0}")] + CqlErrorParseError(#[from] CqlErrorParseError), + #[error( + "Received unexpected response from the server: {0}. Expected RESULT or ERROR response." + )] + UnexpectedResponse(CqlResponseKind), + #[error(transparent)] + BrokenConnectionError(#[from] BrokenConnectionError), + #[error(transparent)] + FrameError(#[from] FrameError), + #[error("Unable to allocate stream id")] + UnableToAllocStreamId, + #[error("Prepared statement Id changed, md5 sum should stay the same")] + RepreparedIdChanged, } /// An error sent from the database in response to a query @@ -338,6 +380,37 @@ pub enum WriteType { Other(String), } +/// Possible requests sent by the client. +#[derive(Debug, Copy, Clone)] +#[non_exhaustive] +pub enum CqlRequestKind { + Startup, + AuthResponse, + Options, + Query, + Prepare, + Execute, + Batch, + Register, +} + +impl std::fmt::Display for CqlRequestKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let kind_str = match self { + CqlRequestKind::Startup => "STARTUP", + CqlRequestKind::AuthResponse => "AUTH_RESPONSE", + CqlRequestKind::Options => "OPTIONS", + CqlRequestKind::Query => "QUERY", + CqlRequestKind::Prepare => "PREPARE", + CqlRequestKind::Execute => "EXECUTE", + CqlRequestKind::Batch => "BATCH", + CqlRequestKind::Register => "REGISTER", + }; + + f.write_str(kind_str) + } +} + /// Error caused by caller creating an invalid query #[derive(Error, Debug, Clone)] #[error("Invalid query passed to Session")] @@ -366,6 +439,37 @@ pub enum BadQuery { Other(String), } +/// Possible CQL responses received from the server +#[derive(Debug, Copy, Clone)] +#[non_exhaustive] +pub enum CqlResponseKind { + Error, + Ready, + Authenticate, + Supported, + Result, + Event, + AuthChallenge, + AuthSuccess, +} + +impl std::fmt::Display for CqlResponseKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let kind_str = match self { + CqlResponseKind::Error => "ERROR", + CqlResponseKind::Ready => "READY", + CqlResponseKind::Authenticate => "AUTHENTICATE", + CqlResponseKind::Supported => "SUPPORTED", + CqlResponseKind::Result => "RESULT", + CqlResponseKind::Event => "EVENT", + CqlResponseKind::AuthChallenge => "AUTH_CHALLENGE", + CqlResponseKind::AuthSuccess => "AUTH_SUCCESS", + }; + + f.write_str(kind_str) + } +} + /// Error that occurred during session creation #[derive(Error, Debug, Clone)] pub enum NewSessionError { @@ -386,14 +490,22 @@ pub enum NewSessionError { #[error(transparent)] BadQuery(#[from] BadQuery), - /// Failed to deserialize a CQL response from the server. + /// Received a RESULT server response, but failed to deserialize it. #[error(transparent)] - CqlResponseParseError(#[from] CqlResponseParseError), + CqlResultParseError(#[from] CqlResultParseError), + + /// Received an ERROR server response, but failed to deserialize it. + #[error("Failed to deserialize ERROR response: {0}")] + CqlErrorParseError(#[from] CqlErrorParseError), /// Input/Output error has occurred, connection broken etc. #[error("IO Error: {0}")] IoError(Arc), + /// Selected node's connection pool is in invalid state. + #[error("No connections in the pool: {0}")] + ConnectionPoolError(#[from] ConnectionPoolError), + /// Unexpected message received #[error("Protocol Error: {0}")] ProtocolError(&'static str), @@ -409,6 +521,9 @@ pub enum NewSessionError { #[error("Too many orphaned stream ids: {0}")] TooManyOrphanedStreamIds(u16), + #[error(transparent)] + BrokenConnection(#[from] BrokenConnectionError), + #[error("Unable to allocate stream id")] UnableToAllocStreamId, @@ -416,10 +531,6 @@ pub enum NewSessionError { /// during `Session` creation. #[error("Client timeout: {0}")] RequestTimeout(String), - - /// Address translation failed - #[error("Address translation failed: {0}")] - TranslationError(#[from] TranslationError), } /// Invalid keyspace name given to `Session::use_keyspace()` @@ -438,15 +549,246 @@ pub enum BadKeyspaceName { IllegalCharacter(String, char), } -impl std::fmt::Display for WriteType { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self) +// FIXME: this should be moved to scylla crate. +/// An error that appeared on a connection level. +/// It indicated that connection can no longer be used +/// and should be dropped. +#[derive(Error, Debug, Clone)] +#[non_exhaustive] +pub enum ConnectionError { + #[error("Connect timeout elapsed")] + ConnectTimeout, + #[error(transparent)] + IoError(Arc), + #[error("Could not find free source port for shard {0}")] + NoSourcePortForShard(u32), + #[error("Address translation failed: {0}")] + TranslationError(#[from] TranslationError), + #[error(transparent)] + BrokenConnection(#[from] BrokenConnectionError), + #[error(transparent)] + ConnectionSetupRequestError(#[from] ConnectionSetupRequestError), +} + +impl From for ConnectionError { + fn from(value: std::io::Error) -> Self { + ConnectionError::IoError(Arc::new(value)) + } +} + +impl ConnectionError { + /// Checks if this error indicates that a chosen source port/address cannot be bound. + /// This is caused by one of the following: + /// - The source address is already used by another socket, + /// - The source address is reserved and the process does not have sufficient privileges to use it. + pub fn is_address_unavailable_for_use(&self) -> bool { + if let ConnectionError::IoError(io_error) = self { + match io_error.kind() { + ErrorKind::AddrInUse | ErrorKind::PermissionDenied => return true, + _ => {} + } + } + + false + } +} + +/// An error that occurred during connection setup request execution. +/// It indicates that request needed to initiate a connection failed. +#[derive(Error, Debug, Clone)] +#[error("Failed to perform a connection setup request. Request: {request_kind}, reason: {error}")] +pub struct ConnectionSetupRequestError { + request_kind: CqlRequestKind, + error: ConnectionSetupRequestErrorKind, +} + +type AuthError = String; + +#[derive(Error, Debug, Clone)] +#[non_exhaustive] +pub enum ConnectionSetupRequestErrorKind { + // TODO: Make FrameError clonable. + #[error(transparent)] + FrameError(Arc), + #[error("Unable to allocate stream id")] + UnableToAllocStreamId, + #[error(transparent)] + BrokenConnection(#[from] BrokenConnectionError), + #[error("Database returned an error: {0}, Error message: {1}")] + DbError(DbError, String), + #[error("Received unexpected response from the server: {0}")] + UnexpectedResponse(CqlResponseKind), + #[error("Failed to deserialize SUPPORTED response: {0}")] + CqlSupportedParseError(#[from] CqlSupportedParseError), + #[error("Failed to deserialize AUTHENTICATE response: {0}")] + CqlAuthenticateParseError(#[from] CqlAuthenticateParseError), + #[error("Failed to deserialize AUTH_SUCCESS response: {0}")] + CqlAuthSuccessParseError(#[from] CqlAuthSuccessParseError), + #[error("Failed to deserialize AUTH_CHALLENGE response: {0}")] + CqlAuthChallengeParseError(#[from] CqlAuthChallengeParseError), + #[error("Failed to deserialize ERROR response: {0}")] + CqlErrorParseError(#[from] CqlErrorParseError), + #[error("Failed to start client's auth session: {0}")] + StartAuthSessionError(AuthError), + #[error("Failed to evaluate auth challenge on client side: {0}")] + AuthChallengeEvaluationError(AuthError), + #[error("Failed to finish auth challenge on client side: {0}")] + AuthFinishError(AuthError), + #[error("Authentication is required. You can use SessionBuilder::user(\"user\", \"pass\") to provide credentials or SessionBuilder::authenticator_provider to provide custom authenticator")] + MissingAuthentication, +} + +impl From for ConnectionSetupRequestErrorKind { + fn from(value: FrameError) -> Self { + ConnectionSetupRequestErrorKind::FrameError(Arc::new(value)) + } +} + +impl ConnectionSetupRequestError { + pub fn new(request_kind: CqlRequestKind, error: ConnectionSetupRequestErrorKind) -> Self { + ConnectionSetupRequestError { + request_kind, + error, + } + } +} + +/// An error that occurred when selecting a node connection +/// to perform a request on. +#[derive(Error, Debug, Clone)] +#[non_exhaustive] +pub enum ConnectionPoolError { + #[error("The pool is broken; Last connection failed with: {last_connection_error}")] + Broken { + last_connection_error: ConnectionError, + }, + #[error("Pool is still being initialized")] + Initializing, + #[error("The node has been disabled by a host filter")] + NodeDisabledByHostFilter, +} + +#[derive(Error, Debug, Clone)] +#[error("Connection broken, reason: {0}")] +pub struct BrokenConnectionError(Arc); + +impl BrokenConnectionError { + pub fn get_inner(&self) -> &Arc { + &self.0 + } +} + +#[derive(Error, Debug)] +#[non_exhaustive] +pub enum BrokenConnectionErrorKind { + #[error("Timed out while waiting for response to keepalive request on connection to node {0}")] + KeepaliveTimeout(IpAddr), + #[error("Failed to execute keepalive query: {0}")] + KeepaliveQueryError(RequestError), + #[error("Failed to deserialize frame: {0}")] + FrameError(FrameError), + #[error("Failed to handle server event: {0}")] + CqlEventHandlingError(#[from] CqlEventHandlingError), + #[error("Received a server frame with unexpected stream id: {0}")] + UnexpectedStreamId(i16), + #[error("Failed to write data: {0}")] + WriteError(std::io::Error), + #[error("Too many orphaned stream ids: {0}")] + TooManyOrphanedStreamIds(u16), + #[error( + "Failed to send/receive data needed to perform a request via tokio channel. + It implies that other half of the channel has been dropped. + The connection was already broken for some other reason." + )] + ChannelError, +} + +/// Failed to handle a CQL event received on a stream -1. +/// Possible error kinds are: +/// - failed to deserialize response's frame header +/// - failed to deserialize CQL event response +/// - received invalid server response +/// - failed to send an event info via channel (connection is probably broken) +#[derive(Error, Debug)] +#[non_exhaustive] +pub enum CqlEventHandlingError { + #[error("Failed to deserialize EVENT response: {0}")] + CqlEventParseError(#[from] CqlEventParseError), + #[error("Received unexpected server response on stream -1: {0}. Expected EVENT response")] + UnexpectedResponse(CqlResponseKind), + #[error("Failed to deserialize a header of frame received on stream -1: {0}")] + FrameError(#[from] FrameError), + #[error("Failed to send event info via channel. The channel is probably closed, which is caused by connection being broken")] + SendError, +} + +/// An error type returned from Connection::parse_response. +/// This is driver's internal type. +#[derive(Error, Debug)] +pub enum ResponseParseError { + #[error(transparent)] + FrameError(#[from] FrameError), + #[error(transparent)] + CqlResponseParseError(#[from] CqlResponseParseError), +} + +/// An error that occurred when performing a request. +/// +/// Possible error kinds: +/// - Connection is broken +/// - Response's frame header deserialization error +/// - CQL response (frame body) deserialization error +/// - Driver was unable to allocate a stream id for a request +/// +/// This error type is only destined to narrow the return error type +/// of some functions that would previously return [`crate::errors::QueryError`]. +#[derive(Error, Debug)] +pub enum RequestError { + #[error(transparent)] + FrameError(#[from] FrameError), + #[error(transparent)] + CqlResponseParseError(#[from] CqlResponseParseError), + #[error(transparent)] + BrokenConnection(#[from] BrokenConnectionError), + #[error("Unable to allocate a stream id")] + UnableToAllocStreamId, +} + +impl From for UserRequestError { + fn from(value: RequestError) -> Self { + match value { + RequestError::FrameError(e) => e.into(), + RequestError::CqlResponseParseError(e) => match e { + // Only possible responses are RESULT and ERROR. If we failed parsing + // other response, treat it as unexpected response. + CqlResponseParseError::CqlErrorParseError(e) => e.into(), + CqlResponseParseError::CqlResultParseError(e) => e.into(), + _ => UserRequestError::UnexpectedResponse(e.to_response_kind()), + }, + RequestError::BrokenConnection(e) => e.into(), + RequestError::UnableToAllocStreamId => UserRequestError::UnableToAllocStreamId, + } } } -impl From for QueryError { - fn from(io_error: std::io::Error) -> QueryError { - QueryError::IoError(Arc::new(io_error)) +impl From for BrokenConnectionError { + fn from(value: BrokenConnectionErrorKind) -> Self { + BrokenConnectionError(Arc::new(value)) + } +} + +impl From for RequestError { + fn from(value: ResponseParseError) -> Self { + match value { + ResponseParseError::FrameError(e) => e.into(), + ResponseParseError::CqlResponseParseError(e) => e.into(), + } + } +} + +impl std::fmt::Display for WriteType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self) } } @@ -492,6 +834,26 @@ impl From for QueryError { } } +impl From for QueryError { + fn from(value: UserRequestError) -> Self { + match value { + UserRequestError::DbError(err, msg) => QueryError::DbError(err, msg), + UserRequestError::CqlResultParseError(e) => e.into(), + UserRequestError::CqlErrorParseError(e) => e.into(), + UserRequestError::BrokenConnectionError(e) => e.into(), + UserRequestError::UnexpectedResponse(_) => { + // FIXME: make it typed. It needs to wait for ProtocolError refactor. + QueryError::ProtocolError("Received unexpected response from the server. Expected RESULT or ERROR response.") + } + UserRequestError::FrameError(e) => e.into(), + UserRequestError::UnableToAllocStreamId => QueryError::UnableToAllocStreamId, + UserRequestError::RepreparedIdChanged => QueryError::ProtocolError( + "Prepared statement Id changed, md5 sum should stay the same", + ), + } + } +} + impl From for NewSessionError { fn from(io_error: std::io::Error) -> NewSessionError { NewSessionError::IoError(Arc::new(io_error)) @@ -503,17 +865,19 @@ impl From for NewSessionError { match query_error { QueryError::DbError(e, msg) => NewSessionError::DbError(e, msg), QueryError::BadQuery(e) => NewSessionError::BadQuery(e), - QueryError::CqlResponseParseError(e) => NewSessionError::CqlResponseParseError(e), + QueryError::CqlResultParseError(e) => NewSessionError::CqlResultParseError(e), + QueryError::CqlErrorParseError(e) => NewSessionError::CqlErrorParseError(e), QueryError::IoError(e) => NewSessionError::IoError(e), + QueryError::ConnectionPoolError(e) => NewSessionError::ConnectionPoolError(e), QueryError::ProtocolError(m) => NewSessionError::ProtocolError(m), QueryError::InvalidMessage(m) => NewSessionError::InvalidMessage(m), QueryError::TimeoutError => NewSessionError::TimeoutError, QueryError::TooManyOrphanedStreamIds(ids) => { NewSessionError::TooManyOrphanedStreamIds(ids) } + QueryError::BrokenConnection(e) => NewSessionError::BrokenConnection(e), QueryError::UnableToAllocStreamId => NewSessionError::UnableToAllocStreamId, QueryError::RequestTimeout(msg) => NewSessionError::RequestTimeout(msg), - QueryError::TranslationError(e) => NewSessionError::TranslationError(e), } } } @@ -524,23 +888,6 @@ impl From for QueryError { } } -impl QueryError { - /// Checks if this error indicates that a chosen source port/address cannot be bound. - /// This is caused by one of the following: - /// - The source address is already used by another socket, - /// - The source address is reserved and the process does not have sufficient privileges to use it. - pub fn is_address_unavailable_for_use(&self) -> bool { - if let QueryError::IoError(io_error) = self { - match io_error.kind() { - ErrorKind::AddrInUse | ErrorKind::PermissionDenied => return true, - _ => {} - } - } - - false - } -} - impl From for OperationType { fn from(operation_type: u8) -> OperationType { match operation_type { diff --git a/scylla-cql/src/frame/frame_errors.rs b/scylla-cql/src/frame/frame_errors.rs index 155491ef91..d5e9fa859d 100644 --- a/scylla-cql/src/frame/frame_errors.rs +++ b/scylla-cql/src/frame/frame_errors.rs @@ -2,6 +2,7 @@ use std::sync::Arc; use super::TryFromPrimitiveError; use crate::cql_to_rust::CqlTypeError; +use crate::errors::CqlResponseKind; use crate::frame::value::SerializeValuesError; use crate::types::deserialize::{DeserializationError, TypeCheckError}; use crate::types::serialize::SerializationError; @@ -78,6 +79,20 @@ pub enum CqlResponseParseError { CqlResultParseError(#[from] CqlResultParseError), } +impl CqlResponseParseError { + pub fn to_response_kind(&self) -> CqlResponseKind { + match self { + CqlResponseParseError::CqlErrorParseError(_) => CqlResponseKind::Error, + CqlResponseParseError::CqlAuthChallengeParseError(_) => CqlResponseKind::AuthChallenge, + CqlResponseParseError::CqlAuthSuccessParseError(_) => CqlResponseKind::AuthSuccess, + CqlResponseParseError::CqlAuthenticateParseError(_) => CqlResponseKind::Authenticate, + CqlResponseParseError::CqlSupportedParseError(_) => CqlResponseKind::Supported, + CqlResponseParseError::CqlEventParseError(_) => CqlResponseKind::Event, + CqlResponseParseError::CqlResultParseError(_) => CqlResponseKind::Result, + } + } +} + /// An error type returned when deserialization of ERROR response fails. #[non_exhaustive] #[derive(Error, Debug, Clone)] @@ -134,15 +149,15 @@ pub enum CqlResultParseError { ResultIdParseError(LowLevelDeserializationError), #[error("Unknown RESULT response id: {0}")] UnknownResultId(i32), - #[error("'Set_keyspace' response deserialization failed: {0}")] + #[error("RESULT:Set_keyspace response deserialization failed: {0}")] SetKeyspaceParseError(#[from] SetKeyspaceParseError), // This is an error returned during deserialization of // `RESULT::Schema_change` response, and not `EVENT` response. - #[error("'Schema_change' response deserialization failed: {0}")] + #[error("RESULT:Schema_change response deserialization failed: {0}")] SchemaChangeParseError(#[from] SchemaChangeEventParseError), - #[error("'Prepared' response deserialization failed: {0}")] + #[error("RESULT:Prepared response deserialization failed: {0}")] PreparedParseError(#[from] PreparedParseError), - #[error("'Rows' response deserialization failed: {0}")] + #[error("RESULT:Rows response deserialization failed: {0}")] RowsParseError(#[from] RowsParseError), } diff --git a/scylla-cql/src/frame/response/mod.rs b/scylla-cql/src/frame/response/mod.rs index d084eb71c9..516f7d24e9 100644 --- a/scylla-cql/src/frame/response/mod.rs +++ b/scylla-cql/src/frame/response/mod.rs @@ -10,7 +10,7 @@ use std::sync::Arc; pub use error::Error; pub use supported::Supported; -use crate::errors::QueryError; +use crate::errors::{CqlResponseKind, UserRequestError}; use crate::frame::protocol_features::ProtocolFeatures; use crate::frame::response::result::ResultMetadata; use crate::frame::TryFromPrimitiveError; @@ -64,6 +64,19 @@ pub enum Response { } impl Response { + pub fn to_response_kind(&self) -> CqlResponseKind { + match self { + Response::Error(_) => CqlResponseKind::Error, + Response::Ready => CqlResponseKind::Ready, + Response::Result(_) => CqlResponseKind::Result, + Response::Authenticate(_) => CqlResponseKind::Authenticate, + Response::AuthSuccess(_) => CqlResponseKind::AuthSuccess, + Response::AuthChallenge(_) => CqlResponseKind::AuthChallenge, + Response::Supported(_) => CqlResponseKind::Supported, + Response::Event(_) => CqlResponseKind::Event, + } + } + pub fn deserialize( features: &ProtocolFeatures, opcode: ResponseOpcode, @@ -93,9 +106,11 @@ impl Response { Ok(response) } - pub fn into_non_error_response(self) -> Result { - Ok(match self { - Response::Error(err) => return Err(QueryError::from(err)), + pub fn into_non_error_response(self) -> Result { + let non_error_response = match self { + Response::Error(error::Error { error, reason }) => { + return Err(UserRequestError::DbError(error, reason)) + } Response::Ready => NonErrorResponse::Ready, Response::Result(res) => NonErrorResponse::Result(res), Response::Authenticate(auth) => NonErrorResponse::Authenticate(auth), @@ -103,7 +118,9 @@ impl Response { Response::AuthChallenge(auth_chal) => NonErrorResponse::AuthChallenge(auth_chal), Response::Supported(sup) => NonErrorResponse::Supported(sup), Response::Event(eve) => NonErrorResponse::Event(eve), - }) + }; + + Ok(non_error_response) } } @@ -118,3 +135,17 @@ pub enum NonErrorResponse { Supported(Supported), Event(event::Event), } + +impl NonErrorResponse { + pub fn to_response_kind(&self) -> CqlResponseKind { + match self { + NonErrorResponse::Ready => CqlResponseKind::Ready, + NonErrorResponse::Result(_) => CqlResponseKind::Result, + NonErrorResponse::Authenticate(_) => CqlResponseKind::Authenticate, + NonErrorResponse::AuthSuccess(_) => CqlResponseKind::AuthSuccess, + NonErrorResponse::AuthChallenge(_) => CqlResponseKind::AuthChallenge, + NonErrorResponse::Supported(_) => CqlResponseKind::Supported, + NonErrorResponse::Event(_) => CqlResponseKind::Event, + } + } +} diff --git a/scylla/src/transport/connection.rs b/scylla/src/transport/connection.rs index 456f08d386..9709f682eb 100644 --- a/scylla/src/transport/connection.rs +++ b/scylla/src/transport/connection.rs @@ -1,9 +1,15 @@ use bytes::Bytes; use futures::{future::RemoteHandle, FutureExt}; -use scylla_cql::errors::TranslationError; +use scylla_cql::errors::{ + BrokenConnectionError, BrokenConnectionErrorKind, ConnectionError, ConnectionSetupRequestError, + ConnectionSetupRequestErrorKind, CqlEventHandlingError, CqlRequestKind, RequestError, + ResponseParseError, TranslationError, UserRequestError, +}; +use scylla_cql::frame::frame_errors::CqlResponseParseError; use scylla_cql::frame::request::options::{self, Options}; use scylla_cql::frame::response::result::{ResultMetadata, TableSpec}; use scylla_cql::frame::response::Error; +use scylla_cql::frame::response::{self, error}; use scylla_cql::frame::types::SerialConsistency; use scylla_cql::types::serialize::batch::{BatchValues, BatchValuesIterator}; use scylla_cql::types::serialize::raw_batch::RawBatchValuesAdapter; @@ -31,7 +37,6 @@ use crate::authentication::AuthenticatorProvider; use scylla_cql::frame::response::authenticate::Authenticate; use std::collections::{BTreeSet, HashMap, HashSet}; use std::convert::TryFrom; -use std::io::ErrorKind; use std::net::{IpAddr, SocketAddr}; use std::sync::Arc; use std::sync::Mutex as StdMutex; @@ -112,7 +117,7 @@ impl RouterHandle { request: &impl SerializableRequest, compression: Option, tracing: bool, - ) -> Result { + ) -> Result { let serialized_request = SerializedRequest::make(request, compression, tracing)?; let request_id = self.allocate_request_id(); @@ -133,18 +138,12 @@ impl RouterHandle { response_handler, }) .await - .map_err(|_| { - QueryError::IoError(Arc::new(std::io::Error::new( - ErrorKind::Other, - "Connection broken", - ))) + .map_err(|_| -> BrokenConnectionError { + BrokenConnectionErrorKind::ChannelError.into() })?; - let task_response = receiver.await.map_err(|_| { - QueryError::IoError(Arc::new(std::io::Error::new( - ErrorKind::Other, - "Connection broken", - ))) + let task_response = receiver.await.map_err(|_| -> BrokenConnectionError { + BrokenConnectionErrorKind::ChannelError.into() })?; // Response was successfully received, so it's time to disable @@ -165,7 +164,7 @@ pub(crate) struct ConnectionFeatures { type RequestId = u64; struct ResponseHandler { - response_sender: oneshot::Sender>, + response_sender: oneshot::Sender>, request_id: RequestId, } @@ -229,7 +228,9 @@ pub(crate) struct NonErrorQueryResponse { } impl QueryResponse { - pub(crate) fn into_non_error_query_response(self) -> Result { + pub(crate) fn into_non_error_query_response( + self, + ) -> Result { Ok(NonErrorQueryResponse { response: self.response.into_non_error_response()?, tracing_id: self.tracing_id, @@ -239,7 +240,7 @@ impl QueryResponse { pub(crate) fn into_query_result_and_paging_state( self, - ) -> Result<(QueryResult, PagingStateResponse), QueryError> { + ) -> Result<(QueryResult, PagingStateResponse), UserRequestError> { self.into_non_error_query_response()? .into_query_result_and_paging_state() } @@ -266,7 +267,7 @@ impl NonErrorQueryResponse { pub(crate) fn into_query_result_and_paging_state( self, - ) -> Result<(QueryResult, PagingStateResponse), QueryError> { + ) -> Result<(QueryResult, PagingStateResponse), UserRequestError> { let (rows, paging_state, metadata, serialized_size) = match self.response { NonErrorResponse::Result(result::Result::Rows(rs)) => ( Some(rs.rows), @@ -276,8 +277,8 @@ impl NonErrorQueryResponse { ), NonErrorResponse::Result(_) => (None, PagingStateResponse::NoMorePages, None, 0), _ => { - return Err(QueryError::ProtocolError( - "Unexpected server response, expected Result or Error", + return Err(UserRequestError::UnexpectedResponse( + self.response.to_response_kind(), )) } }; @@ -307,6 +308,17 @@ impl NonErrorQueryResponse { Ok(result) } } + +pub(crate) enum NonErrorStartupResponse { + Ready, + Authenticate(response::authenticate::Authenticate), +} + +pub(crate) enum NonErrorAuthResponse { + AuthChallenge(response::authenticate::AuthChallenge), + AuthSuccess(response::authenticate::AuthSuccess), +} + #[cfg(feature = "ssl")] mod ssl_config { use openssl::{ @@ -625,7 +637,7 @@ impl ConnectionConfig { } // Used to listen for fatal error in connection -pub(crate) type ErrorReceiver = tokio::sync::oneshot::Receiver; +pub(crate) type ErrorReceiver = tokio::sync::oneshot::Receiver; impl Connection { // Returns new connection and ErrorReceiver which can be used to wait for a fatal error @@ -635,7 +647,7 @@ impl Connection { addr: SocketAddr, source_port: Option, config: ConnectionConfig, - ) -> Result<(Self, ErrorReceiver), QueryError> { + ) -> Result<(Self, ErrorReceiver), ConnectionError> { let stream_connector = match source_port { Some(p) => { tokio::time::timeout(config.connect_timeout, connect_with_source_port(addr, p)) @@ -646,7 +658,7 @@ impl Connection { let stream = match stream_connector { Ok(stream) => stream?, Err(_) => { - return Err(QueryError::TimeoutError); + return Err(ConnectionError::ConnectTimeout); } }; stream.set_nodelay(config.tcp_nodelay)?; @@ -744,21 +756,103 @@ impl Connection { pub(crate) async fn startup( &self, options: HashMap, Cow<'_, str>>, - ) -> Result { - Ok(self + ) -> Result { + let err = |kind: ConnectionSetupRequestErrorKind| { + ConnectionSetupRequestError::new(CqlRequestKind::Startup, kind) + }; + + let req_result = self .send_request(&request::Startup { options }, false, false, None) - .await? - .response) + .await; + + // Extract the response to STARTUP request and tidy up the errors. + let response = match req_result { + Ok(r) => match r.response { + Response::Ready => NonErrorStartupResponse::Ready, + Response::Authenticate(auth) => NonErrorStartupResponse::Authenticate(auth), + Response::Error(Error { error, reason }) => { + return Err(err(ConnectionSetupRequestErrorKind::DbError(error, reason))) + } + _ => { + return Err(err(ConnectionSetupRequestErrorKind::UnexpectedResponse( + r.response.to_response_kind(), + ))) + } + }, + Err(e) => match e { + RequestError::FrameError(e) => return Err(err(e.into())), + RequestError::CqlResponseParseError(e) => match e { + // Parsing of READY response cannot fail, since its body is empty. + // Remaining valid responses are AUTHENTICATE and ERROR. + CqlResponseParseError::CqlAuthenticateParseError(e) => { + return Err(err(e.into())) + } + CqlResponseParseError::CqlErrorParseError(e) => return Err(err(e.into())), + _ => { + return Err(err(ConnectionSetupRequestErrorKind::UnexpectedResponse( + e.to_response_kind(), + ))) + } + }, + RequestError::BrokenConnection(e) => return Err(err(e.into())), + RequestError::UnableToAllocStreamId => { + return Err(err(ConnectionSetupRequestErrorKind::UnableToAllocStreamId)) + } + }, + }; + + Ok(response) } - pub(crate) async fn get_options(&self) -> Result { - Ok(self + pub(crate) async fn get_options( + &self, + ) -> Result { + let err = |kind: ConnectionSetupRequestErrorKind| { + ConnectionSetupRequestError::new(CqlRequestKind::Options, kind) + }; + + let req_result = self .send_request(&request::Options {}, false, false, None) - .await? - .response) + .await; + + // Extract the supported options and tidy up the errors. + let supported = match req_result { + Ok(r) => match r.response { + Response::Supported(supported) => supported, + Response::Error(Error { error, reason }) => { + return Err(err(ConnectionSetupRequestErrorKind::DbError(error, reason))) + } + _ => { + return Err(err(ConnectionSetupRequestErrorKind::UnexpectedResponse( + r.response.to_response_kind(), + ))) + } + }, + Err(e) => match e { + RequestError::FrameError(e) => return Err(err(e.into())), + RequestError::CqlResponseParseError(e) => match e { + CqlResponseParseError::CqlSupportedParseError(e) => return Err(err(e.into())), + CqlResponseParseError::CqlErrorParseError(e) => return Err(err(e.into())), + _ => { + return Err(err(ConnectionSetupRequestErrorKind::UnexpectedResponse( + e.to_response_kind(), + ))) + } + }, + RequestError::BrokenConnection(e) => return Err(err(e.into())), + RequestError::UnableToAllocStreamId => { + return Err(err(ConnectionSetupRequestErrorKind::UnableToAllocStreamId)) + } + }, + }; + + Ok(supported) } - pub(crate) async fn prepare(&self, query: &Query) -> Result { + pub(crate) async fn prepare( + &self, + query: &Query, + ) -> Result { let query_response = self .send_request( &request::Prepare { @@ -771,7 +865,9 @@ impl Connection { .await?; let mut prepared_statement = match query_response.response { - Response::Error(err) => return Err(err.into()), + Response::Error(error::Error { error, reason }) => { + return Err(UserRequestError::DbError(error, reason)) + } Response::Result(result::Result::Prepared(p)) => PreparedStatement::new( p.id, self.features @@ -784,8 +880,8 @@ impl Connection { query.config.clone(), ), _ => { - return Err(QueryError::ProtocolError( - "PREPARE: Unexpected server response", + return Err(UserRequestError::UnexpectedResponse( + query_response.response.to_response_kind(), )) } }; @@ -800,15 +896,13 @@ impl Connection { &self, query: impl Into, previous_prepared: &PreparedStatement, - ) -> Result<(), QueryError> { + ) -> Result<(), UserRequestError> { let reprepare_query: Query = query.into(); let reprepared = self.prepare(&reprepare_query).await?; // Reprepared statement should keep its id - it's the md5 sum // of statement contents if reprepared.get_id() != previous_prepared.get_id() { - Err(QueryError::ProtocolError( - "Prepared statement Id changed, md5 sum should stay the same", - )) + Err(UserRequestError::RepreparedIdChanged) } else { Ok(()) } @@ -817,9 +911,57 @@ impl Connection { pub(crate) async fn authenticate_response( &self, response: Option>, - ) -> Result { - self.send_request(&request::AuthResponse { response }, false, false, None) - .await + ) -> Result { + let err = |kind: ConnectionSetupRequestErrorKind| { + ConnectionSetupRequestError::new(CqlRequestKind::AuthResponse, kind) + }; + + let req_result = self + .send_request(&request::AuthResponse { response }, false, false, None) + .await; + + // Extract non-error response to AUTH_RESPONSE request and tidy up errors. + let response = match req_result { + Ok(r) => match r.response { + Response::AuthSuccess(auth_success) => { + NonErrorAuthResponse::AuthSuccess(auth_success) + } + Response::AuthChallenge(auth_challenge) => { + NonErrorAuthResponse::AuthChallenge(auth_challenge) + } + Response::Error(Error { error, reason }) => { + return Err(err(ConnectionSetupRequestErrorKind::DbError(error, reason))) + } + _ => { + return Err(err(ConnectionSetupRequestErrorKind::UnexpectedResponse( + r.response.to_response_kind(), + ))) + } + }, + Err(e) => match e { + RequestError::FrameError(e) => return Err(err(e.into())), + RequestError::CqlResponseParseError(e) => match e { + CqlResponseParseError::CqlAuthSuccessParseError(e) => { + return Err(err(e.into())) + } + CqlResponseParseError::CqlAuthChallengeParseError(e) => { + return Err(err(e.into())) + } + CqlResponseParseError::CqlErrorParseError(e) => return Err(err(e.into())), + _ => { + return Err(err(ConnectionSetupRequestErrorKind::UnexpectedResponse( + e.to_response_kind(), + ))) + } + }, + RequestError::BrokenConnection(e) => return Err(err(e.into())), + RequestError::UnableToAllocStreamId => { + return Err(err(ConnectionSetupRequestErrorKind::UnableToAllocStreamId)) + } + }, + }; + + Ok(response) } #[allow(dead_code)] @@ -827,7 +969,7 @@ impl Connection { &self, query: impl Into, paging_state: PagingState, - ) -> Result<(QueryResult, PagingStateResponse), QueryError> { + ) -> Result<(QueryResult, PagingStateResponse), UserRequestError> { let query: Query = query.into(); // This method is used only for driver internal queries, so no need to consult execution profile here. @@ -852,7 +994,7 @@ impl Connection { paging_state: PagingState, consistency: Consistency, serial_consistency: Option, - ) -> Result<(QueryResult, PagingStateResponse), QueryError> { + ) -> Result<(QueryResult, PagingStateResponse), UserRequestError> { let query: Query = query.into(); let page_size = query.get_validated_page_size(); @@ -877,6 +1019,7 @@ impl Connection { self.query_raw_unpaged(&query, PagingState::start()) .await + .map_err(Into::into) .and_then(QueryResponse::into_query_result) } @@ -884,7 +1027,7 @@ impl Connection { &self, query: &Query, paging_state: PagingState, - ) -> Result { + ) -> Result { // This method is used only for driver internal queries, so no need to consult execution profile here. self.query_raw_with_consistency( query, @@ -905,7 +1048,7 @@ impl Connection { serial_consistency: Option, page_size: Option, paging_state: PagingState, - ) -> Result { + ) -> Result { let query_frame = query::Query { contents: Cow::Borrowed(&query.contents), parameters: query::QueryParameters { @@ -919,8 +1062,11 @@ impl Connection { }, }; - self.send_request(&query_frame, true, query.config.tracing, None) - .await + let response = self + .send_request(&query_frame, true, query.config.tracing, None) + .await?; + + Ok(response) } #[allow(dead_code)] @@ -932,6 +1078,7 @@ impl Connection { // This method is used only for driver internal queries, so no need to consult execution profile here. self.execute_raw_unpaged(prepared, values, PagingState::start()) .await + .map_err(Into::into) .and_then(QueryResponse::into_query_result) } @@ -941,7 +1088,7 @@ impl Connection { prepared: &PreparedStatement, values: SerializedValues, paging_state: PagingState, - ) -> Result { + ) -> Result { // This method is used only for driver internal queries, so no need to consult execution profile here. self.execute_raw_with_consistency( prepared, @@ -964,7 +1111,7 @@ impl Connection { serial_consistency: Option, page_size: Option, paging_state: PagingState, - ) -> Result { + ) -> Result { let execute_frame = execute::Execute { id: prepared_statement.get_id().to_owned(), parameters: query::QueryParameters { @@ -1117,7 +1264,8 @@ impl Connection { loop { let query_response = self .send_request(&batch_frame, true, batch.config.tracing, None) - .await?; + .await + .map_err(UserRequestError::from)?; return match query_response.response { Response::Error(err) => match err.error { @@ -1231,21 +1379,40 @@ impl Connection { async fn register( &self, event_types_to_register_for: Vec, - ) -> Result<(), QueryError> { + ) -> Result<(), ConnectionSetupRequestError> { + let err = |kind: ConnectionSetupRequestErrorKind| { + ConnectionSetupRequestError::new(CqlRequestKind::Register, kind) + }; + let register_frame = register::Register { event_types_to_register_for, }; - match self - .send_request(®ister_frame, true, false, None) - .await? - .response - { - Response::Ready => Ok(()), - Response::Error(err) => Err(err.into()), - _ => Err(QueryError::ProtocolError( - "Unexpected response to REGISTER message", - )), + // Extract the response and tidy up the errors. + match self.send_request(®ister_frame, true, false, None).await { + Ok(r) => match r.response { + Response::Ready => Ok(()), + Response::Error(Error { error, reason }) => { + Err(err(ConnectionSetupRequestErrorKind::DbError(error, reason))) + } + _ => Err(err(ConnectionSetupRequestErrorKind::UnexpectedResponse( + r.response.to_response_kind(), + ))), + }, + Err(e) => match e { + RequestError::FrameError(e) => Err(err(e.into())), + RequestError::CqlResponseParseError(e) => match e { + // Parsing the READY response cannot fail. Only remaining valid response is ERROR. + CqlResponseParseError::CqlErrorParseError(e) => Err(err(e.into())), + _ => Err(err(ConnectionSetupRequestErrorKind::UnexpectedResponse( + e.to_response_kind(), + ))), + }, + RequestError::BrokenConnection(e) => Err(err(e.into())), + RequestError::UnableToAllocStreamId => { + Err(err(ConnectionSetupRequestErrorKind::UnableToAllocStreamId)) + } + }, } } @@ -1274,7 +1441,7 @@ impl Connection { compress: bool, tracing: bool, cached_metadata: Option<&Arc>, - ) -> Result { + ) -> Result { let compression = if compress { self.config.compression } else { @@ -1286,12 +1453,14 @@ impl Connection { .send_request(request, compression, tracing) .await?; - Self::parse_response( + let response = Self::parse_response( task_response, self.config.compression, &self.features.protocol_features, cached_metadata, - ) + )?; + + Ok(response) } fn parse_response( @@ -1299,7 +1468,7 @@ impl Connection { compression: Option, features: &ProtocolFeatures, cached_metadata: Option<&Arc>, - ) -> Result { + ) -> Result { let body_with_ext = frame::parse_response_body_extensions( task_response.params.flags, compression, @@ -1332,7 +1501,7 @@ impl Connection { config: ConnectionConfig, stream: TcpStream, receiver: mpsc::Receiver, - error_sender: tokio::sync::oneshot::Sender, + error_sender: tokio::sync::oneshot::Sender, orphan_notification_receiver: mpsc::UnboundedReceiver, router_handle: Arc, node_address: IpAddr, @@ -1375,7 +1544,7 @@ impl Connection { config: ConnectionConfig, stream: (impl AsyncRead + AsyncWrite), receiver: mpsc::Receiver, - error_sender: tokio::sync::oneshot::Sender, + error_sender: tokio::sync::oneshot::Sender, orphan_notification_receiver: mpsc::UnboundedReceiver, router_handle: Arc, node_address: IpAddr, @@ -1419,7 +1588,7 @@ impl Connection { let result = futures::try_join!(r, w, o, k); - let error: QueryError = match result { + let error: BrokenConnectionError = match result { Ok(_) => return, // Connection was dropped, we can return Err(err) => err, }; @@ -1430,20 +1599,22 @@ impl Connection { for (_, handler) in response_handlers { // Ignore sending error, request was dropped - let _ = handler.response_sender.send(Err(error.clone())); + let _ = handler.response_sender.send(Err(error.clone().into())); } // If someone is listening for connection errors notify them - let _ = error_sender.send(error); + let _ = error_sender.send(error.into()); } async fn reader( mut read_half: (impl AsyncRead + Unpin), handler_map: &StdMutex, config: ConnectionConfig, - ) -> Result<(), QueryError> { + ) -> Result<(), BrokenConnectionError> { loop { - let (params, opcode, body) = frame::read_response_frame(&mut read_half).await?; + let (params, opcode, body) = frame::read_response_frame(&mut read_half) + .await + .map_err(BrokenConnectionErrorKind::FrameError)?; let response = TaskResponse { params, opcode, @@ -1459,7 +1630,9 @@ impl Connection { } Ordering::Equal => { if let Some(event_sender) = config.event_sender.as_ref() { - Self::handle_event(response, config.compression, event_sender).await?; + Self::handle_event(response, config.compression, event_sender) + .await + .map_err(BrokenConnectionErrorKind::CqlEventHandlingError)? } continue; } @@ -1488,9 +1661,7 @@ impl Connection { "Received response with unexpected StreamId {}", params.stream ); - return Err(QueryError::ProtocolError( - "Received response with unexpected StreamId", - )); + return Err(BrokenConnectionErrorKind::UnexpectedStreamId(params.stream).into()); } Orphaned => { // Do nothing, handler was freed because this stream_id has @@ -1513,7 +1684,7 @@ impl Connection { error!("Could not allocate stream id"); let _ = response_handler .response_sender - .send(Err(QueryError::UnableToAllocStreamId)); + .send(Err(RequestError::UnableToAllocStreamId)); None } } @@ -1524,7 +1695,7 @@ impl Connection { handler_map: &StdMutex, mut task_receiver: mpsc::Receiver, enable_write_coalescing: bool, - ) -> Result<(), QueryError> { + ) -> Result<(), BrokenConnectionError> { // When the Connection object is dropped, the sender half // of the channel will be dropped, this task will return an error // and the whole worker will be stopped @@ -1537,7 +1708,10 @@ impl Connection { let req_data: &[u8] = req.get_data(); total_sent += req_data.len(); num_requests += 1; - write_half.write_all(req_data).await?; + write_half + .write_all(req_data) + .await + .map_err(BrokenConnectionErrorKind::WriteError)?; task = match task_receiver.try_recv() { Ok(t) => t, Err(_) if enable_write_coalescing => { @@ -1554,7 +1728,10 @@ impl Connection { } } trace!("Sending {} requests; {} bytes", num_requests, total_sent); - write_half.flush().await?; + write_half + .flush() + .await + .map_err(BrokenConnectionErrorKind::WriteError)?; } Ok(()) @@ -1567,7 +1744,7 @@ impl Connection { async fn orphaner( handler_map: &StdMutex, mut orphan_receiver: mpsc::UnboundedReceiver, - ) -> Result<(), QueryError> { + ) -> Result<(), BrokenConnectionError> { let mut interval = tokio::time::interval(OLD_AGE_ORPHAN_THRESHOLD); loop { tokio::select! { @@ -1581,7 +1758,7 @@ impl Connection { "Too many old orphaned stream ids: {}", old_orphan_count, ); - return Err(QueryError::TooManyOrphanedStreamIds(old_orphan_count as u16)) + return Err(BrokenConnectionErrorKind::TooManyOrphanedStreamIds(old_orphan_count as u16).into()) } } Some(request_id) = orphan_receiver.recv() => { @@ -1604,12 +1781,15 @@ impl Connection { keepalive_interval: Option, keepalive_timeout: Option, node_address: IpAddr, // This address is only used to enrich the log messages - ) -> Result<(), QueryError> { - async fn issue_keepalive_query(router_handle: &RouterHandle) -> Result<(), QueryError> { + ) -> Result<(), BrokenConnectionError> { + async fn issue_keepalive_query( + router_handle: &RouterHandle, + ) -> Result<(), BrokenConnectionError> { router_handle .send_request(&Options, None, false) .await .map(|_| ()) + .map_err(|q_err| BrokenConnectionErrorKind::KeepaliveQueryError(q_err).into()) } if let Some(keepalive_interval) = keepalive_interval { @@ -1631,13 +1811,9 @@ impl Connection { "Timed out while waiting for response to keepalive request on connection to node {}", node_address ); - return Err(QueryError::IoError(Arc::new(std::io::Error::new( - std::io::ErrorKind::Other, - format!( - "Timed out while waiting for response to keepalive request on connection to node {}", - node_address - ) - )))); + return Err( + BrokenConnectionErrorKind::KeepaliveTimeout(node_address).into() + ); } } } else { @@ -1661,7 +1837,7 @@ impl Connection { task_response: TaskResponse, compression: Option, event_sender: &mpsc::Sender, - ) -> Result<(), QueryError> { + ) -> Result<(), CqlEventHandlingError> { // Protocol features are negotiated during connection handshake. // However, the router is already created and sent to a different tokio // task before the handshake begins, therefore it's hard to cleanly @@ -1673,21 +1849,34 @@ impl Connection { // future implementers. let features = ProtocolFeatures::default(); // TODO: Use the right features - let response = Self::parse_response(task_response, compression, &features, None)?.response; - let event = match response { - Response::Event(e) => e, - _ => { - warn!("Expected to receive Event response, got {:?}", response); - return Ok(()); - } + let event = match Self::parse_response(task_response, compression, &features, None) { + Ok(r) => match r.response { + Response::Event(event) => event, + _ => { + error!("Expected to receive Event response, got {:?}", r.response); + return Err(CqlEventHandlingError::UnexpectedResponse( + r.response.to_response_kind(), + )); + } + }, + Err(e) => match e { + ResponseParseError::FrameError(e) => return Err(e.into()), + ResponseParseError::CqlResponseParseError(e) => match e { + CqlResponseParseError::CqlEventParseError(e) => return Err(e.into()), + // Received a response other than EVENT, but failed to deserialize it. + _ => { + return Err(CqlEventHandlingError::UnexpectedResponse( + e.to_response_kind(), + )) + } + }, + }, }; - event_sender.send(event).await.map_err(|_| { - QueryError::IoError(Arc::new(std::io::Error::new( - ErrorKind::Other, - "Connection broken", - ))) - }) + event_sender + .send(event) + .await + .map_err(|_| CqlEventHandlingError::SendError) } pub(crate) fn get_shard_info(&self) -> &Option { @@ -1781,7 +1970,7 @@ pub(crate) async fn open_connection( endpoint: UntranslatedEndpoint, source_port: Option, config: &ConnectionConfig, -) -> Result<(Connection, ErrorReceiver), QueryError> { +) -> Result<(Connection, ErrorReceiver), ConnectionError> { /* Translate the address, if applicable. */ let addr = maybe_translated_addr(endpoint, config.address_translator.as_deref()).await?; @@ -1792,23 +1981,13 @@ pub(crate) async fn open_connection( /* Perform OPTIONS/SUPPORTED/STARTUP handshake. */ // Get OPTIONS SUPPORTED by the cluster. - let options_result = connection.get_options().await?; + let mut supported = connection.get_options().await?; let shard_aware_port_key = match config.is_ssl() { true => options::SCYLLA_SHARD_AWARE_PORT_SSL, false => options::SCYLLA_SHARD_AWARE_PORT, }; - let mut supported = match options_result { - Response::Supported(supported) => supported, - Response::Error(Error { error, reason }) => return Err(QueryError::DbError(error, reason)), - _ => { - return Err(QueryError::ProtocolError( - "Wrong response to OPTIONS message was received", - )); - } - }; - // If this is ScyllaDB that we connected to, we received sharding information. let shard_info = ShardInfo::try_from(&supported.options).ok(); let supported_compression = supported @@ -1868,18 +2047,12 @@ pub(crate) async fn open_connection( } /* Send the STARTUP frame with all the requested options. */ - let result = connection.startup(options).await?; - match result { - Response::Ready => {} - Response::Authenticate(authenticate) => { + let startup_result = connection.startup(options).await?; + match startup_result { + NonErrorStartupResponse::Ready => {} + NonErrorStartupResponse::Authenticate(authenticate) => { perform_authenticate(&mut connection, &authenticate).await?; } - Response::Error(Error { error, reason }) => return Err(QueryError::DbError(error, reason)), - _ => { - return Err(QueryError::ProtocolError( - "Unexpected response to STARTUP message", - )) - } } /* If this is a control connection, REGISTER to receive all event types. */ @@ -1898,7 +2071,11 @@ pub(crate) async fn open_connection( async fn perform_authenticate( connection: &mut Connection, authenticate: &Authenticate, -) -> Result<(), QueryError> { +) -> Result<(), ConnectionSetupRequestError> { + let err = |kind: ConnectionSetupRequestErrorKind| { + ConnectionSetupRequestError::new(CqlRequestKind::AuthResponse, kind) + }; + let authenticator = &authenticate.authenticator_name as &str; match connection.config.authenticator { @@ -1906,43 +2083,35 @@ async fn perform_authenticate( let (mut response, mut auth_session) = authenticator_provider .start_authentication_session(authenticator) .await - .map_err(QueryError::InvalidMessage)?; + .map_err(|e| err(ConnectionSetupRequestErrorKind::StartAuthSessionError(e)))?; loop { - match connection - .authenticate_response(response) - .await?.response - { - Response::AuthChallenge(challenge) => { + match connection.authenticate_response(response).await? { + NonErrorAuthResponse::AuthChallenge(challenge) => { response = auth_session - .evaluate_challenge( - challenge.authenticate_message.as_deref(), - ) + .evaluate_challenge(challenge.authenticate_message.as_deref()) .await - .map_err(QueryError::InvalidMessage)?; + .map_err(|e| { + err( + ConnectionSetupRequestErrorKind::AuthChallengeEvaluationError( + e, + ), + ) + })?; } - Response::AuthSuccess(success) => { + NonErrorAuthResponse::AuthSuccess(success) => { auth_session .success(success.success_message.as_deref()) .await - .map_err(QueryError::InvalidMessage)?; + .map_err(|e| { + err(ConnectionSetupRequestErrorKind::AuthFinishError(e)) + })?; break; } - Response::Error(err) => { - return Err(err.into()); - } - _ => { - return Err(QueryError::ProtocolError( - "Unexpected response to Authenticate Response message", - )) - } } } - }, - None => return Err(QueryError::InvalidMessage( - "Authentication is required. You can use SessionBuilder::user(\"user\", \"pass\") to provide credentials \ - or SessionBuilder::authenticator_provider to provide custom authenticator".to_string(), - )), + } + None => return Err(err(ConnectionSetupRequestErrorKind::MissingAuthentication)), } Ok(()) @@ -2192,7 +2361,6 @@ impl VerifiedKeyspaceName { #[cfg(test)] mod tests { use assert_matches::assert_matches; - use scylla_cql::errors::QueryError; use scylla_cql::frame::protocol_features::{ LWT_OPTIMIZATION_META_BIT_MASK_KEY, SCYLLA_LWT_ADD_METADATA_MARK_EXTENSION, }; @@ -2541,6 +2709,8 @@ mod tests { #[ntest::timeout(20000)] #[cfg(not(scylla_cloud_tests))] async fn connection_is_closed_on_no_response_to_keepalives() { + use scylla_cql::errors::BrokenConnectionErrorKind; + setup_tracing(); let proxy_addr = SocketAddr::new(scylla_proxy::get_exclusive_local_address(), 9042); @@ -2602,7 +2772,13 @@ mod tests { // Wait until keepaliver gots impatient and terminates router. // Then, the error from keepaliver will be propagated to the error receiver. let err = error_receiver.await.unwrap(); - assert_matches!(err, QueryError::IoError(_)); + let err_inner: &BrokenConnectionErrorKind = match err { + crate::transport::connection::ConnectionError::BrokenConnection(ref e) => { + e.get_inner().downcast_ref().unwrap() + } + _ => panic!("Bad error type. Expected keepalive timeout."), + }; + assert_matches!(err_inner, BrokenConnectionErrorKind::KeepaliveTimeout(_)); // As the router is invalidated, all further queries should immediately // return error. diff --git a/scylla/src/transport/connection_pool.rs b/scylla/src/transport/connection_pool.rs index 849ef8cb8d..54921aee1e 100644 --- a/scylla/src/transport/connection_pool.rs +++ b/scylla/src/transport/connection_pool.rs @@ -19,8 +19,8 @@ use super::NodeAddr; use arc_swap::ArcSwap; use futures::{future::RemoteHandle, stream::FuturesUnordered, Future, FutureExt, StreamExt}; use rand::Rng; +use scylla_cql::errors::{BrokenConnectionErrorKind, ConnectionError, ConnectionPoolError}; use std::convert::TryInto; -use std::io::ErrorKind; use std::num::NonZeroUsize; use std::pin::Pin; use std::sync::{Arc, RwLock, Weak}; @@ -79,7 +79,7 @@ enum MaybePoolConnections { // The pool is empty because either initial filling failed or all connections // became broken; will be asynchronously refilled. Contains an error // from the last connection attempt. - Broken(QueryError), + Broken(ConnectionError), // The pool has some connections which are usable (or will be removed soon) Ready(PoolConnections), @@ -234,7 +234,10 @@ impl NodeConnectionPool { .unwrap_or(None) } - pub(crate) fn connection_for_shard(&self, shard: Shard) -> Result, QueryError> { + pub(crate) fn connection_for_shard( + &self, + shard: Shard, + ) -> Result, ConnectionPoolError> { trace!(shard = shard, "Selecting connection for shard"); self.with_connections(|pool_conns| match pool_conns { PoolConnections::NotSharded(conns) => { @@ -257,7 +260,7 @@ impl NodeConnectionPool { }) } - pub(crate) fn random_connection(&self) -> Result, QueryError> { + pub(crate) fn random_connection(&self) -> Result, ConnectionPoolError> { trace!("Selecting random connection"); self.with_connections(|pool_conns| match pool_conns { PoolConnections::NotSharded(conns) => { @@ -341,7 +344,9 @@ impl NodeConnectionPool { } } - pub(crate) fn get_working_connections(&self) -> Result>, QueryError> { + pub(crate) fn get_working_connections( + &self, + ) -> Result>, ConnectionPoolError> { self.with_connections(|pool_conns| match pool_conns { PoolConnections::NotSharded(conns) => conns.clone(), PoolConnections::Sharded { connections, .. } => { @@ -370,25 +375,17 @@ impl NodeConnectionPool { } } - fn with_connections(&self, f: impl FnOnce(&PoolConnections) -> T) -> Result { + fn with_connections( + &self, + f: impl FnOnce(&PoolConnections) -> T, + ) -> Result { let conns = self.conns.load_full(); match &*conns { MaybePoolConnections::Ready(pool_connections) => Ok(f(pool_connections)), - MaybePoolConnections::Broken(err) => { - Err(QueryError::IoError(Arc::new(std::io::Error::new( - ErrorKind::Other, - format!( - "No connections in the pool; last connection failed with: {}", - err - ), - )))) - } - MaybePoolConnections::Initializing => { - Err(QueryError::IoError(Arc::new(std::io::Error::new( - ErrorKind::Other, - "No connections in the pool, pool is still being initialized", - )))) - } + MaybePoolConnections::Broken(err) => Err(ConnectionPoolError::Broken { + last_connection_error: err.clone(), + }), + MaybePoolConnections::Initializing => Err(ConnectionPoolError::Initializing), } } } @@ -989,7 +986,7 @@ impl PoolRefiller { // Updates `shared_conns` based on `conns`. // `last_error` must not be `None` if there is a possibility of the pool // being empty. - fn update_shared_conns(&mut self, last_error: Option) { + fn update_shared_conns(&mut self, last_error: Option) { let new_conns = if !self.has_connections() { Arc::new(MaybePoolConnections::Broken(last_error.unwrap())) } else { @@ -1015,7 +1012,7 @@ impl PoolRefiller { // Removes given connection from the pool. It looks both into active // connections and excess connections. - fn remove_connection(&mut self, connection: Arc, last_error: QueryError) { + fn remove_connection(&mut self, connection: Arc, last_error: ConnectionError) { let ptr = Arc::as_ptr(&connection); let maybe_remove_in_vec = |v: &mut Vec>| -> bool { @@ -1204,7 +1201,7 @@ impl PoolRefiller { struct BrokenConnectionEvent { connection: Weak, - error: QueryError, + error: ConnectionError, } async fn wait_for_error( @@ -1214,16 +1211,13 @@ async fn wait_for_error( BrokenConnectionEvent { connection, error: error_receiver.await.unwrap_or_else(|_| { - QueryError::IoError(Arc::new(std::io::Error::new( - ErrorKind::Other, - "Connection broken", - ))) + ConnectionError::BrokenConnection(BrokenConnectionErrorKind::ChannelError.into()) }), } } struct OpenedConnectionEvent { - result: Result<(Connection, ErrorReceiver), QueryError>, + result: Result<(Connection, ErrorReceiver), ConnectionError>, requested_shard: Option, keyspace_name: Option, } @@ -1233,7 +1227,7 @@ async fn open_connection_to_shard_aware_port( shard: Shard, sharder: Sharder, connection_config: &ConnectionConfig, -) -> Result<(Connection, ErrorReceiver), QueryError> { +) -> Result<(Connection, ErrorReceiver), ConnectionError> { // Create iterator over all possible source ports for this shard let source_port_iter = sharder.iter_source_ports_for_shard(shard); @@ -1248,10 +1242,7 @@ async fn open_connection_to_shard_aware_port( } // Tried all source ports for that shard, give up - Err(QueryError::IoError(Arc::new(std::io::Error::new( - std::io::ErrorKind::AddrInUse, - "Could not find free source port for shard", - )))) + Err(ConnectionError::NoSourcePortForShard(shard)) } #[cfg(test)] diff --git a/scylla/src/transport/iterator.rs b/scylla/src/transport/iterator.rs index cb5a8141c8..6dfcb622b6 100644 --- a/scylla/src/transport/iterator.rs +++ b/scylla/src/transport/iterator.rs @@ -9,6 +9,7 @@ use std::sync::Arc; use std::task::{Context, Poll}; use futures::Stream; +use scylla_cql::errors::UserRequestError; use scylla_cql::frame::response::NonErrorResponse; use scylla_cql::types::serialize::row::SerializedValues; use std::result::Result; @@ -479,7 +480,7 @@ struct RowIteratorWorker<'a, QueryFunc, SpanCreatorFunc> { sender: ProvingSender>, // Closure used to perform a single page query - // AsyncFn(Arc, Option>) -> Result + // AsyncFn(Arc, Option>) -> Result page_query: QueryFunc, statement_info: RoutingInfo<'a>, @@ -502,7 +503,7 @@ struct RowIteratorWorker<'a, QueryFunc, SpanCreatorFunc> { impl RowIteratorWorker<'_, QueryFunc, SpanCreator> where QueryFunc: Fn(Arc, Consistency, PagingState) -> QueryFut, - QueryFut: Future>, + QueryFut: Future>, SpanCreator: Fn() -> RequestSpan, { // Contract: this function MUST send at least one item through self.sender @@ -535,7 +536,7 @@ where error = %e, "Choosing connection failed" ); - last_error = e; + last_error = e.into(); // Broken connection doesn't count as a failed query, don't log in metrics continue 'nodes_in_plan; } @@ -705,6 +706,7 @@ where Ok(ControlFlow::Continue(())) } Err(err) => { + let err = err.into(); self.metrics.inc_failed_paged_queries(); self.execution_profile .load_balancing_policy @@ -826,7 +828,7 @@ struct SingleConnectionRowIteratorWorker { impl SingleConnectionRowIteratorWorker where Fetcher: Fn(PagingState) -> FetchFut + Send + Sync, - FetchFut: Future> + Send, + FetchFut: Future> + Send, { async fn work(mut self) -> PageSendAttemptedProof { match self.do_work().await { diff --git a/scylla/src/transport/load_balancing/default.rs b/scylla/src/transport/load_balancing/default.rs index 962cbd2440..35d1f926b3 100644 --- a/scylla/src/transport/load_balancing/default.rs +++ b/scylla/src/transport/load_balancing/default.rs @@ -2837,18 +2837,20 @@ mod latency_awareness { match error { // "fast" errors, i.e. ones that are returned quickly after the query begins QueryError::BadQuery(_) + | QueryError::BrokenConnection(_) + | QueryError::ConnectionPoolError(_) | QueryError::TooManyOrphanedStreamIds(_) | QueryError::UnableToAllocStreamId | QueryError::DbError(DbError::IsBootstrapping, _) | QueryError::DbError(DbError::Unavailable { .. }, _) | QueryError::DbError(DbError::Unprepared { .. }, _) - | QueryError::TranslationError(_) | QueryError::DbError(DbError::Overloaded { .. }, _) | QueryError::DbError(DbError::RateLimitReached { .. }, _) => false, // "slow" errors, i.e. ones that are returned after considerable time of query being run QueryError::DbError(_, _) - | QueryError::CqlResponseParseError(_) + | QueryError::CqlResultParseError(_) + | QueryError::CqlErrorParseError(_) | QueryError::InvalidMessage(_) | QueryError::IoError(_) | QueryError::ProtocolError(_) diff --git a/scylla/src/transport/node.rs b/scylla/src/transport/node.rs index 02ca247bc5..79895d65c0 100644 --- a/scylla/src/transport/node.rs +++ b/scylla/src/transport/node.rs @@ -1,3 +1,4 @@ +use scylla_cql::errors::ConnectionPoolError; use tokio::net::lookup_host; use tracing::warn; use uuid::Uuid; @@ -157,7 +158,7 @@ impl Node { pub(crate) async fn connection_for_shard( &self, shard: Shard, - ) -> Result, QueryError> { + ) -> Result, ConnectionPoolError> { self.get_pool()?.connection_for_shard(shard) } @@ -186,7 +187,9 @@ impl Node { Ok(()) } - pub(crate) fn get_working_connections(&self) -> Result>, QueryError> { + pub(crate) fn get_working_connections( + &self, + ) -> Result>, ConnectionPoolError> { self.get_pool()?.get_working_connections() } @@ -196,14 +199,10 @@ impl Node { } } - fn get_pool(&self) -> Result<&NodeConnectionPool, QueryError> { - self.pool.as_ref().ok_or_else(|| { - QueryError::IoError(Arc::new(std::io::Error::new( - std::io::ErrorKind::Other, - "No connections in the pool: the node has been disabled \ - by the host filter", - ))) - }) + fn get_pool(&self) -> Result<&NodeConnectionPool, ConnectionPoolError> { + self.pool + .as_ref() + .ok_or(ConnectionPoolError::NodeDisabledByHostFilter) } } diff --git a/scylla/src/transport/session.rs b/scylla/src/transport/session.rs index 52a2f9cbe4..ea11174d44 100644 --- a/scylla/src/transport/session.rs +++ b/scylla/src/transport/session.rs @@ -76,7 +76,7 @@ pub use crate::transport::connection_pool::PoolSize; use crate::authentication::AuthenticatorProvider; #[cfg(feature = "ssl")] use openssl::ssl::SslContext; -use scylla_cql::errors::BadQuery; +use scylla_cql::errors::{BadQuery, UserRequestError}; pub(crate) const TABLET_CHANNEL_SIZE: usize = 8192; @@ -774,6 +774,7 @@ impl Session { ) .await .and_then(QueryResponse::into_non_error_query_response) + .map_err(Into::into) } else { let prepared = connection.prepare(query_ref).await?; let serialized = prepared.serialize_values(values_ref)?; @@ -789,6 +790,7 @@ impl Session { ) .await .and_then(QueryResponse::into_non_error_query_response) + .map_err(Into::into) } } }, @@ -975,7 +977,7 @@ impl Session { // Safety: there is at least one node in the cluster, and `Cluster::iter_working_connections()` // returns either an error or an iterator with at least one connection, so there will be at least one result. - let first_ok: Result = + let first_ok: Result = results.by_ref().find_or_first(Result::is_ok).unwrap(); let mut prepared: PreparedStatement = first_ok?; @@ -1227,6 +1229,7 @@ impl Session { ) .await .and_then(QueryResponse::into_non_error_query_response) + .map_err(Into::into) } }, &span, @@ -1873,7 +1876,7 @@ impl Session { error = %e, "Choosing connection failed" ); - last_error = Some(e); + last_error = Some(e.into()); // Broken connection doesn't count as a failed query, don't log in metrics continue 'nodes_in_plan; }