diff --git a/scylla-cql/src/frame/mod.rs b/scylla-cql/src/frame/mod.rs index 4f8dc8beb..dc3953c8c 100644 --- a/scylla-cql/src/frame/mod.rs +++ b/scylla-cql/src/frame/mod.rs @@ -184,7 +184,7 @@ pub struct ResponseBodyWithExtensions { pub trace_id: Option, pub warnings: Vec, pub body: Bytes, - pub custom_payload: Option>>, + pub custom_payload: Option>, } pub fn parse_response_body_extensions( diff --git a/scylla-cql/src/frame/types.rs b/scylla-cql/src/frame/types.rs index 70f28f6c2..2ea5a8b6b 100644 --- a/scylla-cql/src/frame/types.rs +++ b/scylla-cql/src/frame/types.rs @@ -3,6 +3,9 @@ use super::frame_errors::LowLevelDeserializationError; use super::TryFromPrimitiveError; use byteorder::{BigEndian, ReadBytesExt}; +use bytes::Bytes; +#[cfg(test)] +use bytes::BytesMut; use bytes::{Buf, BufMut}; use std::collections::HashMap; use std::convert::TryFrom; @@ -314,12 +317,12 @@ pub fn write_short_bytes(v: &[u8], buf: &mut impl BufMut) -> Result<(), std::num pub fn read_bytes_map( buf: &mut &[u8], -) -> Result>, LowLevelDeserializationError> { +) -> Result, LowLevelDeserializationError> { let len = read_short_length(buf)?; let mut v = HashMap::with_capacity(len); for _ in 0..len { let key = read_string(buf)?.to_owned(); - let val = read_bytes(buf)?.to_owned(); + let val = Bytes::copy_from_slice(read_bytes(buf)?); v.insert(key, val); } Ok(v) @@ -344,10 +347,10 @@ where #[test] fn type_bytes_map() { let mut val = HashMap::new(); - val.insert("".to_owned(), vec![]); - val.insert("EXTENSION1".to_owned(), vec![1, 2, 3]); - val.insert("EXTENSION2".to_owned(), vec![4, 5, 6]); - let mut buf = Vec::new(); + val.insert("".to_owned(), Bytes::new()); + val.insert("EXTENSION1".to_owned(), Bytes::from_static(&[1, 2, 3])); + val.insert("EXTENSION2".to_owned(), Bytes::from_static(&[4, 5, 6])); + let mut buf = BytesMut::new(); write_bytes_map(&val, &mut buf).unwrap(); assert_eq!(read_bytes_map(&mut &*buf).unwrap(), val); } diff --git a/scylla/src/transport/connection.rs b/scylla/src/transport/connection.rs index 7610650c6..185286d2f 100644 --- a/scylla/src/transport/connection.rs +++ b/scylla/src/transport/connection.rs @@ -217,7 +217,7 @@ pub(crate) struct QueryResponse { pub(crate) tracing_id: Option, pub(crate) warnings: Vec, #[allow(dead_code)] // This is not exposed to user (yet?) - pub(crate) custom_payload: Option>>, + pub(crate) custom_payload: Option>, } // A QueryResponse in which response can not be Response::Error diff --git a/scylla/src/transport/locator/tablets.rs b/scylla/src/transport/locator/tablets.rs index 0f959059c..750818d65 100644 --- a/scylla/src/transport/locator/tablets.rs +++ b/scylla/src/transport/locator/tablets.rs @@ -1,8 +1,11 @@ +use bytes::Bytes; use itertools::Itertools; use lazy_static::lazy_static; -use scylla_cql::cql_to_rust::FromCqlVal; -use scylla_cql::frame::response::result::{deser_cql_value, ColumnType, TableSpec}; -use scylla_cql::types::deserialize::DeserializationError; +use scylla_cql::frame::response::result::{ColumnType, TableSpec}; +use scylla_cql::types::deserialize::value::ListlikeIterator; +use scylla_cql::types::deserialize::{ + DeserializationError, DeserializeValue, FrameSlice, TypeCheckError, +}; use thiserror::Error; use tracing::warn; use uuid::Uuid; @@ -11,12 +14,15 @@ use crate::routing::{Shard, Token}; use crate::transport::Node; use std::collections::{HashMap, HashSet}; +use std::ops::Deref; use std::sync::Arc; #[derive(Error, Debug)] pub(crate) enum TabletParsingError { #[error(transparent)] Deserialization(#[from] DeserializationError), + #[error(transparent)] + TypeCheck(#[from] TypeCheckError), #[error("Shard id for tablet is negative: {0}")] ShardNum(i32), } @@ -35,7 +41,8 @@ pub(crate) struct RawTablet { replicas: RawTabletReplicas, } -type RawTabletPayload = (i64, i64, Vec<(Uuid, i32)>); +type RawTabletPayload<'frame, 'metadata> = + (i64, i64, ListlikeIterator<'frame, 'metadata, (Uuid, i32)>); lazy_static! { static ref RAW_TABLETS_CQL_TYPE: ColumnType<'static> = ColumnType::Tuple(vec![ @@ -52,29 +59,37 @@ const CUSTOM_PAYLOAD_TABLETS_V1_KEY: &str = "tablets-routing-v1"; impl RawTablet { pub(crate) fn from_custom_payload( - payload: &HashMap>, + payload: &HashMap, ) -> Option> { let payload = payload.get(CUSTOM_PAYLOAD_TABLETS_V1_KEY)?; - let cql_value = match deser_cql_value(&RAW_TABLETS_CQL_TYPE, &mut payload.as_slice()) { - Ok(r) => r, - Err(e) => return Some(Err(e.into())), + + if let Err(err) = + >::type_check(RAW_TABLETS_CQL_TYPE.deref()) + { + return Some(Err(err.into())); }; - // This could only fail if the type was wrong, but we do pass correct type - // to `deser_cql_value`. let (first_token, last_token, replicas): RawTabletPayload = - FromCqlVal::from_cql(cql_value).unwrap(); + match >::deserialize( + RAW_TABLETS_CQL_TYPE.deref(), + Some(FrameSlice::new(payload)), + ) { + Ok(tuple) => tuple, + Err(err) => return Some(Err(err.into())), + }; let replicas = match replicas - .into_iter() - .map(|(uuid, shard_num)| match shard_num.try_into() { - Ok(s) => Ok((uuid, s)), - Err(_) => Err(shard_num), + .map(|res| { + res.map_err(TabletParsingError::from) + .and_then(|(uuid, shard_num)| match shard_num.try_into() { + Ok(s) => Ok((uuid, s)), + Err(_) => Err(TabletParsingError::ShardNum(shard_num)), + }) }) - .collect::, _>>() + .collect::, TabletParsingError>>() { Ok(r) => r, - Err(shard_num) => return Some(Err(TabletParsingError::ShardNum(shard_num))), + Err(err) => return Some(Err(err)), }; Some(Ok(RawTablet { @@ -590,6 +605,7 @@ mod tests { use std::collections::{HashMap, HashSet}; use std::sync::Arc; + use bytes::Bytes; use scylla_cql::frame::response::result::{ColumnType, CqlValue, TableSpec}; use scylla_cql::types::serialize::value::SerializeValue; use scylla_cql::types::serialize::CellWriter; @@ -618,8 +634,10 @@ mod tests { #[test] fn test_raw_tablet_deser_trash() { - let custom_payload = - HashMap::from([(CUSTOM_PAYLOAD_TABLETS_V1_KEY.to_string(), vec![1, 2, 3])]); + let custom_payload = HashMap::from([( + CUSTOM_PAYLOAD_TABLETS_V1_KEY.to_string(), + Bytes::from_static(&[1, 2, 3]), + )]); assert_matches::assert_matches!( RawTablet::from_custom_payload(&custom_payload), Some(Err(TabletParsingError::Deserialization(_))) @@ -648,7 +666,7 @@ mod tests { SerializeValue::serialize(&value, &col_type, CellWriter::new(&mut data)).unwrap(); debug!("{:?}", data); - custom_payload.insert(CUSTOM_PAYLOAD_TABLETS_V1_KEY.to_string(), data); + custom_payload.insert(CUSTOM_PAYLOAD_TABLETS_V1_KEY.to_string(), Bytes::from(data)); assert_matches::assert_matches!( RawTablet::from_custom_payload(&custom_payload), @@ -688,7 +706,7 @@ mod tests { // Skipping length because `SerializeValue::serialize` adds length at the // start of serialized value while Scylla sends the value without initial // length. - data[4..].to_vec(), + Bytes::copy_from_slice(&data[4..]), ); let tablet = RawTablet::from_custom_payload(&custom_payload)