Skip to content

Commit

Permalink
refactor: impl ReferenceSerialization (#230)
Browse files Browse the repository at this point in the history
* refactor: impl `ReferenceSerialization`

* refactor: ColumnId -> Ulid

* fix: column miss match on `Insert`
  • Loading branch information
KKould authored Oct 17, 2024
1 parent bce7cd0 commit d584167
Show file tree
Hide file tree
Showing 95 changed files with 2,166 additions and 2,164 deletions.
4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,15 @@ regex = { version = "1" }
rocksdb = { version = "0.22.0" }
rust_decimal = { version = "1" }
serde = { version = "1", features = ["derive", "rc"] }
serde_macros = { path = "serde_macros" }
siphasher = { version = "1", features = ["serde"] }
sqlparser = { version = "0.34", features = ["serde"] }
strum_macros = { version = "0.26.2" }
thiserror = { version = "1" }
tokio = { version = "1.36", features = ["full"], optional = true }
tracing = { version = "0.1" }
typetag = { version = "0.2" }
ulid = { version = "1", features = ["serde"] }

[dev-dependencies]
cargo-tarpaulin = { version = "0.27" }
Expand All @@ -83,7 +85,7 @@ pprof = { version = "0.13", features = ["flamegraph", "criterion"] }
members = [
"tests/sqllogictest",
"tests/macros-test"
]
, "serde_macros"]

[profile.release]
lto = true
14 changes: 14 additions & 0 deletions serde_macros/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[package]
name = "serde_macros"
version = "0.1.0"
edition = "2021"

[dependencies]
darling = "0.20"
proc-macro2 = "1"
quote = "1"
syn = "2"

[lib]
path = "src/lib.rs"
proc-macro = true
15 changes: 15 additions & 0 deletions serde_macros/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
mod reference_serialization;

use proc_macro::TokenStream;
use syn::{parse_macro_input, DeriveInput};

#[proc_macro_derive(ReferenceSerialization, attributes(reference_serialization))]
pub fn reference_serialization(input: TokenStream) -> TokenStream {
let ast = parse_macro_input!(input as DeriveInput);

let result = reference_serialization::handle(ast);
match result {
Ok(codegen) => codegen.into(),
Err(e) => e.to_compile_error().into(),
}
}
205 changes: 205 additions & 0 deletions serde_macros/src/reference_serialization.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
use darling::ast::Data;
use darling::{FromDeriveInput, FromField, FromVariant};
use proc_macro2::{Ident, Span, TokenStream};
use quote::quote;
use syn::{
AngleBracketedGenericArguments, DeriveInput, Error, GenericArgument, PathArguments, Type,
TypePath,
};

#[derive(Debug, FromDeriveInput)]
#[darling(attributes(record))]
struct SerializationOpts {
ident: Ident,
data: Data<SerializationVariantOpts, SerializationFieldOpt>,
}

#[derive(Debug, FromVariant)]
#[darling(attributes(record))]
struct SerializationVariantOpts {
ident: Ident,
fields: darling::ast::Fields<SerializationFieldOpt>,
}

#[derive(Debug, FromField)]
#[darling(attributes(record))]
struct SerializationFieldOpt {
ident: Option<Ident>,
ty: Type,
}

