diff --git a/scylla-cql/src/frame/frame_errors.rs b/scylla-cql/src/frame/frame_errors.rs index 491ebda82a..0d7a50f083 100644 --- a/scylla-cql/src/frame/frame_errors.rs +++ b/scylla-cql/src/frame/frame_errors.rs @@ -47,6 +47,8 @@ pub enum ParseError { IoError(#[from] std::io::Error), #[error("type not yet implemented, id: {0}")] TypeNotImplemented(u16), + #[error("invalid custom type: {0}")] + InvalidCustomType(String), #[error(transparent)] SerializeValuesError(#[from] SerializeValuesError), #[error(transparent)] diff --git a/scylla-cql/src/frame/response/result.rs b/scylla-cql/src/frame/response/result.rs index 87bbaaeaa6..de373dc81c 100644 --- a/scylla-cql/src/frame/response/result.rs +++ b/scylla-cql/src/frame/response/result.rs @@ -7,11 +7,14 @@ use crate::frame::{frame_errors::ParseError, types}; use crate::types::deserialize::result::{RowIterator, TypedRowIterator}; use crate::types::deserialize::value::{ mk_deser_err, BuiltinDeserializationErrorKind, DeserializeValue, MapIterator, UdtIterator, + VectorIterator, }; use crate::types::deserialize::{DeserializationError, FrameSlice}; use bytes::{Buf, Bytes}; use std::borrow::Cow; -use std::{convert::TryInto, net::IpAddr, result::Result as StdResult, str}; +use std::{convert::TryInto, mem, net::IpAddr, result::Result as StdResult, str}; +use std::mem::ManuallyDrop; +use std::ops::Deref; use uuid::Uuid; #[derive(Debug)] @@ -69,6 +72,7 @@ pub enum ColumnType { Tuple(Vec), Uuid, Varint, + Vector(Box, u32), } #[derive(Clone, Debug, PartialEq)] @@ -94,6 +98,7 @@ pub enum CqlValue { List(Vec), Map(Vec<(CqlValue, CqlValue)>), Set(Vec), + Vector(DropOptimizedVec), UserDefinedType { keyspace: String, type_name: String, @@ -112,6 +117,52 @@ pub enum CqlValue { Varint(CqlVarint), } + +#[derive(Clone, Debug, PartialEq)] +pub struct DropOptimizedVec { + data: Vec, + drop_elements: bool +} + +impl DropOptimizedVec { + pub fn new(data: Vec, drop_elements: bool) -> DropOptimizedVec { + DropOptimizedVec { + data, + drop_elements, + } + } + + pub fn dropping(data: Vec) -> DropOptimizedVec { + Self::new(data, true) + } + + pub fn non_dropping(data: Vec) -> DropOptimizedVec { + Self::new(data, false) + } + + pub fn into_vec(mut self) -> Vec { + let mut vec = vec![]; + mem::swap(&mut self.data, &mut vec); + vec + } +} + +impl Deref for DropOptimizedVec { + type Target = Vec; + + fn deref(&self) -> &Self::Target { + &self.data + } +} + +impl Drop for DropOptimizedVec { + fn drop(&mut self) { + if !self.drop_elements { + unsafe { self.data.set_len(0); } + } + } +} + impl<'a> TableSpec<'a> { pub const fn borrowed(ks: &'a str, table: &'a str) -> Self { Self { @@ -352,6 +403,7 @@ impl CqlValue { pub fn as_list(&self) -> Option<&Vec> { match self { Self::List(s) => Some(s), + Self::Vector(s) => Some(&s), _ => None, } } @@ -381,6 +433,7 @@ impl CqlValue { match self { Self::List(s) => Some(s), Self::Set(s) => Some(s), + Self::Vector(s) => Some(s.into_vec()), _ => None, } } @@ -489,8 +542,19 @@ fn deser_type(buf: &mut &[u8]) -> StdResult { Ok(match id { 0x0000 => { let type_str: String = types::read_string(buf)?.to_string(); - match type_str.as_str() { + let type_parts: Vec<_> = type_str.split(&[',', '(', ')']).collect(); + match type_parts[0] { "org.apache.cassandra.db.marshal.DurationType" => Duration, + "org.apache.cassandra.db.marshal.VectorType" => { + if type_parts.len() < 3 { + return Err(ParseError::InvalidCustomType(type_str)); + } + let elem_type = parse_type_str(type_parts[1].trim())?; + let Ok(dimensions) = type_parts[2].trim().parse() else { + return Err(ParseError::InvalidCustomType(type_str)); + }; + Vector(Box::new(elem_type), dimensions) + } _ => Custom(type_str), } } @@ -552,6 +616,18 @@ fn deser_type(buf: &mut &[u8]) -> StdResult { }) } +fn parse_type_str(name: &str) -> StdResult { + match name { + "org.apache.cassandra.db.marshal.BigIntType" => Ok(ColumnType::BigInt), + "org.apache.cassandra.db.marshal.DoubleType" => Ok(ColumnType::Double), + "org.apache.cassandra.db.marshal.FloatType" => Ok(ColumnType::Float), + "org.apache.cassandra.db.marshal.IntType" => Ok(ColumnType::Int), + "org.apache.cassandra.db.marshal.SmallIntType" => Ok(ColumnType::SmallInt), + "org.apache.cassandra.db.marshal.TinyIntType" => Ok(ColumnType::TinyInt), + _ => Err(ParseError::InvalidCustomType(name.to_string())), + } +} + fn deser_col_specs( buf: &mut &[u8], global_table_spec: &Option>, @@ -802,6 +878,18 @@ pub fn deser_cql_value( .collect::>()?; CqlValue::Tuple(t) } + // Specialization for faster deserialization of vectors of floats, which are currently + // the only type of vector + Vector(elem_type, _) if matches!(elem_type.as_ref(), Float) => { + let v = VectorIterator::::deserialize_vector_of_float_to_vec_of_cql_value( + typ, v, + )?; + CqlValue::Vector(DropOptimizedVec::non_dropping(v)) + } + Vector(_, _) => { + let v = Vec::::deserialize(typ, v)?; + CqlValue::Vector(DropOptimizedVec::dropping(v)) + } }) } diff --git a/scylla-cql/src/frame/types.rs b/scylla-cql/src/frame/types.rs index 0cc791be31..f59a35156b 100644 --- a/scylla-cql/src/frame/types.rs +++ b/scylla-cql/src/frame/types.rs @@ -161,7 +161,7 @@ impl<'a> RawValue<'a> { } } -fn read_raw_bytes<'a>( +pub fn read_raw_bytes<'a>( count: usize, buf: &mut &'a [u8], ) -> Result<&'a [u8], LowLevelDeserializationError> { diff --git a/scylla-cql/src/frame/value.rs b/scylla-cql/src/frame/value.rs index a9c368d195..3f7c6ee764 100644 --- a/scylla-cql/src/frame/value.rs +++ b/scylla-cql/src/frame/value.rs @@ -1414,6 +1414,7 @@ impl Value for CqlValue { CqlValue::Ascii(s) | CqlValue::Text(s) => s.serialize(buf), CqlValue::List(v) | CqlValue::Set(v) => v.serialize(buf), + CqlValue::Vector(v) => v.serialize(buf), CqlValue::Blob(b) => b.serialize(buf), CqlValue::Boolean(b) => b.serialize(buf), diff --git a/scylla-cql/src/lib.rs b/scylla-cql/src/lib.rs index 61d9f345c5..071c454589 100644 --- a/scylla-cql/src/lib.rs +++ b/scylla-cql/src/lib.rs @@ -29,7 +29,7 @@ pub mod _macro_internal { pub use crate::frame::response::cql_to_rust::{ FromCqlVal, FromCqlValError, FromRow, FromRowError, }; - pub use crate::frame::response::result::{ColumnSpec, ColumnType, CqlValue, Row}; + pub use crate::frame::response::result::{ColumnSpec, ColumnType, CqlValue, DropOptimizedVec, Row}; pub use crate::frame::value::{ LegacySerializedValues, SerializedResult, Value, ValueList, ValueTooBig, }; diff --git a/scylla-cql/src/types/deserialize/frame_slice.rs b/scylla-cql/src/types/deserialize/frame_slice.rs index 4471960a03..b628ea0225 100644 --- a/scylla-cql/src/types/deserialize/frame_slice.rs +++ b/scylla-cql/src/types/deserialize/frame_slice.rs @@ -155,6 +155,35 @@ impl<'frame> FrameSlice<'frame> { original_frame: self.original_frame, })) } + + /// Reads and consumes a fixed number of bytes item from the beginning of the frame, + /// returning a subslice that encompasses that item. + /// + /// If this slice is empty, returns `Ok(None)`. + /// Otherwise, if the slice does not contain enough data, it returns `Err`. + /// If the operation fails then the slice remains unchanged. + #[inline] + pub(super) fn read_subslice( + &mut self, + count: usize, + ) -> Result>, LowLevelDeserializationError> { + if self.is_empty() { + return Ok(None); + } + + // We copy the slice reference, not to mutate the FrameSlice in case of an error. + let mut slice = self.frame_subslice; + + let cql_bytes = types::read_raw_bytes(count, &mut slice)?; + + // `read_raw_bytes` hasn't failed, so now we must update the FrameSlice. + self.frame_subslice = slice; + + Ok(Some(Self { + frame_subslice: cql_bytes, + original_frame: self.original_frame, + })) + } } #[cfg(test)] diff --git a/scylla-cql/src/types/deserialize/value.rs b/scylla-cql/src/types/deserialize/value.rs index 074a7c298a..f714142d2c 100644 --- a/scylla-cql/src/types/deserialize/value.rs +++ b/scylla-cql/src/types/deserialize/value.rs @@ -1,15 +1,10 @@ //! Provides types for dealing with CQL value deserialization. -use std::{ - collections::{BTreeMap, BTreeSet, HashMap, HashSet}, - hash::{BuildHasher, Hash}, - net::IpAddr, -}; - use bytes::Bytes; +use std::{collections::{BTreeMap, BTreeSet, HashMap, HashSet}, hash::{BuildHasher, Hash}, net::IpAddr}; use uuid::Uuid; -use std::fmt::Display; +use std::fmt::{Display, Pointer}; use thiserror::Error; @@ -747,23 +742,200 @@ where } } +// vectors + +/// An iterator over either a CQL vector +pub struct VectorIterator<'frame, T> { + coll_typ: &'frame ColumnType, + elem_typ: &'frame ColumnType, + count: usize, + raw_iter: VectorBytesSequenceIterator<'frame>, + phantom_data: std::marker::PhantomData, +} + +impl<'frame, T> VectorIterator<'frame, T> { + pub fn new( + coll_typ: &'frame ColumnType, + elem_typ: &'frame ColumnType, + count: usize, + elem_len: usize, + slice: FrameSlice<'frame>, + ) -> Self { + Self { + coll_typ, + elem_typ, + count, + raw_iter: VectorBytesSequenceIterator::new(count, elem_len, slice), + phantom_data: std::marker::PhantomData, + } + } + + /// Faster specialization for deserializing a `vector` into `Vec`. + /// The generic code `Vec::deserialize(...)` is much slower because it has to + /// match on the element type for every item in the vector. + /// Here we just hardcode `f32` and we can shortcut a lot of code. + /// + /// This could be nicer if Rust had generic type specialization in stable, + /// but for now we need a separate method. + pub fn deserialize_vector_of_float_to_vec_of_cql_value( + typ: &'frame ColumnType, + v: Option>, + ) -> Result, DeserializationError> { + + // Typecheck would make sure those never happen: + let ColumnType::Vector(elem_type, elem_count) = typ else { + panic!("Wrong column type: {:?}. Expected vector<>", typ); + }; + if !matches!(elem_type.as_ref(), ColumnType::Float) { + panic!("Wrong element type: {:?}. Expected float", typ); + } + + let elem_count = *elem_count as usize; + let mut frame = v.map(|s| s.as_slice()).unwrap_or_default(); + let mut result = Vec::with_capacity(elem_count); + + unsafe { + type Float = f32; + + // Check length only once + if frame.len() < size_of::() * elem_count { + return Err(mk_deser_err::>( + typ, + BuiltinDeserializationErrorKind::RawCqlBytesReadError( + LowLevelDeserializationError::TooFewBytesReceived { + expected: size_of::() * elem_count, + received: frame.len(), + }, + ), + )); + } + // We know we have enough elements in the buffer, so now we can skip the checks + for _ in 0..elem_count { + // we did check for frame length earlier, so we can safely not check again + let (elem, remaining) = frame.split_at_unchecked(size_of::()); + let elem = Float::from_be_bytes(*elem.as_ptr().cast()); + result.push(CqlValue::Float(elem)); + frame = remaining; + } + } + Ok(result) + } +} + +impl<'frame, T> DeserializeValue<'frame> for VectorIterator<'frame, T> +where + T: DeserializeValue<'frame>, +{ + fn type_check(typ: &ColumnType) -> Result<(), TypeCheckError> { + match typ { + ColumnType::Vector(el_t, _) => { + if !matches!(el_t.as_ref(), ColumnType::Float) { + return Err(mk_typck_err::( + typ, + BuiltinTypeCheckErrorKind::VectorError( + VectorTypeCheckErrorKind::UnsupportedElementType, + ), + )); + } + >::type_check(el_t).map_err(|err| { + mk_typck_err::(typ, VectorTypeCheckErrorKind::ElementTypeCheckFailed(err)) + }) + } + _ => Err(mk_typck_err::( + typ, + BuiltinTypeCheckErrorKind::VectorError(VectorTypeCheckErrorKind::NotVector), + )), + } + } + + fn deserialize( + typ: &'frame ColumnType, + v: Option>, + ) -> Result { + let v = ensure_not_null_frame_slice::(typ, v)?; + let ColumnType::Vector(elem_typ, dimensions) = typ else { + unreachable!("Typecheck should have prevented this scenario!") + }; + // TODO: This should be better associated with ColumnType + let size_of_cql_float = 4; + Ok(Self::new( + typ, + elem_typ, + *dimensions as usize, + size_of_cql_float, + v, + )) + } +} + +impl<'frame, T> Iterator for VectorIterator<'frame, T> +where + T: DeserializeValue<'frame>, +{ + type Item = Result; + + #[inline] + fn next(&mut self) -> Option { + let raw = self.raw_iter.next()?.map_err(|err| { + mk_deser_err::( + self.coll_typ, + BuiltinDeserializationErrorKind::RawCqlBytesReadError(err), + ) + }); + Some(raw.and_then(|raw| { + T::deserialize(self.elem_typ, raw).map_err(|err| { + mk_deser_err::( + self.coll_typ, + VectorDeserializationErrorKind::ElementDeserializationFailed(err), + ) + }) + })) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + (self.count, Some(self.count)) + } +} + impl<'frame, T> DeserializeValue<'frame> for Vec where T: DeserializeValue<'frame>, { fn type_check(typ: &ColumnType) -> Result<(), TypeCheckError> { - // It makes sense for both Set and List to deserialize to Vec. - ListlikeIterator::<'frame, T>::type_check(typ) - .map_err(typck_error_replace_rust_name::) + // It makes sense for both Set, List and Vector to deserialize to Vec. + match typ { + ColumnType::List(_) | ColumnType::Set(_) => { + ListlikeIterator::<'frame, T>::type_check(typ) + .map_err(typck_error_replace_rust_name::) + } + ColumnType::Vector(_, _) => VectorIterator::<'frame, T>::type_check(typ) + .map_err(typck_error_replace_rust_name::), + _ => Err(mk_typck_err::( + typ, + BuiltinTypeCheckErrorKind::SetOrListError( + SetOrListTypeCheckErrorKind::NotSetOrList, + ), + )), + } } fn deserialize( typ: &'frame ColumnType, v: Option>, ) -> Result { - ListlikeIterator::<'frame, T>::deserialize(typ, v) - .and_then(|it| it.collect::>()) - .map_err(deser_error_replace_rust_name::) + match typ { + ColumnType::List(_) | ColumnType::Set(_) => { + ListlikeIterator::<'frame, T>::deserialize(typ, v) + .and_then(|it| it.collect::>()) + .map_err(deser_error_replace_rust_name::) + } + + ColumnType::Vector(_, _) => VectorIterator::<'frame, T>::deserialize(typ, v) + .and_then(|it| it.collect::>()) + .map_err(deser_error_replace_rust_name::), + _ => unreachable!("Should be prevented by typecheck"), + } } } @@ -1286,6 +1458,38 @@ impl<'frame> Iterator for FixedLengthBytesSequenceIterator<'frame> { } } +/// Iterates over a sequence of CQL vector items from a frame subslice, expecting +/// a particular number of items. +/// +/// The iterator does not consider it to be an error if there are some bytes +/// remaining in the slice after parsing requested amount of items. +#[derive(Clone, Copy, Debug)] +pub struct VectorBytesSequenceIterator<'frame> { + slice: FrameSlice<'frame>, + elem_len: usize, + remaining: usize, +} + +impl<'frame> VectorBytesSequenceIterator<'frame> { + fn new(count: usize, elem_len: usize, slice: FrameSlice<'frame>) -> Self { + Self { + slice, + elem_len, + remaining: count, + } + } +} + +impl<'frame> Iterator for VectorBytesSequenceIterator<'frame> { + type Item = Result>, LowLevelDeserializationError>; + + #[inline] + fn next(&mut self) -> Option { + self.remaining = self.remaining.checked_sub(1)?; + Some(self.slice.read_subslice(self.elem_len)) + } +} + /// Iterates over a sequence of `[bytes]` items from a frame subslice. /// /// The `[bytes]` items are parsed until the end of subslice is reached. @@ -1384,6 +1588,9 @@ pub enum BuiltinTypeCheckErrorKind { /// A type check failure specific to a CQL set or list. SetOrListError(SetOrListTypeCheckErrorKind), + /// A type check failure specific to a CQL vector. + VectorError(VectorTypeCheckErrorKind), + /// A type check failure specific to a CQL map. MapError(MapTypeCheckErrorKind), @@ -1401,6 +1608,13 @@ impl From for BuiltinTypeCheckErrorKind { } } +impl From for BuiltinTypeCheckErrorKind { + #[inline] + fn from(value: VectorTypeCheckErrorKind) -> Self { + BuiltinTypeCheckErrorKind::VectorError(value) + } +} + impl From for BuiltinTypeCheckErrorKind { #[inline] fn from(value: MapTypeCheckErrorKind) -> Self { @@ -1429,6 +1643,7 @@ impl Display for BuiltinTypeCheckErrorKind { write!(f, "expected one of the CQL types: {expected:?}") } BuiltinTypeCheckErrorKind::SetOrListError(err) => err.fmt(f), + BuiltinTypeCheckErrorKind::VectorError(err) => err.fmt(f), BuiltinTypeCheckErrorKind::MapError(err) => err.fmt(f), BuiltinTypeCheckErrorKind::TupleError(err) => err.fmt(f), BuiltinTypeCheckErrorKind::UdtError(err) => err.fmt(f), @@ -1452,7 +1667,7 @@ impl Display for SetOrListTypeCheckErrorKind { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { SetOrListTypeCheckErrorKind::NotSetOrList => { - f.write_str("the CQL type the Rust type was attempted to be type checked against was neither a set nor a list") + f.write_str("the CQL type the Rust type was attempted to be type checked against was neither a set, list nor a vector") } SetOrListTypeCheckErrorKind::NotSet => { f.write_str("the CQL type the Rust type was attempted to be type checked against was not a set") @@ -1464,6 +1679,33 @@ impl Display for SetOrListTypeCheckErrorKind { } } +/// Describes why type checking of a set or list type failed. +#[derive(Debug, Clone)] +#[non_exhaustive] +pub enum VectorTypeCheckErrorKind { + NotVector, + /// Element type not a CQL float + UnsupportedElementType, + /// Incompatible element types. + ElementTypeCheckFailed(TypeCheckError), +} + +impl Display for VectorTypeCheckErrorKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + VectorTypeCheckErrorKind::NotVector => { + f.write_str("the CQL type the Rust type was attempted to be type checked against was not a vector") + } + VectorTypeCheckErrorKind::UnsupportedElementType => { + f.write_str("only float elements are supported in CQL vectors") + } + VectorTypeCheckErrorKind::ElementTypeCheckFailed(err) => { + write!(f, "the vector element types between the CQL type and the Rust type failed to type check against each other: {}", err) + } + } + } +} + /// Describes why type checking of a map type failed. #[derive(Debug, Clone)] #[non_exhaustive] @@ -1695,7 +1937,10 @@ pub enum BuiltinDeserializationErrorKind { ExpectedNonNull, /// The length of read value in bytes is different than expected for the Rust type. - ByteLengthMismatch { expected: usize, got: usize }, + ByteLengthMismatch { + expected: usize, + got: usize, + }, /// Expected valid ASCII string. ExpectedAscii, @@ -1708,11 +1953,15 @@ pub enum BuiltinDeserializationErrorKind { ValueOverflow, /// The length of read value in bytes is not suitable for IP address. - BadInetLength { got: usize }, + BadInetLength { + got: usize, + }, /// A deserialization failure specific to a CQL set or list. SetOrListError(SetOrListDeserializationErrorKind), + VectorError(VectorDeserializationErrorKind), + /// A deserialization failure specific to a CQL map. MapError(MapDeserializationErrorKind), @@ -1751,6 +2000,7 @@ impl Display for BuiltinDeserializationErrorKind { "the length of read value in bytes ({got}) is not suitable for IP address; expected 4 or 16" ), BuiltinDeserializationErrorKind::SetOrListError(err) => err.fmt(f), + BuiltinDeserializationErrorKind::VectorError(err) => err.fmt(f), BuiltinDeserializationErrorKind::MapError(err) => err.fmt(f), BuiltinDeserializationErrorKind::TupleError(err) => err.fmt(f), BuiltinDeserializationErrorKind::UdtError(err) => err.fmt(f), @@ -1790,6 +2040,21 @@ impl From for BuiltinDeserializationErrorKind } } +/// Describes why deserialization of a set or list type failed. +#[derive(Debug)] +#[non_exhaustive] +pub enum VectorDeserializationErrorKind { + /// One of the elements of the vector failed to deserialize. + ElementDeserializationFailed(DeserializationError), +} + +impl From for BuiltinDeserializationErrorKind { + #[inline] + fn from(err: VectorDeserializationErrorKind) -> Self { + Self::VectorError(err) + } +} + /// Describes why deserialization of a map type failed. #[derive(Debug)] #[non_exhaustive] diff --git a/scylla-cql/src/types/serialize/value.rs b/scylla-cql/src/types/serialize/value.rs index 2a7040b789..2a8de41084 100644 --- a/scylla-cql/src/types/serialize/value.rs +++ b/scylla-cql/src/types/serialize/value.rs @@ -4,6 +4,7 @@ use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; use std::fmt::Display; use std::hash::BuildHasher; use std::net::IpAddr; +use std::ops::Deref; use std::sync::Arc; use thiserror::Error; @@ -447,15 +448,30 @@ impl SerializeValue for Vec { typ: &ColumnType, writer: CellWriter<'b>, ) -> Result, SerializationError> { - serialize_sequence( - std::any::type_name::(), - self.len(), - self.iter(), - typ, - writer, - ) + match typ { + ColumnType::List(_) | ColumnType::Set(_) => serialize_sequence( + std::any::type_name::(), + self.len(), + self.iter(), + typ, + writer, + ), + ColumnType::Vector(_, _) => serialize_vector( + std::any::type_name::(), + self.len(), + self.iter(), + typ, + writer, + ), + _ => Err(mk_typck_err_named( + std::any::type_name::(), + typ, + SetOrListTypeCheckErrorKind::NotSetOrList, + )), + } } } + impl<'a, T: SerializeValue + 'a> SerializeValue for &'a [T] { fn serialize<'b>( &self, @@ -561,6 +577,7 @@ fn serialize_cql_value<'b>( } CqlValue::Uuid(u) => <_ as SerializeValue>::serialize(&u, typ, writer), CqlValue::Varint(v) => <_ as SerializeValue>::serialize(&v, typ, writer), + CqlValue::Vector(v) => <_ as SerializeValue>::serialize(v.deref(), typ, writer), } } @@ -831,6 +848,49 @@ fn serialize_sequence<'t, 'b, T: SerializeValue + 't>( .map_err(|_| mk_ser_err_named(rust_name, typ, BuiltinSerializationErrorKind::SizeOverflow)) } +fn serialize_vector<'t, 'b, T: SerializeValue + 't>( + rust_name: &'static str, + len: usize, + iter: impl Iterator, + typ: &ColumnType, + writer: CellWriter<'b>, +) -> Result, SerializationError> { + let ColumnType::Vector(elt, dim) = typ else { + unreachable!("serialize_vector can be only called for vectors") + }; + + if !matches!(elt.as_ref(), ColumnType::Float) { + return Err(mk_typck_err::( + typ, + BuiltinTypeCheckErrorKind::VectorError( + VectorTypeCheckErrorKind::UnsupportedElementType(elt.as_ref().clone()), + ), + )); + } + if len != *dim as usize { + return Err(mk_ser_err_named( + rust_name, + typ, + VectorSerializationErrorKind::InvalidNumberOfElements(len, *dim), + )); + } + + let mut builder = writer.into_fixed_len_value_builder(4); + for el in iter { + T::serialize(el, elt, builder.make_sub_writer()).map_err(|err| { + mk_ser_err_named( + rust_name, + typ, + VectorSerializationErrorKind::ElementSerializationFailed(err), + ) + })?; + } + + builder + .finish() + .map_err(|_| mk_ser_err_named(rust_name, typ, BuiltinSerializationErrorKind::SizeOverflow)) +} + fn serialize_mapping<'t, 'b, K: SerializeValue + 't, V: SerializeValue + 't>( rust_name: &'static str, len: usize, @@ -1099,6 +1159,9 @@ pub enum BuiltinTypeCheckErrorKind { /// A type check failure specific to a CQL set or list. SetOrListError(SetOrListTypeCheckErrorKind), + /// A type check failure specific to a CQL vector. + VectorError(VectorTypeCheckErrorKind), + /// A type check failure specific to a CQL map. MapError(MapTypeCheckErrorKind), @@ -1147,6 +1210,7 @@ impl Display for BuiltinTypeCheckErrorKind { f.write_str("the separate empty representation is not valid for this type") } BuiltinTypeCheckErrorKind::SetOrListError(err) => err.fmt(f), + BuiltinTypeCheckErrorKind::VectorError(err) => err.fmt(f), BuiltinTypeCheckErrorKind::MapError(err) => err.fmt(f), BuiltinTypeCheckErrorKind::TupleError(err) => err.fmt(f), BuiltinTypeCheckErrorKind::UdtError(err) => err.fmt(f), @@ -1171,6 +1235,9 @@ pub enum BuiltinSerializationErrorKind { /// A serialization failure specific to a CQL set or list. SetOrListError(SetOrListSerializationErrorKind), + /// A serialization failure specific to a CQL set or list. + VectorError(VectorSerializationErrorKind), + /// A serialization failure specific to a CQL map. MapError(MapSerializationErrorKind), @@ -1187,6 +1254,12 @@ impl From for BuiltinSerializationErrorKind { } } +impl From for BuiltinSerializationErrorKind { + fn from(value: VectorSerializationErrorKind) -> Self { + BuiltinSerializationErrorKind::VectorError(value) + } +} + impl From for BuiltinSerializationErrorKind { fn from(value: MapSerializationErrorKind) -> Self { BuiltinSerializationErrorKind::MapError(value) @@ -1215,6 +1288,7 @@ impl Display for BuiltinSerializationErrorKind { f.write_str("the Rust value is out of range supported by the CQL type") } BuiltinSerializationErrorKind::SetOrListError(err) => err.fmt(f), + BuiltinSerializationErrorKind::VectorError(err) => err.fmt(f), BuiltinSerializationErrorKind::MapError(err) => err.fmt(f), BuiltinSerializationErrorKind::TupleError(err) => err.fmt(f), BuiltinSerializationErrorKind::UdtError(err) => err.fmt(f), @@ -1282,7 +1356,25 @@ impl Display for SetOrListTypeCheckErrorKind { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { SetOrListTypeCheckErrorKind::NotSetOrList => { - f.write_str("the CQL type the Rust type was attempted to be type checked against was neither a set or a list") + f.write_str("the CQL type the Rust type was attempted to be type checked against was neither a set, list or a vector") + } + } + } +} + +/// Describes why type checking of a vector type failed. +#[derive(Debug, Clone)] +#[non_exhaustive] +pub enum VectorTypeCheckErrorKind { + /// Element type of the vector is not supported. Currently, vectors support only float elements. + UnsupportedElementType(ColumnType), +} + +impl Display for VectorTypeCheckErrorKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + VectorTypeCheckErrorKind::UnsupportedElementType(typ) => { + write!(f, "serializing vectors of {:?} is not supported", typ) } } } @@ -1312,6 +1404,31 @@ impl Display for SetOrListSerializationErrorKind { } } +/// Describes why serialization of a set or list type failed. +#[derive(Debug, Clone)] +#[non_exhaustive] +pub enum VectorSerializationErrorKind { + /// The number of elements in the serialized collection does not match + /// the number of vector dimensions + InvalidNumberOfElements(usize, u32), + + /// One of the elements of the vector failed to serialize. + ElementSerializationFailed(SerializationError), +} + +impl Display for VectorSerializationErrorKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + VectorSerializationErrorKind::InvalidNumberOfElements(actual, expected) => { + write!(f, "number of vector elements ({}) does not much the number of declared dimensions ({})", actual, expected) + } + VectorSerializationErrorKind::ElementSerializationFailed(err) => { + write!(f, "failed to serialize one of the elements: {err}") + } + } + } +} + /// Describes why type checking of a tuple failed. #[derive(Debug, Clone)] #[non_exhaustive] diff --git a/scylla-cql/src/types/serialize/writers.rs b/scylla-cql/src/types/serialize/writers.rs index cfd36202bb..c43fd425fe 100644 --- a/scylla-cql/src/types/serialize/writers.rs +++ b/scylla-cql/src/types/serialize/writers.rs @@ -71,6 +71,7 @@ impl<'buf> RowWriter<'buf> { /// in nothing being written. pub struct CellWriter<'buf> { buf: &'buf mut Vec, + cell_len: Option, } impl<'buf> CellWriter<'buf> { @@ -79,12 +80,25 @@ impl<'buf> CellWriter<'buf> { /// The newly created row writer will append data to the end of the vec. #[inline] pub fn new(buf: &'buf mut Vec) -> Self { - Self { buf } + Self { + buf, + cell_len: None, + } + } + + /// Creates a new cell writer based on an existing Vec, for fixed-length cells. + /// This cell writer will serialize each cell directly, without prepending it + /// with cell length. + /// + /// The newly created row writer will append data to the end of the vec + pub fn with_cell_len(buf: &'buf mut Vec, cell_len: Option) -> Self { + Self { buf, cell_len } } /// Sets this value to be null, consuming this object. #[inline] pub fn set_null(self) -> WrittenCellProof<'buf> { + assert!(self.cell_len.is_none()); self.buf.extend_from_slice(&(-1i32).to_be_bytes()); WrittenCellProof::new() } @@ -92,6 +106,7 @@ impl<'buf> CellWriter<'buf> { /// Sets this value to represent an unset value, consuming this object. #[inline] pub fn set_unset(self) -> WrittenCellProof<'buf> { + assert!(self.cell_len.is_none()); self.buf.extend_from_slice(&(-2i32).to_be_bytes()); WrittenCellProof::new() } @@ -107,12 +122,15 @@ impl<'buf> CellWriter<'buf> { #[inline] pub fn set_value(self, contents: &[u8]) -> Result, CellOverflowError> { let value_len: i32 = contents.len().try_into().map_err(|_| CellOverflowError)?; - self.buf.extend_from_slice(&value_len.to_be_bytes()); + match self.cell_len { + Some(len) => assert_eq!(len, contents.len()), + None => self.buf.extend_from_slice(&value_len.to_be_bytes()), + } self.buf.extend_from_slice(contents); Ok(WrittenCellProof::new()) } - /// Turns this writter into a [`CellValueBuilder`] which can be used + /// Turns this writer into a [`CellValueBuilder`] which can be used /// to gradually initialize the CQL value. /// /// This method should be used if you don't have all of the data @@ -122,6 +140,13 @@ impl<'buf> CellWriter<'buf> { pub fn into_value_builder(self) -> CellValueBuilder<'buf> { CellValueBuilder::new(self.buf) } + + /// Turns this writer into a [`CellValueBuilder`] which can be used + /// to gradually initialize the CQL value of CQL vector type. + #[inline] + pub fn into_fixed_len_value_builder(self, len: usize) -> CellValueBuilder<'buf> { + CellValueBuilder::fixed_len(self.buf, len) + } } /// Allows appending bytes to a non-null, non-unset cell. @@ -136,6 +161,8 @@ pub struct CellValueBuilder<'buf> { // Starting position of the value in the buffer. starting_pos: usize, + + cell_len: Option, } impl<'buf> CellValueBuilder<'buf> { @@ -149,7 +176,28 @@ impl<'buf> CellValueBuilder<'buf> { // won't be misinterpreted. let starting_pos = buf.len(); buf.extend_from_slice(&(-3i32).to_be_bytes()); - Self { buf, starting_pos } + Self { + buf, + starting_pos, + cell_len: None, + } + } + + #[inline] + fn fixed_len(buf: &'buf mut Vec, cell_len: usize) -> Self { + // "Length" of a [bytes] frame can either be a non-negative i32, + // -1 (null) or -1 (not set). Push an invalid value here. It will be + // overwritten eventually either by set_null, set_unset or Drop. + // If the CellSerializer is not dropped as it should, this will trigger + // an error on the DB side and the serialized data + // won't be misinterpreted. + let starting_pos = buf.len(); + buf.extend_from_slice(&(-3i32).to_be_bytes()); + Self { + buf, + starting_pos, + cell_len: Some(cell_len), + } } /// Appends raw bytes to this cell. @@ -162,7 +210,7 @@ impl<'buf> CellValueBuilder<'buf> { /// and returns an object that allows to fill it in. #[inline] pub fn make_sub_writer(&mut self) -> CellWriter<'_> { - CellWriter::new(self.buf) + CellWriter::with_cell_len(self.buf, self.cell_len) } /// Finishes serializing the value. diff --git a/scylla/src/transport/topology.rs b/scylla/src/transport/topology.rs index fafa8afdca..2b4900ab47 100644 --- a/scylla/src/transport/topology.rs +++ b/scylla/src/transport/topology.rs @@ -184,6 +184,7 @@ enum PreCqlType { type_: PreCollectionType, }, Tuple(Vec), + Vector(Box, u16), UserDefinedType { frozen: bool, name: String, @@ -207,6 +208,9 @@ impl PreCqlType { .map(|t| t.into_cql_type(keyspace_name, udts)) .collect(), ), + PreCqlType::Vector(t, dim) => { + CqlType::Vector(Box::new(t.into_cql_type(keyspace_name, udts)), dim) + } PreCqlType::UserDefinedType { frozen, name } => { let definition = match udts .get(keyspace_name) @@ -232,6 +236,7 @@ pub enum CqlType { type_: CollectionType, }, Tuple(Vec), + Vector(Box, u16), UserDefinedType { frozen: bool, // Using Arc here in order not to have many copies of the same definition @@ -1099,6 +1104,9 @@ fn topo_sort_udts(udts: &mut Vec) -> Result<(), Quer PreCqlType::Tuple(types) => types .iter() .for_each(|type_| do_with_referenced_udts(what, type_)), + PreCqlType::Vector(t, _) => { + do_with_referenced_udts(what, t); + } PreCqlType::UserDefinedType { name, .. } => what(name), } } @@ -1608,6 +1616,12 @@ fn parse_cql_type(p: ParserState<'_>) -> ParseResult<(PreCqlType, ParserState<'_ })?; Ok((PreCqlType::Tuple(types), p)) + } else if let Ok(p) = p.accept("vector<") { + let (inner_type, p) = parse_cql_type(p)?; + let p = p.accept(",")?.skip_white(); + let (dim, p) = parse_u16(p)?; + let p = p.accept(">")?; + Ok((PreCqlType::Vector(Box::new(inner_type), dim), p)) } else if let Ok((typ, p)) = parse_native_type(p) { Ok((PreCqlType::Native(typ), p)) } else if let Ok((name, p)) = parse_user_defined_type(p) { @@ -1639,6 +1653,15 @@ fn parse_user_defined_type(p: ParserState) -> ParseResult<(&str, ParserState)> { Ok((tok, p)) } +fn parse_u16(p: ParserState) -> ParseResult<(u16, ParserState)> { + let (tok, p) = p.take_while(|c| c.is_numeric()); + if let Ok(value) = tok.parse() { + Ok((value, p)) + } else { + Err(p.error(ParseErrorCause::Expected("positive integer"))) + } +} + fn freeze_type(type_: PreCqlType) -> PreCqlType { match type_ { PreCqlType::Collection { type_, .. } => PreCqlType::Collection { diff --git a/scylla/src/utils/pretty.rs b/scylla/src/utils/pretty.rs index bd3f06487a..b811168723 100644 --- a/scylla/src/utils/pretty.rs +++ b/scylla/src/utils/pretty.rs @@ -121,6 +121,11 @@ where CommaSeparatedDisplayer(v.iter().map(CqlValueDisplayer)).fmt(f)?; f.write_str("]")?; } + CqlValue::Vector(v) => { + f.write_str("[")?; + CommaSeparatedDisplayer(v.iter().map(CqlValueDisplayer)).fmt(f)?; + f.write_str("]")?; + } CqlValue::Set(v) => { f.write_str("{")?; CommaSeparatedDisplayer(v.iter().map(CqlValueDisplayer)).fmt(f)?;