Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support CQL Vector type #1022

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add faster specialization for deserializing vector<float>
pkolaczk committed Aug 15, 2024
commit faea1f94f9fbe380e378763ed651e708e90257b6
15 changes: 12 additions & 3 deletions scylla-cql/src/frame/response/result.rs
Original file line number Diff line number Diff line change
@@ -7,6 +7,7 @@ use crate::frame::{frame_errors::ParseError, types};
use crate::types::deserialize::result::{RowIterator, TypedRowIterator};
use crate::types::deserialize::value::{
mk_deser_err, BuiltinDeserializationErrorKind, DeserializeValue, MapIterator, UdtIterator,
VectorIterator,
};
use crate::types::deserialize::{DeserializationError, FrameSlice};
use bytes::{Buf, Bytes};
@@ -829,9 +830,17 @@ pub fn deser_cql_value(
.collect::<StdResult<_, _>>()?;
CqlValue::Tuple(t)
}
Vector(_type_name, _) => {
let l = Vec::<CqlValue>::deserialize(typ, v)?;
CqlValue::Vector(l)
// Specialization for faster deserialization of vectors of floats, which are currently
// the only type of vector
Vector(elem_type, _) if matches!(elem_type.as_ref(), Float) => {
let v = VectorIterator::<CqlValue>::deserialize_vector_of_float_to_vec_of_cql_value(
typ, v,
)?;
CqlValue::Vector(v)
}
Vector(_, _) => {
let v = Vec::<CqlValue>::deserialize(typ, v)?;
CqlValue::Vector(v)
}
})
}
63 changes: 56 additions & 7 deletions scylla-cql/src/types/deserialize/value.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
//! Provides types for dealing with CQL value deserialization.

use std::{
collections::{BTreeMap, BTreeSet, HashMap, HashSet},
hash::{BuildHasher, Hash},
net::IpAddr,
};

use bytes::Bytes;
use std::{collections::{BTreeMap, BTreeSet, HashMap, HashSet}, hash::{BuildHasher, Hash}, net::IpAddr};
use uuid::Uuid;

use std::fmt::{Display, Pointer};
@@ -759,7 +754,7 @@ pub struct VectorIterator<'frame, T> {
}

impl<'frame, T> VectorIterator<'frame, T> {
fn new(
pub fn new(
coll_typ: &'frame ColumnType,
elem_typ: &'frame ColumnType,
count: usize,
@@ -774,6 +769,57 @@ impl<'frame, T> VectorIterator<'frame, T> {
phantom_data: std::marker::PhantomData,
}
}

/// Faster specialization for deserializing a `vector<float>` into `Vec<CqlValue>`.
/// The generic code `Vec<CqlValue>::deserialize(...)` is much slower because it has to
/// match on the element type for every item in the vector.
/// Here we just hardcode `f32` and we can shortcut a lot of code.
///
/// This could be nicer if Rust had generic type specialization in stable,
/// but for now we need a separate method.
pub fn deserialize_vector_of_float_to_vec_of_cql_value(
typ: &'frame ColumnType,
v: Option<FrameSlice<'frame>>,
) -> Result<Vec<CqlValue>, DeserializationError> {

// Typecheck would make sure those never happen:
let ColumnType::Vector(elem_type, elem_count) = typ else {
panic!("Wrong column type: {:?}. Expected vector<>", typ);
};
if !matches!(elem_type.as_ref(), ColumnType::Float) {
panic!("Wrong element type: {:?}. Expected float", typ);
}

let elem_count = *elem_count as usize;
let mut frame = v.map(|s| s.as_slice()).unwrap_or_default();
let mut result = Vec::with_capacity(elem_count);

unsafe {
type Float = f32;

// Check length only once
if frame.len() < size_of::<Float>() * elem_count {
return Err(mk_deser_err::<Vec<CqlValue>>(
typ,
BuiltinDeserializationErrorKind::RawCqlBytesReadError(
LowLevelDeserializationError::TooFewBytesReceived {
expected: size_of::<Float>() * elem_count,
received: frame.len(),
},
),
));
}
// We know we have enough elements in the buffer, so now we can skip the checks
for _ in 0..elem_count {
// we did check for frame length earlier, so we can safely not check again
let (elem, remaining) = frame.split_at_unchecked(size_of::<Float>());
let elem = Float::from_be_bytes(*elem.as_ptr().cast());
result.push(CqlValue::Float(elem));
frame = remaining;
}
}
Ok(result)
}
}

impl<'frame, T> DeserializeValue<'frame> for VectorIterator<'frame, T>
@@ -828,6 +874,7 @@ where
{
type Item = Result<T, DeserializationError>;

#[inline]
fn next(&mut self) -> Option<Self::Item> {
let raw = self.raw_iter.next()?.map_err(|err| {
mk_deser_err::<Self>(
@@ -883,6 +930,7 @@ where
.and_then(|it| it.collect::<Result<_, DeserializationError>>())
.map_err(deser_error_replace_rust_name::<Self>)
}

ColumnType::Vector(_, _) => VectorIterator::<'frame, T>::deserialize(typ, v)
.and_then(|it| it.collect::<Result<_, DeserializationError>>())
.map_err(deser_error_replace_rust_name::<Self>),
@@ -1435,6 +1483,7 @@ impl<'frame> VectorBytesSequenceIterator<'frame> {
impl<'frame> Iterator for VectorBytesSequenceIterator<'frame> {
type Item = Result<Option<FrameSlice<'frame>>, LowLevelDeserializationError>;

#[inline]
fn next(&mut self) -> Option<Self::Item> {
self.remaining = self.remaining.checked_sub(1)?;
Some(self.slice.read_subslice(self.elem_len))