diff --git a/scylla-cql/src/frame/frame_errors.rs b/scylla-cql/src/frame/frame_errors.rs index 68757331fc..491ebda82a 100644 --- a/scylla-cql/src/frame/frame_errors.rs +++ b/scylla-cql/src/frame/frame_errors.rs @@ -35,6 +35,8 @@ pub enum FrameError { #[derive(Error, Debug)] pub enum ParseError { + #[error("Low-level deserialization failed: {0}")] + LowLevelDeserializationError(#[from] LowLevelDeserializationError), #[error("Could not serialize frame: {0}")] BadDataToSerialize(String), #[error("Could not deserialize frame: {0}")] @@ -52,3 +54,34 @@ pub enum ParseError { #[error(transparent)] CqlTypeError(#[from] CqlTypeError), } + +/// A low level deserialization error. +/// +/// This type of error is returned when deserialization +/// of some primitive value fails. +/// +/// Possible error kinds: +/// - generic io error - reading from buffer failed +/// - out of range integer conversion +/// - conversion errors - e.g. slice-to-array or primitive-to-enum +/// - not enough bytes in the buffer to deserialize a value +#[non_exhaustive] +#[derive(Error, Debug)] +pub enum LowLevelDeserializationError { + #[error(transparent)] + IoError(#[from] std::io::Error), + #[error(transparent)] + TryFromIntError(#[from] std::num::TryFromIntError), + #[error("Failed to convert slice into array: {0}")] + TryFromSliceError(#[from] std::array::TryFromSliceError), + #[error("Not enough bytes! expected: {expected}, received: {received}")] + TooFewBytesReceived { expected: usize, received: usize }, + #[error("Invalid value length: {0}")] + InvalidValueLength(i32), + #[error("Unknown consistency: {0}")] + UnknownConsistency(#[from] TryFromPrimitiveError), + #[error("Invalid inet bytes length: {0}")] + InvalidInetLength(u8), + #[error("UTF8 deserialization failed: {0}")] + UTF8DeserializationError(#[from] std::str::Utf8Error), +} diff --git a/scylla-cql/src/frame/mod.rs b/scylla-cql/src/frame/mod.rs index e84ca159e6..cec55d140f 100644 --- a/scylla-cql/src/frame/mod.rs +++ b/scylla-cql/src/frame/mod.rs @@ -188,7 +188,7 @@ pub fn parse_response_body_extensions( let trace_id = if flags & FLAG_TRACING != 0 { let buf = &mut &*body; - let trace_id = types::read_uuid(buf)?; + let trace_id = types::read_uuid(buf).map_err(frame_errors::ParseError::from)?; body.advance(16); Some(trace_id) } else { @@ -198,7 +198,7 @@ pub fn parse_response_body_extensions( let warnings = if flags & FLAG_WARNING != 0 { let body_len = body.len(); let buf = &mut &*body; - let warnings = types::read_string_list(buf)?; + let warnings = types::read_string_list(buf).map_err(frame_errors::ParseError::from)?; let buf_len = buf.len(); body.advance(body_len - buf_len); warnings @@ -209,7 +209,7 @@ pub fn parse_response_body_extensions( let custom_payload = if flags & FLAG_CUSTOM_PAYLOAD != 0 { let body_len = body.len(); let buf = &mut &*body; - let payload_map = types::read_bytes_map(buf)?; + let payload_map = types::read_bytes_map(buf).map_err(frame_errors::ParseError::from)?; let buf_len = buf.len(); body.advance(body_len - buf_len); Some(payload_map) diff --git a/scylla-cql/src/frame/request/auth_response.rs b/scylla-cql/src/frame/request/auth_response.rs index dabbf20d34..83d718ee59 100644 --- a/scylla-cql/src/frame/request/auth_response.rs +++ b/scylla-cql/src/frame/request/auth_response.rs @@ -12,6 +12,6 @@ impl SerializableRequest for AuthResponse { const OPCODE: RequestOpcode = RequestOpcode::AuthResponse; fn serialize(&self, buf: &mut Vec) -> Result<(), ParseError> { - write_bytes_opt(self.response.as_ref(), buf) + Ok(write_bytes_opt(self.response.as_ref(), buf)?) } } diff --git a/scylla-cql/src/frame/response/result.rs b/scylla-cql/src/frame/response/result.rs index 3afe7335d3..87bbaaeaa6 100644 --- a/scylla-cql/src/frame/response/result.rs +++ b/scylla-cql/src/frame/response/result.rs @@ -5,7 +5,9 @@ use crate::frame::value::{ }; use crate::frame::{frame_errors::ParseError, types}; use crate::types::deserialize::result::{RowIterator, TypedRowIterator}; -use crate::types::deserialize::value::{DeserializeValue, MapIterator, UdtIterator}; +use crate::types::deserialize::value::{ + mk_deser_err, BuiltinDeserializationErrorKind, DeserializeValue, MapIterator, UdtIterator, +}; use crate::types::deserialize::{DeserializationError, FrameSlice}; use bytes::{Buf, Bytes}; use std::borrow::Cow; @@ -643,7 +645,10 @@ fn deser_prepared_metadata(buf: &mut &[u8]) -> StdResult StdResult { +pub fn deser_cql_value( + typ: &ColumnType, + buf: &mut &[u8], +) -> StdResult { use ColumnType::*; if buf.is_empty() { @@ -662,10 +667,10 @@ pub fn deser_cql_value(typ: &ColumnType, buf: &mut &[u8]) -> StdResult { - return Err(ParseError::BadIncomingData(format!( - "Support for custom types is not yet implemented: {}", - type_str - ))); + return Err(mk_deser_err::( + typ, + BuiltinDeserializationErrorKind::CustomTypeNotSupported(type_str.to_string()), + )) } Ascii => { let s = String::deserialize(typ, v)?; @@ -784,14 +789,15 @@ pub fn deser_cql_value(typ: &ColumnType, buf: &mut &[u8]) -> StdResult { let t = type_names .iter() - .map(|typ| { - types::read_bytes_opt(buf).and_then(|v| { - v.map(|v| { - CqlValue::deserialize(typ, Some(FrameSlice::new_borrowed(v))) - .map_err(Into::into) - }) + .map(|typ| -> StdResult<_, DeserializationError> { + let raw = types::read_bytes_opt(buf).map_err(|e| { + mk_deser_err::( + typ, + BuiltinDeserializationErrorKind::RawCqlBytesReadError(e), + ) + })?; + raw.map(|v| CqlValue::deserialize(typ, Some(FrameSlice::new_borrowed(v)))) .transpose() - }) }) .collect::>()?; CqlValue::Tuple(t) diff --git a/scylla-cql/src/frame/types.rs b/scylla-cql/src/frame/types.rs index de311ac63c..77497cd37a 100644 --- a/scylla-cql/src/frame/types.rs +++ b/scylla-cql/src/frame/types.rs @@ -1,5 +1,6 @@ //! CQL binary protocol in-wire types. +use super::frame_errors::LowLevelDeserializationError; use super::frame_errors::ParseError; use super::TryFromPrimitiveError; use byteorder::{BigEndian, ReadBytesExt}; @@ -160,13 +161,15 @@ impl<'a> RawValue<'a> { } } -fn read_raw_bytes<'a>(count: usize, buf: &mut &'a [u8]) -> Result<&'a [u8], ParseError> { +fn read_raw_bytes<'a>( + count: usize, + buf: &mut &'a [u8], +) -> Result<&'a [u8], LowLevelDeserializationError> { if buf.len() < count { - return Err(ParseError::BadIncomingData(format!( - "Not enough bytes! expected: {} received: {}", - count, - buf.len(), - ))); + return Err(LowLevelDeserializationError::TooFewBytesReceived { + expected: count, + received: buf.len(), + }); } let (ret, rest) = buf.split_at(count); *buf = rest; @@ -182,14 +185,14 @@ pub fn write_int(v: i32, buf: &mut impl BufMut) { buf.put_i32(v); } -pub fn read_int_length(buf: &mut &[u8]) -> Result { +pub fn read_int_length(buf: &mut &[u8]) -> Result { let v = read_int(buf)?; let v: usize = v.try_into()?; Ok(v) } -fn write_int_length(v: usize, buf: &mut impl BufMut) -> Result<(), ParseError> { +fn write_int_length(v: usize, buf: &mut impl BufMut) -> Result<(), std::num::TryFromIntError> { let v: i32 = v.try_into()?; write_int(v, buf); @@ -240,7 +243,7 @@ pub(crate) fn read_short_length(buf: &mut &[u8]) -> Result Result<(), ParseError> { +fn write_short_length(v: usize, buf: &mut impl BufMut) -> Result<(), std::num::TryFromIntError> { let v: u16 = v.try_into()?; write_short(v, buf); Ok(()) @@ -257,7 +260,9 @@ fn type_short() { } // https://github.com/apache/cassandra/blob/trunk/doc/native_protocol_v4.spec#L208 -pub fn read_bytes_opt<'a>(buf: &mut &'a [u8]) -> Result, ParseError> { +pub fn read_bytes_opt<'a>( + buf: &mut &'a [u8], +) -> Result, LowLevelDeserializationError> { let len = read_int(buf)?; if len < 0 { return Ok(None); @@ -268,13 +273,13 @@ pub fn read_bytes_opt<'a>(buf: &mut &'a [u8]) -> Result, ParseE } // Same as read_bytes, but we assume the value won't be `null` -pub fn read_bytes<'a>(buf: &mut &'a [u8]) -> Result<&'a [u8], ParseError> { +pub fn read_bytes<'a>(buf: &mut &'a [u8]) -> Result<&'a [u8], LowLevelDeserializationError> { let len = read_int_length(buf)?; let v = read_raw_bytes(len, buf)?; Ok(v) } -pub fn read_value<'a>(buf: &mut &'a [u8]) -> Result, ParseError> { +pub fn read_value<'a>(buf: &mut &'a [u8]) -> Result, LowLevelDeserializationError> { let len = read_int(buf)?; match len { -2 => Ok(RawValue::Unset), @@ -283,20 +288,17 @@ pub fn read_value<'a>(buf: &mut &'a [u8]) -> Result, ParseError> { let v = read_raw_bytes(len as usize, buf)?; Ok(RawValue::Value(v)) } - len => Err(ParseError::BadIncomingData(format!( - "invalid value length: {}", - len, - ))), + len => Err(LowLevelDeserializationError::InvalidValueLength(len)), } } -pub fn read_short_bytes<'a>(buf: &mut &'a [u8]) -> Result<&'a [u8], ParseError> { +pub fn read_short_bytes<'a>(buf: &mut &'a [u8]) -> Result<&'a [u8], LowLevelDeserializationError> { let len = read_short_length(buf)?; let v = read_raw_bytes(len, buf)?; Ok(v) } -pub fn write_bytes(v: &[u8], buf: &mut impl BufMut) -> Result<(), ParseError> { +pub fn write_bytes(v: &[u8], buf: &mut impl BufMut) -> Result<(), std::num::TryFromIntError> { write_int_length(v.len(), buf)?; buf.put_slice(v); Ok(()) @@ -305,7 +307,7 @@ pub fn write_bytes(v: &[u8], buf: &mut impl BufMut) -> Result<(), ParseError> { pub fn write_bytes_opt( v: Option>, buf: &mut impl BufMut, -) -> Result<(), ParseError> { +) -> Result<(), std::num::TryFromIntError> { match v { Some(bytes) => { write_int_length(bytes.as_ref().len(), buf)?; @@ -317,13 +319,15 @@ pub fn write_bytes_opt( Ok(()) } -pub fn write_short_bytes(v: &[u8], buf: &mut impl BufMut) -> Result<(), ParseError> { +pub fn write_short_bytes(v: &[u8], buf: &mut impl BufMut) -> Result<(), std::num::TryFromIntError> { write_short_length(v.len(), buf)?; buf.put_slice(v); Ok(()) } -pub fn read_bytes_map(buf: &mut &[u8]) -> Result>, ParseError> { +pub fn read_bytes_map( + buf: &mut &[u8], +) -> Result>, LowLevelDeserializationError> { let len = read_short_length(buf)?; let mut v = HashMap::with_capacity(len); for _ in 0..len { @@ -334,7 +338,10 @@ pub fn read_bytes_map(buf: &mut &[u8]) -> Result>, Parse Ok(v) } -pub fn write_bytes_map(v: &HashMap, buf: &mut impl BufMut) -> Result<(), ParseError> +pub fn write_bytes_map( + v: &HashMap, + buf: &mut impl BufMut, +) -> Result<(), std::num::TryFromIntError> where B: AsRef<[u8]>, { @@ -358,14 +365,14 @@ fn type_bytes_map() { assert_eq!(read_bytes_map(&mut &*buf).unwrap(), val); } -pub fn read_string<'a>(buf: &mut &'a [u8]) -> Result<&'a str, ParseError> { +pub fn read_string<'a>(buf: &mut &'a [u8]) -> Result<&'a str, LowLevelDeserializationError> { let len = read_short_length(buf)?; let raw = read_raw_bytes(len, buf)?; let v = str::from_utf8(raw)?; Ok(v) } -pub fn write_string(v: &str, buf: &mut impl BufMut) -> Result<(), ParseError> { +pub fn write_string(v: &str, buf: &mut impl BufMut) -> Result<(), std::num::TryFromIntError> { let raw = v.as_bytes(); write_short_length(v.len(), buf)?; buf.put_slice(raw); @@ -382,14 +389,14 @@ fn type_string() { } } -pub fn read_long_string<'a>(buf: &mut &'a [u8]) -> Result<&'a str, ParseError> { +pub fn read_long_string<'a>(buf: &mut &'a [u8]) -> Result<&'a str, LowLevelDeserializationError> { let len = read_int_length(buf)?; let raw = read_raw_bytes(len, buf)?; let v = str::from_utf8(raw)?; Ok(v) } -pub fn write_long_string(v: &str, buf: &mut impl BufMut) -> Result<(), ParseError> { +pub fn write_long_string(v: &str, buf: &mut impl BufMut) -> Result<(), std::num::TryFromIntError> { let raw = v.as_bytes(); let len = raw.len(); write_int_length(len, buf)?; @@ -407,7 +414,9 @@ fn type_long_string() { } } -pub fn read_string_map(buf: &mut &[u8]) -> Result, ParseError> { +pub fn read_string_map( + buf: &mut &[u8], +) -> Result, LowLevelDeserializationError> { let len = read_short_length(buf)?; let mut v = HashMap::with_capacity(len); for _ in 0..len { @@ -421,7 +430,7 @@ pub fn read_string_map(buf: &mut &[u8]) -> Result, Parse pub fn write_string_map( v: &HashMap, buf: &mut impl BufMut, -) -> Result<(), ParseError> { +) -> Result<(), std::num::TryFromIntError> { let len = v.len(); write_short_length(len, buf)?; for (key, val) in v.iter() { @@ -442,7 +451,7 @@ fn type_string_map() { assert_eq!(read_string_map(&mut &buf[..]).unwrap(), val); } -pub fn read_string_list(buf: &mut &[u8]) -> Result, ParseError> { +pub fn read_string_list(buf: &mut &[u8]) -> Result, LowLevelDeserializationError> { let len = read_short_length(buf)?; let mut v = Vec::with_capacity(len); for _ in 0..len { @@ -451,7 +460,10 @@ pub fn read_string_list(buf: &mut &[u8]) -> Result, ParseError> { Ok(v) } -pub fn write_string_list(v: &[String], buf: &mut impl BufMut) -> Result<(), ParseError> { +pub fn write_string_list( + v: &[String], + buf: &mut impl BufMut, +) -> Result<(), std::num::TryFromIntError> { let len = v.len(); write_short_length(len, buf)?; for v in v.iter() { @@ -473,7 +485,9 @@ fn type_string_list() { assert_eq!(read_string_list(&mut &buf[..]).unwrap(), val); } -pub fn read_string_multimap(buf: &mut &[u8]) -> Result>, ParseError> { +pub fn read_string_multimap( + buf: &mut &[u8], +) -> Result>, LowLevelDeserializationError> { let len = read_short_length(buf)?; let mut v = HashMap::with_capacity(len); for _ in 0..len { @@ -487,7 +501,7 @@ pub fn read_string_multimap(buf: &mut &[u8]) -> Result>, buf: &mut impl BufMut, -) -> Result<(), ParseError> { +) -> Result<(), std::num::TryFromIntError> { let len = v.len(); write_short_length(len, buf)?; for (key, val) in v.iter() { @@ -511,7 +525,7 @@ fn type_string_multimap() { assert_eq!(read_string_multimap(&mut &buf[..]).unwrap(), val); } -pub fn read_uuid(buf: &mut &[u8]) -> Result { +pub fn read_uuid(buf: &mut &[u8]) -> Result { let raw = read_raw_bytes(16, buf)?; // It's safe to unwrap here because the conversion only fails @@ -535,10 +549,9 @@ fn type_uuid() { assert_eq!(u, u2); } -pub fn read_consistency(buf: &mut &[u8]) -> Result { +pub fn read_consistency(buf: &mut &[u8]) -> Result { let raw = read_short(buf)?; - Consistency::try_from(raw) - .map_err(|_| ParseError::BadIncomingData(format!("unknown consistency: {}", raw))) + Consistency::try_from(raw).map_err(LowLevelDeserializationError::UnknownConsistency) } pub fn write_consistency(c: Consistency, buf: &mut impl BufMut) { @@ -568,7 +581,7 @@ fn type_consistency() { assert!(err_str.contains(&format!("{}", c))); } -pub fn read_inet(buf: &mut &[u8]) -> Result { +pub fn read_inet(buf: &mut &[u8]) -> Result { let len = buf.read_u8()?; let ip_addr = match len { 4 => { @@ -581,12 +594,7 @@ pub fn read_inet(buf: &mut &[u8]) -> Result { buf.advance(16); ret } - v => { - return Err(ParseError::BadIncomingData(format!( - "Invalid inet bytes length: {}", - v - ))) - } + v => return Err(LowLevelDeserializationError::InvalidInetLength(v)), }; let port = read_int(buf)?; diff --git a/scylla-cql/src/types/deserialize/frame_slice.rs b/scylla-cql/src/types/deserialize/frame_slice.rs index cfc98d5ce5..4471960a03 100644 --- a/scylla-cql/src/types/deserialize/frame_slice.rs +++ b/scylla-cql/src/types/deserialize/frame_slice.rs @@ -1,6 +1,6 @@ use bytes::Bytes; -use crate::frame::frame_errors::ParseError; +use crate::frame::frame_errors::LowLevelDeserializationError; use crate::frame::types; /// A reference to a part of the frame. @@ -139,7 +139,9 @@ impl<'frame> FrameSlice<'frame> { /// /// If the operation fails then the slice remains unchanged. #[inline] - pub(super) fn read_cql_bytes(&mut self) -> Result>, ParseError> { + pub(super) fn read_cql_bytes( + &mut self, + ) -> Result>, LowLevelDeserializationError> { // We copy the slice reference, not to mutate the FrameSlice in case of an error. let mut slice = self.frame_subslice; diff --git a/scylla-cql/src/types/deserialize/row.rs b/scylla-cql/src/types/deserialize/row.rs index c66f3c7328..5dfec4b12a 100644 --- a/scylla-cql/src/types/deserialize/row.rs +++ b/scylla-cql/src/types/deserialize/row.rs @@ -416,7 +416,6 @@ mod tests { use assert_matches::assert_matches; use bytes::Bytes; - use crate::frame::frame_errors::ParseError; use crate::frame::response::result::{ColumnSpec, ColumnType}; use crate::types::deserialize::row::BuiltinDeserializationErrorKind; use crate::types::deserialize::{DeserializationError, FrameSlice}; @@ -651,13 +650,6 @@ mod tests { let err = super::super::value::tests::get_deser_err(err); assert_eq!(err.rust_name, std::any::type_name::()); assert_eq!(err.cql_type, ColumnType::BigInt); - let super::super::value::BuiltinDeserializationErrorKind::GenericParseError( - ParseError::DeserializationError(d), - ) = &err.kind - else { - panic!("unexpected error kind: {}", err.kind) - }; - let err = super::super::value::tests::get_deser_err(d); let super::super::value::BuiltinDeserializationErrorKind::ByteLengthMismatch { expected: 8, got: 4, diff --git a/scylla-cql/src/types/deserialize/value.rs b/scylla-cql/src/types/deserialize/value.rs index f87aed66e6..8431ea17cc 100644 --- a/scylla-cql/src/types/deserialize/value.rs +++ b/scylla-cql/src/types/deserialize/value.rs @@ -14,7 +14,7 @@ use std::fmt::Display; use thiserror::Error; use super::{make_error_replace_rust_name, DeserializationError, FrameSlice, TypeCheckError}; -use crate::frame::frame_errors::ParseError; +use crate::frame::frame_errors::LowLevelDeserializationError; use crate::frame::response::result::{deser_cql_value, ColumnType, CqlValue}; use crate::frame::types; use crate::frame::value::{ @@ -60,9 +60,7 @@ impl<'frame> DeserializeValue<'frame> for CqlValue { v: Option>, ) -> Result { let mut val = ensure_not_null_slice::(typ, v)?; - let cql = deser_cql_value(typ, &mut val).map_err(|err| { - mk_deser_err::(typ, BuiltinDeserializationErrorKind::GenericParseError(err)) - })?; + let cql = deser_cql_value(typ, &mut val).map_err(deser_error_replace_rust_name::)?; Ok(cql) } } @@ -249,7 +247,7 @@ impl_emptiable_strict_type!( let scale = types::read_int(&mut val).map_err(|err| { mk_deser_err::( typ, - BuiltinDeserializationErrorKind::GenericParseError(err.into()), + BuiltinDeserializationErrorKind::BadDecimalScale(err.into()), ) })?; Ok(CqlDecimal::from_signed_be_bytes_slice_and_exponent( @@ -267,7 +265,7 @@ impl_emptiable_strict_type!( let scale = types::read_int(&mut val).map_err(|err| { mk_deser_err::( typ, - BuiltinDeserializationErrorKind::GenericParseError(err.into()), + BuiltinDeserializationErrorKind::BadDecimalScale(err.into()), ) })? as i64; let int_value = bigdecimal_04::num_bigint::BigInt::from_signed_bytes_be(val); @@ -381,25 +379,28 @@ impl_strict_type!( } let months_i64 = types::vint_decode(&mut val).map_err(|err| { - mk_err!(BuiltinDeserializationErrorKind::GenericParseError( - err.into() - )) + mk_err!(BuiltinDeserializationErrorKind::BadDate { + date_field: "months", + err: err.into() + }) })?; let months = i32::try_from(months_i64) .map_err(|_| mk_err!(BuiltinDeserializationErrorKind::ValueOverflow))?; let days_i64 = types::vint_decode(&mut val).map_err(|err| { - mk_err!(BuiltinDeserializationErrorKind::GenericParseError( - err.into() - )) + mk_err!(BuiltinDeserializationErrorKind::BadDate { + date_field: "days", + err: err.into() + }) })?; let days = i32::try_from(days_i64) .map_err(|_| mk_err!(BuiltinDeserializationErrorKind::ValueOverflow))?; let nanoseconds = types::vint_decode(&mut val).map_err(|err| { - mk_err!(BuiltinDeserializationErrorKind::GenericParseError( - err.into() - )) + mk_err!(BuiltinDeserializationErrorKind::BadDate { + date_field: "nanoseconds", + err: err.into() + }) })?; Ok(CqlDuration { @@ -727,7 +728,7 @@ where let raw = self.raw_iter.next()?.map_err(|err| { mk_deser_err::( self.coll_typ, - BuiltinDeserializationErrorKind::GenericParseError(err), + BuiltinDeserializationErrorKind::RawCqlBytesReadError(err), ) }); Some(raw.and_then(|raw| { @@ -906,7 +907,7 @@ where Some(Err(err)) => { return Some(Err(mk_deser_err::( self.coll_typ, - BuiltinDeserializationErrorKind::GenericParseError(err), + BuiltinDeserializationErrorKind::RawCqlBytesReadError(err), ))); } None => return None, @@ -916,7 +917,7 @@ where Some(Err(err)) => { return Some(Err(mk_deser_err::( self.coll_typ, - BuiltinDeserializationErrorKind::GenericParseError(err), + BuiltinDeserializationErrorKind::RawCqlBytesReadError(err), ))); } None => return None, @@ -1184,7 +1185,7 @@ impl<'frame> Iterator for UdtIterator<'frame> { keyspace: self.keyspace.to_owned(), field_types: self.all_fields.to_owned(), }, - BuiltinDeserializationErrorKind::GenericParseError(err), + BuiltinDeserializationErrorKind::RawCqlBytesReadError(err), )), // The field is just missing from the serialized form @@ -1277,7 +1278,7 @@ impl<'frame> FixedLengthBytesSequenceIterator<'frame> { } impl<'frame> Iterator for FixedLengthBytesSequenceIterator<'frame> { - type Item = Result>, ParseError>; + type Item = Result>, LowLevelDeserializationError>; fn next(&mut self) -> Option { self.remaining = self.remaining.checked_sub(1)?; @@ -1307,7 +1308,7 @@ impl<'frame> From> for BytesSequenceIterator<'frame> { } impl<'frame> Iterator for BytesSequenceIterator<'frame> { - type Item = Result>, ParseError>; + type Item = Result>, LowLevelDeserializationError>; fn next(&mut self) -> Option { if self.slice.as_slice().is_empty() { @@ -1573,7 +1574,7 @@ pub struct BuiltinDeserializationError { pub kind: BuiltinDeserializationErrorKind, } -fn mk_deser_err( +pub(crate) fn mk_deser_err( cql_type: &ColumnType, kind: impl Into, ) -> DeserializationError { @@ -1596,8 +1597,20 @@ fn mk_deser_err_named( #[derive(Debug)] #[non_exhaustive] pub enum BuiltinDeserializationErrorKind { - /// A generic deserialization failure - legacy error type. - GenericParseError(ParseError), + /// Failed to deserialize one of date's fields. + BadDate { + date_field: &'static str, + err: LowLevelDeserializationError, + }, + + /// Failed to deserialize decimal's scale. + BadDecimalScale(LowLevelDeserializationError), + + /// Failed to deserialize raw bytes of cql value. + RawCqlBytesReadError(LowLevelDeserializationError), + + /// Returned on attempt to deserialize a value of custom type. + CustomTypeNotSupported(String), /// Expected non-null value, got null. ExpectedNonNull, @@ -1631,7 +1644,9 @@ pub enum BuiltinDeserializationErrorKind { impl Display for BuiltinDeserializationErrorKind { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - BuiltinDeserializationErrorKind::GenericParseError(err) => err.fmt(f), + BuiltinDeserializationErrorKind::BadDate { date_field, err } => write!(f, "malformed {} during 'date' deserialization: {}", date_field, err), + BuiltinDeserializationErrorKind::BadDecimalScale(err) => write!(f, "malformed decimal's scale: {}", err), + BuiltinDeserializationErrorKind::RawCqlBytesReadError(err) => write!(f, "failed to read raw cql value bytes: {}", err), BuiltinDeserializationErrorKind::ExpectedNonNull => { f.write_str("expected a non-null value, got null") } @@ -1656,6 +1671,7 @@ impl Display for BuiltinDeserializationErrorKind { BuiltinDeserializationErrorKind::SetOrListError(err) => err.fmt(f), BuiltinDeserializationErrorKind::MapError(err) => err.fmt(f), BuiltinDeserializationErrorKind::TupleError(err) => err.fmt(f), + BuiltinDeserializationErrorKind::CustomTypeNotSupported(typ) => write!(f, "Support for custom types is not yet implemented: {}", typ), } } } @@ -2356,7 +2372,10 @@ pub(super) mod tests { .map_err(|typecheck_err| DeserializationError(typecheck_err.0))?; let mut frame_slice = FrameSlice::new(bytes); let value = frame_slice.read_cql_bytes().map_err(|err| { - mk_deser_err::(typ, BuiltinDeserializationErrorKind::GenericParseError(err)) + mk_deser_err::( + typ, + BuiltinDeserializationErrorKind::RawCqlBytesReadError(err), + ) })?; >::deserialize(typ, value) } diff --git a/scylla-proxy/src/proxy.rs b/scylla-proxy/src/proxy.rs index 041e42752b..7a79d7a9d1 100644 --- a/scylla-proxy/src/proxy.rs +++ b/scylla-proxy/src/proxy.rs @@ -5,6 +5,7 @@ use crate::frame::{ }; use crate::{RequestOpcode, TargetShard}; use bytes::Bytes; +use scylla_cql::frame::frame_errors::ParseError; use scylla_cql::frame::types::read_string_multimap; use std::collections::HashMap; use std::fmt::Display; @@ -803,6 +804,7 @@ impl Doorkeeper { .map_err(DoorkeeperError::ObtainingShardNumberFrame)?; let options = read_string_multimap(&mut supported_frame.body.as_ref()) + .map_err(ParseError::from) .map_err(DoorkeeperError::ObtainingShardNumberParseOptions)?; Ok(options) diff --git a/scylla/src/transport/locator/tablets.rs b/scylla/src/transport/locator/tablets.rs index 14a3ea2b32..c946100933 100644 --- a/scylla/src/transport/locator/tablets.rs +++ b/scylla/src/transport/locator/tablets.rs @@ -1,8 +1,8 @@ use itertools::Itertools; use lazy_static::lazy_static; use scylla_cql::cql_to_rust::FromCqlVal; -use scylla_cql::frame::frame_errors::ParseError; use scylla_cql::frame::response::result::{deser_cql_value, ColumnType, TableSpec}; +use scylla_cql::types::deserialize::DeserializationError; use thiserror::Error; use tracing::warn; use uuid::Uuid; @@ -16,7 +16,7 @@ use std::sync::Arc; #[derive(Error, Debug)] pub(crate) enum TabletParsingError { #[error(transparent)] - Parse(#[from] ParseError), + Deserialization(#[from] DeserializationError), #[error("Shard id for tablet is negative: {0}")] ShardNum(i32), } @@ -616,7 +616,7 @@ mod tests { HashMap::from([(CUSTOM_PAYLOAD_TABLETS_V1_KEY.to_string(), vec![1, 2, 3])]); assert_matches::assert_matches!( RawTablet::from_custom_payload(&custom_payload), - Some(Err(TabletParsingError::Parse(_))) + Some(Err(TabletParsingError::Deserialization(_))) ); } @@ -646,7 +646,7 @@ mod tests { assert_matches::assert_matches!( RawTablet::from_custom_payload(&custom_payload), - Some(Err(TabletParsingError::Parse(_))) + Some(Err(TabletParsingError::Deserialization(_))) ); }