Skip to content

Commit

Permalink
types: LowLevelDeserializationError
Browse files Browse the repository at this point in the history
  • Loading branch information
muzarski committed Jun 20, 2024
1 parent 4997853 commit d9b1247
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 40 deletions.
6 changes: 3 additions & 3 deletions scylla-cql/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ pub fn parse_response_body_extensions(

let trace_id = if flags & FLAG_TRACING != 0 {
let buf = &mut &*body;
let trace_id = types::read_uuid(buf)?;
let trace_id = types::read_uuid(buf).map_err(frame_errors::ParseError::from)?;
body.advance(16);
Some(trace_id)
} else {
Expand All @@ -198,7 +198,7 @@ pub fn parse_response_body_extensions(
let warnings = if flags & FLAG_WARNING != 0 {
let body_len = body.len();
let buf = &mut &*body;
let warnings = types::read_string_list(buf)?;
let warnings = types::read_string_list(buf).map_err(frame_errors::ParseError::from)?;
let buf_len = buf.len();
body.advance(body_len - buf_len);
warnings
Expand All @@ -209,7 +209,7 @@ pub fn parse_response_body_extensions(
let custom_payload = if flags & FLAG_CUSTOM_PAYLOAD != 0 {
let body_len = body.len();
let buf = &mut &*body;
let payload_map = types::read_bytes_map(buf)?;
let payload_map = types::read_bytes_map(buf).map_err(frame_errors::ParseError::from)?;
let buf_len = buf.len();
body.advance(body_len - buf_len);
Some(payload_map)
Expand Down
14 changes: 8 additions & 6 deletions scylla-cql/src/frame/response/result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -785,13 +785,15 @@ pub fn deser_cql_value(typ: &ColumnType, buf: &mut &[u8]) -> StdResult<CqlValue,
let t = type_names
.iter()
.map(|typ| {
types::read_bytes_opt(buf).and_then(|v| {
v.map(|v| {
CqlValue::deserialize(typ, Some(FrameSlice::new_borrowed(v)))
.map_err(Into::into)
types::read_bytes_opt(buf)
.map_err(ParseError::from)
.and_then(|v| {
v.map(|v| {
CqlValue::deserialize(typ, Some(FrameSlice::new_borrowed(v)))
.map_err(Into::into)
})
.transpose()
})
.transpose()
})
})
.collect::<StdResult<_, _>>()?;
CqlValue::Tuple(t)
Expand Down
65 changes: 34 additions & 31 deletions scylla-cql/src/frame/types.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//! CQL binary protocol in-wire types.
use super::frame_errors::LowLevelDeserializationError;
use super::frame_errors::LowLevelSerializationError;
use super::frame_errors::ParseError;
use super::TryFromPrimitiveError;
Expand Down Expand Up @@ -161,13 +162,15 @@ impl<'a> RawValue<'a> {
}
}

fn read_raw_bytes<'a>(count: usize, buf: &mut &'a [u8]) -> Result<&'a [u8], ParseError> {
fn read_raw_bytes<'a>(
count: usize,
buf: &mut &'a [u8],
) -> Result<&'a [u8], LowLevelDeserializationError> {
if buf.len() < count {
return Err(ParseError::BadIncomingData(format!(
"Not enough bytes! expected: {} received: {}",
count,
buf.len(),
)));
return Err(LowLevelDeserializationError::TooFewBytesReceived {
expected: count,
received: buf.len(),
});
}
let (ret, rest) = buf.split_at(count);
*buf = rest;
Expand All @@ -183,7 +186,7 @@ pub fn write_int(v: i32, buf: &mut impl BufMut) {
buf.put_i32(v);
}

pub fn read_int_length(buf: &mut &[u8]) -> Result<usize, ParseError> {
pub fn read_int_length(buf: &mut &[u8]) -> Result<usize, LowLevelDeserializationError> {
let v = read_int(buf)?;
let v: usize = v.try_into()?;

Expand Down Expand Up @@ -258,7 +261,9 @@ fn type_short() {
}

// https://github.com/apache/cassandra/blob/trunk/doc/native_protocol_v4.spec#L208
pub fn read_bytes_opt<'a>(buf: &mut &'a [u8]) -> Result<Option<&'a [u8]>, ParseError> {
pub fn read_bytes_opt<'a>(
buf: &mut &'a [u8],
) -> Result<Option<&'a [u8]>, LowLevelDeserializationError> {
let len = read_int(buf)?;
if len < 0 {
return Ok(None);
Expand All @@ -269,13 +274,13 @@ pub fn read_bytes_opt<'a>(buf: &mut &'a [u8]) -> Result<Option<&'a [u8]>, ParseE
}

// Same as read_bytes, but we assume the value won't be `null`
pub fn read_bytes<'a>(buf: &mut &'a [u8]) -> Result<&'a [u8], ParseError> {
pub fn read_bytes<'a>(buf: &mut &'a [u8]) -> Result<&'a [u8], LowLevelDeserializationError> {
let len = read_int_length(buf)?;
let v = read_raw_bytes(len, buf)?;
Ok(v)
}

pub fn read_value<'a>(buf: &mut &'a [u8]) -> Result<RawValue<'a>, ParseError> {
pub fn read_value<'a>(buf: &mut &'a [u8]) -> Result<RawValue<'a>, LowLevelDeserializationError> {
let len = read_int(buf)?;
match len {
-2 => Ok(RawValue::Unset),
Expand All @@ -284,14 +289,11 @@ pub fn read_value<'a>(buf: &mut &'a [u8]) -> Result<RawValue<'a>, ParseError> {
let v = read_raw_bytes(len as usize, buf)?;
Ok(RawValue::Value(v))
}
len => Err(ParseError::BadIncomingData(format!(
"invalid value length: {}",
len,
))),
len => Err(LowLevelDeserializationError::InvalidValueLength(len)),
}
}

