Skip to content

Commit

Permalink
result/deser_cql_value: use DeserializeValue impls
Browse files Browse the repository at this point in the history
In the future, we will probably deprecate and remove `deser_cql_value`
altogether. For now, let's make it at least less bloaty.

To reduce code duplication, `deser_cql_value()` now uses
DeserializeValue impls for nearly all of the deserialized types.
Two notable exceptions are:
1. CQL Map - because it is represented as Vec<(CqlValue, CqlValue)>
   in CqlValue, and Vec<T> is only deserializable from CQL Set|Map.
   Therefore, MapIterator is deserialized using its DeserializeValue
   impl, and then collected into Vec.
2. CQL Tuple - because it is represented in CqlValue much differently
   than in DeserializeValue impls: Vec<CqlValue> vs (T1, T2, ..., Tn).
   Therefore, it's similarly to how it was before, just style is changed
   from imperative to iterator-based, and DeserializeValue impl
   is called instead of `deser_cql_value` there.

As a bonus, we get more descriptive error messages (as compared to old
`ParseError::BadIncomingData` ones).
  • Loading branch information
wprzytula committed Jun 12, 2024
1 parent fadeb48 commit f7bcf33
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 213 deletions.
299 changes: 88 additions & 211 deletions scylla-cql/src/frame/response/result.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
use crate::cql_to_rust::{FromRow, FromRowError};
use crate::frame::response::event::SchemaChangeEvent;
use crate::frame::types::vint_decode;
use crate::frame::value::{
Counter, CqlDate, CqlDecimal, CqlDuration, CqlTime, CqlTimestamp, CqlTimeuuid, CqlVarint,
};
use crate::frame::{frame_errors::ParseError, types};
use byteorder::{BigEndian, ReadBytesExt};
use crate::types::deserialize::value::{DeserializeValue, MapIterator, UdtIterator};
use crate::types::deserialize::FrameSlice;
use bytes::{Buf, Bytes};
use std::borrow::Cow;
use std::{
convert::{TryFrom, TryInto},
net::IpAddr,
result::Result as StdResult,
str,
};
use std::{convert::TryInto, net::IpAddr, result::Result as StdResult, str};
use uuid::Uuid;

#[cfg(feature = "chrono")]
Expand Down Expand Up @@ -655,6 +650,11 @@ pub fn deser_cql_value(typ: &ColumnType, buf: &mut &[u8]) -> StdResult<CqlValue,
_ => return Ok(CqlValue::Empty),
}
}
// The `new_borrowed` version of FrameSlice is deficient in that it does not hold
// a `Bytes` reference to the frame, only a slice.
// This is not a problem here, fortunately, because none of CqlValue variants contain
// any `Bytes` - only exclusively owned types - so we never call FrameSlice::to_bytes().
let v = Some(FrameSlice::new_borrowed(buf));

