Skip to content

Commit

Permalink
Add faster specialization for deserializing vector<float>
Browse files Browse the repository at this point in the history
  • Loading branch information
pkolaczk committed Aug 15, 2024
1 parent de99af7 commit f3eb389
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 10 deletions.
15 changes: 12 additions & 3 deletions scylla-cql/src/frame/response/result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::frame::{frame_errors::ParseError, types};
use crate::types::deserialize::result::{RowIterator, TypedRowIterator};
use crate::types::deserialize::value::{
mk_deser_err, BuiltinDeserializationErrorKind, DeserializeValue, MapIterator, UdtIterator,
VectorIterator,
};
use crate::types::deserialize::{DeserializationError, FrameSlice};
use bytes::{Buf, Bytes};
Expand Down Expand Up @@ -829,9 +830,17 @@ pub fn deser_cql_value(
.collect::<StdResult<_, _>>()?;
CqlValue::Tuple(t)
}
Vector(_type_name, _) => {
let l = Vec::<CqlValue>::deserialize(typ, v)?;
CqlValue::Vector(l)
// Specialization for faster deserialization of vectors of floats, which are currently
// the only type of vector
Vector(elem_type, _) if matches!(elem_type.as_ref(), Float) => {
let v = VectorIterator::<CqlValue>::deserialize_vector_of_float_to_vec_of_cql_value(
typ, v,
)?;
CqlValue::Vector(v)
}
Vector(_, _) => {
let v = Vec::<CqlValue>::deserialize(typ, v)?;
CqlValue::Vector(v)
}
})
}
Expand Down
63 changes: 56 additions & 7 deletions scylla-cql/src/types/deserialize/value.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
//! Provides types for dealing with CQL value deserialization.
use std::{
collections::{BTreeMap, BTreeSet, HashMap, HashSet},
hash::{BuildHasher, Hash},
net::IpAddr,
};

use bytes::Bytes;
use std::{collections::{BTreeMap, BTreeSet, HashMap, HashSet}, hash::{BuildHasher, Hash}, mem, net::IpAddr};
use uuid::Uuid;

use std::fmt::{Display, Pointer};
Expand Down Expand Up @@ -759,7 +754,7 @@ pub struct VectorIterator<'frame, T> {
}

impl<'frame, T> VectorIterator<'frame, T> {
fn new(
pub fn new(
coll_typ: &'frame ColumnType,
elem_typ: &'frame ColumnType,
count: usize,
Expand All @@ -774,6 +769,57 @@ impl<'frame, T> VectorIterator<'frame, T> {
phantom_data: std::marker::PhantomData,
}
}

/// Faster specialization for deserializing a `vector<float>` into `Vec<CqlValue>`.
/// The generic code `Vec<CqlValue>::deserialize(...)` is much slower because it has to
/// match on the element type for every item in the vector.
/// Here we just hardcode `f32` and we can shortcut a lot of code.
///
/// This could be nicer if Rust had generic type specialization in stable,
/// but for now we need a separate method.
pub fn deserialize_vector_of_float_to_vec_of_cql_value(
typ: &'frame ColumnType,
v: Option<FrameSlice<'frame>>,
) -> Result<Vec<CqlValue>, DeserializationError> {

// Typecheck would make sure those never happen:
let ColumnType::Vector(elem_type, elem_count) = typ else {
panic!("Wrong column type: {:?}. Expected vector<>", typ);
};
if !matches!(elem_type.as_ref(), ColumnType::Float) {
panic!("Wrong element type: {:?}. Expected float", typ);
}

let elem_count = *elem_count as usize;
let mut frame = v.map(|s| s.as_slice()).unwrap_or_default();
let mut result = Vec::with_capacity(elem_count);

unsafe {
type Float = f32;

// Check length only once
if frame.len() < size_of::<Float>() * elem_count {
return Err(mk_deser_err::<Vec<CqlValue>>(
typ,
BuiltinDeserializationErrorKind::RawCqlBytesReadError(
LowLevelDeserializationError::TooFewBytesReceived {
expected: size_of::<Float>() * elem_count,
received: frame.len(),
},
),
));
}
// We know we have enough elements in the buffer, so now we can skip the checks
for _ in 0..elem_count {
// we did check for frame length earlier, so we can safely not check again
let (elem, remaining) = frame.split_first_chunk().unwrap_unchecked();
let elem = Float::from_be_bytes(*elem);
result.push(CqlValue::Float(elem));
frame = remaining;
}
}
Ok(result)
}
}

impl<'frame, T> DeserializeValue<'frame> for VectorIterator<'frame, T>
Expand Down Expand Up @@ -828,6 +874,7 @@ where
{
type Item = Result<T, DeserializationError>;

#[inline]
fn next(&mut self) -> Option<Self::Item> {
let raw = self.raw_iter.next()?.map_err(|err| {
mk_deser_err::<Self>(
Expand Down Expand Up @@ -883,6 +930,7 @@ where
.and_then(|it| it.collect::<Result<_, DeserializationError>>())
.map_err(deser_error_replace_rust_name::<Self>)
}

ColumnType::Vector(_, _) => VectorIterator::<'frame, T>::deserialize(typ, v)
.and_then(|it| it.collect::<Result<_, DeserializationError>>())
.map_err(deser_error_replace_rust_name::<Self>),
Expand Down Expand Up @@ -1435,6 +1483,7 @@ impl<'frame> VectorBytesSequenceIterator<'frame> {
impl<'frame> Iterator for VectorBytesSequenceIterator<'frame> {
type Item = Result<Option<FrameSlice<'frame>>, LowLevelDeserializationError>;

#[inline]
fn next(&mut self) -> Option<Self::Item> {
self.remaining = self.remaining.checked_sub(1)?;
Some(self.slice.read_subslice(self.elem_len))
Expand Down

0 comments on commit f3eb389

Please sign in to comment.