diff --git a/scylla-cql/src/frame/response/result.rs b/scylla-cql/src/frame/response/result.rs index 4740d38c5a..359c8e3cea 100644 --- a/scylla-cql/src/frame/response/result.rs +++ b/scylla-cql/src/frame/response/result.rs @@ -7,6 +7,7 @@ use crate::frame::{frame_errors::ParseError, types}; use crate::types::deserialize::result::{RowIterator, TypedRowIterator}; use crate::types::deserialize::value::{ mk_deser_err, BuiltinDeserializationErrorKind, DeserializeValue, MapIterator, UdtIterator, + VectorIterator, }; use crate::types::deserialize::{DeserializationError, FrameSlice}; use bytes::{Buf, Bytes}; @@ -829,9 +830,17 @@ pub fn deser_cql_value( .collect::>()?; CqlValue::Tuple(t) } - Vector(_type_name, _) => { - let l = Vec::::deserialize(typ, v)?; - CqlValue::Vector(l) + // Specialization for faster deserialization of vectors of floats, which are currently + // the only type of vector + Vector(elem_type, _) if matches!(elem_type.as_ref(), Float) => { + let v = VectorIterator::::deserialize_vector_of_float_to_vec_of_cql_value( + typ, v, + )?; + CqlValue::Vector(v) + } + Vector(_, _) => { + let v = Vec::::deserialize(typ, v)?; + CqlValue::Vector(v) } }) } diff --git a/scylla-cql/src/types/deserialize/value.rs b/scylla-cql/src/types/deserialize/value.rs index 0197f05a6b..249c25c503 100644 --- a/scylla-cql/src/types/deserialize/value.rs +++ b/scylla-cql/src/types/deserialize/value.rs @@ -1,12 +1,7 @@ //! Provides types for dealing with CQL value deserialization. -use std::{ - collections::{BTreeMap, BTreeSet, HashMap, HashSet}, - hash::{BuildHasher, Hash}, - net::IpAddr, -}; - use bytes::Bytes; +use std::{collections::{BTreeMap, BTreeSet, HashMap, HashSet}, hash::{BuildHasher, Hash}, mem, net::IpAddr}; use uuid::Uuid; use std::fmt::{Display, Pointer}; @@ -759,7 +754,7 @@ pub struct VectorIterator<'frame, T> { } impl<'frame, T> VectorIterator<'frame, T> { - fn new( + pub fn new( coll_typ: &'frame ColumnType, elem_typ: &'frame ColumnType, count: usize, @@ -774,6 +769,57 @@ impl<'frame, T> VectorIterator<'frame, T> { phantom_data: std::marker::PhantomData, } } + + /// Faster specialization for deserializing a `vector` into `Vec`. + /// The generic code `Vec::deserialize(...)` is much slower because it has to + /// match on the element type for every item in the vector. + /// Here we just hardcode `f32` and we can shortcut a lot of code. + /// + /// This could be nicer if Rust had generic type specialization in stable, + /// but for now we need a separate method. + pub fn deserialize_vector_of_float_to_vec_of_cql_value( + typ: &'frame ColumnType, + v: Option>, + ) -> Result, DeserializationError> { + + // Typecheck would make sure those never happen: + let ColumnType::Vector(elem_type, elem_count) = typ else { + panic!("Wrong column type: {:?}. Expected vector<>", typ); + }; + if !matches!(elem_type.as_ref(), ColumnType::Float) { + panic!("Wrong element type: {:?}. Expected float", typ); + } + + let elem_count = *elem_count as usize; + let mut frame = v.map(|s| s.as_slice()).unwrap_or_default(); + let mut result = Vec::with_capacity(elem_count); + + unsafe { + type Float = f32; + + // Check length only once + if frame.len() < size_of::() * elem_count { + return Err(mk_deser_err::>( + typ, + BuiltinDeserializationErrorKind::RawCqlBytesReadError( + LowLevelDeserializationError::TooFewBytesReceived { + expected: size_of::() * elem_count, + received: frame.len(), + }, + ), + )); + } + // We know we have enough elements in the buffer, so now we can skip the checks + for _ in 0..elem_count { + // we did check for frame length earlier, so we can safely not check again + let (elem, remaining) = frame.split_first_chunk().unwrap_unchecked(); + let elem = Float::from_be_bytes(*elem); + result.push(CqlValue::Float(elem)); + frame = remaining; + } + } + Ok(result) + } } impl<'frame, T> DeserializeValue<'frame> for VectorIterator<'frame, T> @@ -828,6 +874,7 @@ where { type Item = Result; + #[inline] fn next(&mut self) -> Option { let raw = self.raw_iter.next()?.map_err(|err| { mk_deser_err::( @@ -883,6 +930,7 @@ where .and_then(|it| it.collect::>()) .map_err(deser_error_replace_rust_name::) } + ColumnType::Vector(_, _) => VectorIterator::<'frame, T>::deserialize(typ, v) .and_then(|it| it.collect::>()) .map_err(deser_error_replace_rust_name::), @@ -1435,6 +1483,7 @@ impl<'frame> VectorBytesSequenceIterator<'frame> { impl<'frame> Iterator for VectorBytesSequenceIterator<'frame> { type Item = Result>, LowLevelDeserializationError>; + #[inline] fn next(&mut self) -> Option { self.remaining = self.remaining.checked_sub(1)?; Some(self.slice.read_subslice(self.elem_len))