diff --git a/scylla-cql/src/frame/mod.rs b/scylla-cql/src/frame/mod.rs index e84ca159e6..cec55d140f 100644 --- a/scylla-cql/src/frame/mod.rs +++ b/scylla-cql/src/frame/mod.rs @@ -188,7 +188,7 @@ pub fn parse_response_body_extensions( let trace_id = if flags & FLAG_TRACING != 0 { let buf = &mut &*body; - let trace_id = types::read_uuid(buf)?; + let trace_id = types::read_uuid(buf).map_err(frame_errors::ParseError::from)?; body.advance(16); Some(trace_id) } else { @@ -198,7 +198,7 @@ pub fn parse_response_body_extensions( let warnings = if flags & FLAG_WARNING != 0 { let body_len = body.len(); let buf = &mut &*body; - let warnings = types::read_string_list(buf)?; + let warnings = types::read_string_list(buf).map_err(frame_errors::ParseError::from)?; let buf_len = buf.len(); body.advance(body_len - buf_len); warnings @@ -209,7 +209,7 @@ pub fn parse_response_body_extensions( let custom_payload = if flags & FLAG_CUSTOM_PAYLOAD != 0 { let body_len = body.len(); let buf = &mut &*body; - let payload_map = types::read_bytes_map(buf)?; + let payload_map = types::read_bytes_map(buf).map_err(frame_errors::ParseError::from)?; let buf_len = buf.len(); body.advance(body_len - buf_len); Some(payload_map) diff --git a/scylla-cql/src/frame/response/result.rs b/scylla-cql/src/frame/response/result.rs index 3afe7335d3..2a69ab2d87 100644 --- a/scylla-cql/src/frame/response/result.rs +++ b/scylla-cql/src/frame/response/result.rs @@ -785,13 +785,15 @@ pub fn deser_cql_value(typ: &ColumnType, buf: &mut &[u8]) -> StdResult>()?; CqlValue::Tuple(t) diff --git a/scylla-cql/src/frame/types.rs b/scylla-cql/src/frame/types.rs index b079f97473..f0ca7cd938 100644 --- a/scylla-cql/src/frame/types.rs +++ b/scylla-cql/src/frame/types.rs @@ -1,5 +1,6 @@ //! CQL binary protocol in-wire types. +use super::frame_errors::LowLevelDeserializationError; use super::frame_errors::LowLevelSerializationError; use super::frame_errors::ParseError; use super::TryFromPrimitiveError; @@ -161,13 +162,15 @@ impl<'a> RawValue<'a> { } } -fn read_raw_bytes<'a>(count: usize, buf: &mut &'a [u8]) -> Result<&'a [u8], ParseError> { +fn read_raw_bytes<'a>( + count: usize, + buf: &mut &'a [u8], +) -> Result<&'a [u8], LowLevelDeserializationError> { if buf.len() < count { - return Err(ParseError::BadIncomingData(format!( - "Not enough bytes! expected: {} received: {}", - count, - buf.len(), - ))); + return Err(LowLevelDeserializationError::TooFewBytesReceived { + expected: count, + received: buf.len(), + }); } let (ret, rest) = buf.split_at(count); *buf = rest; @@ -183,7 +186,7 @@ pub fn write_int(v: i32, buf: &mut impl BufMut) { buf.put_i32(v); } -pub fn read_int_length(buf: &mut &[u8]) -> Result { +pub fn read_int_length(buf: &mut &[u8]) -> Result { let v = read_int(buf)?; let v: usize = v.try_into()?; @@ -258,7 +261,9 @@ fn type_short() { } // https://github.com/apache/cassandra/blob/trunk/doc/native_protocol_v4.spec#L208 -pub fn read_bytes_opt<'a>(buf: &mut &'a [u8]) -> Result, ParseError> { +pub fn read_bytes_opt<'a>( + buf: &mut &'a [u8], +) -> Result, LowLevelDeserializationError> { let len = read_int(buf)?; if len < 0 { return Ok(None); @@ -269,13 +274,13 @@ pub fn read_bytes_opt<'a>(buf: &mut &'a [u8]) -> Result, ParseE } // Same as read_bytes, but we assume the value won't be `null` -pub fn read_bytes<'a>(buf: &mut &'a [u8]) -> Result<&'a [u8], ParseError> { +pub fn read_bytes<'a>(buf: &mut &'a [u8]) -> Result<&'a [u8], LowLevelDeserializationError> { let len = read_int_length(buf)?; let v = read_raw_bytes(len, buf)?; Ok(v) } -pub fn read_value<'a>(buf: &mut &'a [u8]) -> Result, ParseError> { +pub fn read_value<'a>(buf: &mut &'a [u8]) -> Result, LowLevelDeserializationError> { let len = read_int(buf)?; match len { -2 => Ok(RawValue::Unset), @@ -284,14 +289,11 @@ pub fn read_value<'a>(buf: &mut &'a [u8]) -> Result, ParseError> { let v = read_raw_bytes(len as usize, buf)?; Ok(RawValue::Value(v)) } - len => Err(ParseError::BadIncomingData(format!( - "invalid value length: {}", - len, - ))), + len => Err(LowLevelDeserializationError::InvalidValueLength(len)), } } -pub fn read_short_bytes<'a>(buf: &mut &'a [u8]) -> Result<&'a [u8], ParseError> { +pub fn read_short_bytes<'a>(buf: &mut &'a [u8]) -> Result<&'a [u8], LowLevelDeserializationError> { let len = read_short_length(buf)?; let v = read_raw_bytes(len, buf)?; Ok(v) @@ -324,7 +326,9 @@ pub fn write_short_bytes(v: &[u8], buf: &mut impl BufMut) -> Result<(), std::num Ok(()) } -pub fn read_bytes_map(buf: &mut &[u8]) -> Result>, ParseError> { +pub fn read_bytes_map( + buf: &mut &[u8], +) -> Result>, LowLevelDeserializationError> { let len = read_short_length(buf)?; let mut v = HashMap::with_capacity(len); for _ in 0..len { @@ -362,7 +366,7 @@ fn type_bytes_map() { assert_eq!(read_bytes_map(&mut &*buf).unwrap(), val); } -pub fn read_string<'a>(buf: &mut &'a [u8]) -> Result<&'a str, ParseError> { +pub fn read_string<'a>(buf: &mut &'a [u8]) -> Result<&'a str, LowLevelDeserializationError> { let len = read_short_length(buf)?; let raw = read_raw_bytes(len, buf)?; let v = str::from_utf8(raw)?; @@ -386,7 +390,7 @@ fn type_string() { } } -pub fn read_long_string<'a>(buf: &mut &'a [u8]) -> Result<&'a str, ParseError> { +pub fn read_long_string<'a>(buf: &mut &'a [u8]) -> Result<&'a str, LowLevelDeserializationError> { let len = read_int_length(buf)?; let raw = read_raw_bytes(len, buf)?; let v = str::from_utf8(raw)?; @@ -411,7 +415,9 @@ fn type_long_string() { } } -pub fn read_string_map(buf: &mut &[u8]) -> Result, ParseError> { +pub fn read_string_map( + buf: &mut &[u8], +) -> Result, LowLevelDeserializationError> { let len = read_short_length(buf)?; let mut v = HashMap::with_capacity(len); for _ in 0..len { @@ -446,7 +452,7 @@ fn type_string_map() { assert_eq!(read_string_map(&mut &buf[..]).unwrap(), val); } -pub fn read_string_list(buf: &mut &[u8]) -> Result, ParseError> { +pub fn read_string_list(buf: &mut &[u8]) -> Result, LowLevelDeserializationError> { let len = read_short_length(buf)?; let mut v = Vec::with_capacity(len); for _ in 0..len { @@ -480,7 +486,9 @@ fn type_string_list() { assert_eq!(read_string_list(&mut &buf[..]).unwrap(), val); } -pub fn read_string_multimap(buf: &mut &[u8]) -> Result>, ParseError> { +pub fn read_string_multimap( + buf: &mut &[u8], +) -> Result>, LowLevelDeserializationError> { let len = read_short_length(buf)?; let mut v = HashMap::with_capacity(len); for _ in 0..len { @@ -518,7 +526,7 @@ fn type_string_multimap() { assert_eq!(read_string_multimap(&mut &buf[..]).unwrap(), val); } -pub fn read_uuid(buf: &mut &[u8]) -> Result { +pub fn read_uuid(buf: &mut &[u8]) -> Result { let raw = read_raw_bytes(16, buf)?; // It's safe to unwrap here because the conversion only fails @@ -542,10 +550,9 @@ fn type_uuid() { assert_eq!(u, u2); } -pub fn read_consistency(buf: &mut &[u8]) -> Result { +pub fn read_consistency(buf: &mut &[u8]) -> Result { let raw = read_short(buf)?; - Consistency::try_from(raw) - .map_err(|_| ParseError::BadIncomingData(format!("unknown consistency: {}", raw))) + Consistency::try_from(raw).map_err(LowLevelDeserializationError::UnknownConsistency) } pub fn write_consistency(c: Consistency, buf: &mut impl BufMut) { @@ -575,7 +582,7 @@ fn type_consistency() { assert!(err_str.contains(&format!("{}", c))); } -pub fn read_inet(buf: &mut &[u8]) -> Result { +pub fn read_inet(buf: &mut &[u8]) -> Result { let len = buf.read_u8()?; let ip_addr = match len { 4 => { @@ -588,12 +595,7 @@ pub fn read_inet(buf: &mut &[u8]) -> Result { buf.advance(16); ret } - v => { - return Err(ParseError::BadIncomingData(format!( - "Invalid inet bytes length: {}", - v - ))) - } + v => return Err(LowLevelDeserializationError::InvalidInetLength(v)), }; let port = read_int(buf)?; diff --git a/scylla-proxy/src/proxy.rs b/scylla-proxy/src/proxy.rs index 041e42752b..7a79d7a9d1 100644 --- a/scylla-proxy/src/proxy.rs +++ b/scylla-proxy/src/proxy.rs @@ -5,6 +5,7 @@ use crate::frame::{ }; use crate::{RequestOpcode, TargetShard}; use bytes::Bytes; +use scylla_cql::frame::frame_errors::ParseError; use scylla_cql::frame::types::read_string_multimap; use std::collections::HashMap; use std::fmt::Display; @@ -803,6 +804,7 @@ impl Doorkeeper { .map_err(DoorkeeperError::ObtainingShardNumberFrame)?; let options = read_string_multimap(&mut supported_frame.body.as_ref()) + .map_err(ParseError::from) .map_err(DoorkeeperError::ObtainingShardNumberParseOptions)?; Ok(options)