Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate tablets to new deserialization framework #1120

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion scylla-cql/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ pub struct ResponseBodyWithExtensions {
pub trace_id: Option<Uuid>,
pub warnings: Vec<String>,
pub body: Bytes,
pub custom_payload: Option<HashMap<String, Vec<u8>>>,
pub custom_payload: Option<HashMap<String, Bytes>>,
}

pub fn parse_response_body_extensions(
Expand Down
15 changes: 9 additions & 6 deletions scylla-cql/src/frame/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
use super::frame_errors::LowLevelDeserializationError;
use super::TryFromPrimitiveError;
use byteorder::{BigEndian, ReadBytesExt};
use bytes::Bytes;
#[cfg(test)]
use bytes::BytesMut;
use bytes::{Buf, BufMut};
use std::collections::HashMap;
use std::convert::TryFrom;
Expand Down Expand Up @@ -314,12 +317,12 @@ pub fn write_short_bytes(v: &[u8], buf: &mut impl BufMut) -> Result<(), std::num

pub fn read_bytes_map(
buf: &mut &[u8],
) -> Result<HashMap<String, Vec<u8>>, LowLevelDeserializationError> {
) -> Result<HashMap<String, Bytes>, LowLevelDeserializationError> {
let len = read_short_length(buf)?;
let mut v = HashMap::with_capacity(len);
for _ in 0..len {
let key = read_string(buf)?.to_owned();
let val = read_bytes(buf)?.to_owned();
let val = Bytes::copy_from_slice(read_bytes(buf)?);
v.insert(key, val);
}
Ok(v)
Expand All @@ -344,10 +347,10 @@ where
#[test]
fn type_bytes_map() {
let mut val = HashMap::new();
val.insert("".to_owned(), vec![]);
val.insert("EXTENSION1".to_owned(), vec![1, 2, 3]);
val.insert("EXTENSION2".to_owned(), vec![4, 5, 6]);
let mut buf = Vec::new();
val.insert("".to_owned(), Bytes::new());
val.insert("EXTENSION1".to_owned(), Bytes::from_static(&[1, 2, 3]));
val.insert("EXTENSION2".to_owned(), Bytes::from_static(&[4, 5, 6]));
let mut buf = BytesMut::new();
write_bytes_map(&val, &mut buf).unwrap();
assert_eq!(read_bytes_map(&mut &*buf).unwrap(), val);
}
Expand Down
2 changes: 1 addition & 1 deletion scylla/src/transport/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ pub(crate) struct QueryResponse {
pub(crate) tracing_id: Option<Uuid>,
pub(crate) warnings: Vec<String>,
#[allow(dead_code)] // This is not exposed to user (yet?)
pub(crate) custom_payload: Option<HashMap<String, Vec<u8>>>,
pub(crate) custom_payload: Option<HashMap<String, Bytes>>,
}

// A QueryResponse in which response can not be Response::Error
Expand Down
60 changes: 39 additions & 21 deletions scylla/src/transport/locator/tablets.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
use bytes::Bytes;
use itertools::Itertools;
use lazy_static::lazy_static;
use scylla_cql::cql_to_rust::FromCqlVal;
use scylla_cql::frame::response::result::{deser_cql_value, ColumnType, TableSpec};
use scylla_cql::types::deserialize::DeserializationError;
use scylla_cql::frame::response::result::{ColumnType, TableSpec};
use scylla_cql::types::deserialize::value::ListlikeIterator;
use scylla_cql::types::deserialize::{
DeserializationError, DeserializeValue, FrameSlice, TypeCheckError,
};
use thiserror::Error;
use tracing::warn;
use uuid::Uuid;
Expand All @@ -11,12 +14,15 @@ use crate::routing::{Shard, Token};
use crate::transport::Node;

use std::collections::{HashMap, HashSet};
use std::ops::Deref;
use std::sync::Arc;

#[derive(Error, Debug)]
pub(crate) enum TabletParsingError {
#[error(transparent)]
Deserialization(#[from] DeserializationError),
#[error(transparent)]
TypeCheck(#[from] TypeCheckError),
#[error("Shard id for tablet is negative: {0}")]
ShardNum(i32),
}
Expand All @@ -35,7 +41,8 @@ pub(crate) struct RawTablet {
replicas: RawTabletReplicas,
}

type RawTabletPayload = (i64, i64, Vec<(Uuid, i32)>);
type RawTabletPayload<'frame, 'metadata> =
(i64, i64, ListlikeIterator<'frame, 'metadata, (Uuid, i32)>);

lazy_static! {
static ref RAW_TABLETS_CQL_TYPE: ColumnType<'static> = ColumnType::Tuple(vec![
Expand All @@ -52,29 +59,37 @@ const CUSTOM_PAYLOAD_TABLETS_V1_KEY: &str = "tablets-routing-v1";

impl RawTablet {
pub(crate) fn from_custom_payload(
payload: &HashMap<String, Vec<u8>>,
payload: &HashMap<String, Bytes>,
) -> Option<Result<RawTablet, TabletParsingError>> {
let payload = payload.get(CUSTOM_PAYLOAD_TABLETS_V1_KEY)?;
let cql_value = match deser_cql_value(&RAW_TABLETS_CQL_TYPE, &mut payload.as_slice()) {
Ok(r) => r,
Err(e) => return Some(Err(e.into())),

if let Err(err) =
<RawTabletPayload as DeserializeValue<'_, '_>>::type_check(RAW_TABLETS_CQL_TYPE.deref())
{
return Some(Err(err.into()));
};

// This could only fail if the type was wrong, but we do pass correct type
// to `deser_cql_value`.
let (first_token, last_token, replicas): RawTabletPayload =
FromCqlVal::from_cql(cql_value).unwrap();
match <RawTabletPayload as DeserializeValue<'_, '_>>::deserialize(
RAW_TABLETS_CQL_TYPE.deref(),
Some(FrameSlice::new(payload)),
) {
Ok(tuple) => tuple,
Err(err) => return Some(Err(err.into())),
};

let replicas = match replicas
.into_iter()
.map(|(uuid, shard_num)| match shard_num.try_into() {
Ok(s) => Ok((uuid, s)),
Err(_) => Err(shard_num),
.map(|res| {
res.map_err(TabletParsingError::from)
.and_then(|(uuid, shard_num)| match shard_num.try_into() {
Ok(s) => Ok((uuid, s)),
Err(_) => Err(TabletParsingError::ShardNum(shard_num)),
})
})
.collect::<Result<Vec<(Uuid, Shard)>, _>>()
.collect::<Result<Vec<(Uuid, Shard)>, TabletParsingError>>()
{
Ok(r) => r,
Err(shard_num) => return Some(Err(TabletParsingError::ShardNum(shard_num))),
Err(err) => return Some(Err(err)),
};

Some(Ok(RawTablet {
Expand Down Expand Up @@ -590,6 +605,7 @@ mod tests {
use std::collections::{HashMap, HashSet};
use std::sync::Arc;

use bytes::Bytes;
use scylla_cql::frame::response::result::{ColumnType, CqlValue, TableSpec};
use scylla_cql::types::serialize::value::SerializeValue;
use scylla_cql::types::serialize::CellWriter;
Expand Down Expand Up @@ -618,8 +634,10 @@ mod tests {

#[test]
fn test_raw_tablet_deser_trash() {
let custom_payload =
HashMap::from([(CUSTOM_PAYLOAD_TABLETS_V1_KEY.to_string(), vec![1, 2, 3])]);
let custom_payload = HashMap::from([(
CUSTOM_PAYLOAD_TABLETS_V1_KEY.to_string(),
Bytes::from_static(&[1, 2, 3]),
)]);
assert_matches::assert_matches!(
RawTablet::from_custom_payload(&custom_payload),
Some(Err(TabletParsingError::Deserialization(_)))
Expand Down Expand Up @@ -648,7 +666,7 @@ mod tests {
SerializeValue::serialize(&value, &col_type, CellWriter::new(&mut data)).unwrap();
debug!("{:?}", data);

custom_payload.insert(CUSTOM_PAYLOAD_TABLETS_V1_KEY.to_string(), data);
custom_payload.insert(CUSTOM_PAYLOAD_TABLETS_V1_KEY.to_string(), Bytes::from(data));

assert_matches::assert_matches!(
RawTablet::from_custom_payload(&custom_payload),
Expand Down Expand Up @@ -688,7 +706,7 @@ mod tests {
// Skipping length because `SerializeValue::serialize` adds length at the
// start of serialized value while Scylla sends the value without initial
// length.
data[4..].to_vec(),
Bytes::copy_from_slice(&data[4..]),
);

let tablet = RawTablet::from_custom_payload(&custom_payload)
Expand Down
Loading