Skip to content

Commit

Permalink
Merge pull request #1120 from wprzytula/migrate-tablets-to-new-deseri…
Browse files Browse the repository at this point in the history
…alization-framework

Migrate tablets to new deserialization framework
  • Loading branch information
Lorak-mmk authored Nov 12, 2024
2 parents e99d697 + 453602c commit 00f39a4
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 29 deletions.
2 changes: 1 addition & 1 deletion scylla-cql/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ pub struct ResponseBodyWithExtensions {
pub trace_id: Option<Uuid>,
pub warnings: Vec<String>,
pub body: Bytes,
pub custom_payload: Option<HashMap<String, Vec<u8>>>,
pub custom_payload: Option<HashMap<String, Bytes>>,
}

pub fn parse_response_body_extensions(
Expand Down
15 changes: 9 additions & 6 deletions scylla-cql/src/frame/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
use super::frame_errors::LowLevelDeserializationError;
use super::TryFromPrimitiveError;
use byteorder::{BigEndian, ReadBytesExt};
use bytes::Bytes;
#[cfg(test)]
use bytes::BytesMut;
use bytes::{Buf, BufMut};
use std::collections::HashMap;
use std::convert::TryFrom;
Expand Down Expand Up @@ -314,12 +317,12 @@ pub fn write_short_bytes(v: &[u8], buf: &mut impl BufMut) -> Result<(), std::num

pub fn read_bytes_map(
buf: &mut &[u8],
) -> Result<HashMap<String, Vec<u8>>, LowLevelDeserializationError> {
) -> Result<HashMap<String, Bytes>, LowLevelDeserializationError> {
let len = read_short_length(buf)?;
let mut v = HashMap::with_capacity(len);
for _ in 0..len {
let key = read_string(buf)?.to_owned();
let val = read_bytes(buf)?.to_owned();
let val = Bytes::copy_from_slice(read_bytes(buf)?);
v.insert(key, val);
}
Ok(v)
Expand All @@ -344,10 +347,10 @@ where
#[test]
fn type_bytes_map() {
let mut val = HashMap::new();
val.insert("".to_owned(), vec![]);
val.insert("EXTENSION1".to_owned(), vec![1, 2, 3]);
val.insert("EXTENSION2".to_owned(), vec![4, 5, 6]);
let mut buf = Vec::new();
val.insert("".to_owned(), Bytes::new());
val.insert("EXTENSION1".to_owned(), Bytes::from_static(&[1, 2, 3]));
val.insert("EXTENSION2".to_owned(), Bytes::from_static(&[4, 5, 6]));
let mut buf = BytesMut::new();
write_bytes_map(&val, &mut buf).unwrap();
assert_eq!(read_bytes_map(&mut &*buf).unwrap(), val);
}
Expand Down
2 changes: 1 addition & 1 deletion scylla/src/transport/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ pub(crate) struct QueryResponse {
pub(crate) tracing_id: Option<Uuid>,
pub(crate) warnings: Vec<String>,
#[allow(dead_code)] // This is not exposed to user (yet?)
pub(crate) custom_payload: Option<HashMap<String, Vec<u8>>>,
pub(crate) custom_payload: Option<HashMap<String, Bytes>>,
}

// A QueryResponse in which response can not be Response::Error
Expand Down
60 changes: 39 additions & 21 deletions scylla/src/transport/locator/tablets.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
use bytes::Bytes;
use itertools::Itertools;
use lazy_static::lazy_static;
use scylla_cql::cql_to_rust::FromCqlVal;
use scylla_cql::frame::response::result::{deser_cql_value, ColumnType, TableSpec};
use scylla_cql::types::deserialize::DeserializationError;
use scylla_cql::frame::response::result::{ColumnType, TableSpec};
use scylla_cql::types::deserialize::value::ListlikeIterator;
use scylla_cql::types::deserialize::{
DeserializationError, DeserializeValue, FrameSlice, TypeCheckError,
};
use thiserror::Error;
use tracing::warn;
use uuid::Uuid;
Expand All @@ -11,12 +14,15 @@ use crate::routing::{Shard, Token};
use crate::transport::Node;

use std::collections::{HashMap, HashSet};
use std::ops::Deref;
use std::sync::Arc;

#[derive(Error, Debug)]
pub(crate) enum TabletParsingError {
#[error(transparent)]
Deserialization(#[from] DeserializationError),
#[error(transparent)]
TypeCheck(#[from] TypeCheckError),
#[error("Shard id for tablet is negative: {0}")]
ShardNum(i32),
}
Expand All @@ -35,7 +41,8 @@ pub(crate) struct RawTablet {
replicas: RawTabletReplicas,
}

type RawTabletPayload = (i64, i64, Vec<(Uuid, i32)>);
type RawTabletPayload<'frame, 'metadata> =
(i64, i64, ListlikeIterator<'frame, 'metadata, (Uuid, i32)>);

lazy_static! {
static ref RAW_TABLETS_CQL_TYPE: ColumnType<'static> = ColumnType::Tuple(vec![
Expand All @@ -52,29 +59,37 @@ const CUSTOM_PAYLOAD_TABLETS_V1_KEY: &str = "tablets-routing-v1";

impl RawTablet {
pub(crate) fn from_custom_payload(
payload: &HashMap<String, Vec<u8>>,
payload: &HashMap<String, Bytes>,
) -> Option<Result<RawTablet, TabletParsingError>> {
let payload = payload.get(CUSTOM_PAYLOAD_TABLETS_V1_KEY)?;
let cql_value = match deser_cql_value(&RAW_TABLETS_CQL_TYPE, &mut payload.as_slice()) {
Ok(r) => r,
Err(e) => return Some(Err(e.into())),

if let Err(err) =
<RawTabletPayload as DeserializeValue<'_, '_>>::type_check(RAW_TABLETS_CQL_TYPE.deref())
{
return Some(Err(err.into()));
};

// This could only fail if the type was wrong, but we do pass correct type
// to `deser_cql_value`.
let (first_token, last_token, replicas): RawTabletPayload =
FromCqlVal::from_cql(cql_value).unwrap();
match <RawTabletPayload as DeserializeValue<'_, '_>>::deserialize(
RAW_TABLETS_CQL_TYPE.deref(),
Some(FrameSlice::new(payload)),
) {
Ok(tuple) => tuple,
Err(err) => return Some(Err(err.into())),
};

let replicas = match replicas
.into_iter()
.map(|(uuid, shard_num)| match shard_num.try_into() {
Ok(s) => Ok((uuid, s)),
Err(_) => Err(shard_num),
.map(|res| {
res.map_err(TabletParsingError::from)
.and_then(|(uuid, shard_num)| match shard_num.try_into() {
Ok(s) => Ok((uuid, s)),
Err(_) => Err(TabletParsingError::ShardNum(shard_num)),
})
})
.collect::<Result<Vec<(Uuid, Shard)>, _>>()
.collect::<Result<Vec<(Uuid, Shard)>, TabletParsingError>>()
{
Ok(r) => r,
Err(shard_num) => return Some(Err(TabletParsingError::ShardNum(shard_num))),
Err(err) => return Some(Err(err)),
};

Some(Ok(RawTablet {
Expand Down Expand Up @@ -590,6 +605,7 @@ mod tests {
use std::collections::{HashMap, HashSet};
use std::sync::Arc;

use bytes::Bytes;
use scylla_cql::frame::response::result::{ColumnType, CqlValue, TableSpec};
use scylla_cql::types::serialize::value::SerializeValue;
use scylla_cql::types::serialize::CellWriter;
Expand Down Expand Up @@ -618,8 +634,10 @@ mod tests {

#[test]
fn test_raw_tablet_deser_trash() {
let custom_payload =
HashMap::from([(CUSTOM_PAYLOAD_TABLETS_V1_KEY.to_string(), vec![1, 2, 3])]);
let custom_payload = HashMap::from([(
CUSTOM_PAYLOAD_TABLETS_V1_KEY.to_string(),
Bytes::from_static(&[1, 2, 3]),
)]);
assert_matches::assert_matches!(
RawTablet::from_custom_payload(&custom_payload),
Some(Err(TabletParsingError::Deserialization(_)))
Expand Down Expand Up @@ -648,7 +666,7 @@ mod tests {
SerializeValue::serialize(&value, &col_type, CellWriter::new(&mut data)).unwrap();
debug!("{:?}", data);

custom_payload.insert(CUSTOM_PAYLOAD_TABLETS_V1_KEY.to_string(), data);
custom_payload.insert(CUSTOM_PAYLOAD_TABLETS_V1_KEY.to_string(), Bytes::from(data));

assert_matches::assert_matches!(
RawTablet::from_custom_payload(&custom_payload),
Expand Down Expand Up @@ -688,7 +706,7 @@ mod tests {
// Skipping length because `SerializeValue::serialize` adds length at the
// start of serialized value while Scylla sends the value without initial
// length.
data[4..].to_vec(),
Bytes::copy_from_slice(&data[4..]),
);

let tablet = RawTablet::from_custom_payload(&custom_payload)
Expand Down

0 comments on commit 00f39a4

Please sign in to comment.