Skip to content

Commit

Permalink
Merge pull request #871 from piodul/serialize-trait-genericide
Browse files Browse the repository at this point in the history
serialize: make new serialization traits object-safe
  • Loading branch information
piodul authored Dec 9, 2023
2 parents d7fa0ce + 5733a2f commit 6c9ec8f
Show file tree
Hide file tree
Showing 5 changed files with 371 additions and 836 deletions.
14 changes: 5 additions & 9 deletions scylla-cql/src/frame/value_tests.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::frame::{response::result::CqlValue, types::RawValue, value::BatchValuesIterator};
use crate::types::serialize::row::{RowSerializationContext, SerializeRow};
use crate::types::serialize::value::SerializeCql;
use crate::types::serialize::{BufBackedCellWriter, BufBackedRowWriter};
use crate::types::serialize::{CellWriter, RowWriter};

use super::response::result::{ColumnSpec, ColumnType, TableSpec};
use super::value::{
Expand All @@ -24,10 +24,8 @@ where
let mut result: Vec<u8> = Vec::new();
Value::serialize(&val, &mut result).unwrap();

T::preliminary_type_check(&typ).unwrap();

let mut new_result: Vec<u8> = Vec::new();
let writer = BufBackedCellWriter::new(&mut new_result);
let writer = CellWriter::new(&mut new_result);
SerializeCql::serialize(&val, &typ, writer).unwrap();

assert_eq!(result, new_result);
Expand All @@ -37,7 +35,7 @@ where

fn serialized_only_new<T: SerializeCql>(val: T, typ: ColumnType) -> Vec<u8> {
let mut result: Vec<u8> = Vec::new();
let writer = BufBackedCellWriter::new(&mut result);
let writer = CellWriter::new(&mut result);
SerializeCql::serialize(&val, &typ, writer).unwrap();
result
}
Expand Down Expand Up @@ -995,9 +993,8 @@ fn serialize_values<T: ValueList + SerializeRow>(
serialized.write_to_request(&mut old_serialized);

let ctx = RowSerializationContext { columns };
<T as SerializeRow>::preliminary_type_check(&ctx).unwrap();
let mut new_serialized = vec![0, 0];
let mut writer = BufBackedRowWriter::new(&mut new_serialized);
let mut writer = RowWriter::new(&mut new_serialized);
<T as SerializeRow>::serialize(&vl, &ctx, &mut writer).unwrap();
let value_count: u16 = writer.value_count().try_into().unwrap();
let is_empty = writer.value_count() == 0;
Expand All @@ -1014,9 +1011,8 @@ fn serialize_values<T: ValueList + SerializeRow>(

fn serialize_values_only_new<T: SerializeRow>(vl: T, columns: &[ColumnSpec]) -> Vec<u8> {
let ctx = RowSerializationContext { columns };
<T as SerializeRow>::preliminary_type_check(&ctx).unwrap();
let mut serialized = vec![0, 0];
let mut writer = BufBackedRowWriter::new(&mut serialized);
let mut writer = RowWriter::new(&mut serialized);
<T as SerializeRow>::serialize(&vl, &ctx, &mut writer).unwrap();
let value_count: u16 = writer.value_count().try_into().unwrap();
let is_empty = writer.value_count() == 0;
Expand Down
5 changes: 1 addition & 4 deletions scylla-cql/src/types/serialize/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,7 @@ pub mod row;
pub mod value;
pub mod writers;

pub use writers::{
BufBackedCellValueBuilder, BufBackedCellWriter, BufBackedRowWriter, CellValueBuilder,
CellWriter, CountingCellWriter, RowWriter,
};
pub use writers::{CellValueBuilder, CellWriter, RowWriter};
#[derive(Debug, Clone, Error)]
pub struct SerializationError(Arc<dyn Error + Send + Sync>);

Expand Down
184 changes: 58 additions & 126 deletions scylla-cql/src/types/serialize/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::frame::value::{SerializedValues, ValueList};
use crate::frame::{response::result::ColumnSpec, types::RawValue};

use super::value::SerializeCql;
use super::{CellWriter, RowWriter, SerializationError};
use super::{RowWriter, SerializationError};

/// Contains information needed to serialize a row.
pub struct RowSerializationContext<'a> {
Expand All @@ -33,41 +33,22 @@ impl<'a> RowSerializationContext<'a> {
}

pub trait SerializeRow {
/// Checks if it _might_ be possible to serialize the row according to the
/// information in the context.
///
/// This function is intended to serve as an optimization in the future,
/// if we were ever to introduce prepared statements parametrized by types.
///
/// Sometimes, a row cannot be fully type checked right away without knowing
/// the exact values of the columns (e.g. when deserializing to `CqlValue`),
/// but it's fine to do full type checking later in `serialize`.
fn preliminary_type_check(ctx: &RowSerializationContext<'_>) -> Result<(), SerializationError>;

/// Serializes the row according to the information in the given context.
///
/// The function may assume that `preliminary_type_check` was called,
/// though it must not do anything unsafe if this assumption does not hold.
fn serialize<W: RowWriter>(
fn serialize(
&self,
ctx: &RowSerializationContext<'_>,
writer: &mut W,
writer: &mut RowWriter,
) -> Result<(), SerializationError>;

fn is_empty(&self) -> bool;
}

macro_rules! fallback_impl_contents {
() => {
fn preliminary_type_check(
_ctx: &RowSerializationContext<'_>,
) -> Result<(), SerializationError> {
Ok(())
}
fn serialize<W: RowWriter>(
fn serialize(
&self,
ctx: &RowSerializationContext<'_>,
writer: &mut W,
writer: &mut RowWriter,
) -> Result<(), SerializationError> {
serialize_legacy_row(self, ctx, writer)
}
Expand All @@ -80,8 +61,10 @@ macro_rules! fallback_impl_contents {

macro_rules! impl_serialize_row_for_unit {
() => {
fn preliminary_type_check(
fn serialize(
&self,
ctx: &RowSerializationContext<'_>,
_writer: &mut RowWriter,
) -> Result<(), SerializationError> {
if !ctx.columns().is_empty() {
return Err(mk_typck_err::<Self>(
Expand All @@ -91,14 +74,6 @@ macro_rules! impl_serialize_row_for_unit {
},
));
}
Ok(())
}

fn serialize<W: RowWriter>(
&self,
_ctx: &RowSerializationContext<'_>,
_writer: &mut W,
) -> Result<(), SerializationError> {
// Row is empty - do nothing
Ok(())
}
Expand All @@ -120,26 +95,10 @@ impl SerializeRow for [u8; 0] {

macro_rules! impl_serialize_row_for_slice {
() => {
fn preliminary_type_check(
ctx: &RowSerializationContext<'_>,
) -> Result<(), SerializationError> {
// While we don't know how many columns will be there during serialization,
// we can at least check that all provided columns match T.
for col in ctx.columns() {
<T as SerializeCql>::preliminary_type_check(&col.typ).map_err(|err| {
mk_typck_err::<Self>(BuiltinTypeCheckErrorKind::ColumnTypeCheckFailed {
name: col.name.clone(),
err,
})
})?;
}
Ok(())
}

fn serialize<W: RowWriter>(
fn serialize(
&self,
ctx: &RowSerializationContext<'_>,
writer: &mut W,
writer: &mut RowWriter,
) -> Result<(), SerializationError> {
if ctx.columns().len() != self.len() {
return Err(mk_typck_err::<Self>(
Expand Down Expand Up @@ -181,26 +140,10 @@ impl<T: SerializeCql> SerializeRow for Vec<T> {

macro_rules! impl_serialize_row_for_map {
() => {
fn preliminary_type_check(
ctx: &RowSerializationContext<'_>,
) -> Result<(), SerializationError> {
// While we don't know the column count or their names,
// we can go over all columns and check that their types match T.
for col in ctx.columns() {
<T as SerializeCql>::preliminary_type_check(&col.typ).map_err(|err| {
mk_typck_err::<Self>(BuiltinTypeCheckErrorKind::ColumnTypeCheckFailed {
name: col.name.clone(),
err,
})
})?;
}
Ok(())
}

fn serialize<W: RowWriter>(
fn serialize(
&self,
ctx: &RowSerializationContext<'_>,
writer: &mut W,
writer: &mut RowWriter,
) -> Result<(), SerializationError> {
// Unfortunately, column names aren't guaranteed to be unique.
// We need to track not-yet-used columns in order to see
Expand All @@ -219,8 +162,8 @@ macro_rules! impl_serialize_row_for_map {
Some(v) => {
<T as SerializeCql>::serialize(v, &col.typ, writer.make_cell_writer())
.map_err(|err| {
mk_typck_err::<Self>(
BuiltinTypeCheckErrorKind::ColumnTypeCheckFailed {
mk_ser_err::<Self>(
BuiltinSerializationErrorKind::ColumnSerializationFailed {
name: col.name.clone(),
err,
},
Expand Down Expand Up @@ -267,15 +210,11 @@ impl<T: SerializeCql, S: BuildHasher> SerializeRow for HashMap<&str, T, S> {
impl_serialize_row_for_map!();
}

impl<T: SerializeRow> SerializeRow for &T {
fn preliminary_type_check(ctx: &RowSerializationContext<'_>) -> Result<(), SerializationError> {
<T as SerializeRow>::preliminary_type_check(ctx)
}

fn serialize<W: RowWriter>(
impl<T: SerializeRow + ?Sized> SerializeRow for &T {
fn serialize(
&self,
ctx: &RowSerializationContext<'_>,
writer: &mut W,
writer: &mut RowWriter,
) -> Result<(), SerializationError> {
<T as SerializeRow>::serialize(self, ctx, writer)
}
Expand All @@ -302,34 +241,10 @@ macro_rules! impl_tuple {
$length:expr
) => {
impl<$($typs: SerializeCql),*> SerializeRow for ($($typs,)*) {
fn preliminary_type_check(
ctx: &RowSerializationContext<'_>,
) -> Result<(), SerializationError> {
match ctx.columns() {
[$($tidents),*] => {
$(
<$typs as SerializeCql>::preliminary_type_check(&$tidents.typ).map_err(|err| {
mk_typck_err::<Self>(BuiltinTypeCheckErrorKind::ColumnTypeCheckFailed {
name: $tidents.name.clone(),
err,
})
})?;
)*
}
_ => return Err(mk_typck_err::<Self>(
BuiltinTypeCheckErrorKind::WrongColumnCount {
actual: $length,
asked_for: ctx.columns().len(),
},
)),
};
Ok(())
}

fn serialize<W: RowWriter>(
fn serialize(
&self,
ctx: &RowSerializationContext<'_>,
writer: &mut W,
writer: &mut RowWriter,
) -> Result<(), SerializationError> {
let ($($tidents,)*) = match ctx.columns() {
[$($tidents),*] => ($($tidents,)*),
Expand Down Expand Up @@ -445,17 +360,10 @@ macro_rules! impl_serialize_row_via_value_list {
where
Self: $crate::frame::value::ValueList,
{
fn preliminary_type_check(
_ctx: &$crate::types::serialize::row::RowSerializationContext<'_>,
) -> ::std::result::Result<(), $crate::types::serialize::SerializationError> {
// No-op - the old interface didn't offer type safety
::std::result::Result::Ok(())
}

fn serialize<W: $crate::types::serialize::writers::RowWriter>(
fn serialize(
&self,
ctx: &$crate::types::serialize::row::RowSerializationContext<'_>,
writer: &mut W,
writer: &mut $crate::types::serialize::writers::RowWriter,
) -> ::std::result::Result<(), $crate::types::serialize::SerializationError> {
$crate::types::serialize::row::serialize_legacy_row(self, ctx, writer)
}
Expand Down Expand Up @@ -492,7 +400,7 @@ macro_rules! impl_serialize_row_via_value_list {
pub fn serialize_legacy_row<T: ValueList>(
r: &T,
ctx: &RowSerializationContext<'_>,
writer: &mut impl RowWriter,
writer: &mut RowWriter,
) -> Result<(), SerializationError> {
let serialized =
<T as ValueList>::serialized(r).map_err(|err| SerializationError(Arc::new(err)))?;
Expand Down Expand Up @@ -596,12 +504,6 @@ pub enum BuiltinTypeCheckErrorKind {

/// A value required by the statement is not provided by the Rust type.
ColumnMissingForValue { name: String },

/// One of the columns failed to type check.
ColumnTypeCheckFailed {
name: String,
err: SerializationError,
},
}

impl Display for BuiltinTypeCheckErrorKind {
Expand All @@ -622,9 +524,6 @@ impl Display for BuiltinTypeCheckErrorKind {
"value for column {name} was provided, but there is no bind marker for this column in the query"
)
}
BuiltinTypeCheckErrorKind::ColumnTypeCheckFailed { name, err } => {
write!(f, "failed to check column {name}: {err}")
}
}
}
}
Expand Down Expand Up @@ -660,7 +559,7 @@ pub enum ValueListToSerializeRowAdapterError {
mod tests {
use crate::frame::response::result::{ColumnSpec, ColumnType, TableSpec};
use crate::frame::value::{MaybeUnset, SerializedValues, ValueList};
use crate::types::serialize::BufBackedRowWriter;
use crate::types::serialize::RowWriter;

use super::{RowSerializationContext, SerializeRow};

Expand Down Expand Up @@ -688,7 +587,7 @@ mod tests {
<_ as ValueList>::write_to_request(&row, &mut legacy_data).unwrap();

let mut new_data = Vec::new();
let mut new_data_writer = BufBackedRowWriter::new(&mut new_data);
let mut new_data_writer = RowWriter::new(&mut new_data);
let ctx = RowSerializationContext {
columns: &[
col_spec("a", ColumnType::Int),
Expand Down Expand Up @@ -725,7 +624,7 @@ mod tests {
unsorted_row.add_named_value("c", &None::<i64>).unwrap();

let mut unsorted_row_data = Vec::new();
let mut unsorted_row_data_writer = BufBackedRowWriter::new(&mut unsorted_row_data);
let mut unsorted_row_data_writer = RowWriter::new(&mut unsorted_row_data);
let ctx = RowSerializationContext {
columns: &[
col_spec("a", ColumnType::Int),
Expand All @@ -740,4 +639,37 @@ mod tests {
// Skip the value count
assert_eq!(&sorted_row_data[2..], unsorted_row_data);
}

#[test]
fn test_dyn_serialize_row() {
let row = (
1i32,
"Ala ma kota",
None::<i64>,
MaybeUnset::Unset::<String>,
);
let ctx = RowSerializationContext {
columns: &[
col_spec("a", ColumnType::Int),
col_spec("b", ColumnType::Text),
col_spec("c", ColumnType::BigInt),
col_spec("d", ColumnType::Ascii),
],
};

let mut typed_data = Vec::new();
let mut typed_data_writer = RowWriter::new(&mut typed_data);
<_ as SerializeRow>::serialize(&row, &ctx, &mut typed_data_writer).unwrap();

let row = &row as &dyn SerializeRow;
let mut erased_data = Vec::new();
let mut erased_data_writer = RowWriter::new(&mut erased_data);
<_ as SerializeRow>::serialize(&row, &ctx, &mut erased_data_writer).unwrap();

assert_eq!(
typed_data_writer.value_count(),
erased_data_writer.value_count(),
);
assert_eq!(typed_data, erased_data);
}
}
Loading

0 comments on commit 6c9ec8f

Please sign in to comment.