diff --git a/scylla-cql/src/frame/frame_errors.rs b/scylla-cql/src/frame/frame_errors.rs index 044a9e198c..74be0c9cd6 100644 --- a/scylla-cql/src/frame/frame_errors.rs +++ b/scylla-cql/src/frame/frame_errors.rs @@ -1,61 +1,129 @@ +use std::error::Error; use std::sync::Arc; +pub use super::request::{ + auth_response::AuthResponseSerializationError, + batch::{BatchSerializationError, BatchStatementSerializationError}, + execute::ExecuteSerializationError, + prepare::PrepareSerializationError, + query::{QueryParametersSerializationError, QuerySerializationError}, + register::RegisterSerializationError, + startup::StartupSerializationError, +}; + use super::response::CqlResponseKind; use super::TryFromPrimitiveError; -use crate::cql_to_rust::CqlTypeError; -use crate::frame::value::SerializeValuesError; -use crate::types::deserialize::{DeserializationError, TypeCheckError}; -use crate::types::serialize::SerializationError; +use crate::types::deserialize::DeserializationError; use thiserror::Error; -#[derive(Error, Debug)] -pub enum FrameError { - #[error(transparent)] - Parse(#[from] ParseError), +/// An error returned by `parse_response_body_extensions`. +/// +/// It represents an error that occurred during deserialization of +/// frame body extensions. These extensions include tracing id, +/// warnings and custom payload. +/// +/// Possible error kinds: +/// - failed to decompress frame body (decompression is required for further deserialization) +/// - failed to deserialize tracing id (body ext.) +/// - failed to deserialize warnings list (body ext.) +/// - failed to deserialize custom payload map (body ext.) +#[derive(Error, Debug, Clone)] +#[non_exhaustive] +pub enum FrameBodyExtensionsParseError { + /// Frame is compressed, but no compression was negotiated for the connection. #[error("Frame is compressed, but no compression negotiated for connection.")] NoCompressionNegotiated, + + /// Failed to deserialize frame trace id. + #[error("Malformed trace id: {0}")] + TraceIdParse(LowLevelDeserializationError), + + /// Failed to deserialize warnings attached to frame. + #[error("Malformed warnings list: {0}")] + WarningsListParse(LowLevelDeserializationError), + + /// Failed to deserialize frame's custom payload. + #[error("Malformed custom payload map: {0}")] + CustomPayloadMapParse(LowLevelDeserializationError), + + /// Failed to decompress frame body (snap). + #[error("Snap decompression error: {0}")] + SnapDecompressError(Arc), + + /// Failed to decompress frame body (lz4). + #[error("Error decompressing lz4 data {0}")] + Lz4DecompressError(Arc), +} + +/// An error that occurred during frame header deserialization. +#[derive(Debug, Error)] +#[non_exhaustive] +pub enum FrameHeaderParseError { + /// Failed to read the frame header from the socket. + #[error("Failed to read the frame header: {0}")] + HeaderIoError(std::io::Error), + + /// Received a frame marked as coming from a client. #[error("Received frame marked as coming from a client")] FrameFromClient, + + // FIXME: this should not belong here. User always expects a frame from server. + // This variant is only used in scylla-proxy - need to investigate it later. #[error("Received frame marked as coming from the server")] FrameFromServer, + + /// Received a frame with unsupported version. #[error("Received a frame from version {0}, but only 4 is supported")] VersionNotSupported(u8), + + /// Received unknown response opcode. + #[error("Unrecognized response opcode {0}")] + UnknownResponseOpcode(#[from] TryFromPrimitiveError), + + /// Failed to read frame body from the socket. + #[error("Failed to read a chunk of response body. Expected {0} more bytes, error: {1}")] + BodyChunkIoError(usize, std::io::Error), + + /// Connection was closed before whole frame was read. #[error("Connection was closed before body was read: missing {0} out of {1}")] ConnectionClosed(usize, usize), - #[error("Frame decompression failed.")] - FrameDecompression, - #[error("Frame compression failed.")] - FrameCompression, - #[error(transparent)] - StdIoError(#[from] std::io::Error), - #[error("Unrecognized opcode{0}")] - TryFromPrimitiveError(#[from] TryFromPrimitiveError), - #[error("Error compressing lz4 data {0}")] - Lz4CompressError(#[from] lz4_flex::block::CompressError), - #[error("Error decompressing lz4 data {0}")] - Lz4DecompressError(#[from] lz4_flex::block::DecompressError), } -#[derive(Error, Debug)] -pub enum ParseError { - #[error("Low-level deserialization failed: {0}")] - LowLevelDeserializationError(#[from] LowLevelDeserializationError), - #[error("Could not serialize frame: {0}")] - BadDataToSerialize(String), - #[error("Could not deserialize frame: {0}")] - BadIncomingData(String), - #[error(transparent)] - DeserializationError(#[from] DeserializationError), - #[error(transparent)] - DeserializationTypeCheckError(#[from] TypeCheckError), - #[error(transparent)] - IoError(#[from] std::io::Error), - #[error(transparent)] - SerializeValuesError(#[from] SerializeValuesError), - #[error(transparent)] - SerializationError(#[from] SerializationError), - #[error(transparent)] - CqlTypeError(#[from] CqlTypeError), +/// An error that occurred during CQL request serialization. +#[non_exhaustive] +#[derive(Error, Debug, Clone)] +pub enum CqlRequestSerializationError { + /// Failed to serialize STARTUP request. + #[error("Failed to serialize STARTUP request: {0}")] + StartupSerialization(#[from] StartupSerializationError), + + /// Failed to serialize REGISTER request. + #[error("Failed to serialize REGISTER request: {0}")] + RegisterSerialization(#[from] RegisterSerializationError), + + /// Failed to serialize AUTH_RESPONSE request. + #[error("Failed to serialize AUTH_RESPONSE request: {0}")] + AuthResponseSerialization(#[from] AuthResponseSerializationError), + + /// Failed to serialize BATCH request. + #[error("Failed to serialize BATCH request: {0}")] + BatchSerialization(#[from] BatchSerializationError), + + /// Failed to serialize PREPARE request. + #[error("Failed to serialize PREPARE request: {0}")] + PrepareSerialization(#[from] PrepareSerializationError), + + /// Failed to serialize EXECUTE request. + #[error("Failed to serialize EXECUTE request: {0}")] + ExecuteSerialization(#[from] ExecuteSerializationError), + + /// Failed to serialize QUERY request. + #[error("Failed to serialize QUERY request: {0}")] + QuerySerialization(#[from] QuerySerializationError), + + /// Request body compression failed. + #[error("Snap compression error: {0}")] + SnapCompressError(Arc), } /// An error type returned when deserialization of CQL diff --git a/scylla-cql/src/frame/mod.rs b/scylla-cql/src/frame/mod.rs index 715ba43984..4f8dc8bebd 100644 --- a/scylla-cql/src/frame/mod.rs +++ b/scylla-cql/src/frame/mod.rs @@ -9,13 +9,16 @@ pub mod value; #[cfg(test)] mod value_tests; -use crate::frame::frame_errors::FrameError; use bytes::{Buf, BufMut, Bytes}; +use frame_errors::{ + CqlRequestSerializationError, FrameBodyExtensionsParseError, FrameHeaderParseError, +}; use thiserror::Error; use tokio::io::{AsyncRead, AsyncReadExt}; use uuid::Uuid; use std::fmt::Display; +use std::sync::Arc; use std::{collections::HashMap, convert::TryFrom}; use request::SerializableRequest; @@ -72,7 +75,7 @@ impl SerializedRequest { req: &R, compression: Option, tracing: bool, - ) -> Result { + ) -> Result { let mut flags = 0; let mut data = vec![0; HEADER_SIZE]; @@ -128,19 +131,22 @@ impl Default for FrameParams { pub async fn read_response_frame( reader: &mut (impl AsyncRead + Unpin), -) -> Result<(FrameParams, ResponseOpcode, Bytes), FrameError> { +) -> Result<(FrameParams, ResponseOpcode, Bytes), FrameHeaderParseError> { let mut raw_header = [0u8; HEADER_SIZE]; - reader.read_exact(&mut raw_header[..]).await?; + reader + .read_exact(&mut raw_header[..]) + .await + .map_err(FrameHeaderParseError::HeaderIoError)?; let mut buf = &raw_header[..]; // TODO: Validate version let version = buf.get_u8(); if version & 0x80 != 0x80 { - return Err(FrameError::FrameFromClient); + return Err(FrameHeaderParseError::FrameFromClient); } if version & 0x7F != 0x04 { - return Err(FrameError::VersionNotSupported(version & 0x7f)); + return Err(FrameHeaderParseError::VersionNotSupported(version & 0x7f)); } let flags = buf.get_u8(); @@ -159,10 +165,12 @@ pub async fn read_response_frame( let mut raw_body = Vec::with_capacity(length).limit(length); while raw_body.has_remaining_mut() { - let n = reader.read_buf(&mut raw_body).await?; + let n = reader.read_buf(&mut raw_body).await.map_err(|err| { + FrameHeaderParseError::BodyChunkIoError(raw_body.remaining_mut(), err) + })?; if n == 0 { // EOF, too early - return Err(FrameError::ConnectionClosed( + return Err(FrameHeaderParseError::ConnectionClosed( raw_body.remaining_mut(), length, )); @@ -183,18 +191,19 @@ pub fn parse_response_body_extensions( flags: u8, compression: Option, mut body: Bytes, -) -> Result { +) -> Result { if flags & FLAG_COMPRESSION != 0 { if let Some(compression) = compression { body = decompress(&body, compression)?.into(); } else { - return Err(FrameError::NoCompressionNegotiated); + return Err(FrameBodyExtensionsParseError::NoCompressionNegotiated); } } let trace_id = if flags & FLAG_TRACING != 0 { let buf = &mut &*body; - let trace_id = types::read_uuid(buf).map_err(frame_errors::ParseError::from)?; + let trace_id = + types::read_uuid(buf).map_err(FrameBodyExtensionsParseError::TraceIdParse)?; body.advance(16); Some(trace_id) } else { @@ -204,7 +213,8 @@ pub fn parse_response_body_extensions( let warnings = if flags & FLAG_WARNING != 0 { let body_len = body.len(); let buf = &mut &*body; - let warnings = types::read_string_list(buf).map_err(frame_errors::ParseError::from)?; + let warnings = types::read_string_list(buf) + .map_err(FrameBodyExtensionsParseError::WarningsListParse)?; let buf_len = buf.len(); body.advance(body_len - buf_len); warnings @@ -215,7 +225,8 @@ pub fn parse_response_body_extensions( let custom_payload = if flags & FLAG_CUSTOM_PAYLOAD != 0 { let body_len = body.len(); let buf = &mut &*body; - let payload_map = types::read_bytes_map(buf).map_err(frame_errors::ParseError::from)?; + let payload_map = types::read_bytes_map(buf) + .map_err(FrameBodyExtensionsParseError::CustomPayloadMapParse)?; let buf_len = buf.len(); body.advance(body_len - buf_len); Some(payload_map) @@ -235,7 +246,7 @@ fn compress_append( uncomp_body: &[u8], compression: Compression, out: &mut Vec, -) -> Result<(), FrameError> { +) -> Result<(), CqlRequestSerializationError> { match compression { Compression::Lz4 => { let uncomp_len = uncomp_body.len() as u32; @@ -250,23 +261,27 @@ fn compress_append( out.resize(old_size + snap::raw::max_compress_len(uncomp_body.len()), 0); let compressed_size = snap::raw::Encoder::new() .compress(uncomp_body, &mut out[old_size..]) - .map_err(|_| FrameError::FrameCompression)?; + .map_err(|err| CqlRequestSerializationError::SnapCompressError(Arc::new(err)))?; out.truncate(old_size + compressed_size); Ok(()) } } } -fn decompress(mut comp_body: &[u8], compression: Compression) -> Result, FrameError> { +fn decompress( + mut comp_body: &[u8], + compression: Compression, +) -> Result, FrameBodyExtensionsParseError> { match compression { Compression::Lz4 => { let uncomp_len = comp_body.get_u32() as usize; - let uncomp_body = lz4_flex::decompress(comp_body, uncomp_len)?; + let uncomp_body = lz4_flex::decompress(comp_body, uncomp_len) + .map_err(|err| FrameBodyExtensionsParseError::Lz4DecompressError(Arc::new(err)))?; Ok(uncomp_body) } Compression::Snappy => snap::raw::Decoder::new() .decompress_vec(comp_body) - .map_err(|_| FrameError::FrameDecompression), + .map_err(|err| FrameBodyExtensionsParseError::SnapDecompressError(Arc::new(err))), } } diff --git a/scylla-cql/src/frame/request/auth_response.rs b/scylla-cql/src/frame/request/auth_response.rs index 83d718ee59..c03bebaaf9 100644 --- a/scylla-cql/src/frame/request/auth_response.rs +++ b/scylla-cql/src/frame/request/auth_response.rs @@ -1,4 +1,8 @@ -use crate::frame::frame_errors::ParseError; +use std::num::TryFromIntError; + +use thiserror::Error; + +use crate::frame::frame_errors::CqlRequestSerializationError; use crate::frame::request::{RequestOpcode, SerializableRequest}; use crate::frame::types::write_bytes_opt; @@ -11,7 +15,17 @@ pub struct AuthResponse { impl SerializableRequest for AuthResponse { const OPCODE: RequestOpcode = RequestOpcode::AuthResponse; - fn serialize(&self, buf: &mut Vec) -> Result<(), ParseError> { - Ok(write_bytes_opt(self.response.as_ref(), buf)?) + fn serialize(&self, buf: &mut Vec) -> Result<(), CqlRequestSerializationError> { + Ok(write_bytes_opt(self.response.as_ref(), buf) + .map_err(AuthResponseSerializationError::ResponseSerialization)?) } } + +/// An error type returned when serialization of AUTH_RESPONSE request fails. +#[non_exhaustive] +#[derive(Error, Debug, Clone)] +pub enum AuthResponseSerializationError { + /// Maximum response's body length exceeded. + #[error("AUTH_RESPONSE body bytes length too big: {0}")] + ResponseSerialization(TryFromIntError), +} diff --git a/scylla-cql/src/frame/request/batch.rs b/scylla-cql/src/frame/request/batch.rs index 7f7895e1de..e193fbbfda 100644 --- a/scylla-cql/src/frame/request/batch.rs +++ b/scylla-cql/src/frame/request/batch.rs @@ -1,12 +1,12 @@ use bytes::{Buf, BufMut}; -use std::{borrow::Cow, convert::TryInto}; +use std::{borrow::Cow, convert::TryInto, num::TryFromIntError}; +use thiserror::Error; use crate::{ frame::{ - frame_errors::ParseError, + frame_errors::CqlRequestSerializationError, request::{RequestOpcode, SerializableRequest}, types::{self, SerialConsistency}, - value::SerializeValuesError, }, types::serialize::{ raw_batch::{RawBatchValues, RawBatchValuesIterator}, @@ -15,7 +15,7 @@ use crate::{ }, }; -use super::DeserializableRequest; +use super::{DeserializableRequest, RequestDeserializationError}; // Batch flags const FLAG_WITH_SERIAL_CONSISTENCY: u8 = 0x10; @@ -46,16 +46,12 @@ pub enum BatchType { Counter = 2, } +#[derive(Debug, Error)] +#[error("Malformed batch type: {value}")] pub struct BatchTypeParseError { value: u8, } -impl From for ParseError { - fn from(err: BatchTypeParseError) -> Self { - Self::BadIncomingData(format!("Bad BatchType value: {}", err.value)) - } -} - impl TryFrom for BatchType { type Error = BatchTypeParseError; @@ -75,31 +71,40 @@ pub enum BatchStatement<'a> { Prepared { id: Cow<'a, [u8]> }, } -impl SerializableRequest for Batch<'_, Statement, Values> +impl Batch<'_, Statement, Values> where for<'s> BatchStatement<'s>: From<&'s Statement>, Statement: Clone, Values: RawBatchValues, { - const OPCODE: RequestOpcode = RequestOpcode::Batch; - - fn serialize(&self, buf: &mut Vec) -> Result<(), ParseError> { + fn do_serialize(&self, buf: &mut Vec) -> Result<(), BatchSerializationError> { // Serializing type of batch buf.put_u8(self.batch_type as u8); // Serializing queries - types::write_short(self.statements.len().try_into()?, buf); - - let counts_mismatch_err = |n_values: usize, n_statements: usize| { - ParseError::BadDataToSerialize(format!( - "Length of provided values must be equal to number of batch statements \ - (got {n_values} values, {n_statements} statements)" - )) + types::write_short( + self.statements + .len() + .try_into() + .map_err(|_| BatchSerializationError::TooManyStatements(self.statements.len()))?, + buf, + ); + + let counts_mismatch_err = |n_value_lists: usize, n_statements: usize| { + BatchSerializationError::ValuesAndStatementsLengthMismatch { + n_value_lists, + n_statements, + } }; let mut n_serialized_statements = 0usize; let mut value_lists = self.values.batch_values_iter(); for (idx, statement) in self.statements.iter().enumerate() { - BatchStatement::from(statement).serialize(buf)?; + BatchStatement::from(statement) + .serialize(buf) + .map_err(|err| BatchSerializationError::StatementSerialization { + statement_idx: idx, + error: err, + })?; // Reserve two bytes for length let length_pos = buf.len(); @@ -107,12 +112,23 @@ where let mut row_writer = RowWriter::new(buf); value_lists .serialize_next(&mut row_writer) - .ok_or_else(|| counts_mismatch_err(idx, self.statements.len()))??; + .ok_or_else(|| counts_mismatch_err(idx, self.statements.len()))? + .map_err(|err: SerializationError| { + BatchSerializationError::StatementSerialization { + statement_idx: idx, + error: BatchStatementSerializationError::ValuesSerialiation(err), + } + })?; // Go back and put the length let count: u16 = match row_writer.value_count().try_into() { Ok(n) => n, Err(_) => { - return Err(SerializationError::new(SerializeValuesError::TooManyValues).into()) + return Err(BatchSerializationError::StatementSerialization { + statement_idx: idx, + error: BatchStatementSerializationError::TooManyValues( + row_writer.value_count(), + ), + }) } }; buf[length_pos..length_pos + 2].copy_from_slice(&count.to_be_bytes()); @@ -129,11 +145,10 @@ where if n_serialized_statements != self.statements.len() { // We want to check this to avoid propagating an invalid construction of self.statements_count as a // hard-to-debug silent fail - return Err(ParseError::BadDataToSerialize(format!( - "Invalid Batch constructed: not as many statements serialized as announced \ - (batch.statement_count: {announced_statement_count}, {n_serialized_statements}", - announced_statement_count = self.statements.len() - ))); + return Err(BatchSerializationError::BadBatchConstructed { + n_announced_statements: self.statements.len(), + n_serialized_statements, + }); } // Serializing consistency @@ -161,8 +176,22 @@ where } } +impl SerializableRequest for Batch<'_, Statement, Values> +where + for<'s> BatchStatement<'s>: From<&'s Statement>, + Statement: Clone, + Values: RawBatchValues, +{ + const OPCODE: RequestOpcode = RequestOpcode::Batch; + + fn serialize(&self, buf: &mut Vec) -> Result<(), CqlRequestSerializationError> { + self.do_serialize(buf)?; + Ok(()) + } +} + impl BatchStatement<'_> { - fn deserialize(buf: &mut &[u8]) -> Result { + fn deserialize(buf: &mut &[u8]) -> Result { let kind = buf.get_u8(); match kind { 0 => { @@ -173,24 +202,25 @@ impl BatchStatement<'_> { let id = types::read_short_bytes(buf)?.to_vec().into(); Ok(BatchStatement::Prepared { id }) } - _ => Err(ParseError::BadIncomingData(format!( - "Unexpected batch statement kind: {}", - kind - ))), + _ => Err(RequestDeserializationError::UnexpectedBatchStatementKind( + kind, + )), } } } impl BatchStatement<'_> { - fn serialize(&self, buf: &mut impl BufMut) -> Result<(), ParseError> { + fn serialize(&self, buf: &mut impl BufMut) -> Result<(), BatchStatementSerializationError> { match self { Self::Query { text } => { buf.put_u8(0); - types::write_long_string(text, buf)?; + types::write_long_string(text, buf) + .map_err(BatchStatementSerializationError::StatementStringSerialization)?; } Self::Prepared { id } => { buf.put_u8(1); - types::write_short_bytes(id, buf)?; + types::write_short_bytes(id, buf) + .map_err(BatchStatementSerializationError::StatementIdSerialization)?; } } @@ -208,7 +238,7 @@ impl<'s, 'b> From<&'s BatchStatement<'b>> for BatchStatement<'s> { } impl<'b> DeserializableRequest for Batch<'b, BatchStatement<'b>, Vec> { - fn deserialize(buf: &mut &[u8]) -> Result { + fn deserialize(buf: &mut &[u8]) -> Result { let batch_type = buf.get_u8().try_into()?; let statements_count: usize = types::read_short(buf)?.into(); @@ -221,17 +251,16 @@ impl<'b> DeserializableRequest for Batch<'b, BatchStatement<'b>, Vec, ParseError>>()?; + .collect::, RequestDeserializationError>>()?; let consistency = types::read_consistency(buf)?; let flags = buf.get_u8(); let unknown_flags = flags & (!ALL_FLAGS); if unknown_flags != 0 { - return Err(ParseError::BadIncomingData(format!( - "Specified flags are not recognised: {:02x}", - unknown_flags - ))); + return Err(RequestDeserializationError::UnknownFlags { + flags: unknown_flags, + }); } let serial_consistency_flag = (flags & FLAG_WITH_SERIAL_CONSISTENCY) != 0; let default_timestamp_flag = (flags & FLAG_WITH_DEFAULT_TIMESTAMP) != 0; @@ -242,10 +271,9 @@ impl<'b> DeserializableRequest for Batch<'b, BatchStatement<'b>, Vec Ok(serial_consistency), - Err(_) => Err(ParseError::BadIncomingData(format!( - "Expected SerialConsistency, got regular Consistency {}", - consistency - ))), + Err(_) => Err(RequestDeserializationError::ExpectedSerialConsistency( + consistency, + )), }, ) .transpose()?; @@ -267,3 +295,55 @@ impl<'b> DeserializableRequest for Batch<'b, BatchStatement<'b>, Vec { @@ -17,21 +23,37 @@ pub struct Execute<'a> { impl SerializableRequest for Execute<'_> { const OPCODE: RequestOpcode = RequestOpcode::Execute; - fn serialize(&self, buf: &mut Vec) -> Result<(), ParseError> { + fn serialize(&self, buf: &mut Vec) -> Result<(), CqlRequestSerializationError> { // Serializing statement id - types::write_short_bytes(&self.id[..], buf)?; + types::write_short_bytes(&self.id[..], buf) + .map_err(ExecuteSerializationError::StatementIdSerialization)?; // Serializing params - self.parameters.serialize(buf)?; + self.parameters + .serialize(buf) + .map_err(ExecuteSerializationError::QueryParametersSerialization)?; Ok(()) } } impl<'e> DeserializableRequest for Execute<'e> { - fn deserialize(buf: &mut &[u8]) -> Result { + fn deserialize(buf: &mut &[u8]) -> Result { let id = types::read_short_bytes(buf)?.to_vec().into(); let parameters = QueryParameters::deserialize(buf)?; Ok(Self { id, parameters }) } } + +/// An error type returned when serialization of EXECUTE request fails. +#[non_exhaustive] +#[derive(Error, Debug, Clone)] +pub enum ExecuteSerializationError { + /// Failed to serialize query parameters. + #[error("Malformed query parameters: {0}")] + QueryParametersSerialization(QueryParametersSerializationError), + + /// Failed to serialize prepared statement id. + #[error("Malformed statement id: {0}")] + StatementIdSerialization(TryFromIntError), +} diff --git a/scylla-cql/src/frame/request/mod.rs b/scylla-cql/src/frame/request/mod.rs index 5285b5937c..feef653b97 100644 --- a/scylla-cql/src/frame/request/mod.rs +++ b/scylla-cql/src/frame/request/mod.rs @@ -7,8 +7,11 @@ pub mod query; pub mod register; pub mod startup; +use batch::BatchTypeParseError; +use thiserror::Error; + use crate::types::serialize::row::SerializedValues; -use crate::{frame::frame_errors::ParseError, Consistency}; +use crate::Consistency; use bytes::Bytes; pub use auth_response::AuthResponse; @@ -21,6 +24,7 @@ pub use startup::Startup; use self::batch::BatchStatement; +use super::frame_errors::{CqlRequestSerializationError, LowLevelDeserializationError}; use super::types::SerialConsistency; use super::TryFromPrimitiveError; @@ -92,9 +96,9 @@ impl TryFrom for RequestOpcode { pub trait SerializableRequest { const OPCODE: RequestOpcode; - fn serialize(&self, buf: &mut Vec) -> Result<(), ParseError>; + fn serialize(&self, buf: &mut Vec) -> Result<(), CqlRequestSerializationError>; - fn to_bytes(&self) -> Result { + fn to_bytes(&self) -> Result { let mut v = Vec::new(); self.serialize(&mut v)?; Ok(v.into()) @@ -104,7 +108,29 @@ pub trait SerializableRequest { /// Not intended for driver's direct usage (as driver has no interest in deserialising CQL requests), /// but very useful for testing (e.g. asserting that the sent requests have proper parameters set). pub trait DeserializableRequest: SerializableRequest + Sized { - fn deserialize(buf: &mut &[u8]) -> Result; + fn deserialize(buf: &mut &[u8]) -> Result; +} + +/// An error type returned by [`DeserializableRequest::deserialize`]. +/// This is not intended for driver's direct usage. It's a testing utility, +/// mainly used by `scylla-proxy` crate. +#[doc(hidden)] +#[derive(Debug, Error)] +pub enum RequestDeserializationError { + #[error("Low level deser error: {0}")] + LowLevelDeserialization(#[from] LowLevelDeserializationError), + #[error("Io error: {0}")] + IoError(#[from] std::io::Error), + #[error("Specified flags are not recognised: {:02x}", flags)] + UnknownFlags { flags: u8 }, + #[error("Named values in frame are currently unsupported")] + NamedValuesUnsupported, + #[error("Expected SerialConsistency, got regular Consistency: {0}")] + ExpectedSerialConsistency(Consistency), + #[error(transparent)] + BatchTypeParse(#[from] BatchTypeParseError), + #[error("Unexpected batch statement kind: {0}")] + UnexpectedBatchStatementKind(u8), } #[non_exhaustive] // TODO: add remaining request types @@ -115,7 +141,10 @@ pub enum Request<'r> { } impl<'r> Request<'r> { - pub fn deserialize(buf: &mut &[u8], opcode: RequestOpcode) -> Result { + pub fn deserialize( + buf: &mut &[u8], + opcode: RequestOpcode, + ) -> Result { match opcode { RequestOpcode::Query => Query::deserialize(buf).map(Self::Query), RequestOpcode::Execute => Execute::deserialize(buf).map(Self::Execute), diff --git a/scylla-cql/src/frame/request/options.rs b/scylla-cql/src/frame/request/options.rs index 6ea6517ce6..ce3ba4b9e0 100644 --- a/scylla-cql/src/frame/request/options.rs +++ b/scylla-cql/src/frame/request/options.rs @@ -1,4 +1,4 @@ -use crate::frame::frame_errors::ParseError; +use crate::frame::frame_errors::CqlRequestSerializationError; use crate::frame::request::{RequestOpcode, SerializableRequest}; @@ -7,7 +7,7 @@ pub struct Options; impl SerializableRequest for Options { const OPCODE: RequestOpcode = RequestOpcode::Options; - fn serialize(&self, _buf: &mut Vec) -> Result<(), ParseError> { + fn serialize(&self, _buf: &mut Vec) -> Result<(), CqlRequestSerializationError> { Ok(()) } } diff --git a/scylla-cql/src/frame/request/prepare.rs b/scylla-cql/src/frame/request/prepare.rs index c30e25727a..5d209263e7 100644 --- a/scylla-cql/src/frame/request/prepare.rs +++ b/scylla-cql/src/frame/request/prepare.rs @@ -1,4 +1,8 @@ -use crate::frame::frame_errors::ParseError; +use std::num::TryFromIntError; + +use thiserror::Error; + +use crate::frame::frame_errors::CqlRequestSerializationError; use crate::{ frame::request::{RequestOpcode, SerializableRequest}, @@ -12,8 +16,18 @@ pub struct Prepare<'a> { impl<'a> SerializableRequest for Prepare<'a> { const OPCODE: RequestOpcode = RequestOpcode::Prepare; - fn serialize(&self, buf: &mut Vec) -> Result<(), ParseError> { - types::write_long_string(self.query, buf)?; + fn serialize(&self, buf: &mut Vec) -> Result<(), CqlRequestSerializationError> { + types::write_long_string(self.query, buf) + .map_err(PrepareSerializationError::StatementStringSerialization)?; Ok(()) } } + +/// An error type returned when serialization of PREPARE request fails. +#[non_exhaustive] +#[derive(Error, Debug, Clone)] +pub enum PrepareSerializationError { + /// Failed to serialize the CQL statement string. + #[error("Failed to serialize statement contents: {0}")] + StatementStringSerialization(TryFromIntError), +} diff --git a/scylla-cql/src/frame/request/query.rs b/scylla-cql/src/frame/request/query.rs index 2794a2b5d9..8567cbd419 100644 --- a/scylla-cql/src/frame/request/query.rs +++ b/scylla-cql/src/frame/request/query.rs @@ -1,17 +1,18 @@ -use std::{borrow::Cow, ops::ControlFlow, sync::Arc}; +use std::{borrow::Cow, num::TryFromIntError, ops::ControlFlow, sync::Arc}; use crate::{ - frame::{frame_errors::ParseError, types::SerialConsistency}, + frame::{frame_errors::CqlRequestSerializationError, types::SerialConsistency}, types::serialize::row::SerializedValues, }; use bytes::{Buf, BufMut}; +use thiserror::Error; use crate::{ frame::request::{RequestOpcode, SerializableRequest}, frame::types, }; -use super::DeserializableRequest; +use super::{DeserializableRequest, RequestDeserializationError}; // Query flags const FLAG_VALUES: u8 = 0x01; @@ -38,15 +39,18 @@ pub struct Query<'q> { impl SerializableRequest for Query<'_> { const OPCODE: RequestOpcode = RequestOpcode::Query; - fn serialize(&self, buf: &mut Vec) -> Result<(), ParseError> { - types::write_long_string(&self.contents, buf)?; - self.parameters.serialize(buf)?; + fn serialize(&self, buf: &mut Vec) -> Result<(), CqlRequestSerializationError> { + types::write_long_string(&self.contents, buf) + .map_err(QuerySerializationError::StatementStringSerialization)?; + self.parameters + .serialize(buf) + .map_err(QuerySerializationError::QueryParametersSerialization)?; Ok(()) } } impl<'q> DeserializableRequest for Query<'q> { - fn deserialize(buf: &mut &[u8]) -> Result { + fn deserialize(buf: &mut &[u8]) -> Result { let contents = Cow::Owned(types::read_long_string(buf)?.to_owned()); let parameters = QueryParameters::deserialize(buf)?; @@ -83,7 +87,10 @@ impl Default for QueryParameters<'_> { } impl QueryParameters<'_> { - pub fn serialize(&self, buf: &mut impl BufMut) -> Result<(), ParseError> { + pub fn serialize( + &self, + buf: &mut impl BufMut, + ) -> Result<(), QueryParametersSerializationError> { types::write_consistency(self.consistency, buf); let paging_state_bytes = self.paging_state.as_bytes_slice(); @@ -140,16 +147,15 @@ impl QueryParameters<'_> { } impl<'q> QueryParameters<'q> { - pub fn deserialize(buf: &mut &[u8]) -> Result { + pub fn deserialize(buf: &mut &[u8]) -> Result { let consistency = types::read_consistency(buf)?; let flags = buf.get_u8(); let unknown_flags = flags & (!ALL_FLAGS); if unknown_flags != 0 { - return Err(ParseError::BadIncomingData(format!( - "Specified flags are not recognised: {:02x}", - unknown_flags - ))); + return Err(RequestDeserializationError::UnknownFlags { + flags: unknown_flags, + }); } let values_flag = (flags & FLAG_VALUES) != 0; let skip_metadata = (flags & FLAG_SKIP_METADATA) != 0; @@ -160,9 +166,7 @@ impl<'q> QueryParameters<'q> { let values_have_names_flag = (flags & FLAG_WITH_NAMES_FOR_VALUES) != 0; if values_have_names_flag { - return Err(ParseError::BadIncomingData( - "Named values in frame are currently unsupported".to_string(), - )); + return Err(RequestDeserializationError::NamedValuesUnsupported); } let values = Cow::Owned(if values_flag { @@ -183,10 +187,9 @@ impl<'q> QueryParameters<'q> { .map( |consistency| match SerialConsistency::try_from(consistency) { Ok(serial_consistency) => Ok(serial_consistency), - Err(_) => Err(ParseError::BadIncomingData(format!( - "Expected SerialConsistency, got regular Consistency {}", - consistency - ))), + Err(_) => Err(RequestDeserializationError::ExpectedSerialConsistency( + consistency, + )), }, ) .transpose()?; @@ -290,3 +293,25 @@ impl Default for PagingState { Self::start() } } + +/// An error type returned when serialization of QUERY request fails. +#[non_exhaustive] +#[derive(Error, Debug, Clone)] +pub enum QuerySerializationError { + /// Failed to serialize query parameters. + #[error("Invalid query parameters: {0}")] + QueryParametersSerialization(QueryParametersSerializationError), + + /// Failed to serialize the CQL statement string. + #[error("Failed to serialize a statement content: {0}")] + StatementStringSerialization(TryFromIntError), +} + +/// An error type returned when serialization of query parameters fails. +#[non_exhaustive] +#[derive(Error, Debug, Clone)] +pub enum QueryParametersSerializationError { + /// Failed to serialize paging state. + #[error("Malformed paging state: {0}")] + BadPagingState(#[from] TryFromIntError), +} diff --git a/scylla-cql/src/frame/request/register.rs b/scylla-cql/src/frame/request/register.rs index c29c821964..9abcfde05c 100644 --- a/scylla-cql/src/frame/request/register.rs +++ b/scylla-cql/src/frame/request/register.rs @@ -1,5 +1,9 @@ +use std::num::TryFromIntError; + +use thiserror::Error; + use crate::frame::{ - frame_errors::ParseError, + frame_errors::CqlRequestSerializationError, request::{RequestOpcode, SerializableRequest}, server_event_type::EventType, types, @@ -12,14 +16,24 @@ pub struct Register { impl SerializableRequest for Register { const OPCODE: RequestOpcode = RequestOpcode::Register; - fn serialize(&self, buf: &mut Vec) -> Result<(), ParseError> { + fn serialize(&self, buf: &mut Vec) -> Result<(), CqlRequestSerializationError> { let event_types_list = self .event_types_to_register_for .iter() .map(|event| event.to_string()) .collect::>(); - types::write_string_list(&event_types_list, buf)?; + types::write_string_list(&event_types_list, buf) + .map_err(RegisterSerializationError::EventTypesSerialization)?; Ok(()) } } + +/// An error type returned when serialization of REGISTER request fails. +#[non_exhaustive] +#[derive(Error, Debug, Clone)] +pub enum RegisterSerializationError { + /// Failed to serialize event types list. + #[error("Failed to serialize event types list: {0}")] + EventTypesSerialization(TryFromIntError), +} diff --git a/scylla-cql/src/frame/request/startup.rs b/scylla-cql/src/frame/request/startup.rs index 6759d0cfce..cab84dc398 100644 --- a/scylla-cql/src/frame/request/startup.rs +++ b/scylla-cql/src/frame/request/startup.rs @@ -1,6 +1,8 @@ -use crate::frame::frame_errors::ParseError; +use thiserror::Error; -use std::{borrow::Cow, collections::HashMap}; +use crate::frame::frame_errors::CqlRequestSerializationError; + +use std::{borrow::Cow, collections::HashMap, num::TryFromIntError}; use crate::{ frame::request::{RequestOpcode, SerializableRequest}, @@ -14,8 +16,18 @@ pub struct Startup<'a> { impl SerializableRequest for Startup<'_> { const OPCODE: RequestOpcode = RequestOpcode::Startup; - fn serialize(&self, buf: &mut Vec) -> Result<(), ParseError> { - types::write_string_map(&self.options, buf)?; + fn serialize(&self, buf: &mut Vec) -> Result<(), CqlRequestSerializationError> { + types::write_string_map(&self.options, buf) + .map_err(StartupSerializationError::OptionsSerialization)?; Ok(()) } } + +/// An error type returned when serialization of STARTUP request fails. +#[non_exhaustive] +#[derive(Error, Debug, Clone)] +pub enum StartupSerializationError { + /// Failed to serialize startup options. + #[error("Malformed startup options: {0}")] + OptionsSerialization(TryFromIntError), +} diff --git a/scylla-cql/src/frame/response/cql_to_rust.rs b/scylla-cql/src/frame/response/cql_to_rust.rs index 435d45efcf..ec18ce0958 100644 --- a/scylla-cql/src/frame/response/cql_to_rust.rs +++ b/scylla-cql/src/frame/response/cql_to_rust.rs @@ -16,12 +16,6 @@ pub enum FromRowError { WrongRowSize { expected: usize, actual: usize }, } -#[derive(Error, Clone, Debug, PartialEq, Eq)] -pub enum CqlTypeError { - #[error("Invalid number of set elements: {0}")] - InvalidNumberOfElements(i32), -} - /// This trait defines a way to convert CqlValue or `Option` into some rust type // We can't use From trait because impl From> for String {...} // is forbidden since neither From nor String are defined in this crate diff --git a/scylla-cql/src/frame/types.rs b/scylla-cql/src/frame/types.rs index 0cc791be31..e73347039b 100644 --- a/scylla-cql/src/frame/types.rs +++ b/scylla-cql/src/frame/types.rs @@ -1,7 +1,6 @@ //! CQL binary protocol in-wire types. use super::frame_errors::LowLevelDeserializationError; -use super::frame_errors::ParseError; use super::TryFromPrimitiveError; use byteorder::{BigEndian, ReadBytesExt}; use bytes::{Buf, BufMut}; @@ -126,24 +125,6 @@ impl std::fmt::Display for SerialConsistency { } } -impl From for ParseError { - fn from(_err: std::num::TryFromIntError) -> Self { - ParseError::BadIncomingData("Integer conversion out of range".to_string()) - } -} - -impl From for ParseError { - fn from(_err: std::str::Utf8Error) -> Self { - ParseError::BadIncomingData("UTF8 serialization failed".to_string()) - } -} - -impl From for ParseError { - fn from(_err: std::array::TryFromSliceError) -> Self { - ParseError::BadIncomingData("array try from slice failed".to_string()) - } -} - #[derive(Copy, Clone, Debug, PartialEq, Eq)] pub enum RawValue<'a> { Null, diff --git a/scylla-cql/src/frame/value.rs b/scylla-cql/src/frame/value.rs index a9c368d195..c479b64d01 100644 --- a/scylla-cql/src/frame/value.rs +++ b/scylla-cql/src/frame/value.rs @@ -1,4 +1,3 @@ -use crate::frame::frame_errors::ParseError; use crate::frame::types; use bytes::BufMut; use std::borrow::Cow; @@ -799,26 +798,6 @@ impl LegacySerializedValues { self.serialized_values.len() } - /// Creates value list from the request frame - pub fn new_from_frame(buf: &mut &[u8], contains_names: bool) -> Result { - let values_num = types::read_short(buf)?; - let values_beg = *buf; - for _ in 0..values_num { - if contains_names { - let _name = types::read_string(buf)?; - } - let _serialized = types::read_bytes_opt(buf)?; - } - - let values_len_in_buf = values_beg.len() - buf.len(); - let values_in_frame = &values_beg[0..values_len_in_buf]; - Ok(LegacySerializedValues { - serialized_values: values_in_frame.to_vec(), - values_num, - contains_names, - }) - } - pub fn iter_name_value_pairs(&self) -> impl Iterator, RawValue)> { let mut buf = &self.serialized_values[..]; (0..self.values_num).map(move |_| { diff --git a/scylla-cql/src/types/deserialize/mod.rs b/scylla-cql/src/types/deserialize/mod.rs index 2d8f0713e8..0f4fb64c5a 100644 --- a/scylla-cql/src/types/deserialize/mod.rs +++ b/scylla-cql/src/types/deserialize/mod.rs @@ -46,30 +46,31 @@ //! //! ```rust //! # use scylla_cql::frame::response::result::ColumnType; -//! # use scylla_cql::frame::frame_errors::ParseError; //! # use scylla_cql::types::deserialize::{DeserializationError, FrameSlice, TypeCheckError}; //! # use scylla_cql::types::deserialize::value::DeserializeValue; +//! use thiserror::Error; //! struct MyVec(Vec); +//! #[derive(Debug, Error)] +//! enum MyDeserError { +//! #[error("Expected bytes")] +//! ExpectedBytes, +//! #[error("Expected non-null")] +//! ExpectedNonNull, +//! } //! impl<'frame> DeserializeValue<'frame> for MyVec { //! fn type_check(typ: &ColumnType) -> Result<(), TypeCheckError> { //! if let ColumnType::Blob = typ { //! return Ok(()); //! } -//! Err(TypeCheckError::new( -//! ParseError::BadIncomingData("Expected bytes".to_owned()) -//! )) +//! Err(TypeCheckError::new(MyDeserError::ExpectedBytes)) //! } //! //! fn deserialize( //! _typ: &'frame ColumnType, //! v: Option>, //! ) -> Result { -//! v.ok_or_else(|| { -//! DeserializationError::new( -//! ParseError::BadIncomingData("Expected non-null value".to_owned()) -//! ) -//! }) -//! .map(|v| Self(v.as_slice().to_vec())) +//! v.ok_or_else(|| DeserializationError::new(MyDeserError::ExpectedNonNull)) +//! .map(|v| Self(v.as_slice().to_vec())) //! } //! } //! ``` @@ -85,11 +86,18 @@ //! For example: //! //! ```rust -//! # use scylla_cql::frame::frame_errors::ParseError; //! # use scylla_cql::frame::response::result::ColumnType; //! # use scylla_cql::types::deserialize::{DeserializationError, FrameSlice, TypeCheckError}; //! # use scylla_cql::types::deserialize::value::DeserializeValue; +//! use thiserror::Error; //! struct MySlice<'a>(&'a [u8]); +//! #[derive(Debug, Error)] +//! enum MyDeserError { +//! #[error("Expected bytes")] +//! ExpectedBytes, +//! #[error("Expected non-null")] +//! ExpectedNonNull, +//! } //! impl<'a, 'frame> DeserializeValue<'frame> for MySlice<'a> //! where //! 'frame: 'a, @@ -98,21 +106,15 @@ //! if let ColumnType::Blob = typ { //! return Ok(()); //! } -//! Err(TypeCheckError::new( -//! ParseError::BadIncomingData("Expected bytes".to_owned()) -//! )) +//! Err(TypeCheckError::new(MyDeserError::ExpectedBytes)) //! } //! //! fn deserialize( //! _typ: &'frame ColumnType, //! v: Option>, //! ) -> Result { -//! v.ok_or_else(|| { -//! DeserializationError::new( -//! ParseError::BadIncomingData("Expected non-null value".to_owned()) -//! ) -//! }) -//! .map(|v| Self(v.as_slice())) +//! v.ok_or_else(|| DeserializationError::new(MyDeserError::ExpectedNonNull)) +//! .map(|v| Self(v.as_slice())) //! } //! } //! ``` @@ -135,32 +137,36 @@ //! Example: //! //! ```rust -//! # use scylla_cql::frame::frame_errors::ParseError; //! # use scylla_cql::frame::response::result::ColumnType; //! # use scylla_cql::types::deserialize::{DeserializationError, FrameSlice, TypeCheckError}; //! # use scylla_cql::types::deserialize::value::DeserializeValue; //! # use bytes::Bytes; +//! use thiserror::Error; //! struct MyBytes(Bytes); +//! #[derive(Debug, Error)] +//! enum MyDeserError { +//! #[error("Expected bytes")] +//! ExpectedBytes, +//! #[error("Expected non-null")] +//! ExpectedNonNull, +//! } //! impl<'frame> DeserializeValue<'frame> for MyBytes { //! fn type_check(typ: &ColumnType) -> Result<(), TypeCheckError> { //! if let ColumnType::Blob = typ { //! return Ok(()); //! } -//! Err(TypeCheckError::new(ParseError::BadIncomingData("Expected bytes".to_owned()))) +//! Err(TypeCheckError::new(MyDeserError::ExpectedBytes)) //! } //! //! fn deserialize( //! _typ: &'frame ColumnType, //! v: Option>, //! ) -> Result { -//! v.ok_or_else(|| { -//! DeserializationError::new(ParseError::BadIncomingData("Expected non-null value".to_owned())) -//! }) -//! .map(|v| Self(v.to_bytes())) +//! v.ok_or_else(|| DeserializationError::new(MyDeserError::ExpectedNonNull)) +//! .map(|v| Self(v.to_bytes())) //! } //! } //! ``` -// TODO: in the above module docstring, stop abusing ParseError once errors are refactored. pub mod frame_slice; pub mod result; diff --git a/scylla-cql/src/types/serialize/row.rs b/scylla-cql/src/types/serialize/row.rs index 236451d59b..6d0ea9c30e 100644 --- a/scylla-cql/src/types/serialize/row.rs +++ b/scylla-cql/src/types/serialize/row.rs @@ -9,7 +9,7 @@ use std::{collections::HashMap, sync::Arc}; use bytes::BufMut; use thiserror::Error; -use crate::frame::frame_errors::ParseError; +use crate::frame::request::RequestDeserializationError; use crate::frame::response::result::ColumnType; use crate::frame::response::result::PreparedMetadata; use crate::frame::types; @@ -807,7 +807,8 @@ impl SerializedValues { } /// Creates value list from the request frame - pub(crate) fn new_from_frame(buf: &mut &[u8]) -> Result { + /// This is used only for testing - request deserialization. + pub(crate) fn new_from_frame(buf: &mut &[u8]) -> Result { let values_num = types::read_short(buf)?; let values_beg = *buf; for _ in 0..values_num { diff --git a/scylla-proxy/src/errors.rs b/scylla-proxy/src/errors.rs index 72079ae5b6..fa1cc47d83 100644 --- a/scylla-proxy/src/errors.rs +++ b/scylla-proxy/src/errors.rs @@ -1,6 +1,6 @@ use std::net::SocketAddr; -use scylla_cql::frame::frame_errors::{FrameError, ParseError}; +use scylla_cql::frame::frame_errors::{FrameHeaderParseError, LowLevelDeserializationError}; use thiserror::Error; #[derive(Debug, Error)] @@ -20,9 +20,9 @@ pub enum DoorkeeperError { #[error("Could not send Options frame for obtaining shards number: {0}")] ObtainingShardNumber(std::io::Error), #[error("Could not send read Supported frame for obtaining shards number: {0}")] - ObtainingShardNumberFrame(FrameError), + ObtainingShardNumberFrame(FrameHeaderParseError), #[error("Could not read Supported options: {0}")] - ObtainingShardNumberParseOptions(ParseError), + ObtainingShardNumberParseOptions(LowLevelDeserializationError), #[error("ShardInfo parameters missing")] ObtainingShardNumberNoShardInfo, #[error("Could not parse shard number: {0}")] diff --git a/scylla-proxy/src/frame.rs b/scylla-proxy/src/frame.rs index 5cabc07146..435a164863 100644 --- a/scylla-proxy/src/frame.rs +++ b/scylla-proxy/src/frame.rs @@ -1,10 +1,10 @@ use std::collections::HashMap; use bytes::{Buf, BufMut, Bytes, BytesMut}; -use scylla_cql::frame::frame_errors::{FrameError, ParseError}; +use scylla_cql::frame::frame_errors::FrameHeaderParseError; use scylla_cql::frame::protocol_features::ProtocolFeatures; -use scylla_cql::frame::request::Request; pub use scylla_cql::frame::request::RequestOpcode; +use scylla_cql::frame::request::{Request, RequestDeserializationError}; pub use scylla_cql::frame::response::ResponseOpcode; use scylla_cql::frame::{response::error::DbError, types}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; @@ -69,7 +69,7 @@ impl RequestFrame { .await } - pub fn deserialize(&self) -> Result { + pub fn deserialize(&self) -> Result { Request::deserialize(&mut &self.body[..], self.opcode) } } @@ -87,7 +87,7 @@ impl ResponseFrame { request_params: FrameParams, error: DbError, msg: Option<&str>, - ) -> Result { + ) -> Result { let msg = msg.unwrap_or("Proxy-triggered error."); let len_bytes = (msg.len() as u16).to_be_bytes(); // string len is a short in CQL protocol let code_bytes = error.code(&ProtocolFeatures::default()).to_be_bytes(); // TODO: configurable features @@ -111,7 +111,7 @@ impl ResponseFrame { pub fn forged_supported( request_params: FrameParams, options: &HashMap>, - ) -> Result { + ) -> Result { let mut buf = BytesMut::new(); types::write_string_multimap(options, &mut buf)?; @@ -144,7 +144,10 @@ impl ResponseFrame { } } -fn serialize_error_specific_fields(buf: &mut BytesMut, error: DbError) -> Result<(), ParseError> { +fn serialize_error_specific_fields( + buf: &mut BytesMut, + error: DbError, +) -> Result<(), std::num::TryFromIntError> { match error { DbError::Unavailable { consistency, @@ -250,17 +253,20 @@ pub(crate) async fn write_frame( pub(crate) async fn read_frame( reader: &mut (impl AsyncRead + Unpin), frame_type: FrameType, -) -> Result<(FrameParams, FrameOpcode, Bytes), FrameError> { +) -> Result<(FrameParams, FrameOpcode, Bytes), FrameHeaderParseError> { let mut raw_header = [0u8; HEADER_SIZE]; - reader.read_exact(&mut raw_header[..]).await?; + reader + .read_exact(&mut raw_header[..]) + .await + .map_err(FrameHeaderParseError::HeaderIoError)?; let mut buf = &raw_header[..]; let version = buf.get_u8(); { let (err, valid_direction, direction_str) = match frame_type { - FrameType::Request => (FrameError::FrameFromServer, 0x00, "request"), - FrameType::Response => (FrameError::FrameFromClient, 0x80, "response"), + FrameType::Request => (FrameHeaderParseError::FrameFromServer, 0x00, "request"), + FrameType::Response => (FrameHeaderParseError::FrameFromClient, 0x80, "response"), }; if version & 0x80 != valid_direction { return Err(err); @@ -285,10 +291,12 @@ pub(crate) async fn read_frame( let opcode = match frame_type { FrameType::Request => FrameOpcode::Request( - RequestOpcode::try_from(buf.get_u8()).map_err(|_| FrameError::FrameFromServer)?, + RequestOpcode::try_from(buf.get_u8()) + .map_err(|_| FrameHeaderParseError::FrameFromServer)?, ), FrameType::Response => FrameOpcode::Response( - ResponseOpcode::try_from(buf.get_u8()).map_err(|_| FrameError::FrameFromClient)?, + ResponseOpcode::try_from(buf.get_u8()) + .map_err(|_| FrameHeaderParseError::FrameFromClient)?, ), }; @@ -297,10 +305,16 @@ pub(crate) async fn read_frame( let mut body = Vec::with_capacity(length).limit(length); while body.has_remaining_mut() { - let n = reader.read_buf(&mut body).await?; + let n = reader + .read_buf(&mut body) + .await + .map_err(|err| FrameHeaderParseError::BodyChunkIoError(body.remaining_mut(), err))?; if n == 0 { // EOF, too early - return Err(FrameError::ConnectionClosed(body.remaining_mut(), length)); + return Err(FrameHeaderParseError::ConnectionClosed( + body.remaining_mut(), + length, + )); } } @@ -309,7 +323,7 @@ pub(crate) async fn read_frame( pub(crate) async fn read_request_frame( reader: &mut (impl AsyncRead + Unpin), -) -> Result { +) -> Result { read_frame(reader, FrameType::Request) .await .map(|(params, opcode, body)| RequestFrame { @@ -324,7 +338,7 @@ pub(crate) async fn read_request_frame( pub(crate) async fn read_response_frame( reader: &mut (impl AsyncRead + Unpin), -) -> Result { +) -> Result { read_frame(reader, FrameType::Response) .await .map(|(params, opcode, body)| ResponseFrame { diff --git a/scylla-proxy/src/proxy.rs b/scylla-proxy/src/proxy.rs index 64e0b4f317..6c5b5e62f1 100644 --- a/scylla-proxy/src/proxy.rs +++ b/scylla-proxy/src/proxy.rs @@ -5,7 +5,6 @@ use crate::frame::{ }; use crate::{RequestOpcode, TargetShard}; use bytes::Bytes; -use scylla_cql::frame::frame_errors::ParseError; use scylla_cql::frame::types::read_string_multimap; use std::collections::HashMap; use std::fmt::Display; @@ -806,7 +805,6 @@ impl Doorkeeper { .map_err(DoorkeeperError::ObtainingShardNumberFrame)?; let options = read_string_multimap(&mut supported_frame.body.as_ref()) - .map_err(ParseError::from) .map_err(DoorkeeperError::ObtainingShardNumberParseOptions)?; Ok(options) @@ -1305,7 +1303,7 @@ mod tests { use bytes::{BufMut, BytesMut}; use futures::future::{join, join3}; use rand::RngCore; - use scylla_cql::frame::frame_errors::FrameError; + use scylla_cql::frame::frame_errors::FrameHeaderParseError; use scylla_cql::frame::types::write_string_multimap; use std::collections::HashMap; use std::mem; @@ -1724,7 +1722,7 @@ mod tests { params: FrameParams, opcode: FrameOpcode, body: &Bytes, - ) -> Result { + ) -> Result { let (send_res, recv_res) = join( write_frame(params, opcode, &body.clone(), driver), read_request_frame(node), @@ -1839,7 +1837,7 @@ mod tests { params: FrameParams, opcode: FrameOpcode, body: &Bytes, - ) -> Result { + ) -> Result { let (send_res, recv_res) = join( write_frame(params, opcode, &body.clone(), driver), read_request_frame(node), diff --git a/scylla/src/transport/connection.rs b/scylla/src/transport/connection.rs index ac9a203eb8..44c59878fa 100644 --- a/scylla/src/transport/connection.rs +++ b/scylla/src/transport/connection.rs @@ -780,7 +780,8 @@ impl Connection { } }, Err(e) => match e { - RequestError::FrameError(e) => return Err(err(e.into())), + RequestError::CqlRequestSerialization(e) => return Err(err(e.into())), + RequestError::BodyExtensionsParseError(e) => return Err(err(e.into())), RequestError::CqlResponseParseError(e) => match e { // Parsing of READY response cannot fail, since its body is empty. // Remaining valid responses are AUTHENTICATE and ERROR. @@ -829,7 +830,8 @@ impl Connection { } }, Err(e) => match e { - RequestError::FrameError(e) => return Err(err(e.into())), + RequestError::CqlRequestSerialization(e) => return Err(err(e.into())), + RequestError::BodyExtensionsParseError(e) => return Err(err(e.into())), RequestError::CqlResponseParseError(e) => match e { CqlResponseParseError::CqlSupportedParseError(e) => return Err(err(e.into())), CqlResponseParseError::CqlErrorParseError(e) => return Err(err(e.into())), @@ -939,7 +941,8 @@ impl Connection { } }, Err(e) => match e { - RequestError::FrameError(e) => return Err(err(e.into())), + RequestError::CqlRequestSerialization(e) => return Err(err(e.into())), + RequestError::BodyExtensionsParseError(e) => return Err(err(e.into())), RequestError::CqlResponseParseError(e) => match e { CqlResponseParseError::CqlAuthSuccessParseError(e) => { return Err(err(e.into())) @@ -1400,7 +1403,8 @@ impl Connection { ))), }, Err(e) => match e { - RequestError::FrameError(e) => Err(err(e.into())), + RequestError::CqlRequestSerialization(e) => Err(err(e.into())), + RequestError::BodyExtensionsParseError(e) => Err(err(e.into())), RequestError::CqlResponseParseError(e) => match e { // Parsing the READY response cannot fail. Only remaining valid response is ERROR. CqlResponseParseError::CqlErrorParseError(e) => Err(err(e.into())), @@ -1614,7 +1618,7 @@ impl Connection { loop { let (params, opcode, body) = frame::read_response_frame(&mut read_half) .await - .map_err(BrokenConnectionErrorKind::FrameError)?; + .map_err(BrokenConnectionErrorKind::FrameHeaderParseError)?; let response = TaskResponse { params, opcode, @@ -1860,7 +1864,7 @@ impl Connection { } }, Err(e) => match e { - ResponseParseError::FrameError(e) => return Err(e.into()), + ResponseParseError::BodyExtensionsParseError(e) => return Err(e.into()), ResponseParseError::CqlResponseParseError(e) => match e { CqlResponseParseError::CqlEventParseError(e) => return Err(e.into()), // Received a response other than EVENT, but failed to deserialize it. diff --git a/scylla/src/transport/errors.rs b/scylla/src/transport/errors.rs index 2ab1314d3c..d00e68de2d 100644 --- a/scylla/src/transport/errors.rs +++ b/scylla/src/transport/errors.rs @@ -15,17 +15,15 @@ use scylla_cql::{ frame::{ frame_errors::{ CqlAuthChallengeParseError, CqlAuthSuccessParseError, CqlAuthenticateParseError, - CqlErrorParseError, CqlEventParseError, CqlResponseParseError, CqlResultParseError, - CqlSupportedParseError, FrameError, ParseError, + CqlErrorParseError, CqlEventParseError, CqlRequestSerializationError, + CqlResponseParseError, CqlResultParseError, CqlSupportedParseError, + FrameBodyExtensionsParseError, FrameHeaderParseError, }, request::CqlRequestKind, response::CqlResponseKind, value::SerializeValuesError, }, - types::{ - deserialize::{DeserializationError, TypeCheckError}, - serialize::SerializationError, - }, + types::serialize::SerializationError, }; use thiserror::Error; @@ -44,6 +42,14 @@ pub enum QueryError { #[error(transparent)] BadQuery(#[from] BadQuery), + /// Failed to serialize CQL request. + #[error("Failed to serialize CQL request: {0}")] + CqlRequestSerialization(#[from] CqlRequestSerializationError), + + /// Failed to deserialize frame body extensions. + #[error(transparent)] + BodyExtensionsParseError(#[from] FrameBodyExtensionsParseError), + /// Received a RESULT server response, but failed to deserialize it. #[error(transparent)] CqlResultParseError(#[from] CqlResultParseError), @@ -93,30 +99,6 @@ impl From for QueryError { } } -impl From for QueryError { - fn from(value: DeserializationError) -> Self { - Self::InvalidMessage(value.to_string()) - } -} - -impl From for QueryError { - fn from(value: TypeCheckError) -> Self { - Self::InvalidMessage(value.to_string()) - } -} - -impl From for QueryError { - fn from(parse_error: ParseError) -> QueryError { - QueryError::InvalidMessage(format!("Error parsing message: {}", parse_error)) - } -} - -impl From for QueryError { - fn from(frame_error: FrameError) -> QueryError { - QueryError::InvalidMessage(format!("Frame error: {}", frame_error)) - } -} - impl From for QueryError { fn from(timer_error: tokio::time::error::Elapsed) -> QueryError { QueryError::RequestTimeout(format!("{}", timer_error)) @@ -126,6 +108,7 @@ impl From for QueryError { impl From for QueryError { fn from(value: UserRequestError) -> Self { match value { + UserRequestError::CqlRequestSerialization(e) => e.into(), UserRequestError::DbError(err, msg) => QueryError::DbError(err, msg), UserRequestError::CqlResultParseError(e) => e.into(), UserRequestError::CqlErrorParseError(e) => e.into(), @@ -134,7 +117,7 @@ impl From for QueryError { // FIXME: make it typed. It needs to wait for ProtocolError refactor. QueryError::ProtocolError("Received unexpected response from the server. Expected RESULT or ERROR response.") } - UserRequestError::FrameError(e) => e.into(), + UserRequestError::BodyExtensionsParseError(e) => e.into(), UserRequestError::UnableToAllocStreamId => QueryError::UnableToAllocStreamId, UserRequestError::RepreparedIdChanged => QueryError::ProtocolError( "Prepared statement Id changed, md5 sum should stay the same", @@ -148,8 +131,10 @@ impl From for NewSessionError { match query_error { QueryError::DbError(e, msg) => NewSessionError::DbError(e, msg), QueryError::BadQuery(e) => NewSessionError::BadQuery(e), + QueryError::CqlRequestSerialization(e) => NewSessionError::CqlRequestSerialization(e), QueryError::CqlResultParseError(e) => NewSessionError::CqlResultParseError(e), QueryError::CqlErrorParseError(e) => NewSessionError::CqlErrorParseError(e), + QueryError::BodyExtensionsParseError(e) => NewSessionError::BodyExtensionsParseError(e), QueryError::ConnectionPoolError(e) => NewSessionError::ConnectionPoolError(e), QueryError::ProtocolError(m) => NewSessionError::ProtocolError(m), QueryError::InvalidMessage(m) => NewSessionError::InvalidMessage(m), @@ -194,6 +179,14 @@ pub enum NewSessionError { #[error(transparent)] BadQuery(#[from] BadQuery), + /// Failed to serialize CQL request. + #[error("Failed to serialize CQL request: {0}")] + CqlRequestSerialization(#[from] CqlRequestSerializationError), + + /// Failed to deserialize frame body extensions. + #[error(transparent)] + BodyExtensionsParseError(#[from] FrameBodyExtensionsParseError), + /// Received a RESULT server response, but failed to deserialize it. #[error(transparent)] CqlResultParseError(#[from] CqlResultParseError), @@ -380,10 +373,13 @@ pub struct ConnectionSetupRequestError { #[derive(Error, Debug, Clone)] #[non_exhaustive] pub enum ConnectionSetupRequestErrorKind { - // TODO: Make FrameError clonable. - /// An error occurred when parsing response frame header. + /// Failed to serialize CQL request. + #[error("Failed to serialize CQL request: {0}")] + CqlRequestSerialization(#[from] CqlRequestSerializationError), + + /// Failed to deserialize frame body extensions. #[error(transparent)] - FrameError(Arc), + BodyExtensionsParseError(#[from] FrameBodyExtensionsParseError), /// Driver was unable to allocate a stream id to execute a setup request on. #[error("Unable to allocate stream id")] @@ -440,12 +436,6 @@ pub enum ConnectionSetupRequestErrorKind { MissingAuthentication, } -impl From for ConnectionSetupRequestErrorKind { - fn from(value: FrameError) -> Self { - ConnectionSetupRequestErrorKind::FrameError(Arc::new(value)) - } -} - impl ConnectionSetupRequestError { pub(crate) fn new( request_kind: CqlRequestKind, @@ -496,7 +486,7 @@ pub enum BrokenConnectionErrorKind { /// Failed to deserialize response frame header. #[error("Failed to deserialize frame: {0}")] - FrameError(FrameError), + FrameHeaderParseError(FrameHeaderParseError), /// Failed to handle a CQL event (server response received on stream -1). #[error("Failed to handle server event: {0}")] @@ -547,9 +537,9 @@ pub enum CqlEventHandlingError { #[error("Received unexpected server response on stream -1: {0}. Expected EVENT response")] UnexpectedResponse(CqlResponseKind), - /// Failed to deserialize a header of frame received on stream -1. + /// Failed to deserialize body extensions of frame received on stream -1. #[error("Failed to deserialize a header of frame received on stream -1: {0}")] - FrameError(#[from] FrameError), + BodyExtensionParseError(#[from] FrameBodyExtensionsParseError), /// Driver failed to send event data between the internal tasks. /// It implies that connection was broken for some reason. @@ -566,6 +556,8 @@ pub enum CqlEventHandlingError { /// requests. #[derive(Error, Debug)] pub(crate) enum UserRequestError { + #[error("Failed to serialize CQL request: {0}")] + CqlRequestSerialization(#[from] CqlRequestSerializationError), #[error("Database returned an error: {0}, Error message: {1}")] DbError(DbError, String), #[error(transparent)] @@ -579,7 +571,7 @@ pub(crate) enum UserRequestError { #[error(transparent)] BrokenConnectionError(#[from] BrokenConnectionError), #[error(transparent)] - FrameError(#[from] FrameError), + BodyExtensionsParseError(#[from] FrameBodyExtensionsParseError), #[error("Unable to allocate stream id")] UnableToAllocStreamId, #[error("Prepared statement Id changed, md5 sum should stay the same")] @@ -595,7 +587,8 @@ impl From for UserRequestError { impl From for UserRequestError { fn from(value: RequestError) -> Self { match value { - RequestError::FrameError(e) => e.into(), + RequestError::CqlRequestSerialization(e) => e.into(), + RequestError::BodyExtensionsParseError(e) => e.into(), RequestError::CqlResponseParseError(e) => match e { // Only possible responses are RESULT and ERROR. If we failed parsing // other response, treat it as unexpected response. @@ -619,9 +612,13 @@ impl From for UserRequestError { #[derive(Error, Debug)] #[non_exhaustive] pub enum RequestError { - /// Failed to deserialize response frame header. + /// Failed to serialize CQL request. + #[error("Failed to serialize CQL request: {0}")] + CqlRequestSerialization(#[from] CqlRequestSerializationError), + + /// Failed to deserialize frame body extensions. #[error(transparent)] - FrameError(#[from] FrameError), + BodyExtensionsParseError(#[from] FrameBodyExtensionsParseError), /// Failed to deserialize a CQL response (frame body). #[error(transparent)] @@ -639,7 +636,7 @@ pub enum RequestError { impl From for RequestError { fn from(value: ResponseParseError) -> Self { match value { - ResponseParseError::FrameError(e) => e.into(), + ResponseParseError::BodyExtensionsParseError(e) => e.into(), ResponseParseError::CqlResponseParseError(e) => e.into(), } } @@ -650,7 +647,7 @@ impl From for RequestError { #[derive(Error, Debug)] pub(crate) enum ResponseParseError { #[error(transparent)] - FrameError(#[from] FrameError), + BodyExtensionsParseError(#[from] FrameBodyExtensionsParseError), #[error(transparent)] CqlResponseParseError(#[from] CqlResponseParseError), } diff --git a/scylla/src/transport/load_balancing/default.rs b/scylla/src/transport/load_balancing/default.rs index 51cc9fe6a7..20368cb968 100644 --- a/scylla/src/transport/load_balancing/default.rs +++ b/scylla/src/transport/load_balancing/default.rs @@ -2841,6 +2841,7 @@ mod latency_awareness { match error { // "fast" errors, i.e. ones that are returned quickly after the query begins QueryError::BadQuery(_) + | QueryError::CqlRequestSerialization(_) | QueryError::BrokenConnection(_) | QueryError::ConnectionPoolError(_) | QueryError::UnableToAllocStreamId @@ -2854,6 +2855,7 @@ mod latency_awareness { QueryError::DbError(_, _) | QueryError::CqlResultParseError(_) | QueryError::CqlErrorParseError(_) + | QueryError::BodyExtensionsParseError(_) | QueryError::InvalidMessage(_) | QueryError::ProtocolError(_) | QueryError::TimeoutError