From f263046f12b43d034dbe093daa330bea25bc4a18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Uzarski?= Date: Fri, 4 Oct 2024 13:29:53 +0200 Subject: [PATCH] use_keyspace: don't use wildcard '_' in QueryError match Since last time, during error refactor I introduced a silent bug to the code (https://github.com/scylladb/scylla-rust-driver/pull/1075), I'd like to prevent that from happening in the future. This is why we replace a `_` match with explicit error variants when deciding if error received after `USE KEYSPACE` should be ignored. --- scylla/src/transport/cluster.rs | 9 +++++---- scylla/src/transport/errors.rs | 23 +++++++++++++++++++++++ 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/scylla/src/transport/cluster.rs b/scylla/src/transport/cluster.rs index 2313402bc1..9bf9bc52ff 100644 --- a/scylla/src/transport/cluster.rs +++ b/scylla/src/transport/cluster.rs @@ -784,12 +784,13 @@ pub(crate) fn use_keyspace_result( for result in use_keyspace_results { match result { Ok(()) => was_ok = true, - Err(err) => match err { - QueryError::BrokenConnection(_) | QueryError::ConnectionPoolError(_) => { + Err(err) => { + if err.is_connection_broken() { broken_conn_error = Some(err) + } else { + return Err(err); } - _ => return Err(err), - }, + } } } diff --git a/scylla/src/transport/errors.rs b/scylla/src/transport/errors.rs index f6b31b6c95..83ee173bd1 100644 --- a/scylla/src/transport/errors.rs +++ b/scylla/src/transport/errors.rs @@ -100,6 +100,29 @@ pub enum QueryError { RequestTimeout(String), } +impl QueryError { + pub(crate) fn is_connection_broken(&self) -> bool { + match self { + // Error variants that imply that some connection error appeared before/during exeuction. + QueryError::BrokenConnection(_) | QueryError::ConnectionPoolError(_) => true, + + // Other errors. + QueryError::DbError(_, _) + | QueryError::BadQuery(_) + | QueryError::CqlRequestSerialization(_) + | QueryError::BodyExtensionsParseError(_) + | QueryError::EmptyPlan + | QueryError::CqlResultParseError(_) + | QueryError::CqlErrorParseError(_) + | QueryError::MetadataError(_) + | QueryError::ProtocolError(_) + | QueryError::TimeoutError + | QueryError::UnableToAllocStreamId + | QueryError::RequestTimeout(_) => false, + } + } +} + impl From for QueryError { fn from(serialized_err: SerializeValuesError) -> QueryError { QueryError::BadQuery(BadQuery::SerializeValuesError(serialized_err))