Skip to content

Commit

Permalink
serialization and actually hopefully working parser
Browse files Browse the repository at this point in the history
  • Loading branch information
smoczy123 committed Jan 7, 2025
1 parent bd92659 commit 9c92169
Show file tree
Hide file tree
Showing 8 changed files with 233 additions and 59 deletions.
6 changes: 3 additions & 3 deletions scylla-cql/src/frame/response/result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ impl ColumnType<'_> {
ColumnType::Tuple(_) => None,
ColumnType::Uuid => Some(16),
ColumnType::Varint => None,
ColumnType::Vector(elem_type, _) => None,
ColumnType::Vector(_, _) => None,
}
}
}
Expand Down Expand Up @@ -915,10 +915,10 @@ fn deser_type_generic<'frame, 'result, StrT: Into<Cow<'result, str>>>(
0x0000 => {
let type_str = read_string(buf).map_err(CqlTypeParseError::CustomTypeNameParseError)?;
let type_cow: Cow<'result, str> = type_str.into();
if let Ok(typ) = type_parser::TypeParser::parse(&type_cow) {
if let Ok(typ) = type_parser::TypeParser::parse(type_cow.clone()) {
typ
} else {
Ascii
Custom(type_cow)
}
}
0x0001 => Ascii,
Expand Down
110 changes: 68 additions & 42 deletions scylla-cql/src/frame/response/type_parser.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,25 @@
use crate::frame::frame_errors::CqlTypeParseError;
use std::{borrow::Cow, char};
use std::{borrow::Cow, char, str::from_utf8};

use super::result::ColumnType;

type UDTParameters<'result> = (
Cow<'result, str>,
Cow<'result, str>,
Vec<(Cow<'result, str>, ColumnType<'result>)>,
);

pub(crate) struct TypeParser<'result> {
pos: usize,
str: &'result str,
str: Cow<'result, str>,
}