pub fn read_short_bytes<'a>(buf: &mut &'a [u8]) -> Result<&'a [u8], ParseError> {
pub fn read_short_bytes<'a>(buf: &mut &'a [u8]) -> Result<&'a [u8], LowLevelDeserializationError> {
let len = read_short_length(buf)?;
let v = read_raw_bytes(len, buf)?;
Ok(v)
Expand Down Expand Up @@ -324,7 +326,9 @@ pub fn write_short_bytes(v: &[u8], buf: &mut impl BufMut) -> Result<(), std::num
Ok(())
}

pub fn read_bytes_map(buf: &mut &[u8]) -> Result<HashMap<String, Vec<u8>>, ParseError> {
pub fn read_bytes_map(
buf: &mut &[u8],
) -> Result<HashMap<String, Vec<u8>>, LowLevelDeserializationError> {
let len = read_short_length(buf)?;
let mut v = HashMap::with_capacity(len);
for _ in 0..len {
Expand Down Expand Up @@ -362,7 +366,7 @@ fn type_bytes_map() {
assert_eq!(read_bytes_map(&mut &*buf).unwrap(), val);
}

pub fn read_string<'a>(buf: &mut &'a [u8]) -> Result<&'a str, ParseError> {
pub fn read_string<'a>(buf: &mut &'a [u8]) -> Result<&'a str, LowLevelDeserializationError> {
let len = read_short_length(buf)?;
let raw = read_raw_bytes(len, buf)?;
let v = str::from_utf8(raw)?;
Expand All @@ -386,7 +390,7 @@ fn type_string() {
}
}

pub fn read_long_string<'a>(buf: &mut &'a [u8]) -> Result<&'a str, ParseError> {
pub fn read_long_string<'a>(buf: &mut &'a [u8]) -> Result<&'a str, LowLevelDeserializationError> {
let len = read_int_length(buf)?;
let raw = read_raw_bytes(len, buf)?;
let v = str::from_utf8(raw)?;
Expand All @@ -411,7 +415,9 @@ fn type_long_string() {
}
}

