From b83db3d3bad9c8a437ea3aba0c97f1479a86e6a3 Mon Sep 17 00:00:00 2001 From: Kould Date: Wed, 28 Aug 2024 01:07:25 +0800 Subject: [PATCH] chore: clean macro --- Cargo.toml | 1 + README.md | 4 +- examples/hello_world.rs | 2 - src/catalog/column.rs | 12 +- src/catalog/mod.rs | 4 +- src/catalog/table.rs | 4 +- src/db.rs | 33 +-- src/expression/function/mod.rs | 4 +- src/function/current_date.rs | 47 +++- src/function/numbers.rs | 83 ++++++- src/lib.rs | 4 +- src/macros/mod.rs | 221 ++++++++++++++++++ src/marcos/mod.rs | 414 --------------------------------- tests/macros-test/Cargo.toml | 11 + tests/macros-test/src/main.rs | 196 ++++++++++++++++ 15 files changed, 566 insertions(+), 474 deletions(-) create mode 100644 src/macros/mod.rs delete mode 100644 src/marcos/mod.rs create mode 100644 tests/macros-test/Cargo.toml create mode 100644 tests/macros-test/src/main.rs diff --git a/Cargo.toml b/Cargo.toml index c557485d..26ca3577 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -81,4 +81,5 @@ pprof = { version = "0.13", features = ["flamegraph", "criterion"] } [workspace] members = [ "tests/sqllogictest", + "tests/macros-test" ] \ No newline at end of file diff --git a/README.md b/README.md index 0ab5ce35..624ee3b7 100755 --- a/README.md +++ b/README.md @@ -91,7 +91,7 @@ kould23333/fncksql:latest ~~~ ### Features -- ORM Mapping: `features = ["marcos"]` +- ORM Mapping: `features = ["macros"]` ```rust #[derive(Default, Debug, PartialEq)] struct MyStruct { @@ -114,7 +114,7 @@ implement_from_tuple!( ) ); ``` -- User-Defined Function: `features = ["marcos"]` +- User-Defined Function: `features = ["macros"]` ```rust function!(TestFunction::test(LogicalType::Integer, LogicalType::Integer) -> LogicalType::Integer => |v1: ValueRef, v2: ValueRef| { let plus_binary_evaluator = EvaluatorFactory::binary_create(LogicalType::Integer, BinaryOperator::Plus)?; diff --git a/examples/hello_world.rs b/examples/hello_world.rs index ea72f426..a0f36b6c 100644 --- a/examples/hello_world.rs +++ b/examples/hello_world.rs @@ -1,9 +1,7 @@ use fnck_sql::db::DataBaseBuilder; use fnck_sql::errors::DatabaseError; use fnck_sql::implement_from_tuple; -use fnck_sql::types::tuple::{SchemaRef, Tuple}; use fnck_sql::types::value::DataValue; -use fnck_sql::types::LogicalType; use itertools::Itertools; #[derive(Default, Debug, PartialEq)] diff --git a/src/catalog/column.rs b/src/catalog/column.rs index fe3e4d68..13c940f8 100644 --- a/src/catalog/column.rs +++ b/src/catalog/column.rs @@ -27,11 +27,7 @@ pub struct ColumnSummary { } impl ColumnCatalog { - pub(crate) fn new( - column_name: String, - nullable: bool, - column_desc: ColumnDesc, - ) -> ColumnCatalog { + pub fn new(column_name: String, nullable: bool, column_desc: ColumnDesc) -> ColumnCatalog { ColumnCatalog { summary: ColumnSummary { id: None, @@ -87,6 +83,10 @@ impl ColumnCatalog { self.summary.name = name; } + pub fn set_id(&mut self, id: ColumnId) { + self.summary.id = Some(id); + } + pub fn set_table_name(&mut self, table_name: TableName) { self.summary.table_name = Some(table_name); } @@ -119,7 +119,7 @@ pub struct ColumnDesc { } impl ColumnDesc { - pub(crate) const fn new( + pub const fn new( column_datatype: LogicalType, is_primary: bool, is_unique: bool, diff --git a/src/catalog/mod.rs b/src/catalog/mod.rs index f1f1fe83..27e63f94 100644 --- a/src/catalog/mod.rs +++ b/src/catalog/mod.rs @@ -3,5 +3,5 @@ pub(crate) use self::column::*; pub(crate) use self::table::*; -mod column; -mod table; +pub mod column; +pub mod table; diff --git a/src/catalog/table.rs b/src/catalog/table.rs index 752c8103..bfa30bd6 100644 --- a/src/catalog/table.rs +++ b/src/catalog/table.rs @@ -64,7 +64,7 @@ impl TableCatalog { self.indexes.iter() } - pub(crate) fn schema_ref(&self) -> &SchemaRef { + pub fn schema_ref(&self) -> &SchemaRef { &self.schema_ref } @@ -139,7 +139,7 @@ impl TableCatalog { Ok(self.indexes.last().unwrap()) } - pub(crate) fn new( + pub fn new( name: TableName, columns: Vec, ) -> Result { diff --git a/src/db.rs b/src/db.rs index 670f327c..bf0f4eec 100644 --- a/src/db.rs +++ b/src/db.rs @@ -303,20 +303,12 @@ impl DBTransaction<'_, S> { #[cfg(test)] mod test { - use crate::catalog::{ColumnCatalog, ColumnDesc, ColumnRef}; + use crate::catalog::{ColumnCatalog, ColumnDesc}; use crate::db::{DataBaseBuilder, DatabaseError}; - use crate::expression::function::scala::{FuncMonotonicity, ScalarFunctionImpl}; - use crate::expression::function::FunctionSummary; - use crate::expression::ScalarExpression; - use crate::expression::{BinaryOperator, UnaryOperator}; - use crate::scala_function; use crate::storage::{Storage, TableCache, Transaction}; - use crate::types::evaluator::EvaluatorFactory; - use crate::types::tuple::{create_table, Tuple}; - use crate::types::value::{DataValue, ValueRef}; + use crate::types::tuple::create_table; + use crate::types::value::DataValue; use crate::types::LogicalType; - use serde::Deserialize; - use serde::Serialize; use std::sync::Arc; use tempfile::TempDir; @@ -356,25 +348,12 @@ mod test { Ok(()) } - scala_function!(TestFunction::test(LogicalType::Integer, LogicalType::Integer) -> LogicalType::Integer => (|v1: ValueRef, v2: ValueRef| { - let plus_binary_evaluator = EvaluatorFactory::binary_create(LogicalType::Integer, BinaryOperator::Plus)?; - let value = plus_binary_evaluator.binary_eval(&v1, &v2); - - let plus_unary_evaluator = EvaluatorFactory::unary_create(LogicalType::Integer, UnaryOperator::Minus)?; - Ok(plus_unary_evaluator.unary_eval(&value)) - })); - + /// use [CurrentDate](crate::function::current_date::CurrentDate) on this case #[test] fn test_udf() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); - let fnck_sql = DataBaseBuilder::path(temp_dir.path()) - .register_scala_function(TestFunction::new()) - .build()?; - let _ = fnck_sql - .run("CREATE TABLE test (id int primary key, c1 int, c2 int default test(1, 2));")?; - let _ = fnck_sql - .run("INSERT INTO test VALUES (1, 2, 2), (0, 1, 1), (2, 1, 1), (3, 3, default);")?; - let (schema, tuples) = fnck_sql.run("select test(c1, 1), c2 from test")?; + let fnck_sql = DataBaseBuilder::path(temp_dir.path()).build()?; + let (schema, tuples) = fnck_sql.run("select current_date()")?; println!("{}", create_table(&schema, &tuples)); Ok(()) diff --git a/src/expression/function/mod.rs b/src/expression/function/mod.rs index 352a14b2..85e515b0 100644 --- a/src/expression/function/mod.rs +++ b/src/expression/function/mod.rs @@ -6,6 +6,6 @@ pub mod table; #[derive(Debug, Eq, PartialEq, Hash, Clone, Serialize, Deserialize)] pub struct FunctionSummary { - pub(crate) name: String, - pub(crate) arg_types: Vec, + pub name: String, + pub arg_types: Vec, } diff --git a/src/function/current_date.rs b/src/function/current_date.rs index eaeb4f6e..7e7a4d85 100644 --- a/src/function/current_date.rs +++ b/src/function/current_date.rs @@ -4,7 +4,6 @@ use crate::expression::function::scala::FuncMonotonicity; use crate::expression::function::scala::ScalarFunctionImpl; use crate::expression::function::FunctionSummary; use crate::expression::ScalarExpression; -use crate::scala_function; use crate::types::tuple::Tuple; use crate::types::value::DataValue; use crate::types::LogicalType; @@ -13,6 +12,46 @@ use serde::Deserialize; use serde::Serialize; use std::sync::Arc; -scala_function!(CurrentDate::current_date() -> LogicalType::Date => (|| { - Ok(DataValue::Date32(Some(Local::now().num_days_from_ce()))) -})); +#[derive(Debug, Serialize, Deserialize)] +pub(crate) struct CurrentDate { + summary: FunctionSummary, +} + +impl CurrentDate { + #[allow(unused_mut)] + pub(crate) fn new() -> Arc { + let function_name = "current_date".to_lowercase(); + + Arc::new(Self { + summary: FunctionSummary { + name: function_name, + arg_types: Vec::new(), + }, + }) + } +} + +#[typetag::serde] +impl ScalarFunctionImpl for CurrentDate { + #[allow(unused_variables, clippy::redundant_closure_call)] + fn eval( + &self, + _: &[ScalarExpression], + _: &Tuple, + _: &[ColumnRef], + ) -> Result { + Ok(DataValue::Date32(Some(Local::now().num_days_from_ce()))) + } + + fn monotonicity(&self) -> Option { + todo!() + } + + fn return_type(&self) -> &LogicalType { + &LogicalType::Date + } + + fn summary(&self) -> &FunctionSummary { + &self.summary + } +} diff --git a/src/function/numbers.rs b/src/function/numbers.rs index 3cbc3756..4ee0e036 100644 --- a/src/function/numbers.rs +++ b/src/function/numbers.rs @@ -5,24 +5,85 @@ use crate::errors::DatabaseError; use crate::expression::function::table::TableFunctionImpl; use crate::expression::function::FunctionSummary; use crate::expression::ScalarExpression; -use crate::table_function; use crate::types::tuple::SchemaRef; use crate::types::tuple::Tuple; -use crate::types::value::{DataValue, ValueRef}; +use crate::types::value::DataValue; use crate::types::LogicalType; use lazy_static::lazy_static; use serde::Deserialize; use serde::Serialize; use std::sync::Arc; -table_function!(Numbers::numbers(LogicalType::Integer) -> [number: LogicalType::Integer] => (|v1: ValueRef| { - let num = v1.i32().ok_or_else(|| DatabaseError::NotNull)?; +lazy_static! { + static ref NUMBERS: TableCatalog = { + TableCatalog::new( + Arc::new("numbers".to_lowercase()), + vec![ColumnCatalog::new( + "number".to_lowercase(), + true, + ColumnDesc::new(LogicalType::Integer, false, false, None), + )], + ) + .unwrap() + }; +} - Ok(Box::new((0..num) - .map(|i| Ok(Tuple { +#[derive(Debug, Serialize, Deserialize)] +pub(crate) struct Numbers { + summary: FunctionSummary, +} + +impl Numbers { + #[allow(unused_mut)] + pub(crate) fn new() -> Arc { + let function_name = "numbers".to_lowercase(); + + Arc::new(Self { + summary: FunctionSummary { + name: function_name, + arg_types: vec![LogicalType::Integer], + }, + }) + } +} + +#[typetag::serde] +impl TableFunctionImpl for Numbers { + #[allow(unused_variables, clippy::redundant_closure_call)] + fn eval( + &self, + args: &[ScalarExpression], + ) -> Result>>, DatabaseError> { + let tuple = Tuple { id: None, - values: vec![ - Arc::new(DataValue::Int32(Some(i))), - ] - }))) as Box>>) -})); + values: Vec::new(), + }; + + let mut value = args[0].eval(&tuple, &[])?; + + if value.logical_type() != LogicalType::Integer { + value = Arc::new(DataValue::clone(&value).cast(&LogicalType::Integer)?); + } + let num = value.i32().ok_or_else(|| DatabaseError::NotNull)?; + + Ok(Box::new((0..num).map(|i| { + Ok(Tuple { + id: None, + values: vec![Arc::new(DataValue::Int32(Some(i)))], + }) + })) + as Box>>) + } + + fn output_schema(&self) -> &SchemaRef { + NUMBERS.schema_ref() + } + + fn summary(&self) -> &FunctionSummary { + &self.summary + } + + fn table(&self) -> &'static TableCatalog { + &NUMBERS + } +} diff --git a/src/lib.rs b/src/lib.rs index f3c3544a..c0ed8b05 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -66,7 +66,7 @@ //! ) //! ); //! -//! #[cfg(feature = "marcos")] +//! #[cfg(feature = "macros")] //! fn main() -> Result<(), DatabaseError> { //! let database = DataBaseBuilder::path("./hello_world").build()?; //! @@ -104,7 +104,7 @@ pub mod execution; pub mod expression; mod function; #[cfg(feature = "marcos")] -pub mod marcos; +pub mod macros; mod optimizer; pub mod parser; pub mod planner; diff --git a/src/macros/mod.rs b/src/macros/mod.rs new file mode 100644 index 00000000..2e35fe7f --- /dev/null +++ b/src/macros/mod.rs @@ -0,0 +1,221 @@ +/// # Examples +/// +/// ``` +///struct MyStruct { +/// c1: i32, +/// c2: String, +///} +/// +///implement_from_tuple!( +/// MyStruct, ( +/// c1: i32 => |inner: &mut MyStruct, value| { +/// if let DataValue::Int32(Some(val)) = value { +/// inner.c1 = val; +/// } +/// }, +/// c2: String => |inner: &mut MyStruct, value| { +/// if let DataValue::Utf8(Some(val)) = value { +/// inner.c2 = val; +/// } +/// } +/// ) +/// ); +/// ``` +#[macro_export] +macro_rules! implement_from_tuple { + ($struct_name:ident, ($($field_name:ident : $field_type:ty => $closure:expr),+)) => { + impl From<(&::fnck_sql::types::tuple::SchemaRef, ::fnck_sql::types::tuple::Tuple)> for $struct_name { + fn from((schema, tuple): (&::fnck_sql::types::tuple::SchemaRef, ::fnck_sql::types::tuple::Tuple)) -> Self { + fn try_get(tuple: &::fnck_sql::types::tuple::Tuple, schema: &::fnck_sql::types::tuple::SchemaRef, field_name: &str) -> Option<::fnck_sql::types::value::DataValue> { + let ty = ::fnck_sql::types::LogicalType::type_trans::()?; + let (idx, _) = schema + .iter() + .enumerate() + .find(|(_, col)| col.name() == field_name)?; + + ::fnck_sql::types::value::DataValue::clone(&tuple.values[idx]) + .cast(&ty) + .ok() + } + + let mut struct_instance = $struct_name::default(); + $( + if let Some(value) = try_get::<$field_type>(&tuple, schema, stringify!($field_name)) { + $closure( + &mut struct_instance, + value + ); + } + )+ + struct_instance + } + } + }; +} + +/// # Examples +/// +/// ``` +/// scala_function!(MyFunction::sum(LogicalType::Integer, LogicalType::Integer) -> LogicalType::Integer => |v1: ValueRef, v2: ValueRef| { +/// DataValue::binary_op(&v1, &v2, &BinaryOperator::Plus) +/// }); +/// +/// let fnck_sql = DataBaseBuilder::path("./example") +/// .register_scala_function(TestFunction::new()) +/// .build() +/// ?; +/// ``` +#[macro_export] +macro_rules! scala_function { + ($struct_name:ident::$function_name:ident($($arg_ty:expr),*) -> $return_ty:expr => $closure:expr) => { + #[derive(Debug, ::serde::Serialize, ::serde::Deserialize)] + pub(crate) struct $struct_name { + summary: ::fnck_sql::expression::function::FunctionSummary + } + + impl $struct_name { + #[allow(unused_mut)] + pub(crate) fn new() -> Arc { + let function_name = stringify!($function_name).to_lowercase(); + + let mut arg_types = Vec::new(); + $({ + arg_types.push($arg_ty); + })* + + Arc::new(Self { + summary: ::fnck_sql::expression::function::FunctionSummary { + name: function_name, + arg_types + } + }) + } + } + + #[typetag::serde] + impl ::fnck_sql::expression::function::scala::ScalarFunctionImpl for $struct_name { + #[allow(unused_variables, clippy::redundant_closure_call)] + fn eval(&self, args: &[::fnck_sql::expression::ScalarExpression], tuple: &::fnck_sql::types::tuple::Tuple, schema: &[::fnck_sql::catalog::column::ColumnRef]) -> Result<::fnck_sql::types::value::DataValue, ::fnck_sql::errors::DatabaseError> { + let mut _index = 0; + + $closure($({ + let mut value = args[_index].eval(tuple, schema)?; + _index += 1; + + if value.logical_type() != $arg_ty { + value = Arc::new(::fnck_sql::types::value::DataValue::clone(&value).cast(&$arg_ty)?); + } + value + }, )*) + } + + fn monotonicity(&self) -> Option<::fnck_sql::expression::function::scala::FuncMonotonicity> { + todo!() + } + + fn return_type(&self) -> &::fnck_sql::types::LogicalType { + &$return_ty + } + + fn summary(&self) -> &::fnck_sql::expression::function::FunctionSummary { + &self.summary + } + } + }; +} + +/// # Examples +/// +/// ``` +/// table_function!(MyTableFunction::test_numbers(LogicalType::Integer) -> [c1: LogicalType::Integer, c2: LogicalType::Integer] => (|v1: ValueRef| { +/// let num = v1.i32().unwrap(); +/// +/// Ok(Box::new((0..num) +/// .into_iter() +/// .map(|i| Ok(Tuple { +/// id: None, +/// values: vec![ +/// Arc::new(DataValue::Int32(Some(i))), +/// Arc::new(DataValue::Int32(Some(i))), +/// ] +/// }))) as Box>>) +/// })); +/// +/// let fnck_sql = DataBaseBuilder::path("./example") +/// .register_table_function(MyTableFunction::new()) +/// .build() +/// ?; +/// ``` +#[macro_export] +macro_rules! table_function { + ($struct_name:ident::$function_name:ident($($arg_ty:expr),*) -> [$($output_name:ident: $output_ty:expr),*] => $closure:expr) => { + ::lazy_static::lazy_static! { + static ref $function_name: ::fnck_sql::catalog::table::TableCatalog = { + let mut columns = Vec::new(); + + $({ + columns.push(::fnck_sql::catalog::column::ColumnCatalog::new(stringify!($output_name).to_lowercase(), true, ::fnck_sql::catalog::column::ColumnDesc::new($output_ty, false, false, None))); + })* + ::fnck_sql::catalog::table::TableCatalog::new(Arc::new(stringify!($function_name).to_lowercase()), columns).unwrap() + }; + } + + #[derive(Debug, ::serde::Serialize, ::serde::Deserialize)] + pub(crate) struct $struct_name { + summary: ::fnck_sql::expression::function::FunctionSummary + } + + impl $struct_name { + #[allow(unused_mut)] + pub(crate) fn new() -> Arc { + let function_name = stringify!($function_name).to_lowercase(); + + let mut arg_types = Vec::new(); + $({ + arg_types.push($arg_ty); + })* + + Arc::new(Self { + summary: ::fnck_sql::expression::function::FunctionSummary { + name: function_name, + arg_types + } + }) + } + } + + #[typetag::serde] + impl ::fnck_sql::expression::function::table::TableFunctionImpl for $struct_name { + #[allow(unused_variables, clippy::redundant_closure_call)] + fn eval(&self, args: &[::fnck_sql::expression::ScalarExpression]) -> Result>>, ::fnck_sql::errors::DatabaseError> { + let mut _index = 0; + let tuple = ::fnck_sql::types::tuple::Tuple { + id: None, + values: Vec::new(), + }; + + $closure($({ + let mut value = args[_index].eval(&tuple, &[])?; + _index += 1; + + if value.logical_type() != $arg_ty { + value = Arc::new(::fnck_sql::types::value::DataValue::clone(&value).cast(&$arg_ty)?); + } + value + }, )*) + } + + fn output_schema(&self) -> &::fnck_sql::types::tuple::SchemaRef { + $function_name.schema_ref() + } + + fn summary(&self) -> &::fnck_sql::expression::function::FunctionSummary { + &self.summary + } + + fn table(&self) -> &'static ::fnck_sql::catalog::table::TableCatalog { + &$function_name + } + } + }; +} diff --git a/src/marcos/mod.rs b/src/marcos/mod.rs deleted file mode 100644 index d85dd246..00000000 --- a/src/marcos/mod.rs +++ /dev/null @@ -1,414 +0,0 @@ -/// # Examples -/// -/// ``` -///struct MyStruct { -/// c1: i32, -/// c2: String, -///} -/// -///implement_from_tuple!( -/// MyStruct, ( -/// c1: i32 => |inner: &mut MyStruct, value| { -/// if let DataValue::Int32(Some(val)) = value { -/// inner.c1 = val; -/// } -/// }, -/// c2: String => |inner: &mut MyStruct, value| { -/// if let DataValue::Utf8(Some(val)) = value { -/// inner.c2 = val; -/// } -/// } -/// ) -/// ); -/// ``` -#[macro_export] -macro_rules! implement_from_tuple { - ($struct_name:ident, ($($field_name:ident : $field_type:ty => $closure:expr),+)) => { - impl From<(&SchemaRef, Tuple)> for $struct_name { - fn from((schema, tuple): (&SchemaRef, Tuple)) -> Self { - fn try_get(tuple: &Tuple, schema: &SchemaRef, field_name: &str) -> Option { - let ty = LogicalType::type_trans::()?; - let (idx, _) = schema - .iter() - .enumerate() - .find(|(_, col)| col.name() == field_name)?; - - DataValue::clone(&tuple.values[idx]) - .cast(&ty) - .ok() - } - - let mut struct_instance = $struct_name::default(); - $( - if let Some(value) = try_get::<$field_type>(&tuple, schema, stringify!($field_name)) { - $closure( - &mut struct_instance, - value - ); - } - )+ - struct_instance - } - } - }; -} - -/// # Examples -/// -/// ``` -/// scala_function!(MyFunction::sum(LogicalType::Integer, LogicalType::Integer) -> LogicalType::Integer => |v1: ValueRef, v2: ValueRef| { -/// DataValue::binary_op(&v1, &v2, &BinaryOperator::Plus) -/// }); -/// -/// let fnck_sql = DataBaseBuilder::path("./example") -/// .register_scala_function(TestFunction::new()) -/// .build() -/// ?; -/// ``` -#[macro_export] -macro_rules! scala_function { - ($struct_name:ident::$function_name:ident($($arg_ty:expr),*) -> $return_ty:expr => $closure:expr) => { - #[derive(Debug, Serialize, Deserialize)] - pub(crate) struct $struct_name { - summary: FunctionSummary - } - - impl $struct_name { - #[allow(unused_mut)] - pub(crate) fn new() -> Arc { - let function_name = stringify!($function_name).to_lowercase(); - - let mut arg_types = Vec::new(); - $({ - arg_types.push($arg_ty); - })* - - Arc::new(Self { - summary: FunctionSummary { - name: function_name, - arg_types - } - }) - } - } - - #[typetag::serde] - impl ScalarFunctionImpl for $struct_name { - #[allow(unused_variables, clippy::redundant_closure_call)] - fn eval(&self, args: &[ScalarExpression], tuple: &Tuple, schema: &[ColumnRef]) -> Result { - let mut _index = 0; - - $closure($({ - let mut value = args[_index].eval(tuple, schema)?; - _index += 1; - - if value.logical_type() != $arg_ty { - value = Arc::new(DataValue::clone(&value).cast(&$arg_ty)?); - } - value - }, )*) - } - - fn monotonicity(&self) -> Option { - todo!() - } - - fn return_type(&self) -> &LogicalType { - &$return_ty - } - - fn summary(&self) -> &FunctionSummary { - &self.summary - } - } - }; -} - -/// # Examples -/// -/// ``` -/// table_function!(MyTableFunction::test_numbers(LogicalType::Integer) -> [c1: LogicalType::Integer, c2: LogicalType::Integer] => (|v1: ValueRef| { -/// let num = v1.i32().unwrap(); -/// -/// Ok(Box::new((0..num) -/// .into_iter() -/// .map(|i| Ok(Tuple { -/// id: None, -/// values: vec![ -/// Arc::new(DataValue::Int32(Some(i))), -/// Arc::new(DataValue::Int32(Some(i))), -/// ] -/// }))) as Box>>) -/// })); -/// -/// let fnck_sql = DataBaseBuilder::path("./example") -/// .register_table_function(MyTableFunction::new()) -/// .build() -/// ?; -/// ``` -#[macro_export] -macro_rules! table_function { - ($struct_name:ident::$function_name:ident($($arg_ty:expr),*) -> [$($output_name:ident: $output_ty:expr),*] => $closure:expr) => { - lazy_static! { - static ref $function_name: TableCatalog = { - let mut columns = Vec::new(); - - $({ - columns.push(ColumnCatalog::new(stringify!($output_name).to_lowercase(), true, ColumnDesc::new($output_ty, false, false, None))); - })* - TableCatalog::new(Arc::new(stringify!($function_name).to_lowercase()), columns).unwrap() - }; - } - - #[derive(Debug, Serialize, Deserialize)] - pub(crate) struct $struct_name { - summary: FunctionSummary - } - - impl $struct_name { - #[allow(unused_mut)] - pub(crate) fn new() -> Arc { - let function_name = stringify!($function_name).to_lowercase(); - - let mut arg_types = Vec::new(); - $({ - arg_types.push($arg_ty); - })* - - Arc::new(Self { - summary: FunctionSummary { - name: function_name, - arg_types - } - }) - } - } - - #[typetag::serde] - impl TableFunctionImpl for $struct_name { - #[allow(unused_variables, clippy::redundant_closure_call)] - fn eval(&self, args: &[ScalarExpression]) -> Result>>, DatabaseError> { - let mut _index = 0; - let tuple = Tuple { - id: None, - values: Vec::new(), - }; - - $closure($({ - let mut value = args[_index].eval(&tuple, &[])?; - _index += 1; - - if value.logical_type() != $arg_ty { - value = Arc::new(DataValue::clone(&value).cast(&$arg_ty)?); - } - value - }, )*) - } - - fn output_schema(&self) -> &SchemaRef { - $function_name.schema_ref() - } - - fn summary(&self) -> &FunctionSummary { - &self.summary - } - - fn table(&self) -> &'static TableCatalog { - &$function_name - } - } - }; -} - -#[cfg(test)] -mod test { - use crate::catalog::ColumnRef; - use crate::catalog::{ColumnCatalog, ColumnDesc}; - use crate::errors::DatabaseError; - use crate::expression::function::scala::{FuncMonotonicity, ScalarFunctionImpl}; - use crate::expression::function::FunctionSummary; - use crate::expression::BinaryOperator; - use crate::expression::ScalarExpression; - use crate::types::evaluator::EvaluatorFactory; - use crate::types::tuple::{SchemaRef, Tuple}; - use crate::types::value::{DataValue, Utf8Type, ValueRef}; - use crate::types::LogicalType; - use crate::catalog::TableCatalog; - use crate::expression::function::table::TableFunctionImpl; - use lazy_static::lazy_static; - use serde::Deserialize; - use serde::Serialize; - use sqlparser::ast::CharLengthUnits; - use std::sync::Arc; - - fn build_tuple() -> (Tuple, SchemaRef) { - let schema_ref = Arc::new(vec![ - Arc::new(ColumnCatalog::new( - "c1".to_string(), - false, - ColumnDesc::new(LogicalType::Integer, true, false, None), - )), - Arc::new(ColumnCatalog::new( - "c2".to_string(), - false, - ColumnDesc::new( - LogicalType::Varchar(None, CharLengthUnits::Characters), - false, - false, - None, - ), - )), - ]); - let values = vec![ - Arc::new(DataValue::Int32(Some(9))), - Arc::new(DataValue::Utf8 { - value: Some("LOL".to_string()), - ty: Utf8Type::Variable(None), - unit: CharLengthUnits::Characters, - }), - ]; - - (Tuple { id: None, values }, schema_ref) - } - - #[derive(Default, Debug, PartialEq)] - struct MyStruct { - c1: i32, - c2: String, - } - - implement_from_tuple!( - MyStruct, ( - c1: i32 => |inner: &mut MyStruct, value| { - if let DataValue::Int32(Some(val)) = value { - inner.c1 = val; - } - }, - c2: String => |inner: &mut MyStruct, value| { - if let DataValue::Utf8 { value: Some(val), .. } = value { - inner.c2 = val; - } - } - ) - ); - - #[test] - fn test_from_tuple() { - let (tuple, schema_ref) = build_tuple(); - let my_struct = MyStruct::from((&schema_ref, tuple)); - - println!("{:?}", my_struct); - - assert_eq!(my_struct.c1, 9); - assert_eq!(my_struct.c2, "LOL"); - } - - scala_function!(MyScalaFunction::sum(LogicalType::Integer, LogicalType::Integer) -> LogicalType::Integer => (|v1: ValueRef, v2: ValueRef| { - let plus_evaluator = EvaluatorFactory::binary_create(LogicalType::Integer, BinaryOperator::Plus)?; - - Ok(plus_evaluator.0.binary_eval(&v1, &v2)) - })); - - table_function!(MyTableFunction::test_numbers(LogicalType::Integer) -> [c1: LogicalType::Integer, c2: LogicalType::Integer] => (|v1: ValueRef| { - let num = v1.i32().unwrap(); - - Ok(Box::new((0..num) - .into_iter() - .map(|i| Ok(Tuple { - id: None, - values: vec![ - Arc::new(DataValue::Int32(Some(i))), - Arc::new(DataValue::Int32(Some(i))), - ] - }))) as Box>>) - })); - - #[test] - fn test_scala_function() -> Result<(), DatabaseError> { - let function = MyScalaFunction::new(); - let sum = function.eval( - &[ - ScalarExpression::Constant(Arc::new(DataValue::Int8(Some(1)))), - ScalarExpression::Constant(Arc::new(DataValue::Utf8 { - value: Some("1".to_string()), - ty: Utf8Type::Variable(None), - unit: CharLengthUnits::Characters, - })), - ], - &Tuple { - id: None, - values: vec![], - }, - &vec![], - )?; - - println!("{:?}", function); - - assert_eq!( - function.summary, - FunctionSummary { - name: "sum".to_string(), - arg_types: vec![LogicalType::Integer, LogicalType::Integer], - } - ); - assert_eq!(sum, DataValue::Int32(Some(2))); - Ok(()) - } - - #[test] - fn test_table_function() -> Result<(), DatabaseError> { - let function = MyTableFunction::new(); - let mut numbers = function.eval(&[ScalarExpression::Constant(Arc::new( - DataValue::Int8(Some(2)), - ))])?; - - println!("{:?}", function); - - assert_eq!( - function.summary, - FunctionSummary { - name: "numbers".to_string(), - arg_types: vec![LogicalType::Integer], - } - ); - assert_eq!( - numbers.next().unwrap().unwrap(), - Tuple { - id: None, - values: vec![ - Arc::new(DataValue::Int32(Some(0))), - Arc::new(DataValue::Int32(Some(0))), - ] - } - ); - assert_eq!( - numbers.next().unwrap().unwrap(), - Tuple { - id: None, - values: vec![ - Arc::new(DataValue::Int32(Some(1))), - Arc::new(DataValue::Int32(Some(1))), - ] - } - ); - assert!(numbers.next().is_none()); - - assert_eq!( - function.output_schema(), - &Arc::new(vec![ - Arc::new(ColumnCatalog::new( - "c1".to_string(), - true, - ColumnDesc::new(LogicalType::Integer, false, false, None) - )), - Arc::new(ColumnCatalog::new( - "c2".to_string(), - true, - ColumnDesc::new(LogicalType::Integer, false, false, None) - )) - ]) - ); - - Ok(()) - } -} diff --git a/tests/macros-test/Cargo.toml b/tests/macros-test/Cargo.toml new file mode 100644 index 00000000..74a3173d --- /dev/null +++ b/tests/macros-test/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "macros-test" +version = "0.4.0" +edition = "2021" + +[dev-dependencies] +"fnck_sql" = { path = "../.." } +lazy_static = { version = "1" } +serde = { version = "1", features = ["derive", "rc"] } +sqlparser = { version = "0.34", features = ["serde"] } +typetag = { version = "0.2" } \ No newline at end of file diff --git a/tests/macros-test/src/main.rs b/tests/macros-test/src/main.rs new file mode 100644 index 00000000..c35e0e6b --- /dev/null +++ b/tests/macros-test/src/main.rs @@ -0,0 +1,196 @@ +fn main() {} + +#[cfg(test)] +mod test { + use fnck_sql::catalog::column::{ColumnCatalog, ColumnDesc}; + use fnck_sql::errors::DatabaseError; + use fnck_sql::expression::function::scala::ScalarFunctionImpl; + use fnck_sql::expression::function::table::TableFunctionImpl; + use fnck_sql::expression::function::FunctionSummary; + use fnck_sql::expression::BinaryOperator; + use fnck_sql::expression::ScalarExpression; + use fnck_sql::types::evaluator::EvaluatorFactory; + use fnck_sql::types::tuple::{SchemaRef, Tuple}; + use fnck_sql::types::value::ValueRef; + use fnck_sql::types::value::{DataValue, Utf8Type}; + use fnck_sql::types::LogicalType; + use fnck_sql::{implement_from_tuple, scala_function, table_function}; + use sqlparser::ast::CharLengthUnits; + use std::sync::Arc; + + fn build_tuple() -> (Tuple, SchemaRef) { + let schema_ref = Arc::new(vec![ + Arc::new(ColumnCatalog::new( + "c1".to_string(), + false, + ColumnDesc::new(LogicalType::Integer, true, false, None), + )), + Arc::new(ColumnCatalog::new( + "c2".to_string(), + false, + ColumnDesc::new( + LogicalType::Varchar(None, CharLengthUnits::Characters), + false, + false, + None, + ), + )), + ]); + let values = vec![ + Arc::new(DataValue::Int32(Some(9))), + Arc::new(DataValue::Utf8 { + value: Some("LOL".to_string()), + ty: Utf8Type::Variable(None), + unit: CharLengthUnits::Characters, + }), + ]; + + (Tuple { id: None, values }, schema_ref) + } + + #[derive(Default, Debug, PartialEq)] + struct MyStruct { + c1: i32, + c2: String, + } + + implement_from_tuple!( + MyStruct, ( + c1: i32 => |inner: &mut MyStruct, value| { + if let DataValue::Int32(Some(val)) = value { + inner.c1 = val; + } + }, + c2: String => |inner: &mut MyStruct, value| { + if let DataValue::Utf8 { value: Some(val), .. } = value { + inner.c2 = val; + } + } + ) + ); + + #[test] + fn test_from_tuple() { + let (tuple, schema_ref) = build_tuple(); + let my_struct = MyStruct::from((&schema_ref, tuple)); + + println!("{:?}", my_struct); + + assert_eq!(my_struct.c1, 9); + assert_eq!(my_struct.c2, "LOL"); + } + + scala_function!(MyScalaFunction::sum(LogicalType::Integer, LogicalType::Integer) -> LogicalType::Integer => (|v1: ValueRef, v2: ValueRef| { + let plus_evaluator = EvaluatorFactory::binary_create(LogicalType::Integer, BinaryOperator::Plus)?; + + Ok(plus_evaluator.0.binary_eval(&v1, &v2)) + })); + + table_function!(MyTableFunction::test_numbers(LogicalType::Integer) -> [c1: LogicalType::Integer, c2: LogicalType::Integer] => (|v1: ValueRef| { + let num = v1.i32().unwrap(); + + Ok(Box::new((0..num) + .into_iter() + .map(|i| Ok(Tuple { + id: None, + values: vec![ + Arc::new(DataValue::Int32(Some(i))), + Arc::new(DataValue::Int32(Some(i))), + ] + }))) as Box>>) + })); + + #[test] + fn test_scala_function() -> Result<(), DatabaseError> { + let function = MyScalaFunction::new(); + let sum = function.eval( + &[ + ScalarExpression::Constant(Arc::new(DataValue::Int8(Some(1)))), + ScalarExpression::Constant(Arc::new(DataValue::Utf8 { + value: Some("1".to_string()), + ty: Utf8Type::Variable(None), + unit: CharLengthUnits::Characters, + })), + ], + &Tuple { + id: None, + values: vec![], + }, + &vec![], + )?; + + println!("{:?}", function); + + assert_eq!( + function.summary, + FunctionSummary { + name: "sum".to_string(), + arg_types: vec![LogicalType::Integer, LogicalType::Integer], + } + ); + assert_eq!(sum, DataValue::Int32(Some(2))); + Ok(()) + } + + #[test] + fn test_table_function() -> Result<(), DatabaseError> { + let function = MyTableFunction::new(); + let mut numbers = function.eval(&[ScalarExpression::Constant(Arc::new( + DataValue::Int8(Some(2)), + ))])?; + + println!("{:?}", function); + + assert_eq!( + function.summary, + FunctionSummary { + name: "test_numbers".to_string(), + arg_types: vec![LogicalType::Integer], + } + ); + assert_eq!( + numbers.next().unwrap().unwrap(), + Tuple { + id: None, + values: vec![ + Arc::new(DataValue::Int32(Some(0))), + Arc::new(DataValue::Int32(Some(0))), + ] + } + ); + assert_eq!( + numbers.next().unwrap().unwrap(), + Tuple { + id: None, + values: vec![ + Arc::new(DataValue::Int32(Some(1))), + Arc::new(DataValue::Int32(Some(1))), + ] + } + ); + assert!(numbers.next().is_none()); + + let table_name = Arc::new("test_numbers".to_string()); + let mut c1 = ColumnCatalog::new( + "c1".to_string(), + true, + ColumnDesc::new(LogicalType::Integer, false, false, None), + ); + c1.set_id(0); + c1.set_table_name(table_name.clone()); + let mut c2 = ColumnCatalog::new( + "c2".to_string(), + true, + ColumnDesc::new(LogicalType::Integer, false, false, None), + ); + c2.set_id(1); + c2.set_table_name(table_name.clone()); + + assert_eq!( + function.output_schema(), + &Arc::new(vec![Arc::new(c1), Arc::new(c2)]) + ); + + Ok(()) + } +}