fn process_type(ty: &Type) -> TokenStream {
if let Type::Path(TypePath { path, .. }) = ty {
let ident = &path.segments.last().unwrap().ident;

match ident.to_string().as_str() {
"Vec" | "Option" | "Arc" | "Box" | "PhantomData" | "Bound" | "CountMinSketch" => {
if let PathArguments::AngleBracketed(AngleBracketedGenericArguments {
args, ..
}) = &path.segments.last().unwrap().arguments
{
if let Some(GenericArgument::Type(inner_ty)) = args.first() {
let inner_processed = process_type(inner_ty);

return quote! {
#ident::<#inner_processed>
};
}
}
}
_ => {}
}

quote! { #ty }
} else {
quote! { #ty }
}
}

pub(crate) fn handle(ast: DeriveInput) -> Result<TokenStream, Error> {
let record_opts: SerializationOpts = SerializationOpts::from_derive_input(&ast)?;
let struct_name = &record_opts.ident;

Ok(match record_opts.data {
Data::Struct(data_struct) => {
let mut encode_fields: Vec<TokenStream> = Vec::new();
let mut decode_fields: Vec<TokenStream> = Vec::new();
let mut init_fields: Vec<TokenStream> = Vec::new();
let mut is_tuple = false;

for (i, field_opts) in data_struct.fields.into_iter().enumerate() {
is_tuple = is_tuple || field_opts.ident.is_none();

let field_name = field_opts
.ident
.unwrap_or_else(|| Ident::new(&format!("filed_{}", i), Span::call_site()));
let ty = process_type(&field_opts.ty);

encode_fields.push(quote! {
#field_name.encode(writer, is_direct, reference_tables)?;
});
decode_fields.push(quote! {
let #field_name = #ty::decode(reader, drive, reference_tables)?;
});
init_fields.push(quote! {
#field_name,
})
}
let init_stream = if is_tuple {
quote! { #struct_name ( #(#init_fields)* ) }
} else {
quote! { #struct_name { #(#init_fields)* } }
};

quote! {
impl crate::serdes::ReferenceSerialization for #struct_name {
fn encode<W: std::io::Write>(
&self,
writer: &mut W,
is_direct: bool,
reference_tables: &mut crate::serdes::ReferenceTables,
) -> Result<(), crate::errors::DatabaseError> {
let #init_stream = self;

#(#encode_fields)*

Ok(())
}

fn decode<T: crate::storage::Transaction, R: std::io::Read>(
reader: &mut R,
drive: Option<(&T, &crate::storage::TableCache)>,
reference_tables: &crate::serdes::ReferenceTables,
) -> Result<Self, crate::errors::DatabaseError> {
#(#decode_fields)*

Ok(#init_stream)
}
}
}
}
Data::Enum(data_enum) => {
let mut variant_encode_fields: Vec<TokenStream> = Vec::new();
let mut variant_decode_fields: Vec<TokenStream> = Vec::new();

for (i, variant_opts) in data_enum.into_iter().enumerate() {
let i = i as u8;
let mut encode_fields: Vec<TokenStream> = Vec::new();
let mut decode_fields: Vec<TokenStream> = Vec::new();
let mut init_fields: Vec<TokenStream> = Vec::new();
let enum_name = variant_opts.ident;
let mut is_tuple = false;

for (i, field_opts) in variant_opts.fields.into_iter().enumerate() {
is_tuple = is_tuple || field_opts.ident.is_none();

let field_name = field_opts
.ident
.unwrap_or_else(|| Ident::new(&format!("filed_{}", i), Span::call_site()));
let ty = process_type(&field_opts.ty);

encode_fields.push(quote! {
#field_name.encode(writer, is_direct, reference_tables)?;
});
decode_fields.push(quote! {
let #field_name = #ty::decode(reader, drive, reference_tables)?;
});
init_fields.push(quote! {
#field_name,
})
}

let init_stream = if is_tuple {
quote! { #struct_name::#enum_name ( #(#init_fields)* ) }
} else {
quote! { #struct_name::#enum_name { #(#init_fields)* } }
};
variant_encode_fields.push(quote! {
#init_stream => {
std::io::Write::write_all(writer, &[#i])?;

#(#encode_fields)*
}
});
variant_decode_fields.push(quote! {
#i => {
#(#decode_fields)*

#init_stream
}
});
}

quote! {
impl crate::serdes::ReferenceSerialization for #struct_name {
fn encode<W: std::io::Write>(
&self,
writer: &mut W,
is_direct: bool,
reference_tables: &mut crate::serdes::ReferenceTables,
) -> Result<(), crate::errors::DatabaseError> {
match self {
#(#variant_encode_fields)*
}

Ok(())
}

fn decode<T: crate::storage::Transaction, R: std::io::Read>(
reader: &mut R,
drive: Option<(&T, &crate::storage::TableCache)>,
reference_tables: &crate::serdes::ReferenceTables,
) -> Result<Self, crate::errors::DatabaseError> {
let mut type_bytes = [0u8; 1];
std::io::Read::read_exact(reader, &mut type_bytes)?;

Ok(match type_bytes[0] {
#(#variant_decode_fields)*
_ => unreachable!(),
})
}
}
}
}
})
}
30 changes: 26 additions & 4 deletions src/binder/copy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,45 @@ use std::path::PathBuf;
use std::str::FromStr;
use std::sync::Arc;

use super::*;
use crate::errors::DatabaseError;
use crate::planner::operator::copy_from_file::CopyFromFileOperator;
use crate::planner::operator::copy_to_file::CopyToFileOperator;
use crate::planner::operator::Operator;
use serde::{Deserialize, Serialize};
use serde_macros::ReferenceSerialization;
use sqlparser::ast::{CopyOption, CopySource, CopyTarget};

use super::*;

#[derive(Debug, PartialEq, PartialOrd, Ord, Hash, Eq, Clone, Serialize, Deserialize)]
#[derive(
Debug,
PartialEq,
PartialOrd,
Ord,
Hash,
Eq,
Clone,
Serialize,
Deserialize,
ReferenceSerialization,
)]
pub struct ExtSource {
pub path: PathBuf,
pub format: FileFormat,
}

/// File format.
#[derive(Debug, PartialEq, PartialOrd, Ord, Hash, Eq, Clone, Serialize, Deserialize)]
#[derive(
Debug,
PartialEq,
PartialOrd,
Ord,
Hash,
Eq,
Clone,
Serialize,
Deserialize,
ReferenceSerialization,
)]
pub enum FileFormat {
Csv {
/// Delimiter to parse.
Expand Down
16 changes: 8 additions & 8 deletions src/binder/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@ use std::slice;
use std::sync::Arc;

use super::{lower_ident, Binder, BinderContext, QueryBindStep, SubQueryType};
use crate::expression::function::scala::ScalarFunction;
use crate::expression::function::table::TableFunction;
use crate::expression::function::scala::{ArcScalarFunctionImpl, ScalarFunction};
use crate::expression::function::table::{ArcTableFunctionImpl, TableFunction};
use crate::expression::function::FunctionSummary;
use crate::expression::{AliasType, ScalarExpression};
use crate::planner::LogicalPlan;
use crate::storage::Transaction;
use crate::types::value::{DataValue, Utf8Type};
use crate::types::LogicalType;
use crate::types::{ColumnId, LogicalType};

macro_rules! try_alias {
($context:expr, $full_name:expr) => {
Expand Down Expand Up @@ -231,11 +231,11 @@ impl<'a, 'b, T: Transaction> Binder<'a, 'b, T> {
sub_query: LogicalPlan,
) -> Result<(ScalarExpression, LogicalPlan), DatabaseError> {
let mut alias_column = ColumnCatalog::clone(&column);
alias_column.set_ref_table(self.context.temp_table(), 0);
alias_column.set_ref_table(self.context.temp_table(), ColumnId::new());

let alias_expr = ScalarExpression::Alias {
expr: Box::new(ScalarExpression::ColumnRef(column)),
alias: AliasType::Expr(Box::new(ScalarExpression::ColumnRef(Arc::new(
alias: AliasType::Expr(Box::new(ScalarExpression::ColumnRef(ColumnRef::from(
alias_column,
)))),
};
Expand All @@ -246,7 +246,7 @@ impl<'a, 'b, T: Transaction> Binder<'a, 'b, T> {
fn bind_subquery(
&mut self,
subquery: &Query,
) -> Result<(LogicalPlan, Arc<ColumnCatalog>), DatabaseError> {
) -> Result<(LogicalPlan, ColumnRef), DatabaseError> {
let BinderContext {
table_cache,
transaction,
Expand Down Expand Up @@ -601,13 +601,13 @@ impl<'a, 'b, T: Transaction> Binder<'a, 'b, T> {
if let Some(function) = self.context.scala_functions.get(&summary) {
return Ok(ScalarExpression::ScalaFunction(ScalarFunction {
args,
inner: function.clone(),
inner: ArcScalarFunctionImpl(function.clone()),
}));
}
if let Some(function) = self.context.table_functions.get(&summary) {
return Ok(ScalarExpression::TableFunction(TableFunction {
args,
inner: function.clone(),
inner: ArcTableFunctionImpl(function.clone()),
}));
}

Expand Down
2 changes: 2 additions & 0 deletions src/binder/insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ impl<'a, 'b, T: Transaction> Binder<'a, 'b, T> {
idents: &[Ident],
expr_rows: &Vec<Vec<Expr>>,
is_overwrite: bool,
is_mapping_by_name: bool,
) -> Result<LogicalPlan, DatabaseError> {
// FIXME: Make it better to detect the current BindStep
self.context.allow_default = true;
Expand Down Expand Up @@ -97,6 +98,7 @@ impl<'a, 'b, T: Transaction> Binder<'a, 'b, T> {
Operator::Insert(InsertOperator {
table_name,
is_overwrite,
is_mapping_by_name,
}),
vec![values_plan],
))
Expand Down
Loading

0 comments on commit d584167

Please sign in to comment.