impl<'result> TypeParser<'result> {
fn new(str: &str) -> TypeParser {
fn new(str: Cow<'result, str>) -> TypeParser<'result> {
TypeParser { pos: 0, str }
}

pub(crate) fn parse(str: &'result str) -> Result<ColumnType<'result>, CqlTypeParseError> {
pub(crate) fn parse(str: Cow<'result, str>) -> Result<ColumnType<'result>, CqlTypeParseError> {
let mut parser = TypeParser::new(str);
parser.do_parse()
}
Expand All @@ -30,14 +36,17 @@ impl<'result> TypeParser<'result> {
c.is_alphanumeric() || c == '+' || c == '-' || c == '_' || c == '.' || c == '&'
}

fn read_next_identifier(&mut self) -> &'result str {
fn read_next_identifier(&mut self) -> Cow<'result, str> {
let start = self.pos;
while !self.is_eos()
&& TypeParser::is_identifier_char(self.str.as_bytes()[self.pos] as char)
{
self.pos += 1;
}
&self.str[start..self.pos]
match &self.str {
Cow::Borrowed(s) => Cow::Borrowed(&s[start..self.pos]),
Cow::Owned(s) => Cow::Owned(s[start..self.pos].to_owned()),
}
}

fn skip_blank(&mut self) -> usize {
Expand All @@ -62,20 +71,22 @@ impl<'result> TypeParser<'result> {
}
self.pos += 1;
}
return false;
false
}

fn get_simple_abstract_type(name: &str) -> Result<ColumnType, CqlTypeParseError> {
fn get_simple_abstract_type(
name: Cow<'result, str>,
) -> Result<ColumnType<'result>, CqlTypeParseError> {
let string_class_name: String;
let class_name: &str;
let class_name: Cow<'result, str>;
if name.contains("org.apache.cassandra.db.marshal.") {
class_name = name
} else {
string_class_name = "org.apache.cassandra.db.marshal.".to_owned() + name;
class_name = &string_class_name;
string_class_name = "org.apache.cassandra.db.marshal.".to_owned() + &name;
class_name = Cow::Owned(string_class_name);
}

match class_name {
match class_name.as_ref() {
"org.apache.cassandra.db.marshal.AsciiType" => Ok(ColumnType::Ascii),
"org.apache.cassandra.db.marshal.BooleanType" => Ok(ColumnType::Boolean),
"org.apache.cassandra.db.marshal.BytesType" => Ok(ColumnType::Blob),
Expand Down Expand Up @@ -108,15 +119,15 @@ impl<'result> TypeParser<'result> {
if self.is_eos() {
return Ok(parameters);
}
if self.str.as_bytes()[self.pos] != '(' as u8 {
if self.str.as_bytes()[self.pos] != b'(' {
return Err(CqlTypeParseError::AbstractTypeParseError());
}
self.pos += 1;
loop {
if !self.skip_blank_and_comma() {
return Err(CqlTypeParseError::AbstractTypeParseError());
}
if self.str.as_bytes()[self.pos] == ')' as u8 {
if self.str.as_bytes()[self.pos] == b')' {
self.pos += 1;
return Ok(parameters);
}
Expand All @@ -126,12 +137,12 @@ impl<'result> TypeParser<'result> {
}

fn get_vector_parameters(&mut self) -> Result<(ColumnType<'result>, u32), CqlTypeParseError> {
if self.is_eos() || self.str.as_bytes()[self.pos] != '(' as u8 {
if self.is_eos() || self.str.as_bytes()[self.pos] != b'(' {
return Err(CqlTypeParseError::AbstractTypeParseError());
}
self.pos += 1;
self.skip_blank_and_comma();
if self.str.as_bytes()[self.pos] == ')' as u8 {
if self.str.as_bytes()[self.pos] == b')' {
return Err(CqlTypeParseError::AbstractTypeParseError());
}

Expand All @@ -140,21 +151,22 @@ impl<'result> TypeParser<'result> {
while !self.is_eos() && char::is_numeric(self.str.as_bytes()[self.pos] as char) {
self.pos += 1;
}
let len = u32::from_str_radix(&self.str[start..self.pos], 10)
let len = self.str[start..self.pos]
.parse::<u32>()
.map_err(|_| CqlTypeParseError::AbstractTypeParseError())?;
if self.is_eos() || self.str.as_bytes()[self.pos] != ')' as u8 {
if self.is_eos() || self.str.as_bytes()[self.pos] != b')' {
return Err(CqlTypeParseError::AbstractTypeParseError());
}
self.pos += 1;
Ok((typ, len))
}

fn from_hex(s: &str) -> Result<Vec<u8>, CqlTypeParseError> {
fn from_hex(s: Cow<'result, str>) -> Result<Vec<u8>, CqlTypeParseError> {
if s.len() % 2 != 0 {
return Err(CqlTypeParseError::AbstractTypeParseError());
}
for c in s.chars() {
if !c.is_digit(16) {
if !c.is_ascii_hexdigit() {
return Err(CqlTypeParseError::AbstractTypeParseError());
}
}
Expand All @@ -167,32 +179,46 @@ impl<'result> TypeParser<'result> {
Ok(bytes)
}

fn get_udt_parameters(
&mut self,
) -> Result<
(
Cow<'result, str>,
Cow<'result, str>,
Vec<(Cow<'result, str>, ColumnType<'result>)>,
),
CqlTypeParseError,
> {
unimplemented!("get_udt_parameters");
fn get_udt_parameters(&mut self) -> Result<UDTParameters<'result>, CqlTypeParseError> {
if self.is_eos() || self.str.as_bytes()[self.pos] != b'(' {
return Err(CqlTypeParseError::AbstractTypeParseError());
}
self.pos += 1;

self.skip_blank_and_comma();
let keyspace = self.read_next_identifier();
self.skip_blank_and_comma();
let hex_name = &TypeParser::from_hex(self.read_next_identifier())?;
let name = from_utf8(hex_name).map_err(|_| CqlTypeParseError::AbstractTypeParseError())?;
let mut fields = Vec::new();
loop {
if !self.skip_blank_and_comma() {
return Err(CqlTypeParseError::AbstractTypeParseError());
}
if self.str.as_bytes()[self.pos] == b')' {
self.pos += 1;
return Ok((keyspace, Cow::Owned(name.to_owned()), fields));
}
let field_name = self.read_next_identifier();
self.skip_blank_and_comma();
let field_type = self.do_parse()?;
fields.push((field_name, field_type));
}
}

fn get_complex_abstract_type(
&mut self,
name: &str,
name: Cow<'result, str>,
) -> Result<ColumnType<'result>, CqlTypeParseError> {
let string_class_name: String;
let class_name: &str;
let class_name: Cow<'result, str>;
if name.contains("org.apache.cassandra.db.marshal.") {
class_name = name
} else {
string_class_name = "org.apache.cassandra.db.marshal.".to_owned() + name;
class_name = &string_class_name;
string_class_name = "org.apache.cassandra.db.marshal.".to_owned() + &name;
class_name = Cow::Owned(string_class_name);
}
match class_name {
match class_name.as_ref() {
"org.apache.cassandra.db.marshal.ListType" => {
let mut params = self.get_type_parameters()?;
if params.len() != 1 {
Expand Down Expand Up @@ -236,7 +262,7 @@ impl<'result> TypeParser<'result> {
field_types: fields,
})
}
_ => return Err(CqlTypeParseError::AbstractTypeParseError()),
_ => Err(CqlTypeParseError::AbstractTypeParseError()),
}
}

Expand All @@ -251,17 +277,17 @@ impl<'result> TypeParser<'result> {
return Ok(ColumnType::Blob);
}

if self.str.as_bytes()[self.pos] == ':' as u8 {
if self.str.as_bytes()[self.pos] == b':' {
self.pos += 1;
let _ = usize::from_str_radix(name, 16)
let _ = usize::from_str_radix(&name, 16)
.map_err(|_| CqlTypeParseError::AbstractTypeParseError());
name = self.read_next_identifier();
}
self.skip_blank();
if !self.is_eos() && self.str.as_bytes()[self.pos] == '(' as u8 {
return Ok(self.get_complex_abstract_type(name)?);
if !self.is_eos() && self.str.as_bytes()[self.pos] == b'(' {
self.get_complex_abstract_type(name)
} else {
return Ok(TypeParser::get_simple_abstract_type(name)?);
TypeParser::get_simple_abstract_type(name)
}
}
}
2 changes: 1 addition & 1 deletion scylla-cql/src/frame/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1534,7 +1534,7 @@ mod legacy {
CqlValue::Map(m) => serialize_map(m.iter().map(|p| (&p.0, &p.1)), m.len(), buf),
CqlValue::Tuple(t) => serialize_tuple(t.iter(), buf),

CqlValue::Vector(v) => {
CqlValue::Vector(_) => {
unimplemented!("Vector serialization is not implemented yet");
}

Expand Down
2 changes: 0 additions & 2 deletions scylla-cql/src/types/deserialize/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -917,7 +917,6 @@ where
pub struct ConstLengthVectorIterator<'frame, 'metadata, T> {
coll_typ: &'metadata ColumnType<'metadata>,
elem_typ: &'metadata ColumnType<'metadata>,
count: usize,
raw_iter: VectorBytesSequenceIterator<'frame>,
phantom_data: std::marker::PhantomData<T>,
}
Expand All @@ -933,7 +932,6 @@ impl<'frame, 'metadata, T> ConstLengthVectorIterator<'frame, 'metadata, T> {
Self {
coll_typ,
elem_typ,
count,
raw_iter: VectorBytesSequenceIterator::new(count, elem_len, slice),
phantom_data: std::marker::PhantomData,
}
Expand Down
11 changes: 11 additions & 0 deletions scylla-cql/src/types/deserialize/value_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use std::fmt::Debug;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};

use crate::frame::response::result::{ColumnType, CqlValue};
use crate::frame::response::type_parser::TypeParser;
use crate::frame::value::{
Counter, CqlDate, CqlDecimal, CqlDecimalBorrowed, CqlDuration, CqlTime, CqlTimestamp,
CqlTimeuuid, CqlVarint, CqlVarintBorrowed,
Expand All @@ -26,6 +27,16 @@ use super::{
UdtTypeCheckErrorKind,
};

#[test]
fn test_cassandra_type_parser() {
let type_name =
"org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.Integer, 5)";
assert_eq!(
TypeParser::parse(Cow::Borrowed(type_name)).unwrap(),
ColumnType::Vector(Box::new(ColumnType::Int), 5)
)
}

#[test]
fn test_deserialize_bytes() {
const ORIGINAL_BYTES: &[u8] = &[1, 5, 2, 4, 3];
Expand Down
Loading

0 comments on commit 9c92169

Please sign in to comment.