pub fn read_string_map(buf: &mut &[u8]) -> Result<HashMap<String, String>, ParseError> {
pub fn read_string_map(
buf: &mut &[u8],
) -> Result<HashMap<String, String>, LowLevelDeserializationError> {
let len = read_short_length(buf)?;
let mut v = HashMap::with_capacity(len);
for _ in 0..len {
Expand Down Expand Up @@ -446,7 +452,7 @@ fn type_string_map() {
assert_eq!(read_string_map(&mut &buf[..]).unwrap(), val);
}

pub fn read_string_list(buf: &mut &[u8]) -> Result<Vec<String>, ParseError> {
pub fn read_string_list(buf: &mut &[u8]) -> Result<Vec<String>, LowLevelDeserializationError> {
let len = read_short_length(buf)?;
let mut v = Vec::with_capacity(len);
for _ in 0..len {
Expand Down Expand Up @@ -480,7 +486,9 @@ fn type_string_list() {
assert_eq!(read_string_list(&mut &buf[..]).unwrap(), val);
}

pub fn read_string_multimap(buf: &mut &[u8]) -> Result<HashMap<String, Vec<String>>, ParseError> {
pub fn read_string_multimap(
buf: &mut &[u8],
) -> Result<HashMap<String, Vec<String>>, LowLevelDeserializationError> {
let len = read_short_length(buf)?;
let mut v = HashMap::with_capacity(len);
for _ in 0..len {
Expand Down Expand Up @@ -518,7 +526,7 @@ fn type_string_multimap() {
assert_eq!(read_string_multimap(&mut &buf[..]).unwrap(), val);
}

pub fn read_uuid(buf: &mut &[u8]) -> Result<Uuid, ParseError> {
pub fn read_uuid(buf: &mut &[u8]) -> Result<Uuid, LowLevelDeserializationError> {
let raw = read_raw_bytes(16, buf)?;

// It's safe to unwrap here because the conversion only fails
Expand All @@ -542,10 +550,10 @@ fn type_uuid() {
assert_eq!(u, u2);
}

pub fn read_consistency(buf: &mut &[u8]) -> Result<Consistency, ParseError> {
pub fn read_consistency(buf: &mut &[u8]) -> Result<Consistency, LowLevelDeserializationError> {
let raw = read_short(buf)?;
Consistency::try_from(raw)
.map_err(|_| ParseError::BadIncomingData(format!("unknown consistency: {}", raw)))
.map_err(|_| LowLevelDeserializationError::UnknownConsistency { raw_value: raw })
}

pub fn write_consistency(c: Consistency, buf: &mut impl BufMut) {
Expand Down Expand Up @@ -575,7 +583,7 @@ fn type_consistency() {
assert!(err_str.contains(&format!("{}", c)));
}

pub fn read_inet(buf: &mut &[u8]) -> Result<SocketAddr, ParseError> {
pub fn read_inet(buf: &mut &[u8]) -> Result<SocketAddr, LowLevelDeserializationError> {
let len = buf.read_u8()?;
let ip_addr = match len {
4 => {
Expand All @@ -588,12 +596,7 @@ pub fn read_inet(buf: &mut &[u8]) -> Result<SocketAddr, ParseError> {
buf.advance(16);
ret
}
v => {
return Err(ParseError::BadIncomingData(format!(
"Invalid inet bytes length: {}",
v
)))
}
v => return Err(LowLevelDeserializationError::InvalidInetLength(v)),
};
let port = read_int(buf)?;

Expand Down
2 changes: 2 additions & 0 deletions scylla-proxy/src/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::frame::{
};
use crate::{RequestOpcode, TargetShard};
use bytes::Bytes;
use scylla_cql::frame::frame_errors::ParseError;
use scylla_cql::frame::types::read_string_multimap;
use std::collections::HashMap;
use std::fmt::Display;
Expand Down Expand Up @@ -803,6 +804,7 @@ impl Doorkeeper {
.map_err(DoorkeeperError::ObtainingShardNumberFrame)?;

let options = read_string_multimap(&mut supported_frame.body.as_ref())
.map_err(ParseError::from)
.map_err(DoorkeeperError::ObtainingShardNumberParseOptions)?;

Ok(options)
Expand Down

0 comments on commit d9b1247

Please sign in to comment.