Ok(match typ {
Custom(type_str) => {
Expand All @@ -664,239 +664,112 @@ pub fn deser_cql_value(typ: &ColumnType, buf: &mut &[u8]) -> StdResult<CqlValue,
)));
}
Ascii => {
if !buf.is_ascii() {
return Err(ParseError::BadIncomingData(
"String is not ascii!".to_string(),
));
}
CqlValue::Ascii(str::from_utf8(buf)?.to_owned())
let s = String::deserialize(typ, v)?;
CqlValue::Ascii(s)
}
Boolean => {
if buf.len() != 1 {
return Err(ParseError::BadIncomingData(format!(
"Buffer length should be 1 not {}",
buf.len()
)));
}
CqlValue::Boolean(buf[0] != 0x00)
let b = bool::deserialize(typ, v)?;
CqlValue::Boolean(b)
}
Blob => {
let b = Vec::<u8>::deserialize(typ, v)?;
CqlValue::Blob(b)
}
Blob => CqlValue::Blob(buf.to_vec()),
Date => {
if buf.len() != 4 {
return Err(ParseError::BadIncomingData(format!(
"Buffer length should be 4 not {}",
buf.len()
)));
}

let date_value = buf.read_u32::<BigEndian>()?;
CqlValue::Date(CqlDate(date_value))
let d = CqlDate::deserialize(typ, v)?;
CqlValue::Date(d)
}
Counter => {
if buf.len() != 8 {
return Err(ParseError::BadIncomingData(format!(
"Buffer length should be 8 not {}",
buf.len()
)));
}
CqlValue::Counter(crate::frame::value::Counter(buf.read_i64::<BigEndian>()?))
let c = crate::frame::response::result::Counter::deserialize(typ, v)?;
CqlValue::Counter(c)
}
Decimal => {
let scale = types::read_int(buf)?;
let bytes = buf.to_vec();
let big_decimal: CqlDecimal =
CqlDecimal::from_signed_be_bytes_and_exponent(bytes, scale);

CqlValue::Decimal(big_decimal)
let d = CqlDecimal::deserialize(typ, v)?;
CqlValue::Decimal(d)
}
Double => {
if buf.len() != 8 {
return Err(ParseError::BadIncomingData(format!(
"Buffer length should be 8 not {}",
buf.len()
)));
}
CqlValue::Double(buf.read_f64::<BigEndian>()?)
let d = f64::deserialize(typ, v)?;
CqlValue::Double(d)
}
Float => {
if buf.len() != 4 {
return Err(ParseError::BadIncomingData(format!(
"Buffer length should be 4 not {}",
buf.len()
)));
}
CqlValue::Float(buf.read_f32::<BigEndian>()?)
let f = f32::deserialize(typ, v)?;
CqlValue::Float(f)
}
Int => {
if buf.len() != 4 {
return Err(ParseError::BadIncomingData(format!(
"Buffer length should be 4 not {}",
buf.len()
)));
}
CqlValue::Int(buf.read_i32::<BigEndian>()?)
let i = i32::deserialize(typ, v)?;
CqlValue::Int(i)
}
SmallInt => {
if buf.len() != 2 {
return Err(ParseError::BadIncomingData(format!(
"Buffer length should be 2 not {}",
buf.len()
)));
}

CqlValue::SmallInt(buf.read_i16::<BigEndian>()?)
let si = i16::deserialize(typ, v)?;
CqlValue::SmallInt(si)
}
TinyInt => {
if buf.len() != 1 {
return Err(ParseError::BadIncomingData(format!(
"Buffer length should be 1 not {}",
buf.len()
)));
}
CqlValue::TinyInt(buf.read_i8()?)
let ti = i8::deserialize(typ, v)?;
CqlValue::TinyInt(ti)
}
BigInt => {
if buf.len() != 8 {
return Err(ParseError::BadIncomingData(format!(
"Buffer length should be 8 not {}",
buf.len()
)));
}
CqlValue::BigInt(buf.read_i64::<BigEndian>()?)
let bi = i64::deserialize(typ, v)?;
CqlValue::BigInt(bi)
}
Text => {
let s = String::deserialize(typ, v)?;
CqlValue::Text(s)
}
Text => CqlValue::Text(str::from_utf8(buf)?.to_owned()),
Timestamp => {
if buf.len() != 8 {
return Err(ParseError::BadIncomingData(format!(
"Buffer length should be 8 not {}",
buf.len()
)));
}
let millis = buf.read_i64::<BigEndian>()?;

CqlValue::Timestamp(CqlTimestamp(millis))
let t = CqlTimestamp::deserialize(typ, v)?;
CqlValue::Timestamp(t)
}
Time => {
if buf.len() != 8 {
return Err(ParseError::BadIncomingData(format!(
"Buffer length should be 8 not {}",
buf.len()
)));
}
let nanoseconds: i64 = buf.read_i64::<BigEndian>()?;

// Valid values are in the range 0 to 86399999999999
if !(0..=86399999999999).contains(&nanoseconds) {
return Err(ParseError::BadIncomingData(format! {
"Invalid time value only 0 to 86399999999999 allowed: {}.", nanoseconds
}));
}

CqlValue::Time(CqlTime(nanoseconds))
let t = CqlTime::deserialize(typ, v)?;
CqlValue::Time(t)
}
Timeuuid => {
if buf.len() != 16 {
return Err(ParseError::BadIncomingData(format!(
"Buffer length should be 16 not {}",
buf.len()
)));
}
let uuid = uuid::Uuid::from_slice(buf).expect("Deserializing Uuid failed.");
CqlValue::Timeuuid(CqlTimeuuid::from(uuid))
let t = CqlTimeuuid::deserialize(typ, v)?;
CqlValue::Timeuuid(t)
}
Duration => {
let months = i32::try_from(vint_decode(buf)?)?;
let days = i32::try_from(vint_decode(buf)?)?;
let nanoseconds = vint_decode(buf)?;

CqlValue::Duration(CqlDuration {
months,
days,
nanoseconds,
})
let d = CqlDuration::deserialize(typ, v)?;
CqlValue::Duration(d)
}
Inet => {
let i = IpAddr::deserialize(typ, v)?;
CqlValue::Inet(i)
}
Inet => CqlValue::Inet(match buf.len() {
4 => {
let ret = IpAddr::from(<[u8; 4]>::try_from(&buf[0..4])?);
buf.advance(4);
ret
}
16 => {
let ret = IpAddr::from(<[u8; 16]>::try_from(&buf[0..16])?);
buf.advance(16);
ret
}
v => {
return Err(ParseError::BadIncomingData(format!(
"Invalid inet bytes length: {}",
v
)));
}
}),
Uuid => {
if buf.len() != 16 {
return Err(ParseError::BadIncomingData(format!(
"Buffer length should be 16 not {}",
buf.len()
)));
}
let uuid = uuid::Uuid::from_slice(buf).expect("Deserializing Uuid failed.");
let uuid = uuid::Uuid::deserialize(typ, v)?;
CqlValue::Uuid(uuid)
}
Varint => CqlValue::Varint(CqlVarint::from_signed_bytes_be(buf.to_vec())),
List(type_name) => {
let len: usize = types::read_int(buf)?.try_into()?;
let mut res = Vec::with_capacity(len);
for _ in 0..len {
let mut b = types::read_bytes(buf)?;
res.push(deser_cql_value(type_name, &mut b)?);
}
CqlValue::List(res)
Varint => {
let vi = CqlVarint::deserialize(typ, v)?;
CqlValue::Varint(vi)
}
Map(key_type, value_type) => {
let len: usize = types::read_int(buf)?.try_into()?;
let mut res = Vec::with_capacity(len);
for _ in 0..len {
let mut b = types::read_bytes(buf)?;
let key = deser_cql_value(key_type, &mut b)?;
b = types::read_bytes(buf)?;
let val = deser_cql_value(value_type, &mut b)?;
res.push((key, val));
}
CqlValue::Map(res)
List(_type_name) => {
let l = Vec::<CqlValue>::deserialize(typ, v)?;
CqlValue::List(l)
}
Set(type_name) => {
let len: usize = types::read_int(buf)?.try_into()?;
let mut res = Vec::with_capacity(len);
for _ in 0..len {
// TODO: is `null` allowed as set element? Should we use read_bytes_opt?
let mut b = types::read_bytes(buf)?;
res.push(deser_cql_value(type_name, &mut b)?);
}
CqlValue::Set(res)
Map(_key_type, _value_type) => {
let iter = MapIterator::<'_, CqlValue, CqlValue>::deserialize(typ, v)?;
let m: Vec<(CqlValue, CqlValue)> = iter.collect::<StdResult<_, _>>()?;
CqlValue::Map(m)
}
Set(_type_name) => {
let s = Vec::<CqlValue>::deserialize(typ, v)?;
CqlValue::Set(s)
}
UserDefinedType {
type_name,
keyspace,
field_types,
..
} => {
let mut fields: Vec<(String, Option<CqlValue>)> = Vec::new();

for (field_name, field_type) in field_types {
// If a field is added to a UDT and we read an old (frozen ?) version of it,
// the driver will fail to parse the whole UDT.
// This is why we break the parsing after we reach the end of the serialized UDT.
if buf.is_empty() {
break;
}

let mut field_value: Option<CqlValue> = None;
if let Some(mut field_val_bytes) = types::read_bytes_opt(buf)? {
field_value = Some(deser_cql_value(field_type, &mut field_val_bytes)?);
}

fields.push((field_name.clone(), field_value));
}
let iter = UdtIterator::deserialize(typ, v)?;
let fields: Vec<(String, Option<CqlValue>)> = iter
.map(|res| {
res.and_then(|((col_name, col_type), v)| {
let val = Option::<CqlValue>::deserialize(col_type, v.flatten())?;
Ok((col_name.clone(), val))
})
})
.collect::<StdResult<_, _>>()?;

CqlValue::UserDefinedType {
keyspace: keyspace.clone(),
Expand All @@ -905,15 +778,19 @@ pub fn deser_cql_value(typ: &ColumnType, buf: &mut &[u8]) -> StdResult<CqlValue,
}
}
Tuple(type_names) => {
let mut res = Vec::with_capacity(type_names.len());
for type_name in type_names {
match types::read_bytes_opt(buf)? {
Some(mut b) => res.push(Some(deser_cql_value(type_name, &mut b)?)),
None => res.push(None),
};
}

CqlValue::Tuple(res)
let t = type_names
.iter()
.map(|typ| {
types::read_bytes_opt(buf).and_then(|v| {
v.map(|v| {
CqlValue::deserialize(typ, Some(FrameSlice::new_borrowed(v)))
.map_err(Into::into)
})
.transpose()
})
})
.collect::<StdResult<_, _>>()?;
CqlValue::Tuple(t)
}
})
}
Expand Down
Loading

0 comments on commit f7bcf33

Please sign in to comment.