diff --git a/Cargo.lock b/Cargo.lock index 4bda283428..602382c744 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2780,32 +2780,38 @@ name = "dozer-sql" version = "0.1.39" dependencies = [ "ahash 0.8.3", - "bigdecimal", "dozer-core", + "dozer-sql-expression", "dozer-storage", "dozer-tracing", "dozer-types", "enum_dispatch", - "half 2.3.1", - "jsonpath", - "jsonpath-rust", - "like", "linked-hash-map", "metrics", "multimap", - "ndarray", - "num-traits", - "ort", - "pest", - "pest_derive", "proptest", "regex", - "sqlparser 0.35.0 (git+https://github.com/getdozer/sqlparser-rs.git)", "tempdir", "tokio", "uuid", ] +[[package]] +name = "dozer-sql-expression" +version = "0.1.39" +dependencies = [ + "bigdecimal", + "dozer-types", + "half 2.3.1", + "jsonpath", + "like", + "ndarray", + "num-traits", + "ort", + "proptest", + "sqlparser 0.35.0 (git+https://github.com/getdozer/sqlparser-rs.git)", +] + [[package]] name = "dozer-storage" version = "0.1.39" @@ -4261,18 +4267,6 @@ dependencies = [ "regex", ] -[[package]] -name = "jsonpath-rust" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b55563e28c54b1cc0d7eb92475cf9e210cd58e2fce9fabbc0cb5bb1136b4ab3" -dependencies = [ - "pest", - "pest_derive", - "regex", - "serde_json", -] - [[package]] name = "jsonrpc-core" version = "18.0.0" @@ -5259,9 +5253,9 @@ dependencies = [ [[package]] name = "num-traits" -version = "0.2.15" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd" +checksum = "f30b0abd723be7e2ffca1272140fac1a2f084c77ec3e123c192b66af1ee9e6c2" dependencies = [ "autocfg", "libm", diff --git a/dozer-cli/src/errors.rs b/dozer-cli/src/errors.rs index 4928411452..740a640f7a 100644 --- a/dozer-cli/src/errors.rs +++ b/dozer-cli/src/errors.rs @@ -17,7 +17,7 @@ use dozer_cache::dozer_log::storage; use dozer_cache::errors::CacheError; use dozer_core::errors::ExecutionError; use dozer_ingestion::errors::ConnectorError; -use dozer_sql::pipeline::errors::PipelineError; +use dozer_sql::errors::PipelineError; use dozer_types::{constants::LOCK_FILE, thiserror::Error}; use dozer_types::{errors::internal::BoxedError, serde_json}; use dozer_types::{serde_yaml, thiserror}; diff --git a/dozer-cli/src/lib.rs b/dozer-cli/src/lib.rs index 879e328c43..8ed645d374 100644 --- a/dozer-cli/src/lib.rs +++ b/dozer-cli/src/lib.rs @@ -5,7 +5,7 @@ pub mod pipeline; pub mod shutdown; pub mod simple; use dozer_core::{app::AppPipeline, errors::ExecutionError}; -use dozer_sql::pipeline::{builder::statement_to_pipeline, errors::PipelineError}; +use dozer_sql::{builder::statement_to_pipeline, errors::PipelineError}; use dozer_types::log::debug; use errors::OrchestrationError; use shutdown::ShutdownSender; @@ -58,7 +58,7 @@ pub use dozer_ingestion::{ connectors::{get_connector, TableInfo}, errors::ConnectorError, }; -pub use dozer_sql::pipeline::builder::QueryContext; +pub use dozer_sql::builder::QueryContext; pub fn wrapped_statement_to_pipeline(sql: &str) -> Result { let mut pipeline = AppPipeline::new_with_default_flags(); statement_to_pipeline(sql, &mut pipeline, None, vec![]) diff --git a/dozer-cli/src/live/errors.rs b/dozer-cli/src/live/errors.rs index 523262d911..34496892a5 100644 --- a/dozer-cli/src/live/errors.rs +++ b/dozer-cli/src/live/errors.rs @@ -1,6 +1,6 @@ use crate::errors::{BuildError, CliError, OrchestrationError}; use dozer_core::errors::ExecutionError; -use dozer_sql::pipeline::errors::PipelineError; +use dozer_sql::errors::PipelineError; use dozer_types::thiserror; use dozer_types::thiserror::Error; diff --git a/dozer-cli/src/live/state.rs b/dozer-cli/src/live/state.rs index d045eaf105..d941d808e6 100644 --- a/dozer-cli/src/live/state.rs +++ b/dozer-cli/src/live/state.rs @@ -4,7 +4,7 @@ use clap::Parser; use dozer_cache::dozer_log::camino::Utf8Path; use dozer_core::{app::AppPipeline, dag_schemas::DagSchemas, Dag}; -use dozer_sql::pipeline::builder::statement_to_pipeline; +use dozer_sql::builder::statement_to_pipeline; use dozer_tracing::{Labels, LabelsAndProgress}; use dozer_types::{ constants::DEFAULT_DEFAULT_MAX_NUM_RECORDS, diff --git a/dozer-cli/src/pipeline/builder.rs b/dozer-cli/src/pipeline/builder.rs index 50a9e31e03..ce393ae261 100644 --- a/dozer-cli/src/pipeline/builder.rs +++ b/dozer-cli/src/pipeline/builder.rs @@ -9,8 +9,8 @@ use dozer_core::app::PipelineEntryPoint; use dozer_core::node::SinkFactory; use dozer_core::DEFAULT_PORT_HANDLE; use dozer_ingestion::connectors::{get_connector, get_connector_info_table}; -use dozer_sql::pipeline::builder::statement_to_pipeline; -use dozer_sql::pipeline::builder::{OutputNodeInfo, QueryContext}; +use dozer_sql::builder::statement_to_pipeline; +use dozer_sql::builder::{OutputNodeInfo, QueryContext}; use dozer_tracing::LabelsAndProgress; use dozer_types::log::debug; use dozer_types::models::api_endpoint::ApiEndpoint; diff --git a/dozer-cli/src/simple/orchestrator.rs b/dozer-cli/src/simple/orchestrator.rs index 884b7d75ee..2c1d2dde99 100644 --- a/dozer-cli/src/simple/orchestrator.rs +++ b/dozer-cli/src/simple/orchestrator.rs @@ -31,8 +31,8 @@ use crate::console_helper::PURPLE; use crate::console_helper::RED; use dozer_core::errors::ExecutionError; use dozer_ingestion::connectors::{get_connector, SourceSchema, TableInfo}; -use dozer_sql::pipeline::builder::statement_to_pipeline; -use dozer_sql::pipeline::errors::PipelineError; +use dozer_sql::builder::statement_to_pipeline; +use dozer_sql::errors::PipelineError; use dozer_types::log::info; use dozer_types::models::config::Config; use dozer_types::tracing::error; diff --git a/dozer-ingestion/tests/test_suite/connectors/sql.rs b/dozer-ingestion/tests/test_suite/connectors/sql.rs index acd23e5cc6..e1105bed63 100644 --- a/dozer-ingestion/tests/test_suite/connectors/sql.rs +++ b/dozer-ingestion/tests/test_suite/connectors/sql.rs @@ -237,7 +237,7 @@ fn field_to_sql(field: &Field) -> String { Field::Date(d) => format!("'{}'", d), Field::Json(b) => format!("'{b}'::jsonb"), Field::Point(p) => format!("'({},{})'", p.0.x(), p.0.y()), - Field::Duration(d) => d.to_string(), + Field::Duration(_) => field.to_string(), Field::Null => "NULL".to_string(), } } diff --git a/dozer-log-python/src/mapper.rs b/dozer-log-python/src/mapper.rs index 0abe5721f8..49709a74e1 100644 --- a/dozer-log-python/src/mapper.rs +++ b/dozer-log-python/src/mapper.rs @@ -79,7 +79,7 @@ fn map_value(value: Field, py: Python) -> PyResult> { Field::Date(v) => Ok(v.to_string().to_object(py)), Field::Json(v) => map_json_py(v, py), Field::Point(v) => map_point(v, py), - Field::Duration(v) => Ok(v.to_string().to_object(py)), + Field::Duration(_) => Ok(value.to_string().to_object(py)), Field::Null => Ok(py.None()), } } diff --git a/dozer-sql/Cargo.toml b/dozer-sql/Cargo.toml index 883c16e0fb..82248dd3f1 100644 --- a/dozer-sql/Cargo.toml +++ b/dozer-sql/Cargo.toml @@ -11,25 +11,15 @@ dozer-types = { path = "../dozer-types" } dozer-storage = { path = "../dozer-storage" } dozer-core = { path = "../dozer-core" } dozer-tracing = { path = "../dozer-tracing" } -jsonpath = { path = "jsonpath" } +dozer-sql-expression = { path = "expression" } ahash = "0.8.3" enum_dispatch = "0.3.11" -jsonpath-rust = "0.3.1" -like = "0.3.1" linked-hash-map = { version = "0.5.6", features = ["serde_impl"] } metrics = "0.21.0" multimap = "0.8.3" -num-traits = "0.2.15" -pest = "2.6.0" -pest_derive = "2.5.6" regex = "1.8.1" -sqlparser = { git = "https://github.com/getdozer/sqlparser-rs.git" } uuid = { version = "1.3.0", features = ["v1", "v4", "fast-rng"] } -bigdecimal = { version = "0.3", features = ["serde"], optional = true } -ort = { version = "1.15.2", optional = true } -ndarray = { version = "0.15", optional = true } -half = { version = "2.3.1", optional = true } [dev-dependencies] tempdir = "0.3.7" @@ -37,6 +27,5 @@ proptest = "1.2.0" tokio = { version = "1", features = ["rt", "macros"] } [features] -python = ["dozer-types/python-auto-initialize"] -bigdecimal = ["dep:bigdecimal", "sqlparser/bigdecimal"] -onnx = ["dep:ort", "dep:ndarray", "dep:half"] +python = ["dozer-sql-expression/python"] +onnx = ["dozer-sql-expression/onnx"] diff --git a/dozer-sql/expression/Cargo.toml b/dozer-sql/expression/Cargo.toml new file mode 100644 index 0000000000..16038f8386 --- /dev/null +++ b/dozer-sql/expression/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "dozer-sql-expression" +version = "0.1.39" +edition = "2021" +authors = ["getdozer/dozer-dev"] + +[dependencies] +dozer-types = { path = "../../dozer-types" } +num-traits = "0.2.16" +sqlparser = { git = "https://github.com/getdozer/sqlparser-rs.git" } +bigdecimal = { version = "0.3", features = ["serde"], optional = true } +ort = { version = "1.15.2", optional = true } +ndarray = { version = "0.15", optional = true } +half = { version = "2.3.1", optional = true } +like = "0.3.1" +jsonpath = { path = "../jsonpath" } + +[dev-dependencies] +proptest = "1.2.0" + +[features] +bigdecimal = ["dep:bigdecimal", "sqlparser/bigdecimal"] +python = ["dozer-types/python-auto-initialize"] +onnx = ["dep:ort", "dep:ndarray", "dep:half"] diff --git a/dozer-sql/expression/proptest-regressions/comparison.txt b/dozer-sql/expression/proptest-regressions/comparison.txt new file mode 100644 index 0000000000..a7f99974df --- /dev/null +++ b/dozer-sql/expression/proptest-regressions/comparison.txt @@ -0,0 +1,7 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc dc65a968a73d59d750e7242dccff455c9e35c2de94328e74635aafea06c80916 # shrinks to u_num1 = 0, u_num2 = 1, i_num1 = -1, i_num2 = 0, f_num1 = 8.565445625875324e-309, f_num2 = 0.0, d_num1 = ArbitraryDecimal(-1), d_num2 = ArbitraryDecimal(0) diff --git a/dozer-sql/expression/proptest-regressions/datetime.txt b/dozer-sql/expression/proptest-regressions/datetime.txt new file mode 100644 index 0000000000..d88d309e16 --- /dev/null +++ b/dozer-sql/expression/proptest-regressions/datetime.txt @@ -0,0 +1,7 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc be943bc26443acef7454b3e70de97c2a02d7557a4e71f904e0724f6db1988b41 # shrinks to datetime = ArbitraryDateTime(0000-01-01T00:00:00+08:00) diff --git a/dozer-sql/expression/proptest-regressions/geo/point.txt b/dozer-sql/expression/proptest-regressions/geo/point.txt new file mode 100644 index 0000000000..dc6eea99fb --- /dev/null +++ b/dozer-sql/expression/proptest-regressions/geo/point.txt @@ -0,0 +1,7 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc 7c7362809a16e126e93c162f05e1835a1daf3f3b78a9f1750cd53c7a9105e09d # shrinks to x = 0, y = 0 diff --git a/dozer-sql/expression/proptest-regressions/logical.txt b/dozer-sql/expression/proptest-regressions/logical.txt new file mode 100644 index 0000000000..6513b0c3f8 --- /dev/null +++ b/dozer-sql/expression/proptest-regressions/logical.txt @@ -0,0 +1,7 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc 095bb1b25ce78cad75e07af1530f4a0fe57497b474a9d6c09e4dd84a611f1ca9 # shrinks to bool1 = false, bool2 = false, u_num = 0, i_num = 0, f_num = 0.0, str = "" diff --git a/dozer-sql/src/pipeline/expression/aggregate.rs b/dozer-sql/expression/src/aggregate.rs similarity index 57% rename from dozer-sql/src/pipeline/expression/aggregate.rs rename to dozer-sql/expression/src/aggregate.rs index 4abea6e3cb..d136358979 100644 --- a/dozer-sql/src/pipeline/expression/aggregate.rs +++ b/dozer-sql/expression/src/aggregate.rs @@ -1,5 +1,3 @@ -use crate::pipeline::errors::PipelineError; -use crate::pipeline::errors::PipelineError::InvalidFunction; use std::fmt::{Display, Formatter}; #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] @@ -14,16 +12,16 @@ pub enum AggregateFunctionType { } impl AggregateFunctionType { - pub(crate) fn new(name: &str) -> Result { + pub(crate) fn new(name: &str) -> Option { match name { - "avg" => Ok(AggregateFunctionType::Avg), - "count" => Ok(AggregateFunctionType::Count), - "max" => Ok(AggregateFunctionType::Max), - "max_value" => Ok(AggregateFunctionType::MaxValue), - "min" => Ok(AggregateFunctionType::Min), - "min_value" => Ok(AggregateFunctionType::MinValue), - "sum" => Ok(AggregateFunctionType::Sum), - _ => Err(InvalidFunction(name.to_string())), + "avg" => Some(AggregateFunctionType::Avg), + "count" => Some(AggregateFunctionType::Count), + "max" => Some(AggregateFunctionType::Max), + "max_value" => Some(AggregateFunctionType::MaxValue), + "min" => Some(AggregateFunctionType::Min), + "min_value" => Some(AggregateFunctionType::MinValue), + "sum" => Some(AggregateFunctionType::Sum), + _ => None, } } } diff --git a/dozer-sql/expression/src/arg_utils.rs b/dozer-sql/expression/src/arg_utils.rs new file mode 100644 index 0000000000..5c625b91ed --- /dev/null +++ b/dozer-sql/expression/src/arg_utils.rs @@ -0,0 +1,127 @@ +use std::fmt::Display; +use std::ops::Range; + +use crate::error::Error; +use crate::execution::{Expression, ExpressionType}; +use dozer_types::chrono::{DateTime, FixedOffset}; +use dozer_types::types::{DozerPoint, Field, FieldType, Schema}; + +pub fn validate_one_argument( + args: &[Expression], + schema: &Schema, + function_name: impl Display, +) -> Result { + validate_num_arguments(1..2, args.len(), function_name)?; + args[0].get_type(schema) +} + +pub fn validate_two_arguments( + args: &[Expression], + schema: &Schema, + function_name: impl Display, +) -> Result<(ExpressionType, ExpressionType), Error> { + validate_num_arguments(2..3, args.len(), function_name)?; + let arg1 = args[0].get_type(schema)?; + let arg2 = args[1].get_type(schema)?; + Ok((arg1, arg2)) +} + +pub fn validate_num_arguments( + expected: Range, + actual: usize, + function_name: impl Display, +) -> Result<(), Error> { + if !expected.contains(&actual) { + Err(Error::InvalidNumberOfArguments { + function_name: function_name.to_string(), + expected, + actual, + }) + } else { + Ok(()) + } +} + +pub fn validate_arg_type( + arg: &Expression, + expected: Vec, + schema: &Schema, + function_name: impl Display, + argument_index: usize, +) -> Result { + let arg_t = arg.get_type(schema)?; + if !expected.contains(&arg_t.return_type) { + Err(Error::InvalidFunctionArgumentType { + function_name: function_name.to_string(), + argument_index, + actual: arg_t.return_type, + expected, + }) + } else { + Ok(arg_t) + } +} + +pub fn extract_uint( + field: Field, + function_name: impl Display, + argument_index: usize, +) -> Result { + if let Some(value) = field.to_uint() { + Ok(value) + } else { + Err(Error::InvalidFunctionArgument { + function_name: function_name.to_string(), + argument_index, + argument: field, + }) + } +} + +pub fn extract_float( + field: Field, + function_name: impl Display, + argument_index: usize, +) -> Result { + if let Some(value) = field.to_float() { + Ok(value) + } else { + Err(Error::InvalidFunctionArgument { + function_name: function_name.to_string(), + argument_index, + argument: field, + }) + } +} + +pub fn extract_point( + field: Field, + function_name: impl Display, + argument_index: usize, +) -> Result { + if let Some(value) = field.to_point() { + Ok(value) + } else { + Err(Error::InvalidFunctionArgument { + function_name: function_name.to_string(), + argument_index, + argument: field, + }) + } +} + +pub fn extract_timestamp( + field: Field, + function_name: impl Display, + argument_index: usize, +) -> Result, Error> { + if let Some(value) = field.to_timestamp() { + Ok(value) + } else { + Err(Error::InvalidFunctionArgument { + function_name: function_name.to_string(), + argument_index, + argument: field, + }) + } +} diff --git a/dozer-sql/src/pipeline/expression/builder.rs b/dozer-sql/expression/src/builder.rs similarity index 67% rename from dozer-sql/src/pipeline/expression/builder.rs rename to dozer-sql/expression/src/builder.rs index dbf53a5976..d6c3c4c062 100644 --- a/dozer-sql/src/pipeline/expression/builder.rs +++ b/dozer-sql/expression/src/builder.rs @@ -1,12 +1,9 @@ -use crate::pipeline::errors::PipelineError::{ - InvalidArgument, InvalidExpression, InvalidFunction, InvalidNestedAggregationFunction, - InvalidOperator, InvalidValue, -}; -use crate::pipeline::errors::{PipelineError, SqlError}; -use crate::pipeline::expression::aggregate::AggregateFunctionType; -use crate::pipeline::expression::conditional::ConditionalExpressionType; -use crate::pipeline::expression::datetime::DateTimeFunctionType; +use crate::aggregate::AggregateFunctionType; +use crate::conditional::ConditionalExpressionType; +use crate::datetime::DateTimeFunctionType; +use crate::error::Error; use dozer_types::models::udf_config::{UdfConfig, UdfType}; +use dozer_types::types::FieldType; use dozer_types::{ ordered_float::OrderedFloat, types::{Field, FieldDefinition, Schema, SourceDefinition}, @@ -17,27 +14,16 @@ use sqlparser::ast::{ UnaryOperator as SqlUnaryOperator, Value as SqlValue, }; -use crate::pipeline::expression::execution::Expression; -use crate::pipeline::expression::execution::Expression::{ - ConditionalExpression, GeoFunction, Now, ScalarFunction, -}; -use crate::pipeline::expression::geo::common::GeoFunctionType; -use crate::pipeline::expression::json_functions::JsonFunctionType; -use crate::pipeline::expression::operator::{BinaryOperatorType, UnaryOperatorType}; -use crate::pipeline::expression::scalar::common::ScalarFunctionType; -use crate::pipeline::expression::scalar::string::TrimType; +use crate::execution::Expression; +use crate::execution::Expression::{ConditionalExpression, GeoFunction, Now, ScalarFunction}; +use crate::geo::common::GeoFunctionType; +use crate::json_functions::JsonFunctionType; +use crate::operator::{BinaryOperatorType, UnaryOperatorType}; +use crate::scalar::common::ScalarFunctionType; +use crate::scalar::string::TrimType; use super::cast::CastOperatorType; -#[cfg(feature = "onnx")] -use crate::pipeline::errors::PipelineError::OnnxError; -#[cfg(feature = "onnx")] -use crate::pipeline::onnx::DozerSession; -#[cfg(feature = "onnx")] -use crate::pipeline::onnx::OnnxError::OnnxOrtErr; -#[cfg(feature = "onnx")] -use dozer_types::models::udf_config::OnnxConfig; - #[derive(Clone, PartialEq, Debug)] pub struct ExpressionBuilder { // Must be an aggregation function @@ -66,17 +52,17 @@ impl ExpressionBuilder { sql_expression: &SqlExpr, schema: &Schema, udfs: &[UdfConfig], - ) -> Result { + ) -> Result { self.parse_sql_expression(parse_aggregations, sql_expression, schema, udfs) } - pub(crate) fn parse_sql_expression( + pub fn parse_sql_expression( &mut self, parse_aggregations: bool, expression: &SqlExpr, schema: &Schema, udfs: &[UdfConfig], - ) -> Result { + ) -> Result { match expression { SqlExpr::Trim { expr, @@ -169,11 +155,11 @@ impl ExpressionBuilder { schema, udfs, ), - _ => Err(InvalidExpression(format!("{expression:?}"))), + _ => Err(Error::UnsupportedExpression(expression.clone())), } } - fn parse_sql_column(ident: &[Ident], schema: &Schema) -> Result { + fn parse_sql_column(ident: &[Ident], schema: &Schema) -> Result { let (src_field, src_table_or_alias, src_connection) = match ident.len() { 1 => (&ident[0].value, None, None), 2 => (&ident[1].value, Some(&ident[0].value), None), @@ -183,13 +169,7 @@ impl ExpressionBuilder { Some(&ident[0].value), ), _ => { - return Err(PipelineError::SqlError(SqlError::InvalidColumn( - ident - .iter() - .map(|e| e.value.as_str()) - .collect::>() - .join("."), - ))); + return Err(Error::InvalidIdent(ident.to_vec())); } }; @@ -205,13 +185,7 @@ impl ExpressionBuilder { index: matching_by_field[0].0, }), _ => match src_table_or_alias { - None => Err(PipelineError::SqlError(SqlError::InvalidColumn( - ident - .iter() - .map(|e| e.value.as_str()) - .collect::>() - .join("."), - ))), + None => Err(Error::InvalidIdent(ident.to_vec())), Some(src_table_or_alias) => { let matching_by_table_or_alias: Vec<(usize, &FieldDefinition)> = matching_by_field @@ -231,11 +205,7 @@ impl ExpressionBuilder { index: matching_by_table_or_alias[0].0, }), _ => match src_connection { - None => Err(PipelineError::SqlError(SqlError::InvalidColumn( - ident - .iter() - .fold(String::new(), |a, b| a + "." + b.value.as_str()), - ))), + None => Err(Error::InvalidIdent(ident.to_vec())), Some(src_connection) => { let matching_by_connection: Vec<(usize, &FieldDefinition)> = matching_by_table_or_alias @@ -253,13 +223,7 @@ impl ExpressionBuilder { 1 => Ok(Expression::Column { index: matching_by_connection[0].0, }), - _ => Err(PipelineError::SqlError(SqlError::InvalidColumn( - ident - .iter() - .map(|e| e.value.as_str()) - .collect::>() - .join("."), - ))), + _ => Err(Error::InvalidIdent(ident.to_vec())), } } }, @@ -277,7 +241,7 @@ impl ExpressionBuilder { trim_what: &Option>, schema: &Schema, udfs: &[UdfConfig], - ) -> Result { + ) -> Result { let arg = Box::new(self.parse_sql_expression(parse_aggregations, expr, schema, udfs)?); let what = match trim_what { Some(e) => Some(Box::new(self.parse_sql_expression( @@ -303,40 +267,37 @@ impl ExpressionBuilder { sql_function: &Function, schema: &Schema, udfs: &[UdfConfig], - ) -> Result { - match ( - AggregateFunctionType::new(function_name.as_str()), - parse_aggregations, - ) { - (Ok(aggr), true) => { - let mut arg_expr: Vec = Vec::new(); - for arg in &sql_function.args { - let aggregation = self.parse_sql_function_arg(true, arg, schema, udfs)?; - arg_expr.push(aggregation); - } - let measure = Expression::AggregateFunction { - fun: aggr, - args: arg_expr, - }; - let index = match self - .aggregations - .iter() - .enumerate() - .find(|e| e.1 == &measure) - { - Some((index, _existing)) => index, - _ => { - self.aggregations.push(measure); - self.aggregations.len() - 1 - } - }; - Ok(Expression::Column { - index: self.offset + index, - }) - } - (Ok(_agg), false) => Err(InvalidNestedAggregationFunction(function_name)), - (Err(_), _) => Err(InvalidNestedAggregationFunction(function_name)), + ) -> Option { + if !parse_aggregations { + return None; } + + let aggr = AggregateFunctionType::new(function_name.as_str())?; + + let mut arg_expr: Vec = Vec::new(); + for arg in &sql_function.args { + let aggregation = self.parse_sql_function_arg(true, arg, schema, udfs).ok()?; + arg_expr.push(aggregation); + } + let measure = Expression::AggregateFunction { + fun: aggr, + args: arg_expr, + }; + let index = match self + .aggregations + .iter() + .enumerate() + .find(|e| e.1 == &measure) + { + Some((index, _existing)) => index, + _ => { + self.aggregations.push(measure); + self.aggregations.len() - 1 + } + }; + Some(Expression::Column { + index: self.offset + index, + }) } fn scalar_function_check( @@ -346,24 +307,20 @@ impl ExpressionBuilder { sql_function: &Function, schema: &Schema, udfs: &[UdfConfig], - ) -> Result { + ) -> Option { let mut function_args: Vec = Vec::new(); for arg in &sql_function.args { - function_args.push(self.parse_sql_function_arg( - parse_aggregations, - arg, - schema, - udfs, - )?); + function_args.push( + self.parse_sql_function_arg(parse_aggregations, arg, schema, udfs) + .ok()?, + ); } - match ScalarFunctionType::new(function_name.as_str()) { - Ok(sft) => Ok(ScalarFunction { - fun: sft, - args: function_args.clone(), - }), - Err(_d) => Err(InvalidFunction(function_name)), - } + let sft = ScalarFunctionType::new(function_name.as_str())?; + Some(ScalarFunction { + fun: sft, + args: function_args, + }) } fn geo_expr_check( @@ -373,31 +330,25 @@ impl ExpressionBuilder { sql_function: &Function, schema: &Schema, udfs: &[UdfConfig], - ) -> Result { + ) -> Option { let mut function_args: Vec = Vec::new(); for arg in &sql_function.args { - function_args.push(self.parse_sql_function_arg( - parse_aggregations, - arg, - schema, - udfs, - )?); + function_args.push( + self.parse_sql_function_arg(parse_aggregations, arg, schema, udfs) + .ok()?, + ); } - match GeoFunctionType::new(function_name.as_str()) { - Ok(gft) => Ok(GeoFunction { - fun: gft, - args: function_args.clone(), - }), - Err(_e) => Err(InvalidFunction(function_name)), - } + let gft = GeoFunctionType::new(function_name.as_str())?; + Some(GeoFunction { + fun: gft, + args: function_args, + }) } - fn datetime_expr_check(&mut self, function_name: String) -> Result { - match DateTimeFunctionType::new(function_name.as_str()) { - Ok(dtf) => Ok(Now { fun: dtf }), - Err(_e) => Err(InvalidFunction(function_name)), - } + fn datetime_expr_check(&mut self, function_name: String) -> Option { + let dtf = DateTimeFunctionType::new(function_name.as_str())?; + Some(Now { fun: dtf }) } fn json_func_check( @@ -407,24 +358,20 @@ impl ExpressionBuilder { sql_function: &Function, schema: &Schema, udfs: &[UdfConfig], - ) -> Result { + ) -> Option { let mut function_args: Vec = Vec::new(); for arg in &sql_function.args { - function_args.push(self.parse_sql_function_arg( - parse_aggregations, - arg, - schema, - udfs, - )?); + function_args.push( + self.parse_sql_function_arg(parse_aggregations, arg, schema, udfs) + .ok()?, + ); } - match JsonFunctionType::new(function_name.as_str()) { - Ok(jft) => Ok(Expression::Json { - fun: jft, - args: function_args, - }), - Err(_e) => Err(InvalidFunction(function_name)), - } + let jft = JsonFunctionType::new(function_name.as_str())?; + Some(Expression::Json { + fun: jft, + args: function_args, + }) } fn conditional_expr_check( @@ -434,24 +381,20 @@ impl ExpressionBuilder { sql_function: &Function, schema: &Schema, udfs: &[UdfConfig], - ) -> Result { + ) -> Option { let mut function_args: Vec = Vec::new(); for arg in &sql_function.args { - function_args.push(self.parse_sql_function_arg( - parse_aggregations, - arg, - schema, - udfs, - )?); + function_args.push( + self.parse_sql_function_arg(parse_aggregations, arg, schema, udfs) + .ok()?, + ); } - match ConditionalExpressionType::new(function_name.as_str()) { - Ok(cet) => Ok(ConditionalExpression { - fun: cet, - args: function_args.clone(), - }), - Err(_err) => Err(InvalidFunction(function_name)), - } + let cet = ConditionalExpressionType::new(function_name.as_str())?; + Some(ConditionalExpression { + fun: cet, + args: function_args, + }) } fn parse_sql_function( @@ -460,7 +403,7 @@ impl ExpressionBuilder { sql_function: &Function, schema: &Schema, udfs: &[UdfConfig], - ) -> Result { + ) -> Result { let function_name = sql_function.name.to_string().to_lowercase(); #[cfg(feature = "python")] @@ -470,64 +413,58 @@ impl ExpressionBuilder { return self.parse_python_udf(udf_name, sql_function, schema, udfs); } - let aggr_check = self.aggr_function_check( + if let Some(aggr_check) = self.aggr_function_check( function_name.clone(), parse_aggregations, sql_function, schema, udfs, - ); - if aggr_check.is_ok() { - return aggr_check; + ) { + return Ok(aggr_check); } - let scalar_check = self.scalar_function_check( + if let Some(scalar_check) = self.scalar_function_check( function_name.clone(), parse_aggregations, sql_function, schema, udfs, - ); - if scalar_check.is_ok() { - return scalar_check; + ) { + return Ok(scalar_check); } - let geo_check = self.geo_expr_check( + if let Some(geo_check) = self.geo_expr_check( function_name.clone(), parse_aggregations, sql_function, schema, udfs, - ); - if geo_check.is_ok() { - return geo_check; + ) { + return Ok(geo_check); } - let conditional_check = self.conditional_expr_check( + if let Some(conditional_check) = self.conditional_expr_check( function_name.clone(), parse_aggregations, sql_function, schema, udfs, - ); - if conditional_check.is_ok() { - return conditional_check; + ) { + return Ok(conditional_check); } - let datetime_check = self.datetime_expr_check(function_name.clone()); - if datetime_check.is_ok() { - return datetime_check; + if let Some(datetime_check) = self.datetime_expr_check(function_name.clone()) { + return Ok(datetime_check); } - let json_check = self.json_func_check( + if let Some(json_check) = self.json_func_check( function_name.clone(), parse_aggregations, sql_function, schema, udfs, - ); - if json_check.is_ok() { - return json_check; + ) { + return Ok(json_check); } // config check for udfs @@ -549,14 +486,14 @@ impl ExpressionBuilder { #[cfg(not(feature = "onnx"))] { let _ = config; - Err(PipelineError::OnnxNotEnabled) + Err(Error::OnnxNotEnabled) } } - None => Err(PipelineError::UdfConfigMissing(function_name.clone())), + None => Err(Error::UdfConfigMissing(function_name.clone())), }; } - Err(PipelineError::UnknownFunction(function_name.clone())) + Err(Error::UnknownFunction(function_name.clone())) } fn parse_sql_function_arg( @@ -565,7 +502,7 @@ impl ExpressionBuilder { argument: &FunctionArg, schema: &Schema, udfs: &[UdfConfig], - ) -> Result { + ) -> Result { match argument { FunctionArg::Named { name: _, @@ -579,13 +516,7 @@ impl ExpressionBuilder { self.parse_sql_expression(parse_aggregations, arg, schema, udfs) } FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => Ok(Expression::Literal(Field::Null)), - FunctionArg::Named { - name: _, - arg: FunctionArgExpr::QualifiedWildcard(_), - } => Err(InvalidArgument(format!("{argument:?}"))), - FunctionArg::Unnamed(FunctionArgExpr::QualifiedWildcard(_)) => { - Err(InvalidArgument(format!("{argument:?}"))) - } + _ => Err(Error::UnsupportedFunctionArg(argument.clone())), } } @@ -599,7 +530,7 @@ impl ExpressionBuilder { else_result: &Option>, schema: &Schema, udfs: &[UdfConfig], - ) -> Result { + ) -> Result { let op = match operand { Some(o) => Some(Box::new(self.parse_sql_expression( parse_aggregations, @@ -612,11 +543,11 @@ impl ExpressionBuilder { let conds = conditions .iter() .map(|cond| self.parse_sql_expression(parse_aggregations, cond, schema, udfs)) - .collect::, PipelineError>>()?; + .collect::, Error>>()?; let res = results .iter() .map(|r| self.parse_sql_expression(parse_aggregations, r, schema, udfs)) - .collect::, PipelineError>>()?; + .collect::, Error>>()?; let else_res = match else_result { Some(r) => Some(Box::new(self.parse_sql_expression( parse_aggregations, @@ -642,17 +573,17 @@ impl ExpressionBuilder { leading_field: &Option, schema: &Schema, udfs: &[UdfConfig], - ) -> Result { + ) -> Result { let right = self.parse_sql_expression(parse_aggregations, value, schema, udfs)?; - if leading_field.is_some() { + if let Some(leading_field) = leading_field { Ok(Expression::DateTimeFunction { fun: DateTimeFunctionType::Interval { - field: leading_field.unwrap(), + field: *leading_field, }, arg: Box::new(right), }) } else { - Err(InvalidExpression(format!("INTERVAL for {leading_field:?}"))) + Err(Error::MissingLeadingFieldInInterval) } } @@ -663,13 +594,13 @@ impl ExpressionBuilder { expr: &SqlExpr, schema: &Schema, udfs: &[UdfConfig], - ) -> Result { + ) -> Result { let arg = Box::new(self.parse_sql_expression(parse_aggregations, expr, schema, udfs)?); let operator = match op { SqlUnaryOperator::Not => UnaryOperatorType::Not, SqlUnaryOperator::Plus => UnaryOperatorType::Plus, SqlUnaryOperator::Minus => UnaryOperatorType::Minus, - _ => return Err(InvalidOperator(format!("{op:?}"))), + _ => return Err(Error::UnsupportedUnaryOperator(*op)), }; Ok(Expression::UnaryOperator { operator, arg }) @@ -683,7 +614,7 @@ impl ExpressionBuilder { right: &SqlExpr, schema: &Schema, udfs: &[UdfConfig], - ) -> Result { + ) -> Result { let left_op = self.parse_sql_expression(parse_aggregations, left, schema, udfs)?; let right_op = self.parse_sql_expression(parse_aggregations, right, schema, udfs)?; @@ -701,7 +632,7 @@ impl ExpressionBuilder { SqlBinaryOperator::Modulo => BinaryOperatorType::Mod, SqlBinaryOperator::And => BinaryOperatorType::And, SqlBinaryOperator::Or => BinaryOperatorType::Or, - _ => return Err(InvalidOperator(format!("{op:?}"))), + _ => return Err(Error::UnsupportedBinaryOperator(op.clone())), }; Ok(Expression::BinaryOperator { @@ -712,25 +643,25 @@ impl ExpressionBuilder { } #[cfg(not(feature = "bigdecimal"))] - fn parse_sql_number(n: &str) -> Result { + fn parse_sql_number(n: &str) -> Result { match n.parse::() { Ok(n) => Ok(Expression::Literal(Field::Int(n))), Err(_) => match n.parse::() { Ok(f) => Ok(Expression::Literal(Field::Float(OrderedFloat(f)))), - Err(_) => Err(InvalidValue(n.to_string())), + Err(_) => Err(Error::NotANumber(n.to_string())), }, } } #[cfg(feature = "bigdecimal")] - fn parse_sql_number(n: &bigdecimal::BigDecimal) -> Result { + fn parse_sql_number(n: &bigdecimal::BigDecimal) -> Result { use bigdecimal::ToPrimitive; if n.is_integer() { Ok(Expression::Literal(Field::Int(n.to_i64().unwrap()))) } else { match n.to_f64() { Some(f) => Ok(Expression::Literal(Field::Float(OrderedFloat(f)))), - None => Err(InvalidValue(n.to_string())), + None => Err(Error::NotANumber(n.to_string())), } } } @@ -745,7 +676,7 @@ impl ExpressionBuilder { escape_char: &Option, schema: &Schema, udfs: &[UdfConfig], - ) -> Result { + ) -> Result { let arg = self.parse_sql_expression(parse_aggregations, expr, schema, udfs)?; let pattern = self.parse_sql_expression(parse_aggregations, pattern, schema, udfs)?; let like_expression = Expression::Like { @@ -770,7 +701,7 @@ impl ExpressionBuilder { expr: &Expr, schema: &Schema, udfs: &[UdfConfig], - ) -> Result { + ) -> Result { let right = self.parse_sql_expression(parse_aggregations, expr, schema, udfs)?; Ok(Expression::DateTimeFunction { fun: DateTimeFunctionType::Extract { field: *field }, @@ -785,38 +716,34 @@ impl ExpressionBuilder { data_type: &DataType, schema: &Schema, udfs: &[UdfConfig], - ) -> Result { + ) -> Result { let expression = self.parse_sql_expression(parse_aggregations, expr, schema, udfs)?; let cast_to = match data_type { - DataType::Decimal(_) => CastOperatorType::Decimal, - DataType::Binary(_) => CastOperatorType::Binary, - DataType::Float(_) => CastOperatorType::Float, - DataType::Int(_) => CastOperatorType::Int, - DataType::Integer(_) => CastOperatorType::Int, - DataType::UnsignedInt(_) => CastOperatorType::UInt, - DataType::UnsignedInteger(_) => CastOperatorType::UInt, - DataType::Boolean => CastOperatorType::Boolean, - DataType::Date => CastOperatorType::Date, - DataType::Timestamp(..) => CastOperatorType::Timestamp, - DataType::Text => CastOperatorType::Text, - DataType::String => CastOperatorType::String, - DataType::JSON => CastOperatorType::Json, + DataType::Decimal(_) => CastOperatorType(FieldType::Decimal), + DataType::Binary(_) => CastOperatorType(FieldType::Binary), + DataType::Float(_) => CastOperatorType(FieldType::Float), + DataType::Int(_) => CastOperatorType(FieldType::Int), + DataType::Integer(_) => CastOperatorType(FieldType::Int), + DataType::UnsignedInt(_) => CastOperatorType(FieldType::UInt), + DataType::UnsignedInteger(_) => CastOperatorType(FieldType::UInt), + DataType::Boolean => CastOperatorType(FieldType::Boolean), + DataType::Date => CastOperatorType(FieldType::Date), + DataType::Timestamp(..) => CastOperatorType(FieldType::Timestamp), + DataType::Text => CastOperatorType(FieldType::Text), + DataType::String => CastOperatorType(FieldType::String), + DataType::JSON => CastOperatorType(FieldType::Json), DataType::Custom(name, ..) => { if name.to_string().to_lowercase() == "uint" { - CastOperatorType::UInt + CastOperatorType(FieldType::UInt) } else if name.to_string().to_lowercase() == "u128" { - CastOperatorType::U128 + CastOperatorType(FieldType::U128) } else if name.to_string().to_lowercase() == "i128" { - CastOperatorType::I128 + CastOperatorType(FieldType::I128) } else { - Err(PipelineError::InvalidFunction(format!( - "Unsupported Cast type {name}" - )))? + return Err(Error::UnsupportedDataType(data_type.clone())); } } - _ => Err(PipelineError::InvalidFunction(format!( - "Unsupported Cast type {data_type}" - )))?, + _ => Err(Error::UnsupportedDataType(data_type.clone()))?, }; Ok(Expression::Cast { arg: Box::new(expression), @@ -824,7 +751,7 @@ impl ExpressionBuilder { }) } - fn parse_sql_string(s: &str) -> Result { + fn parse_sql_string(s: &str) -> Result { Ok(Expression::Literal(Field::String(s.to_owned()))) } @@ -836,7 +763,7 @@ impl ExpressionBuilder { ident_tokens.join(".") } - pub(crate) fn normalize_ident(id: &Ident) -> String { + pub fn normalize_ident(id: &Ident) -> String { match id.quote_style { Some(_) => id.value.clone(), None => id.value.clone(), @@ -850,26 +777,24 @@ impl ExpressionBuilder { function: &Function, schema: &Schema, udfs: &[UdfConfig], - ) -> Result { + ) -> Result { + use crate::python_udf::Error::{FailedToParseReturnType, MissingReturnType}; + // First, get python function define by name. // Then, transfer python function to Expression::PythonUDF - use dozer_types::types::FieldType; - use PipelineError::InvalidQuery; - let args = function .args .iter() .map(|argument| self.parse_sql_function_arg(false, argument, schema, udfs)) - .collect::, PipelineError>>()?; + .collect::, Error>>()?; let return_type = { let ident = function .return_type .as_ref() - .ok_or_else(|| InvalidQuery("Python UDF must have a return type. The syntax is: function_name(arguments)".to_string()))?; + .ok_or_else(|| MissingReturnType)?; - FieldType::try_from(ident.value.as_str()) - .map_err(|e| InvalidQuery(format!("Failed to parse Python UDF return type: {e}")))? + FieldType::try_from(ident.value.as_str()).map_err(FailedToParseReturnType)? }; Ok(Expression::PythonUDF { @@ -883,16 +808,17 @@ impl ExpressionBuilder { fn parse_onnx_udf( &mut self, name: String, - config: &OnnxConfig, + config: &dozer_types::models::udf_config::OnnxConfig, function: &Function, schema: &Schema, udfs: &[UdfConfig], - ) -> Result { + ) -> Result { + use crate::error::Error::Onnx; + use crate::onnx::error::Error::OnnxOrtErr; + // First, get onnx function define by name. // Then, transfer onnx function to Expression::OnnxUDF - use crate::pipeline::expression::onnx::onnx_utils::{ - onnx_input_validation, onnx_output_validation, - }; + use crate::onnx::utils::{onnx_input_validation, onnx_output_validation}; use ort::{Environment, GraphOptimizationLevel, LoggingLevel, SessionBuilder}; use std::path::Path; @@ -900,23 +826,23 @@ impl ExpressionBuilder { .args .iter() .map(|argument| self.parse_sql_function_arg(false, argument, schema, udfs)) - .collect::, PipelineError>>()?; + .collect::, Error>>()?; let environment = Environment::builder() .with_name("dozer_onnx") .with_log_level(LoggingLevel::Verbose) .build() - .map_err(|e| OnnxError(OnnxOrtErr(e)))? + .map_err(|e| Onnx(OnnxOrtErr(e)))? .into_arc(); let session = SessionBuilder::new(&environment) - .map_err(|e| OnnxError(OnnxOrtErr(e)))? + .map_err(|e| Onnx(OnnxOrtErr(e)))? .with_optimization_level(GraphOptimizationLevel::Level1) - .map_err(|e| OnnxError(OnnxOrtErr(e)))? + .map_err(|e| Onnx(OnnxOrtErr(e)))? .with_intra_threads(1) - .map_err(|e| OnnxError(OnnxOrtErr(e)))? + .map_err(|e| Onnx(OnnxOrtErr(e)))? .with_model_from_file(Path::new(config.path.as_str())) - .map_err(|e| OnnxError(OnnxOrtErr(e)))?; + .map_err(|e| Onnx(OnnxOrtErr(e)))?; // input number, type, shape validation onnx_input_validation(schema, &args, &session.inputs)?; @@ -925,7 +851,7 @@ impl ExpressionBuilder { Ok(Expression::OnnxUDF { name, - session: DozerSession(session.into()), + session: crate::onnx::DozerSession(session.into()), args, }) } @@ -938,12 +864,12 @@ impl ExpressionBuilder { negated: bool, schema: &Schema, udfs: &[UdfConfig], - ) -> Result { + ) -> Result { let expr = self.parse_sql_expression(parse_aggregations, expr, schema, udfs)?; let list = list .iter() .map(|expr| self.parse_sql_expression(parse_aggregations, expr, schema, udfs)) - .collect::, PipelineError>>()?; + .collect::, Error>>()?; let in_list_expression = Expression::InList { expr: Box::new(expr), list, diff --git a/dozer-sql/src/pipeline/expression/case.rs b/dozer-sql/expression/src/case.rs similarity index 86% rename from dozer-sql/src/pipeline/expression/case.rs rename to dozer-sql/expression/src/case.rs index ed867bc297..24adcf82c4 100644 --- a/dozer-sql/src/pipeline/expression/case.rs +++ b/dozer-sql/expression/src/case.rs @@ -1,9 +1,10 @@ -use crate::pipeline::errors::PipelineError; -use crate::pipeline::expression::execution::Expression; use dozer_types::types::Record; use dozer_types::types::{Field, Schema}; use std::iter::zip; +use crate::error::Error; +use crate::execution::Expression; + pub fn evaluate_case( schema: &Schema, _operand: &Option>, @@ -11,7 +12,7 @@ pub fn evaluate_case( results: &Vec, else_result: &Option>, record: &Record, -) -> Result { +) -> Result { let iter = zip(conditions, results); for (cond, res) in iter { let field = cond.evaluate(record, schema)?; diff --git a/dozer-sql/expression/src/cast.rs b/dozer-sql/expression/src/cast.rs new file mode 100644 index 0000000000..9c242707b4 --- /dev/null +++ b/dozer-sql/expression/src/cast.rs @@ -0,0 +1,355 @@ +use std::fmt::{Display, Formatter}; + +use dozer_types::types::Record; +use dozer_types::{ + ordered_float::OrderedFloat, + types::{Field, FieldType, Schema}, +}; + +use crate::arg_utils::validate_arg_type; +use crate::error::Error; + +use super::execution::{Expression, ExpressionType}; + +#[allow(dead_code)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] +pub struct CastOperatorType(pub FieldType); + +impl Display for CastOperatorType { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self.0 { + FieldType::UInt => f.write_str("CAST AS UINT"), + FieldType::U128 => f.write_str("CAST AS U128"), + FieldType::Int => f.write_str("CAST AS INT"), + FieldType::I128 => f.write_str("CAST AS I128"), + FieldType::Float => f.write_str("CAST AS FLOAT"), + FieldType::Boolean => f.write_str("CAST AS BOOLEAN"), + FieldType::String => f.write_str("CAST AS STRING"), + FieldType::Text => f.write_str("CAST AS TEXT"), + FieldType::Binary => f.write_str("CAST AS BINARY"), + FieldType::Decimal => f.write_str("CAST AS DECIMAL"), + FieldType::Timestamp => f.write_str("CAST AS TIMESTAMP"), + FieldType::Date => f.write_str("CAST AS DATE"), + FieldType::Json => f.write_str("CAST AS JSON"), + FieldType::Point => f.write_str("CAST AS POINT"), + FieldType::Duration => f.write_str("CAST AS DURATION"), + } + } +} + +impl CastOperatorType { + pub(crate) fn evaluate( + &self, + schema: &Schema, + arg: &Expression, + record: &Record, + ) -> Result { + let field = arg.evaluate(record, schema)?; + cast_field(&field, self.0) + } + + pub(crate) fn get_return_type( + &self, + schema: &Schema, + arg: &Expression, + ) -> Result { + let (expected_input_type, return_type) = match self.0 { + FieldType::UInt => ( + vec![ + FieldType::Int, + FieldType::String, + FieldType::UInt, + FieldType::I128, + FieldType::U128, + FieldType::Json, + ], + FieldType::UInt, + ), + FieldType::U128 => ( + vec![ + FieldType::Int, + FieldType::String, + FieldType::UInt, + FieldType::I128, + FieldType::U128, + FieldType::Json, + ], + FieldType::U128, + ), + FieldType::Int => ( + vec![ + FieldType::Int, + FieldType::String, + FieldType::UInt, + FieldType::I128, + FieldType::U128, + FieldType::Json, + ], + FieldType::Int, + ), + FieldType::I128 => ( + vec![ + FieldType::Int, + FieldType::String, + FieldType::UInt, + FieldType::I128, + FieldType::U128, + FieldType::Json, + ], + FieldType::I128, + ), + FieldType::Float => ( + vec![ + FieldType::Decimal, + FieldType::Float, + FieldType::Int, + FieldType::I128, + FieldType::String, + FieldType::UInt, + FieldType::U128, + FieldType::Json, + ], + FieldType::Float, + ), + FieldType::Boolean => ( + vec![ + FieldType::Boolean, + FieldType::Decimal, + FieldType::Float, + FieldType::Int, + FieldType::I128, + FieldType::UInt, + FieldType::U128, + FieldType::Json, + ], + FieldType::Boolean, + ), + FieldType::String => ( + vec![ + FieldType::Binary, + FieldType::Boolean, + FieldType::Date, + FieldType::Decimal, + FieldType::Float, + FieldType::Int, + FieldType::I128, + FieldType::String, + FieldType::Text, + FieldType::Timestamp, + FieldType::UInt, + FieldType::U128, + FieldType::Json, + ], + FieldType::String, + ), + FieldType::Text => ( + vec![ + FieldType::Binary, + FieldType::Boolean, + FieldType::Date, + FieldType::Decimal, + FieldType::Float, + FieldType::Int, + FieldType::I128, + FieldType::String, + FieldType::Text, + FieldType::Timestamp, + FieldType::UInt, + FieldType::U128, + FieldType::Json, + ], + FieldType::Text, + ), + FieldType::Binary => (vec![FieldType::Binary], FieldType::Binary), + FieldType::Decimal => ( + vec![ + FieldType::Decimal, + FieldType::Float, + FieldType::Int, + FieldType::I128, + FieldType::String, + FieldType::UInt, + FieldType::U128, + ], + FieldType::Decimal, + ), + FieldType::Timestamp => ( + vec![FieldType::String, FieldType::Timestamp], + FieldType::Timestamp, + ), + FieldType::Date => (vec![FieldType::Date, FieldType::String], FieldType::Date), + FieldType::Json => ( + vec![ + FieldType::Boolean, + FieldType::Float, + FieldType::Int, + FieldType::I128, + FieldType::String, + FieldType::Text, + FieldType::UInt, + FieldType::U128, + FieldType::Json, + ], + FieldType::Json, + ), + FieldType::Point => (vec![FieldType::Point], FieldType::Point), + FieldType::Duration => ( + vec![ + FieldType::UInt, + FieldType::U128, + FieldType::Int, + FieldType::I128, + FieldType::Duration, + FieldType::String, + FieldType::Text, + ], + FieldType::Duration, + ), + }; + + let expression_type = validate_arg_type(arg, expected_input_type, schema, self, 0)?; + Ok(ExpressionType { + return_type, + nullable: expression_type.nullable, + source: expression_type.source, + is_primary_key: expression_type.is_primary_key, + }) + } +} + +pub fn cast_field(input: &Field, output_type: FieldType) -> Result { + match output_type { + FieldType::UInt => { + if let Some(value) = input.to_uint() { + Ok(Field::UInt(value)) + } else { + Err(Error::InvalidCast { + from: input.clone(), + to: FieldType::UInt, + }) + } + } + FieldType::U128 => { + if let Some(value) = input.to_u128() { + Ok(Field::U128(value)) + } else { + Err(Error::InvalidCast { + from: input.clone(), + to: FieldType::U128, + }) + } + } + FieldType::Int => { + if let Some(value) = input.to_int() { + Ok(Field::Int(value)) + } else { + Err(Error::InvalidCast { + from: input.clone(), + to: FieldType::Int, + }) + } + } + FieldType::I128 => { + if let Some(value) = input.to_i128() { + Ok(Field::I128(value)) + } else { + Err(Error::InvalidCast { + from: input.clone(), + to: FieldType::I128, + }) + } + } + FieldType::Float => { + if let Some(value) = input.to_float() { + Ok(Field::Float(OrderedFloat(value))) + } else { + Err(Error::InvalidCast { + from: input.clone(), + to: FieldType::Float, + }) + } + } + FieldType::Boolean => { + if let Some(value) = input.to_boolean() { + Ok(Field::Boolean(value)) + } else { + Err(Error::InvalidCast { + from: input.clone(), + to: FieldType::Boolean, + }) + } + } + FieldType::String => Ok(Field::String(input.to_string())), + FieldType::Text => Ok(Field::Text(input.to_text())), + FieldType::Binary => { + if let Some(value) = input.to_binary() { + Ok(Field::Binary(value.to_vec())) + } else { + Err(Error::InvalidCast { + from: input.clone(), + to: FieldType::Binary, + }) + } + } + FieldType::Decimal => { + if let Some(value) = input.to_decimal() { + Ok(Field::Decimal(value)) + } else { + Err(Error::InvalidCast { + from: input.clone(), + to: FieldType::Decimal, + }) + } + } + FieldType::Timestamp => { + if let Some(value) = input.to_timestamp() { + Ok(Field::Timestamp(value)) + } else { + Err(Error::InvalidCast { + from: input.clone(), + to: FieldType::Timestamp, + }) + } + } + FieldType::Date => { + if let Some(value) = input.to_date() { + Ok(Field::Date(value)) + } else { + Err(Error::InvalidCast { + from: input.clone(), + to: FieldType::Date, + }) + } + } + FieldType::Json => { + if let Some(value) = input.to_json() { + Ok(Field::Json(value)) + } else { + Err(Error::InvalidCast { + from: input.clone(), + to: FieldType::Json, + }) + } + } + FieldType::Point => { + if let Some(value) = input.to_point() { + Ok(Field::Point(value)) + } else { + Err(Error::InvalidCast { + from: input.clone(), + to: FieldType::Point, + }) + } + } + FieldType::Duration => { + if let Some(value) = input.to_duration() { + Ok(Field::Duration(value)) + } else { + Err(Error::InvalidCast { + from: input.clone(), + to: FieldType::Duration, + }) + } + } + } +} diff --git a/dozer-sql/src/pipeline/expression/comparison.rs b/dozer-sql/expression/src/comparison/mod.rs similarity index 99% rename from dozer-sql/src/pipeline/expression/comparison.rs rename to dozer-sql/expression/src/comparison/mod.rs index 90c327daf8..4c7c17da08 100644 --- a/dozer-sql/src/pipeline/expression/comparison.rs +++ b/dozer-sql/expression/src/comparison/mod.rs @@ -1,14 +1,14 @@ -use crate::pipeline::errors::PipelineError; +use crate::error::Error as PipelineError; +use crate::execution::Expression; use dozer_types::chrono::{DateTime, NaiveDate}; use dozer_types::rust_decimal::Decimal; use dozer_types::types::Record; +use dozer_types::types::DATE_FORMAT; use dozer_types::types::{DozerDuration, DozerPoint, Field, Schema, TimeUnit}; use num_traits::cast::*; use std::str::FromStr; use std::time::Duration; -pub const DATE_FORMAT: &str = "%Y-%m-%d"; - macro_rules! define_comparison { ($id:ident, $op:expr, $function:expr) => { pub fn $id( @@ -625,8 +625,6 @@ macro_rules! define_comparison { }; } -use crate::pipeline::expression::execution::Expression; - pub fn evaluate_lt( schema: &Schema, left: &Expression, @@ -1718,3 +1716,6 @@ define_comparison!(evaluate_eq, "=", eq); define_comparison!(evaluate_ne, "!=", ne); define_comparison!(evaluate_lte, "<=", le); define_comparison!(evaluate_gte, ">=", ge); + +#[cfg(test)] +mod tests; diff --git a/dozer-sql/src/pipeline/expression/tests/comparison.rs b/dozer-sql/expression/src/comparison/tests.rs similarity index 81% rename from dozer-sql/src/pipeline/expression/tests/comparison.rs rename to dozer-sql/expression/src/comparison/tests.rs index 4c6026f7b2..d04e2d4018 100644 --- a/dozer-sql/src/pipeline/expression/tests/comparison.rs +++ b/dozer-sql/expression/src/comparison/tests.rs @@ -1,19 +1,11 @@ -use crate::pipeline::expression::comparison::{ - evaluate_eq, evaluate_gt, evaluate_gte, evaluate_lt, evaluate_lte, evaluate_ne, DATE_FORMAT, -}; -use crate::pipeline::expression::execution::Expression; -use crate::pipeline::expression::execution::Expression::Literal; -use crate::pipeline::expression::tests::test_common::*; -use dozer_types::chrono::{DateTime, NaiveDate}; -use dozer_types::types::Record; -use dozer_types::types::{FieldDefinition, FieldType, SourceDefinition}; -use dozer_types::{ - ordered_float::OrderedFloat, - rust_decimal::Decimal, - types::{Field, Schema}, -}; +use crate::tests::ArbitraryDecimal; + +use super::*; + +use dozer_types::{ordered_float::OrderedFloat, rust_decimal::Decimal}; use num_traits::FromPrimitive; use proptest::prelude::*; +use Expression::Literal; #[test] fn test_comparison() { @@ -524,172 +516,3 @@ fn test_lte(exp1: &Expression, exp2: &Expression, row: &Record, result: Option= '124'", - schema.clone(), - record.clone(), - ); - assert_eq!(f, Field::Int(124)); - - let f = run_fct( - "SELECT id = '124' FROM users", - schema.clone(), - record.clone(), - ); - assert_eq!(f, Field::Boolean(true)); - - let f = run_fct( - "SELECT id < '124' FROM users", - schema.clone(), - record.clone(), - ); - assert_eq!(f, Field::Boolean(false)); - - let f = run_fct( - "SELECT id > '124' FROM users", - schema.clone(), - record.clone(), - ); - assert_eq!(f, Field::Boolean(false)); - - let f = run_fct( - "SELECT id <= '124' FROM users", - schema.clone(), - record.clone(), - ); - assert_eq!(f, Field::Boolean(true)); - - let f = run_fct("SELECT id >= '124' FROM users", schema, record); - assert_eq!(f, Field::Boolean(true)); -} - -#[test] -fn test_comparison_logical_timestamp() { - let f = run_fct( - "SELECT time = '2020-01-01T00:00:00Z' FROM users", - Schema::default() - .field( - FieldDefinition::new( - String::from("time"), - FieldType::Timestamp, - false, - SourceDefinition::Dynamic, - ), - false, - ) - .clone(), - vec![Field::Timestamp( - DateTime::parse_from_rfc3339("2020-01-01T00:00:00Z").unwrap(), - )], - ); - assert_eq!(f, Field::Boolean(true)); - - let f = run_fct( - "SELECT time < '2020-01-01T00:00:01Z' FROM users", - Schema::default() - .field( - FieldDefinition::new( - String::from("time"), - FieldType::Timestamp, - false, - SourceDefinition::Dynamic, - ), - false, - ) - .clone(), - vec![Field::Timestamp( - DateTime::parse_from_rfc3339("2020-01-01T00:00:00Z").unwrap(), - )], - ); - assert_eq!(f, Field::Boolean(true)); -} - -#[test] -fn test_comparison_logical_date() { - let f = run_fct( - "SELECT date = '2020-01-01' FROM users", - Schema::default() - .field( - FieldDefinition::new( - String::from("date"), - FieldType::Int, - false, - SourceDefinition::Dynamic, - ), - false, - ) - .clone(), - vec![Field::Date( - NaiveDate::parse_from_str("2020-01-01", DATE_FORMAT).unwrap(), - )], - ); - assert_eq!(f, Field::Boolean(true)); - - let f = run_fct( - "SELECT date != '2020-01-01' FROM users", - Schema::default() - .field( - FieldDefinition::new( - String::from("date"), - FieldType::Int, - false, - SourceDefinition::Dynamic, - ), - false, - ) - .clone(), - vec![Field::Date( - NaiveDate::parse_from_str("2020-01-01", DATE_FORMAT).unwrap(), - )], - ); - assert_eq!(f, Field::Boolean(false)); - - let f = run_fct( - "SELECT date > '2020-01-01' FROM users", - Schema::default() - .field( - FieldDefinition::new( - String::from("date"), - FieldType::Int, - false, - SourceDefinition::Dynamic, - ), - false, - ) - .clone(), - vec![Field::Date( - NaiveDate::parse_from_str("2020-01-02", DATE_FORMAT).unwrap(), - )], - ); - assert_eq!(f, Field::Boolean(true)); -} diff --git a/dozer-sql/expression/src/conditional.rs b/dozer-sql/expression/src/conditional.rs new file mode 100644 index 0000000000..2e0cd7062f --- /dev/null +++ b/dozer-sql/expression/src/conditional.rs @@ -0,0 +1,334 @@ +use crate::error::Error; +use crate::execution::{Expression, ExpressionType}; +use dozer_types::types::Record; +use dozer_types::types::{Field, FieldType, Schema}; +use std::fmt::{Display, Formatter}; + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] +pub enum ConditionalExpressionType { + Coalesce, + NullIf, +} + +pub(crate) fn get_conditional_expr_type( + function: &ConditionalExpressionType, + args: &[Expression], + schema: &Schema, +) -> Result { + match function { + ConditionalExpressionType::Coalesce => validate_coalesce(args, schema), + ConditionalExpressionType::NullIf => todo!(), + } +} + +impl ConditionalExpressionType { + pub(crate) fn new(name: &str) -> Option { + match name { + "coalesce" => Some(ConditionalExpressionType::Coalesce), + "nullif" => Some(ConditionalExpressionType::NullIf), + _ => None, + } + } + + pub(crate) fn evaluate( + &self, + schema: &Schema, + args: &[Expression], + record: &Record, + ) -> Result { + match self { + ConditionalExpressionType::Coalesce => evaluate_coalesce(schema, args, record), + ConditionalExpressionType::NullIf => todo!(), + } + } +} + +pub(crate) fn validate_coalesce( + args: &[Expression], + schema: &Schema, +) -> Result { + if args.is_empty() { + return Err(Error::EmptyCoalesceArguments); + } + + let return_types = args + .iter() + .map(|expr| expr.get_type(schema).unwrap().return_type) + .collect::>(); + let return_type = return_types[0]; + + Ok(ExpressionType::new( + return_type, + false, + dozer_types::types::SourceDefinition::Dynamic, + false, + )) +} + +pub(crate) fn evaluate_coalesce( + schema: &Schema, + args: &[Expression], + record: &Record, +) -> Result { + // The COALESCE function returns the first of its arguments that is not null. + for expr in args { + let field = expr.evaluate(record, schema)?; + if field != Field::Null { + return Ok(field); + } + } + // Null is returned only if all arguments are null. + Ok(Field::Null) +} + +impl Display for ConditionalExpressionType { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + ConditionalExpressionType::Coalesce => f.write_str("COALESCE"), + ConditionalExpressionType::NullIf => f.write_str("NULLIF"), + } + } +} + +#[cfg(test)] +mod tests { + use crate::tests::{ArbitraryDateTime, ArbitraryDecimal}; + + use super::*; + + use dozer_types::{ + ordered_float::OrderedFloat, + types::{FieldDefinition, SourceDefinition}, + }; + use proptest::prelude::*; + + #[test] + fn test_coalesce() { + proptest!(ProptestConfig::with_cases(1000), move |( + u_num1: u64, u_num2: u64, i_num1: i64, i_num2: i64, f_num1: f64, f_num2: f64, + d_num1: ArbitraryDecimal, d_num2: ArbitraryDecimal, + s_val1: String, s_val2: String, + dt_val1: ArbitraryDateTime, dt_val2: ArbitraryDateTime)| { + let uint1 = Expression::Literal(Field::UInt(u_num1)); + let uint2 = Expression::Literal(Field::UInt(u_num2)); + let int1 = Expression::Literal(Field::Int(i_num1)); + let int2 = Expression::Literal(Field::Int(i_num2)); + let float1 = Expression::Literal(Field::Float(OrderedFloat(f_num1))); + let float2 = Expression::Literal(Field::Float(OrderedFloat(f_num2))); + let dec1 = Expression::Literal(Field::Decimal(d_num1.0)); + let dec2 = Expression::Literal(Field::Decimal(d_num2.0)); + let str1 = Expression::Literal(Field::String(s_val1.clone())); + let str2 = Expression::Literal(Field::String(s_val2)); + let t1 = Expression::Literal(Field::Timestamp(dt_val1.0)); + let t2 = Expression::Literal(Field::Timestamp(dt_val1.0)); + let dt1 = Expression::Literal(Field::Date(dt_val1.0.date_naive())); + let dt2 = Expression::Literal(Field::Date(dt_val2.0.date_naive())); + let null = Expression::Column{ index: 0usize }; + + // UInt + let typ = FieldType::UInt; + let f = Field::UInt(u_num1); + let row = Record::new(vec![f.clone()]); + + let args = vec![null.clone(), uint1.clone()]; + test_validate_coalesce(&args, typ); + test_evaluate_coalesce(&args, &row, typ, f.clone()); + + let args = vec![null.clone(), null.clone(), uint1.clone()]; + test_validate_coalesce(&args, typ); + test_evaluate_coalesce(&args, &row, typ, f.clone()); + + let args = vec![null.clone(), null.clone(), uint1.clone(), null.clone()]; + test_validate_coalesce(&args, typ); + test_evaluate_coalesce(&args, &row, typ, f.clone()); + + let args = vec![null.clone(), null.clone(), uint1, uint2]; + test_validate_coalesce(&args, typ); + test_evaluate_coalesce(&args, &row, typ, f); + + // Int + let typ = FieldType::Int; + let f = Field::Int(i_num1); + let row = Record::new(vec![f.clone()]); + + let args = vec![null.clone(), int1.clone()]; + test_validate_coalesce(&args, typ); + test_evaluate_coalesce(&args, &row, typ, f.clone()); + + let args = vec![null.clone(), null.clone(), int1.clone()]; + test_validate_coalesce(&args, typ); + test_evaluate_coalesce(&args, &row, typ, f.clone()); + + let args = vec![null.clone(), null.clone(), int1.clone(), null.clone()]; + test_validate_coalesce(&args, typ); + test_evaluate_coalesce(&args, &row, typ, f.clone()); + + let args = vec![null.clone(), null.clone(), int1, int2]; + test_validate_coalesce(&args, typ); + test_evaluate_coalesce(&args, &row, typ, f); + + // Float + let typ = FieldType::Float; + let f = Field::Float(OrderedFloat(f_num1)); + let row = Record::new(vec![f.clone()]); + + let args = vec![null.clone(), float1.clone()]; + test_validate_coalesce(&args, typ); + test_evaluate_coalesce(&args, &row, typ, f.clone()); + + let args = vec![null.clone(), null.clone(), float1.clone()]; + test_validate_coalesce(&args, typ); + test_evaluate_coalesce(&args, &row, typ, f.clone()); + + let args = vec![null.clone(), null.clone(), float1.clone(), null.clone()]; + test_validate_coalesce(&args, typ); + test_evaluate_coalesce(&args, &row, typ, f.clone()); + + let args = vec![null.clone(), null.clone(), float1, float2]; + test_validate_coalesce(&args, typ); + test_evaluate_coalesce(&args, &row, typ, f); + + // Decimal + let typ = FieldType::Decimal; + let f = Field::Decimal(d_num1.0); + let row = Record::new(vec![f.clone()]); + + let args = vec![null.clone(), dec1.clone()]; + test_validate_coalesce(&args, typ); + test_evaluate_coalesce(&args, &row, typ, f.clone()); + + let args = vec![null.clone(), null.clone(), dec1.clone()]; + test_validate_coalesce(&args, typ); + test_evaluate_coalesce(&args, &row, typ, f.clone()); + + let args = vec![null.clone(), null.clone(), dec1.clone(), null.clone()]; + test_validate_coalesce(&args, typ); + test_evaluate_coalesce(&args, &row, typ, f.clone()); + + let args = vec![null.clone(), null.clone(), dec1, dec2]; + test_validate_coalesce(&args, typ); + test_evaluate_coalesce(&args, &row, typ, f); + + // String + let typ = FieldType::String; + let f = Field::String(s_val1.clone()); + let row = Record::new(vec![f.clone()]); + + let args = vec![null.clone(), str1.clone()]; + test_validate_coalesce(&args, typ); + test_evaluate_coalesce(&args, &row, typ, f.clone()); + + let args = vec![null.clone(), null.clone(), str1.clone()]; + test_validate_coalesce(&args, typ); + test_evaluate_coalesce(&args, &row, typ, f.clone()); + + let args = vec![null.clone(), null.clone(), str1.clone(), null.clone()]; + test_validate_coalesce(&args, typ); + test_evaluate_coalesce(&args, &row, typ, f.clone()); + + let args = vec![null.clone(), null.clone(), str1.clone(), str2.clone()]; + test_validate_coalesce(&args, typ); + test_evaluate_coalesce(&args, &row, typ, f); + + // String + let typ = FieldType::String; + let f = Field::String(s_val1); + let row = Record::new(vec![f.clone()]); + + let args = vec![null.clone(), str1.clone()]; + test_validate_coalesce(&args, typ); + test_evaluate_coalesce(&args, &row, typ, f.clone()); + + let args = vec![null.clone(), null.clone(), str1.clone()]; + test_validate_coalesce(&args, typ); + test_evaluate_coalesce(&args, &row, typ, f.clone()); + + let args = vec![null.clone(), null.clone(), str1.clone(), null.clone()]; + test_validate_coalesce(&args, typ); + test_evaluate_coalesce(&args, &row, typ, f.clone()); + + let args = vec![null.clone(), null.clone(), str1, str2]; + test_validate_coalesce(&args, typ); + test_evaluate_coalesce(&args, &row, typ, f); + + // Timestamp + let typ = FieldType::Timestamp; + let f = Field::Timestamp(dt_val1.0); + let row = Record::new(vec![f.clone()]); + + let args = vec![null.clone(), t1.clone()]; + test_validate_coalesce(&args, typ); + test_evaluate_coalesce(&args, &row, typ, f.clone()); + + let args = vec![null.clone(), null.clone(), t1.clone()]; + test_validate_coalesce(&args, typ); + test_evaluate_coalesce(&args, &row, typ, f.clone()); + + let args = vec![null.clone(), null.clone(), t1.clone(), null.clone()]; + test_validate_coalesce(&args, typ); + test_evaluate_coalesce(&args, &row, typ, f.clone()); + + let args = vec![null.clone(), null.clone(), t1, t2]; + test_validate_coalesce(&args, typ); + test_evaluate_coalesce(&args, &row, typ, f); + + // Date + let typ = FieldType::Date; + let f = Field::Date(dt_val1.0.date_naive()); + let row = Record::new(vec![f.clone()]); + + let args = vec![null.clone(), dt1.clone()]; + test_validate_coalesce(&args, typ); + test_evaluate_coalesce(&args, &row, typ, f.clone()); + + let args = vec![null.clone(), null.clone(), dt1.clone()]; + test_validate_coalesce(&args, typ); + test_evaluate_coalesce(&args, &row, typ, f.clone()); + + let args = vec![null.clone(), null.clone(), dt1.clone(), null.clone()]; + test_validate_coalesce(&args, typ); + test_evaluate_coalesce(&args, &row, typ, f.clone()); + + let args = vec![null.clone(), null.clone(), dt1, dt2]; + test_validate_coalesce(&args, typ); + test_evaluate_coalesce(&args, &row, typ, f); + + // Null + let typ = FieldType::Date; + let f = Field::Null; + let row = Record::new(vec![f.clone()]); + + let args = vec![null.clone()]; + test_validate_coalesce(&args, typ); + test_evaluate_coalesce(&args, &row, typ, f.clone()); + + let args = vec![null.clone(), null]; + test_validate_coalesce(&args, typ); + test_evaluate_coalesce(&args, &row, typ, f); + }); + } + + fn test_validate_coalesce(args: &[Expression], typ: FieldType) { + let schema = Schema::default() + .field( + FieldDefinition::new(String::from("field"), typ, false, SourceDefinition::Dynamic), + false, + ) + .clone(); + + let result = validate_coalesce(args, &schema).unwrap().return_type; + assert_eq!(result, typ); + } + + fn test_evaluate_coalesce(args: &[Expression], row: &Record, typ: FieldType, _result: Field) { + let schema = Schema::default() + .field( + FieldDefinition::new(String::from("field"), typ, false, SourceDefinition::Dynamic), + false, + ) + .clone(); + + let res = evaluate_coalesce(&schema, args, row).unwrap(); + assert_eq!(res, _result); + } +} diff --git a/dozer-sql/src/pipeline/expression/datetime.rs b/dozer-sql/expression/src/datetime.rs similarity index 66% rename from dozer-sql/src/pipeline/expression/datetime.rs rename to dozer-sql/expression/src/datetime.rs index f75e31c255..e17b74352d 100644 --- a/dozer-sql/src/pipeline/expression/datetime.rs +++ b/dozer-sql/expression/src/datetime.rs @@ -1,9 +1,6 @@ -use crate::pipeline::errors::PipelineError::{ - InvalidFunction, InvalidFunctionArgument, InvalidFunctionArgumentType, -}; -use crate::pipeline::errors::{FieldTypes, PipelineError}; -use crate::pipeline::expression::datetime::PipelineError::InvalidValue; -use crate::pipeline::expression::execution::{Expression, ExpressionType}; +use crate::arg_utils::{extract_timestamp, extract_uint, validate_arg_type}; +use crate::error::Error; +use crate::execution::{Expression, ExpressionType}; use dozer_types::chrono::{DateTime, Datelike, FixedOffset, Offset, Timelike, Utc}; use dozer_types::types::Record; @@ -41,25 +38,20 @@ pub(crate) fn get_datetime_function_type( function: &DateTimeFunctionType, arg: &Expression, schema: &Schema, -) -> Result { - let return_type = arg.get_type(schema)?.return_type; - if return_type != FieldType::Date - && return_type != FieldType::Timestamp - && return_type != FieldType::Duration - && return_type != FieldType::String - { - return Err(InvalidFunctionArgumentType( - function.to_string(), - return_type, - FieldTypes::new(vec![ - FieldType::Date, - FieldType::Timestamp, - FieldType::Duration, - FieldType::String, - ]), - 0, - )); - } +) -> Result { + validate_arg_type( + arg, + vec![ + FieldType::Date, + FieldType::Timestamp, + FieldType::Duration, + FieldType::String, + FieldType::Text, + ], + schema, + function, + 0, + )?; match function { DateTimeFunctionType::Extract { field: _ } => Ok(ExpressionType::new( FieldType::Int, @@ -83,10 +75,10 @@ pub(crate) fn get_datetime_function_type( } impl DateTimeFunctionType { - pub(crate) fn new(name: &str) -> Result { + pub(crate) fn new(name: &str) -> Option { match name { - "now" => Ok(DateTimeFunctionType::Now), - _ => Err(InvalidFunction(name.to_string())), + "now" => Some(DateTimeFunctionType::Now), + _ => None, } } @@ -95,7 +87,7 @@ impl DateTimeFunctionType { schema: &Schema, arg: &Expression, record: &Record, - ) -> Result { + ) -> Result { match self { DateTimeFunctionType::Extract { field } => { evaluate_date_part(schema, field, arg, record) @@ -107,7 +99,7 @@ impl DateTimeFunctionType { } } - pub(crate) fn evaluate_now(&self) -> Result { + pub(crate) fn evaluate_now(&self) -> Result { Ok(Field::Timestamp(DateTime::::from(Utc::now()))) } } @@ -117,38 +109,10 @@ pub(crate) fn evaluate_date_part( field: &sqlparser::ast::DateTimeField, arg: &Expression, record: &Record, -) -> Result { +) -> Result { let value = arg.evaluate(record, schema)?; - let ts = match value { - Field::Timestamp(ts) => Ok(ts), - Field::Date(d) => d - .and_hms_milli_opt(0, 0, 0, 0) - .map(|ts| DateTime::from_utc(ts, Utc.fix())) - .ok_or(InvalidValue(format!( - "Unable to cast date {d} to timestamp" - ))), - Field::UInt(_) - | Field::U128(_) - | Field::Int(_) - | Field::I128(_) - | Field::Float(_) - | Field::Boolean(_) - | Field::String(_) - | Field::Text(_) - | Field::Binary(_) - | Field::Decimal(_) - | Field::Json(_) - | Field::Point(_) - | Field::Duration(_) - | Field::Null => { - return Err(InvalidFunctionArgument( - DateTimeFunctionType::Extract { field: *field }.to_string(), - value, - 0, - )) - } - }?; + let ts = extract_timestamp(value, DateTimeFunctionType::Extract { field: *field }, 0)?; match field { DateTimeField::Dow => ts.weekday().num_days_from_monday().to_i64(), @@ -178,9 +142,7 @@ pub(crate) fn evaluate_date_part( | DateTimeField::Date | DateTimeField::NoDateTime => None, } - .ok_or(PipelineError::InvalidOperandType(format!( - "Unable to extract date part {field} from {value}" - ))) + .ok_or(Error::UnsupportedExtract(*field)) .map(Field::Int) } @@ -189,33 +151,30 @@ pub(crate) fn evaluate_interval( field: &sqlparser::ast::DateTimeField, arg: &Expression, record: &Record, -) -> Result { +) -> Result { let value = arg.evaluate(record, schema)?; - let dur = value.to_duration()?.unwrap().0.as_nanos(); + let dur = extract_uint(value, DateTimeFunctionType::Interval { field: *field }, 0)?; match field { DateTimeField::Second => Ok(Field::Duration(DozerDuration( - std::time::Duration::from_secs(dur as u64), + std::time::Duration::from_secs(dur), TimeUnit::Seconds, ))), DateTimeField::Millisecond | DateTimeField::Milliseconds => { Ok(Field::Duration(DozerDuration( - std::time::Duration::from_millis(dur as u64), + std::time::Duration::from_millis(dur), TimeUnit::Milliseconds, ))) } DateTimeField::Microsecond | DateTimeField::Microseconds => { Ok(Field::Duration(DozerDuration( - std::time::Duration::from_micros(dur as u64), + std::time::Duration::from_micros(dur), TimeUnit::Microseconds, ))) } - DateTimeField::Nanoseconds | DateTimeField::Nanosecond => { - Ok(Field::Duration(DozerDuration( - std::time::Duration::from_nanos(dur as u64), - TimeUnit::Nanoseconds, - ))) - } + DateTimeField::Nanoseconds | DateTimeField::Nanosecond => Ok(Field::Duration( + DozerDuration(std::time::Duration::from_nanos(dur), TimeUnit::Nanoseconds), + )), DateTimeField::Isodow | DateTimeField::Timezone | DateTimeField::Dow @@ -237,8 +196,55 @@ pub(crate) fn evaluate_interval( | DateTimeField::Week | DateTimeField::Century | DateTimeField::Decade - | DateTimeField::Doy => Err(PipelineError::InvalidOperandType(format!( - "Unable to extract date part {field} from {value}" - ))), + | DateTimeField::Doy => Err(Error::UnsupportedInterval(*field)), + } +} + +#[cfg(test)] +mod tests { + use crate::tests::ArbitraryDateTime; + + use super::*; + + use proptest::prelude::*; + + #[test] + fn test_time() { + proptest!( + ProptestConfig::with_cases(1000), + move |(datetime: ArbitraryDateTime)| { + test_date_parts(datetime) + }); + } + + fn test_date_parts(datetime: ArbitraryDateTime) { + let row = Record::new(vec![]); + + let date_parts = vec![ + ( + DateTimeField::Dow, + datetime + .0 + .weekday() + .num_days_from_monday() + .to_i64() + .unwrap(), + ), + (DateTimeField::Year, datetime.0.year().to_i64().unwrap()), + (DateTimeField::Month, datetime.0.month().to_i64().unwrap()), + (DateTimeField::Hour, 0), + (DateTimeField::Second, 0), + ( + DateTimeField::Quarter, + datetime.0.month0().to_i64().map(|m| m / 3 + 1).unwrap(), + ), + ]; + + let v = Expression::Literal(Field::Date(datetime.0.date_naive())); + + for (part, value) in date_parts { + let result = evaluate_date_part(&Schema::default(), &part, &v, &row).unwrap(); + assert_eq!(result, Field::Int(value)); + } } } diff --git a/dozer-sql/expression/src/error.rs b/dozer-sql/expression/src/error.rs new file mode 100644 index 0000000000..0576d32776 --- /dev/null +++ b/dozer-sql/expression/src/error.rs @@ -0,0 +1,128 @@ +use std::ops::Range; + +use dozer_types::{ + thiserror::{self, Error}, + types::{Field, FieldType}, +}; +use sqlparser::ast::{ + BinaryOperator, DataType, DateTimeField, Expr, FunctionArg, Ident, UnaryOperator, +}; + +use crate::{aggregate::AggregateFunctionType, operator::BinaryOperatorType}; + +#[derive(Debug, Error)] +pub enum Error { + #[error("Unsupported SQL expression: {0:?}")] + UnsupportedExpression(Expr), + #[error("Unsupported SQL function arg: {0:?}")] + UnsupportedFunctionArg(FunctionArg), + #[error("Invalid ident: {}", .0.iter().map(|ident| ident.value.as_str()).collect::>().join("."))] + InvalidIdent(Vec), + #[error("Udf is defined but missing with config: {0}")] + UdfConfigMissing(String), + #[error("Unknown function: {0}")] + UnknownFunction(String), + #[error("Missing leading field in interval")] + MissingLeadingFieldInInterval, + #[error("Unsupported SQL unary operator: {0:?}")] + UnsupportedUnaryOperator(UnaryOperator), + #[error("Unsupported SQL binary operator: {0:?}")] + UnsupportedBinaryOperator(BinaryOperator), + #[error("Not a number: {0}")] + NotANumber(String), + #[error("Unsupported data type: {0}")] + UnsupportedDataType(DataType), + + #[error("Aggregate Function {0:?} should not be executed at this point")] + UnexpectedAggregationExecution(AggregateFunctionType), + #[error("literal expression cannot be null")] + LiteralExpressionIsNull, + #[error("cannot apply NOT to {0:?}")] + CannotApplyNotTo(FieldType), + #[error("cannot apply {operator:?} to {left_field_type:?} and {right_field_type:?}")] + CannotApplyBinaryOperator { + operator: BinaryOperatorType, + left_field_type: FieldType, + right_field_type: FieldType, + }, + #[error("expected {expected:?} arguments for function {function_name}, got {actual}")] + InvalidNumberOfArguments { + function_name: String, + expected: Range, + actual: usize, + }, + #[error("Empty coalesce arguments")] + EmptyCoalesceArguments, + #[error( + "Invalid argument type for function {function_name}: type: {actual}, expected types: {expected:?}, index: {argument_index}" + )] + InvalidFunctionArgumentType { + function_name: String, + argument_index: usize, + expected: Vec, + actual: FieldType, + }, + #[error("Invalid cast: from: {from}, to: {to}")] + InvalidCast { from: Field, to: FieldType }, + #[error("Invalid argument for function {function_name}(): argument: {argument}, index: {argument_index}")] + InvalidFunctionArgument { + function_name: String, + argument_index: usize, + argument: Field, + }, + + #[error("Invalid distance algorithm: {0}")] + InvalidDistanceAlgorithm(String), + #[error("Failed to calculate vincenty distance: {0}")] + FailedToCalculateVincentyDistance( + #[from] dozer_types::geo::vincenty_distance::FailedToConvergeError, + ), + + #[error("Invalid like escape: {0}")] + InvalidLikeEscape(#[from] like::InvalidEscapeError), + #[error("Invalid like pattern: {0}")] + InvalidLikePattern(#[from] like::InvalidPatternError), + + #[error("Unsupported extract: {0}")] + UnsupportedExtract(DateTimeField), + #[error("Unsupported interval: {0}")] + UnsupportedInterval(DateTimeField), + + #[error("Invalid json path: {0}")] + InvalidJsonPath(String), + + #[cfg(feature = "python")] + #[error("Python UDF error: {0}")] + PythonUdf(#[from] crate::python_udf::Error), + + #[cfg(feature = "onnx")] + #[error("ONNX UDF error: {0}")] + Onnx(#[from] crate::onnx::error::Error), + #[cfg(not(feature = "onnx"))] + #[error("ONNX UDF is not enabled")] + OnnxNotEnabled, + + // Legacy error types. + #[error("Sql error: {0}")] + SqlError(#[source] OperationError), + #[error("Invalid types on {0} and {1} for {2} operand")] + InvalidTypeComparison(Field, Field, String), + #[error("Unable to cast {0} to {1}")] + UnableToCast(String, String), + #[error("Invalid types on {0} for {1} operand")] + InvalidType(Field, String), +} + +#[derive(Error, Debug)] +pub enum OperationError { + #[error("SQL Error: Addition operation cannot be done due to overflow.")] + AdditionOverflow, + #[error("SQL Error: Subtraction operation cannot be done due to overflow.")] + SubtractionOverflow, + #[error("SQL Error: Multiplication operation cannot be done due to overflow.")] + MultiplicationOverflow, + #[error("SQL Error: Division operation cannot be done.")] + DivisionByZeroOrOverflow, + #[error("SQL Error: Modulo operation cannot be done.")] + ModuloByZeroOrOverflow, +} diff --git a/dozer-sql/src/pipeline/expression/execution.rs b/dozer-sql/expression/src/execution.rs similarity index 69% rename from dozer-sql/src/pipeline/expression/execution.rs rename to dozer-sql/expression/src/execution.rs index feee750e7a..2499d1a0af 100644 --- a/dozer-sql/src/pipeline/expression/execution.rs +++ b/dozer-sql/expression/src/execution.rs @@ -1,34 +1,21 @@ -use crate::pipeline::aggregation::avg::validate_avg; -use crate::pipeline::aggregation::count::validate_count; -use crate::pipeline::aggregation::max::validate_max; -use crate::pipeline::aggregation::min::validate_min; -use crate::pipeline::aggregation::sum::validate_sum; -use crate::pipeline::errors::PipelineError; -use crate::pipeline::expression::case::evaluate_case; -use crate::pipeline::expression::conditional::{ - get_conditional_expr_type, ConditionalExpressionType, -}; -use crate::pipeline::expression::datetime::{get_datetime_function_type, DateTimeFunctionType}; -use crate::pipeline::expression::geo::common::{get_geo_function_type, GeoFunctionType}; -use crate::pipeline::expression::json_functions::JsonFunctionType; -use crate::pipeline::expression::operator::{BinaryOperatorType, UnaryOperatorType}; -use crate::pipeline::expression::scalar::common::{get_scalar_function_type, ScalarFunctionType}; -use crate::pipeline::expression::scalar::string::{evaluate_trim, validate_trim, TrimType}; +use crate::arg_utils::{validate_one_argument, validate_two_arguments}; +use crate::case::evaluate_case; +use crate::conditional::{get_conditional_expr_type, ConditionalExpressionType}; +use crate::datetime::{get_datetime_function_type, DateTimeFunctionType}; +use crate::error::Error; +use crate::geo::common::{get_geo_function_type, GeoFunctionType}; +use crate::json_functions::JsonFunctionType; +use crate::operator::{BinaryOperatorType, UnaryOperatorType}; +use crate::scalar::common::{get_scalar_function_type, ScalarFunctionType}; +use crate::scalar::string::{evaluate_trim, validate_trim, TrimType}; use std::iter::zip; use super::aggregate::AggregateFunctionType; use super::cast::CastOperatorType; use super::in_list::evaluate_in_list; use super::scalar::string::{evaluate_like, get_like_operator_type}; -use crate::pipeline::aggregation::max_value::validate_max_value; -use crate::pipeline::aggregation::min_value::validate_min_value; -#[cfg(feature = "onnx")] -use crate::pipeline::expression::onnx::onnx_udf::evaluate_onnx_udf; -#[cfg(feature = "onnx")] -use crate::pipeline::onnx::DozerSession; use dozer_types::types::Record; use dozer_types::types::{Field, FieldType, Schema, SourceDefinition}; -use uuid::Uuid; #[derive(Clone, Debug, PartialEq)] pub enum Expression { @@ -106,7 +93,7 @@ pub enum Expression { #[cfg(feature = "onnx")] OnnxUDF { name: String, - session: DozerSession, + session: crate::onnx::DozerSession, args: Vec, }, } @@ -115,9 +102,7 @@ impl Expression { pub fn to_string(&self, schema: &Schema) -> String { match &self { Expression::Column { index } => schema.fields[*index].name.clone(), - Expression::Literal(value) => value - .to_string() - .unwrap_or_else(|| Uuid::new_v4().to_string()), + Expression::Literal(value) => format!("{}", value), Expression::UnaryOperator { operator, arg } => { operator.to_string() + arg.to_string(schema).as_str() } @@ -317,7 +302,7 @@ impl ExpressionType { } impl Expression { - pub fn evaluate(&self, record: &Record, schema: &Schema) -> Result { + pub fn evaluate(&self, record: &Record, schema: &Schema) -> Result { match self { Expression::Literal(field) => Ok(field.clone()), Expression::Column { index } => Ok(record.values[*index].clone()), @@ -335,7 +320,7 @@ impl Expression { return_type, .. } => { - use crate::pipeline::expression::python_udf::evaluate_py_udf; + use crate::python_udf::evaluate_py_udf; evaluate_py_udf(schema, name, args, return_type, record) } #[cfg(feature = "onnx")] @@ -346,14 +331,12 @@ impl Expression { .. } => { use std::borrow::Borrow; - evaluate_onnx_udf(schema, session.0.borrow(), args, record) + crate::onnx::udf::evaluate_onnx_udf(schema, session.0.borrow(), args, record) } Expression::UnaryOperator { operator, arg } => operator.evaluate(schema, arg, record), Expression::AggregateFunction { fun, args: _ } => { - Err(PipelineError::InvalidExpression(format!( - "Aggregate Function {fun:?} should not be executed at this point" - ))) + Err(Error::UnexpectedAggregationExecution(fun.clone())) } Expression::Trim { typ, what, arg } => evaluate_trim(schema, arg, what, typ, record), Expression::Like { @@ -381,7 +364,7 @@ impl Expression { } } - pub fn get_type(&self, schema: &Schema) -> Result { + pub fn get_type(&self, schema: &Schema) -> Result { match self { Expression::Literal(field) => { let field_type = get_field_type(field); @@ -392,9 +375,7 @@ impl Expression { SourceDefinition::Dynamic, false, )), - None => Err(PipelineError::InvalidExpression( - "literal expression cannot be null".to_string(), - )), + None => Err(Error::LiteralExpressionIsNull), } } Expression::Column { index } => { @@ -491,7 +472,7 @@ impl Expression { } } -fn get_field_type(field: &Field) -> Option { +pub fn get_field_type(field: &Field) -> Option { match field { Field::UInt(_) => Some(FieldType::UInt), Field::U128(_) => Some(FieldType::U128), @@ -516,14 +497,12 @@ fn get_unary_operator_type( operator: &UnaryOperatorType, expression: &Expression, schema: &Schema, -) -> Result { +) -> Result { let field_type = expression.get_type(schema)?; match operator { UnaryOperatorType::Not => match field_type.return_type { FieldType::Boolean => Ok(field_type), - field_type => Err(PipelineError::InvalidExpression(format!( - "cannot apply NOT to {field_type:?}" - ))), + field_type => Err(Error::CannotApplyNotTo(field_type)), }, UnaryOperatorType::Plus => Ok(field_type), UnaryOperatorType::Minus => Ok(field_type), @@ -535,7 +514,7 @@ fn get_binary_operator_type( operator: &BinaryOperatorType, right: &Expression, schema: &Schema, -) -> Result { +) -> Result { let left_field_type = left.get_type(schema)?; let right_field_type = right.get_type(schema)?; match operator { @@ -587,11 +566,11 @@ fn get_binary_operator_type( SourceDefinition::Dynamic, false, )), - (left_field_type, right_field_type) => { - Err(PipelineError::InvalidExpression(format!( - "cannot apply {operator:?} to {left_field_type:?} and {right_field_type:?}" - ))) - } + (left_field_type, right_field_type) => Err(Error::CannotApplyBinaryOperator { + operator: operator.clone(), + left_field_type, + right_field_type, + }), } } @@ -688,11 +667,11 @@ fn get_binary_operator_type( SourceDefinition::Dynamic, false, )), - (left_field_type, right_field_type) => { - Err(PipelineError::InvalidExpression(format!( - "cannot apply {operator:?} to {left_field_type:?} and {right_field_type:?}" - ))) - } + (left_field_type, right_field_type) => Err(Error::CannotApplyBinaryOperator { + operator: operator.clone(), + left_field_type, + right_field_type, + }), } } @@ -744,11 +723,11 @@ fn get_binary_operator_type( SourceDefinition::Dynamic, false, )), - (left_field_type, right_field_type) => { - Err(PipelineError::InvalidExpression(format!( - "cannot apply {operator:?} to {left_field_type:?} and {right_field_type:?}" - ))) - } + (left_field_type, right_field_type) => Err(Error::CannotApplyBinaryOperator { + operator: operator.clone(), + left_field_type, + right_field_type, + }), } } } @@ -758,7 +737,7 @@ fn get_aggregate_function_type( function: &AggregateFunctionType, args: &[Expression], schema: &Schema, -) -> Result { +) -> Result { match function { AggregateFunctionType::Avg => validate_avg(args, schema), AggregateFunctionType::Count => validate_count(args, schema), @@ -769,3 +748,281 @@ fn get_aggregate_function_type( AggregateFunctionType::Sum => validate_sum(args, schema), } } + +fn validate_avg(args: &[Expression], schema: &Schema) -> Result { + let arg = validate_one_argument(args, schema, AggregateFunctionType::Avg)?; + + let ret_type = match arg.return_type { + FieldType::UInt => FieldType::Decimal, + FieldType::U128 => FieldType::Decimal, + FieldType::Int => FieldType::Decimal, + FieldType::I128 => FieldType::Decimal, + FieldType::Float => FieldType::Float, + FieldType::Decimal => FieldType::Decimal, + FieldType::Duration => FieldType::Duration, + FieldType::Boolean + | FieldType::String + | FieldType::Text + | FieldType::Date + | FieldType::Timestamp + | FieldType::Binary + | FieldType::Json + | FieldType::Point => { + return Err(Error::InvalidFunctionArgumentType { + function_name: AggregateFunctionType::Avg.to_string(), + argument_index: 0, + actual: arg.return_type, + expected: vec![ + FieldType::UInt, + FieldType::U128, + FieldType::Int, + FieldType::I128, + FieldType::Float, + FieldType::Decimal, + FieldType::Duration, + ], + }); + } + }; + + Ok(ExpressionType::new( + ret_type, + true, + SourceDefinition::Dynamic, + false, + )) +} + +fn validate_count(_args: &[Expression], _schema: &Schema) -> Result { + Ok(ExpressionType::new( + FieldType::Int, + false, + SourceDefinition::Dynamic, + false, + )) +} + +fn validate_max(args: &[Expression], schema: &Schema) -> Result { + let arg = validate_one_argument(args, schema, AggregateFunctionType::Max)?; + + let ret_type = match arg.return_type { + FieldType::UInt => FieldType::UInt, + FieldType::U128 => FieldType::U128, + FieldType::Int => FieldType::Int, + FieldType::I128 => FieldType::I128, + FieldType::Float => FieldType::Float, + FieldType::Decimal => FieldType::Decimal, + FieldType::Timestamp => FieldType::Timestamp, + FieldType::Date => FieldType::Date, + FieldType::Duration => FieldType::Duration, + FieldType::Boolean + | FieldType::String + | FieldType::Text + | FieldType::Binary + | FieldType::Json + | FieldType::Point => { + return Err(Error::InvalidFunctionArgumentType { + function_name: AggregateFunctionType::Max.to_string(), + argument_index: 0, + actual: arg.return_type, + expected: vec![ + FieldType::Decimal, + FieldType::UInt, + FieldType::U128, + FieldType::Int, + FieldType::I128, + FieldType::Float, + FieldType::Timestamp, + FieldType::Date, + FieldType::Duration, + ], + }); + } + }; + Ok(ExpressionType::new( + ret_type, + true, + SourceDefinition::Dynamic, + false, + )) +} + +fn validate_min(args: &[Expression], schema: &Schema) -> Result { + let arg = validate_one_argument(args, schema, AggregateFunctionType::Min)?; + + let ret_type = match arg.return_type { + FieldType::UInt => FieldType::UInt, + FieldType::U128 => FieldType::U128, + FieldType::Int => FieldType::Int, + FieldType::I128 => FieldType::I128, + FieldType::Float => FieldType::Float, + FieldType::Decimal => FieldType::Decimal, + FieldType::Timestamp => FieldType::Timestamp, + FieldType::Date => FieldType::Date, + FieldType::Duration => FieldType::Duration, + FieldType::Boolean + | FieldType::String + | FieldType::Text + | FieldType::Binary + | FieldType::Json + | FieldType::Point => { + return Err(Error::InvalidFunctionArgumentType { + function_name: AggregateFunctionType::Min.to_string(), + argument_index: 0, + actual: arg.return_type, + expected: vec![ + FieldType::Decimal, + FieldType::UInt, + FieldType::U128, + FieldType::Int, + FieldType::I128, + FieldType::Float, + FieldType::Timestamp, + FieldType::Date, + FieldType::Duration, + ], + }); + } + }; + Ok(ExpressionType::new( + ret_type, + true, + SourceDefinition::Dynamic, + false, + )) +} + +fn validate_sum(args: &[Expression], schema: &Schema) -> Result { + let arg = validate_one_argument(args, schema, AggregateFunctionType::Sum)?; + + let ret_type = match arg.return_type { + FieldType::UInt => FieldType::UInt, + FieldType::U128 => FieldType::U128, + FieldType::Int => FieldType::Int, + FieldType::I128 => FieldType::I128, + FieldType::Float => FieldType::Float, + FieldType::Decimal => FieldType::Decimal, + FieldType::Duration => FieldType::Duration, + FieldType::Boolean + | FieldType::String + | FieldType::Text + | FieldType::Date + | FieldType::Timestamp + | FieldType::Binary + | FieldType::Json + | FieldType::Point => { + return Err(Error::InvalidFunctionArgumentType { + function_name: AggregateFunctionType::Sum.to_string(), + argument_index: 0, + actual: arg.return_type, + expected: vec![ + FieldType::UInt, + FieldType::U128, + FieldType::Int, + FieldType::I128, + FieldType::Float, + FieldType::Decimal, + FieldType::Duration, + ], + }); + } + }; + Ok(ExpressionType::new( + ret_type, + true, + SourceDefinition::Dynamic, + false, + )) +} + +fn validate_max_value(args: &[Expression], schema: &Schema) -> Result { + let (base_arg, arg) = validate_two_arguments(args, schema, AggregateFunctionType::MaxValue)?; + + match base_arg.return_type { + FieldType::UInt => FieldType::UInt, + FieldType::U128 => FieldType::U128, + FieldType::Int => FieldType::Int, + FieldType::I128 => FieldType::I128, + FieldType::Float => FieldType::Float, + FieldType::Decimal => FieldType::Decimal, + FieldType::Timestamp => FieldType::Timestamp, + FieldType::Date => FieldType::Date, + FieldType::Duration => FieldType::Duration, + FieldType::Boolean + | FieldType::String + | FieldType::Text + | FieldType::Binary + | FieldType::Json + | FieldType::Point => { + return Err(Error::InvalidFunctionArgumentType { + function_name: AggregateFunctionType::MaxValue.to_string(), + argument_index: 0, + actual: base_arg.return_type, + expected: vec![ + FieldType::Decimal, + FieldType::UInt, + FieldType::U128, + FieldType::Int, + FieldType::I128, + FieldType::Float, + FieldType::Timestamp, + FieldType::Date, + FieldType::Duration, + ], + }); + } + }; + + Ok(ExpressionType::new( + arg.return_type, + true, + SourceDefinition::Dynamic, + false, + )) +} + +fn validate_min_value(args: &[Expression], schema: &Schema) -> Result { + let (base_arg, arg) = validate_two_arguments(args, schema, AggregateFunctionType::MinValue)?; + + match base_arg.return_type { + FieldType::UInt => FieldType::UInt, + FieldType::U128 => FieldType::U128, + FieldType::Int => FieldType::Int, + FieldType::I128 => FieldType::I128, + FieldType::Float => FieldType::Float, + FieldType::Decimal => FieldType::Decimal, + FieldType::Timestamp => FieldType::Timestamp, + FieldType::Date => FieldType::Date, + FieldType::Duration => FieldType::Duration, + FieldType::Boolean + | FieldType::String + | FieldType::Text + | FieldType::Binary + | FieldType::Json + | FieldType::Point => { + return Err(Error::InvalidFunctionArgumentType { + function_name: AggregateFunctionType::MinValue.to_string(), + argument_index: 0, + actual: base_arg.return_type, + expected: vec![ + FieldType::Decimal, + FieldType::UInt, + FieldType::U128, + FieldType::Int, + FieldType::I128, + FieldType::Float, + FieldType::Timestamp, + FieldType::Date, + FieldType::Duration, + ], + }); + } + }; + + Ok(ExpressionType::new( + arg.return_type, + true, + SourceDefinition::Dynamic, + false, + )) +} diff --git a/dozer-sql/src/pipeline/expression/geo/common.rs b/dozer-sql/expression/src/geo/common.rs similarity index 64% rename from dozer-sql/src/pipeline/expression/geo/common.rs rename to dozer-sql/expression/src/geo/common.rs index 4b3558b5bf..180cda9e05 100644 --- a/dozer-sql/src/pipeline/expression/geo/common.rs +++ b/dozer-sql/expression/src/geo/common.rs @@ -1,8 +1,8 @@ -use crate::pipeline::errors::PipelineError; -use crate::pipeline::expression::execution::{Expression, ExpressionType}; +use crate::error::Error; +use crate::execution::{Expression, ExpressionType}; -use crate::pipeline::expression::geo::distance::{evaluate_distance, validate_distance}; -use crate::pipeline::expression::geo::point::{evaluate_point, validate_point}; +use crate::geo::distance::{evaluate_distance, validate_distance}; +use crate::geo::point::{evaluate_point, validate_point}; use dozer_types::types::Record; use dozer_types::types::{Field, Schema}; use std::fmt::{Display, Formatter}; @@ -26,7 +26,7 @@ pub(crate) fn get_geo_function_type( function: &GeoFunctionType, args: &[Expression], schema: &Schema, -) -> Result { +) -> Result { match function { GeoFunctionType::Point => validate_point(args, schema), GeoFunctionType::Distance => validate_distance(args, schema), @@ -34,11 +34,11 @@ pub(crate) fn get_geo_function_type( } impl GeoFunctionType { - pub fn new(name: &str) -> Result { + pub fn new(name: &str) -> Option { match name { - "point" => Ok(GeoFunctionType::Point), - "distance" => Ok(GeoFunctionType::Distance), - _ => Err(PipelineError::InvalidFunction(name.to_string())), + "point" => Some(GeoFunctionType::Point), + "distance" => Some(GeoFunctionType::Distance), + _ => None, } } @@ -47,7 +47,7 @@ impl GeoFunctionType { schema: &Schema, args: &[Expression], record: &Record, - ) -> Result { + ) -> Result { match self { GeoFunctionType::Point => evaluate_point(schema, args, record), GeoFunctionType::Distance => evaluate_distance(schema, args, record), diff --git a/dozer-sql/expression/src/geo/distance.rs b/dozer-sql/expression/src/geo/distance.rs new file mode 100644 index 0000000000..d1b58e0f4a --- /dev/null +++ b/dozer-sql/expression/src/geo/distance.rs @@ -0,0 +1,268 @@ +use std::str::FromStr; + +use crate::arg_utils::{extract_point, validate_num_arguments}; +use crate::error::Error; +use dozer_types::types::Record; +use dozer_types::types::{Field, FieldType, Schema}; + +use crate::execution::{Expression, ExpressionType}; +use crate::geo::common::GeoFunctionType; +use dozer_types::geo::GeodesicDistance; +use dozer_types::geo::HaversineDistance; +use dozer_types::geo::VincentyDistance; + +use dozer_types::ordered_float::OrderedFloat; + +const EXPECTED_ARGS_TYPES: &[FieldType] = &[FieldType::Point, FieldType::Point, FieldType::String]; + +pub enum Algorithm { + Geodesic, + Haversine, + Vincenty, +} + +impl FromStr for Algorithm { + type Err = Error; + + fn from_str(s: &str) -> Result { + match s { + "GEODESIC" => Ok(Algorithm::Geodesic), + "HAVERSINE" => Ok(Algorithm::Haversine), + "VINCENTY" => Ok(Algorithm::Vincenty), + &_ => Err(Error::InvalidDistanceAlgorithm(s.to_string())), + } + } +} + +const DEFAULT_ALGORITHM: Algorithm = Algorithm::Geodesic; + +pub(crate) fn validate_distance( + args: &[Expression], + schema: &Schema, +) -> Result { + let ret_type = FieldType::Float; + validate_num_arguments(2..4, args.len(), GeoFunctionType::Distance)?; + + for (argument_index, exp) in args.iter().enumerate() { + let return_type = exp.get_type(schema)?.return_type; + let expected_arg_type_option = EXPECTED_ARGS_TYPES.get(argument_index); + if let Some(expected_arg_type) = expected_arg_type_option { + if &return_type != expected_arg_type { + return Err(Error::InvalidFunctionArgumentType { + function_name: GeoFunctionType::Distance.to_string(), + argument_index, + actual: return_type, + expected: vec![*expected_arg_type], + }); + } + } + } + + Ok(ExpressionType::new( + ret_type, + false, + dozer_types::types::SourceDefinition::Dynamic, + false, + )) +} + +pub(crate) fn evaluate_distance( + schema: &Schema, + args: &[Expression], + record: &Record, +) -> Result { + validate_num_arguments(2..4, args.len(), GeoFunctionType::Distance)?; + let f_from = args[0].evaluate(record, schema)?; + + let f_to = args[1].evaluate(record, schema)?; + + if f_from == Field::Null || f_to == Field::Null { + Ok(Field::Null) + } else { + let from = extract_point(f_from, GeoFunctionType::Distance, 0)?; + let to = extract_point(f_to, GeoFunctionType::Distance, 1)?; + let calculation_type = args.get(2).map_or_else( + || Ok(DEFAULT_ALGORITHM), + |arg| { + let f = arg.evaluate(record, schema)?; + let t = f.to_string(); + Algorithm::from_str(&t) + }, + )?; + + let distance: OrderedFloat = match calculation_type { + Algorithm::Geodesic => Ok(from.geodesic_distance(&to)), + Algorithm::Haversine => Ok(from.0.haversine_distance(&to.0)), + Algorithm::Vincenty => from + .0 + .vincenty_distance(&to.0) + .map_err(Error::FailedToCalculateVincentyDistance), + }?; + + Ok(Field::Float(distance)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use dozer_types::types::{DozerPoint, FieldDefinition, SourceDefinition}; + use proptest::prelude::*; + use Expression::Literal; + + #[test] + fn test_geo() { + proptest!(ProptestConfig::with_cases(1000), move |(x1: f64, x2: f64, y1: f64, y2: f64)| { + let row = Record::new(vec![]); + let from = Field::Point(DozerPoint::from((x1, y1))); + let to = Field::Point(DozerPoint::from((x2, y2))); + let null = Field::Null; + + test_distance(&from, &to, None, &row, None); + test_distance(&from, &null, None, &row, Some(Ok(Field::Null))); + test_distance(&null, &to, None, &row, Some(Ok(Field::Null))); + + test_distance(&from, &to, Some(Algorithm::Geodesic), &row, None); + test_distance(&from, &null, Some(Algorithm::Geodesic), &row, Some(Ok(Field::Null))); + test_distance(&null, &to, Some(Algorithm::Geodesic), &row, Some(Ok(Field::Null))); + + test_distance(&from, &to, Some(Algorithm::Haversine), &row, None); + test_distance(&from, &null, Some(Algorithm::Haversine), &row, Some(Ok(Field::Null))); + test_distance(&null, &to, Some(Algorithm::Haversine), &row, Some(Ok(Field::Null))); + + // test_distance(&from, &to, Some(Algorithm::Vincenty), &row, None); + // test_distance(&from, &null, Some(Algorithm::Vincenty), &row, Some(Ok(Field::Null))); + // test_distance(&null, &to, Some(Algorithm::Vincenty), &row, Some(Ok(Field::Null))); + }); + } + + fn test_distance( + from: &Field, + to: &Field, + typ: Option, + row: &Record, + result: Option>, + ) { + let args = &vec![Literal(from.clone()), Literal(to.clone())]; + if validate_distance(args, &Schema::default()).is_ok() { + match result { + None => { + let from_f = from.to_owned(); + let to_f = to.to_owned(); + let f = extract_point(from_f, GeoFunctionType::Distance, 0).unwrap(); + let t = extract_point(to_f, GeoFunctionType::Distance, 0).unwrap(); + let _dist = match typ { + None => f.geodesic_distance(&t), + Some(Algorithm::Geodesic) => f.geodesic_distance(&t), + Some(Algorithm::Haversine) => f.0.haversine_distance(&t.0), + Some(Algorithm::Vincenty) => OrderedFloat(0.0), + // Some(Algorithm::Vincenty) => f.0.vincenty_distance(&t.0).unwrap(), + }; + assert!(matches!( + evaluate_distance(&Schema::default(), args, row), + Ok(Field::Float(_dist)), + )) + } + Some(_val) => { + assert!(matches!( + evaluate_distance(&Schema::default(), args, row), + _val, + )) + } + } + } + } + + #[test] + fn test_validate_distance() { + let schema = Schema::default() + .field( + FieldDefinition::new( + String::from("from"), + FieldType::Point, + false, + SourceDefinition::Dynamic, + ), + false, + ) + .field( + FieldDefinition::new( + String::from("to"), + FieldType::Point, + false, + SourceDefinition::Dynamic, + ), + false, + ) + .clone(); + + let result = validate_distance(&[], &schema); + assert!(result.is_err()); + assert!(matches!( + result, + Err(Error::InvalidNumberOfArguments { .. }) + )); + + let result = validate_distance(&[Expression::Column { index: 0 }], &schema); + + assert!(result.is_err()); + assert!(matches!( + result, + Err(Error::InvalidNumberOfArguments { .. }) + )); + + let result = validate_distance( + &[ + Expression::Column { index: 0 }, + Expression::Column { index: 1 }, + ], + &schema, + ); + + assert!(result.is_ok()); + + let result = validate_distance( + &[ + Expression::Column { index: 0 }, + Expression::Column { index: 1 }, + Expression::Literal(Field::String("GEODESIC".to_string())), + ], + &schema, + ); + + assert!(result.is_ok()); + + let result = validate_distance( + &[ + Expression::Column { index: 0 }, + Expression::Column { index: 1 }, + Expression::Literal(Field::String("GEODESIC".to_string())), + Expression::Column { index: 2 }, + ], + &schema, + ); + + assert!(result.is_err()); + assert!(matches!( + result, + Err(Error::InvalidNumberOfArguments { .. }) + )); + + let result = validate_distance( + &[ + Expression::Column { index: 0 }, + Expression::Literal(Field::String("GEODESIC".to_string())), + Expression::Column { index: 2 }, + ], + &schema, + ); + + let _expected_types = [FieldType::Point]; + assert!(result.is_err()); + assert!(matches!( + result, + Err(Error::InvalidFunctionArgumentType { .. }) + )); + } +} diff --git a/dozer-sql/src/pipeline/expression/geo/mod.rs b/dozer-sql/expression/src/geo/mod.rs similarity index 100% rename from dozer-sql/src/pipeline/expression/geo/mod.rs rename to dozer-sql/expression/src/geo/mod.rs diff --git a/dozer-sql/expression/src/geo/point.rs b/dozer-sql/expression/src/geo/point.rs new file mode 100644 index 0000000000..55df4bb282 --- /dev/null +++ b/dozer-sql/expression/src/geo/point.rs @@ -0,0 +1,234 @@ +use crate::arg_utils::{extract_float, validate_num_arguments}; +use crate::error::Error; +use dozer_types::types::Record; +use dozer_types::types::{DozerPoint, Field, FieldType, Schema}; + +use crate::execution::{Expression, ExpressionType}; +use crate::geo::common::GeoFunctionType; + +pub fn validate_point(args: &[Expression], schema: &Schema) -> Result { + let ret_type = FieldType::Point; + let expected_arg_type = FieldType::Float; + + validate_num_arguments(2..3, args.len(), GeoFunctionType::Point)?; + + for (argument_index, exp) in args.iter().enumerate() { + let return_type = exp.get_type(schema)?.return_type; + if return_type != expected_arg_type { + return Err(Error::InvalidFunctionArgumentType { + function_name: GeoFunctionType::Point.to_string(), + argument_index, + actual: return_type, + expected: vec![expected_arg_type], + }); + } + } + + Ok(ExpressionType::new( + ret_type, + false, + dozer_types::types::SourceDefinition::Dynamic, + false, + )) +} + +pub fn evaluate_point( + schema: &Schema, + args: &[Expression], + record: &Record, +) -> Result { + validate_num_arguments(2..3, args.len(), GeoFunctionType::Point)?; + let f_x = args[0].evaluate(record, schema)?; + let f_y = args[1].evaluate(record, schema)?; + + if f_x == Field::Null || f_y == Field::Null { + Ok(Field::Null) + } else { + let x = extract_float(f_x, GeoFunctionType::Point, 0)?; + let y = extract_float(f_y, GeoFunctionType::Point, 1)?; + + Ok(Field::Point(DozerPoint::from((x, y)))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use dozer_types::types::{FieldDefinition, SourceDefinition}; + use proptest::prelude::*; + + #[test] + fn test_point() { + proptest!( + ProptestConfig::with_cases(1000), move |(x: i64, y: i64)| { + test_validate_point(x, y); + test_evaluate_point(x, y); + }); + } + + fn test_validate_point(x: i64, y: i64) { + let schema = Schema::default() + .field( + FieldDefinition::new( + String::from("x"), + FieldType::Float, + false, + SourceDefinition::Dynamic, + ), + false, + ) + .field( + FieldDefinition::new( + String::from("y"), + FieldType::Float, + false, + SourceDefinition::Dynamic, + ), + false, + ) + .clone(); + + let result = validate_point(&[], &schema); + assert!(result.is_err()); + assert!(matches!( + result, + Err(Error::InvalidNumberOfArguments { .. }) + )); + + let result = validate_point(&[Expression::Column { index: 0 }], &schema); + + assert!(result.is_err()); + assert!(matches!( + result, + Err(Error::InvalidNumberOfArguments { .. }) + )); + + let result = validate_point( + &[ + Expression::Column { index: 0 }, + Expression::Column { index: 1 }, + ], + &schema, + ); + + assert!(result.is_ok()); + + let result = validate_point( + &[ + Expression::Column { index: 0 }, + Expression::Column { index: 1 }, + Expression::Column { index: 2 }, + ], + &schema, + ); + + assert!(result.is_err()); + assert!(matches!( + result, + Err(Error::InvalidNumberOfArguments { .. }) + )); + + let result = validate_point( + &[ + Expression::Column { index: 0 }, + Expression::Literal(Field::Int(y)), + ], + &schema, + ); + + assert!(result.is_err()); + assert!(matches!( + result, + Err(Error::InvalidFunctionArgumentType { .. }) + )); + + let result = validate_point( + &[ + Expression::Literal(Field::Int(x)), + Expression::Column { index: 0 }, + ], + &schema, + ); + + assert!(result.is_err()); + assert!(matches!( + result, + Err(Error::InvalidFunctionArgumentType { .. }) + )); + } + + fn test_evaluate_point(x: i64, y: i64) { + let row = Record::new(vec![]); + + let schema = Schema::default() + .field( + FieldDefinition::new( + String::from("x"), + FieldType::Float, + false, + SourceDefinition::Dynamic, + ), + false, + ) + .field( + FieldDefinition::new( + String::from("y"), + FieldType::Float, + false, + SourceDefinition::Dynamic, + ), + false, + ) + .clone(); + + let result = evaluate_point(&schema, &[], &row); + assert!(result.is_err()); + assert!(matches!( + result, + Err(Error::InvalidNumberOfArguments { .. }) + )); + + let result = evaluate_point(&schema, &[Expression::Literal(Field::Int(x))], &row); + assert!(result.is_err()); + assert!(matches!( + result, + Err(Error::InvalidNumberOfArguments { .. }) + )); + + let result = evaluate_point( + &schema, + &[ + Expression::Literal(Field::Int(x)), + Expression::Literal(Field::Int(y)), + ], + &row, + ); + + assert!(result.is_ok()); + + let result = evaluate_point( + &schema, + &[ + Expression::Literal(Field::Int(x)), + Expression::Literal(Field::Null), + ], + &row, + ); + + assert!(result.is_ok()); + assert!(matches!(result, Ok(Field::Null))); + + let result = evaluate_point( + &schema, + &[ + Expression::Literal(Field::Null), + Expression::Literal(Field::Int(y)), + ], + &row, + ); + + assert!(result.is_ok()); + assert!(matches!(result, Ok(Field::Null))); + } +} diff --git a/dozer-sql/src/pipeline/expression/in_list.rs b/dozer-sql/expression/src/in_list.rs similarity index 81% rename from dozer-sql/src/pipeline/expression/in_list.rs rename to dozer-sql/expression/src/in_list.rs index 76fb642835..ad8909fd22 100644 --- a/dozer-sql/src/pipeline/expression/in_list.rs +++ b/dozer-sql/expression/src/in_list.rs @@ -1,8 +1,8 @@ use dozer_types::types::Record; use dozer_types::types::{Field, Schema}; -use crate::pipeline::errors::PipelineError; -use crate::pipeline::expression::execution::Expression; +use crate::error::Error; +use crate::execution::Expression; pub(crate) fn evaluate_in_list( schema: &Schema, @@ -10,7 +10,7 @@ pub(crate) fn evaluate_in_list( list: &[Expression], negated: bool, record: &Record, -) -> Result { +) -> Result { let field = expr.evaluate(record, schema)?; let mut result = false; for item in list { diff --git a/dozer-sql/src/pipeline/expression/json_functions.rs b/dozer-sql/expression/src/json_functions.rs similarity index 56% rename from dozer-sql/src/pipeline/expression/json_functions.rs rename to dozer-sql/expression/src/json_functions.rs index 4538cec3d0..c4b6872ca1 100644 --- a/dozer-sql/src/pipeline/expression/json_functions.rs +++ b/dozer-sql/expression/src/json_functions.rs @@ -1,8 +1,6 @@ -use crate::pipeline::errors::PipelineError; -use crate::pipeline::errors::PipelineError::{ - InvalidArgument, InvalidFunction, InvalidFunctionArgument, InvalidValue, -}; -use crate::pipeline::expression::execution::Expression; +use crate::arg_utils::validate_num_arguments; +use crate::error::Error; +use crate::execution::Expression; use dozer_types::json_types::JsonValue; use dozer_types::types::Record; @@ -27,11 +25,11 @@ impl Display for JsonFunctionType { } impl JsonFunctionType { - pub(crate) fn new(name: &str) -> Result { + pub(crate) fn new(name: &str) -> Option { match name { - "json_value" => Ok(JsonFunctionType::JsonValue), - "json_query" => Ok(JsonFunctionType::JsonQuery), - _ => Err(InvalidFunction(name.to_string())), + "json_value" => Some(JsonFunctionType::JsonValue), + "json_query" => Some(JsonFunctionType::JsonQuery), + _ => None, } } @@ -40,7 +38,7 @@ impl JsonFunctionType { schema: &Schema, args: &Vec, record: &Record, - ) -> Result { + ) -> Result { match self { JsonFunctionType::JsonValue => self.evaluate_json_value(schema, args, record), JsonFunctionType::JsonQuery => self.evaluate_json_query(schema, args, record), @@ -52,19 +50,10 @@ impl JsonFunctionType { schema: &Schema, args: &Vec, record: &Record, - ) -> Result { - if args.len() > 2 { - return Err(InvalidFunctionArgument( - self.to_string(), - args[2].evaluate(record, schema)?, - 2, - )); - } + ) -> Result { + validate_num_arguments(2..3, args.len(), self)?; let json_input = args[0].evaluate(record, schema)?; - let path = args[1] - .evaluate(record, schema)? - .to_string() - .ok_or(InvalidArgument(args[1].to_string(schema)))?; + let path = args[1].evaluate(record, schema)?.to_string(); Ok(Field::Json(self.evaluate_json(json_input, path)?)) } @@ -74,26 +63,18 @@ impl JsonFunctionType { schema: &Schema, args: &Vec, record: &Record, - ) -> Result { - let mut path = String::from("$"); - if args.len() < 2 && !args.is_empty() { - Ok(Field::Json( - self.evaluate_json(args[0].evaluate(record, schema)?, path)?, - )) - } else if args.len() == 2 { + ) -> Result { + validate_num_arguments(1..3, args.len(), self)?; + if args.len() == 1 { + Ok(Field::Json(self.evaluate_json( + args[0].evaluate(record, schema)?, + String::from("$"), + )?)) + } else { let json_input = args[0].evaluate(record, schema)?; - path = args[1] - .evaluate(record, schema)? - .to_string() - .ok_or(InvalidArgument(args[1].to_string(schema)))?; + let path = args[1].evaluate(record, schema)?.to_string(); Ok(Field::Json(self.evaluate_json(json_input, path)?)) - } else { - Err(InvalidFunctionArgument( - self.to_string(), - args[2].evaluate(record, schema)?, - 2, - )) } } @@ -101,7 +82,7 @@ impl JsonFunctionType { &self, json_input: Field, path: String, - ) -> Result { + ) -> Result { let json_val = match json_input.to_json() { Some(json) => json, None => JsonValue::Null, @@ -109,20 +90,16 @@ impl JsonFunctionType { let finder = JsonPathFinder::new( Box::from(json_val), - Box::from(JsonPathInst::from_str(path.as_str()).map_err(InvalidArgument)?), + Box::from(JsonPathInst::from_str(path.as_str()).map_err(Error::InvalidJsonPath)?), ); match finder.find() { JsonValue::Null => Ok(JsonValue::Null), - JsonValue::Array(a) => { + JsonValue::Array(mut a) => { if a.is_empty() { Ok(JsonValue::Array(vec![])) } else if a.len() == 1 { - let item = match a.first() { - Some(i) => i, - None => return Err(InvalidValue("Invalid length of array".to_string())), - }; - Ok(item.to_owned()) + Ok(a.remove(0)) } else { let mut array_val = vec![]; for item in a { @@ -131,7 +108,7 @@ impl JsonFunctionType { Ok(JsonValue::Array(array_val)) } } - _ => Err(InvalidValue(path)), + other => Ok(other), } } } diff --git a/dozer-sql/expression/src/lib.rs b/dozer-sql/expression/src/lib.rs new file mode 100644 index 0000000000..8a2895779b --- /dev/null +++ b/dozer-sql/expression/src/lib.rs @@ -0,0 +1,84 @@ +pub mod aggregate; +mod arg_utils; +pub mod builder; +mod case; +mod cast; +mod comparison; +mod conditional; +mod datetime; +pub mod error; +pub mod execution; +mod geo; +mod in_list; +mod json_functions; +mod logical; +mod mathematical; +pub mod operator; +pub mod scalar; + +#[cfg(feature = "onnx")] +mod onnx; +#[cfg(feature = "python")] +mod python_udf; + +pub use num_traits; +pub use sqlparser; + +#[cfg(test)] +mod tests { + use dozer_types::{ + chrono::{DateTime, Datelike, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime, Timelike}, + rust_decimal::Decimal, + }; + use proptest::{ + prelude::Arbitrary, + strategy::{BoxedStrategy, Strategy}, + }; + + #[derive(Debug)] + pub struct ArbitraryDecimal(pub Decimal); + + impl Arbitrary for ArbitraryDecimal { + type Parameters = (); + type Strategy = BoxedStrategy; + + fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy { + (i64::MIN..i64::MAX, u32::MIN..29u32) + .prop_map(|(num, scale)| ArbitraryDecimal(Decimal::new(num, scale))) + .boxed() + } + } + + #[derive(Debug)] + pub struct ArbitraryDateTime(pub DateTime); + + impl Arbitrary for ArbitraryDateTime { + type Parameters = (); + type Strategy = BoxedStrategy; + + fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy { + ( + NaiveDateTime::MIN.year()..NaiveDateTime::MAX.year(), + 1..13u32, + 1..32u32, + 0..NaiveDateTime::MAX.second(), + 0..NaiveDateTime::MAX.nanosecond(), + ) + .prop_map(|(year, month, day, secs, nano)| { + let timezone_east = FixedOffset::east_opt(8 * 60 * 60).unwrap(); + let date = NaiveDate::from_ymd_opt(year, month, day); + // Some dates are not able to created caused by leap in February with day larger than 28 or 29 + if date.is_none() { + return ArbitraryDateTime(DateTime::default()); + } + let time = NaiveTime::from_num_seconds_from_midnight_opt(secs, nano).unwrap(); + let datetime = DateTime::::from_local( + NaiveDateTime::new(date.unwrap(), time), + timezone_east, + ); + ArbitraryDateTime(datetime) + }) + .boxed() + } + } +} diff --git a/dozer-sql/expression/src/logical.rs b/dozer-sql/expression/src/logical.rs new file mode 100644 index 0000000000..14a07311a8 --- /dev/null +++ b/dozer-sql/expression/src/logical.rs @@ -0,0 +1,292 @@ +use dozer_types::types::Record; +use dozer_types::types::{Field, Schema}; + +use crate::error::Error; +use crate::execution::Expression; + +pub fn evaluate_and( + schema: &Schema, + left: &Expression, + right: &Expression, + record: &Record, +) -> Result { + let l_field = left.evaluate(record, schema)?; + let r_field = right.evaluate(record, schema)?; + match l_field { + Field::Boolean(true) => match r_field { + Field::Boolean(true) => Ok(Field::Boolean(true)), + Field::Boolean(false) => Ok(Field::Boolean(false)), + Field::Null => Ok(Field::Boolean(false)), + Field::UInt(_) + | Field::U128(_) + | Field::Int(_) + | Field::I128(_) + | Field::Float(_) + | Field::String(_) + | Field::Text(_) + | Field::Binary(_) + | Field::Decimal(_) + | Field::Timestamp(_) + | Field::Date(_) + | Field::Json(_) + | Field::Point(_) + | Field::Duration(_) => Err(Error::InvalidType(r_field, "AND".to_string())), + }, + Field::Boolean(false) => match r_field { + Field::Boolean(true) => Ok(Field::Boolean(false)), + Field::Boolean(false) => Ok(Field::Boolean(false)), + Field::Null => Ok(Field::Boolean(false)), + Field::UInt(_) + | Field::U128(_) + | Field::Int(_) + | Field::I128(_) + | Field::Float(_) + | Field::String(_) + | Field::Text(_) + | Field::Binary(_) + | Field::Decimal(_) + | Field::Timestamp(_) + | Field::Date(_) + | Field::Json(_) + | Field::Point(_) + | Field::Duration(_) => Err(Error::InvalidType(r_field, "AND".to_string())), + }, + Field::Null => Ok(Field::Boolean(false)), + Field::UInt(_) + | Field::U128(_) + | Field::Int(_) + | Field::I128(_) + | Field::Float(_) + | Field::String(_) + | Field::Text(_) + | Field::Binary(_) + | Field::Decimal(_) + | Field::Timestamp(_) + | Field::Date(_) + | Field::Json(_) + | Field::Point(_) + | Field::Duration(_) => Err(Error::InvalidType(l_field, "AND".to_string())), + } +} + +pub fn evaluate_or( + schema: &Schema, + left: &Expression, + right: &Expression, + record: &Record, +) -> Result { + let l_field = left.evaluate(record, schema)?; + let r_field = right.evaluate(record, schema)?; + match l_field { + Field::Boolean(true) => match r_field { + Field::Boolean(false) => Ok(Field::Boolean(true)), + Field::Boolean(true) => Ok(Field::Boolean(true)), + Field::Null => Ok(Field::Boolean(true)), + Field::UInt(_) + | Field::U128(_) + | Field::Int(_) + | Field::I128(_) + | Field::Float(_) + | Field::String(_) + | Field::Text(_) + | Field::Binary(_) + | Field::Decimal(_) + | Field::Timestamp(_) + | Field::Date(_) + | Field::Json(_) + | Field::Point(_) + | Field::Duration(_) => Err(Error::InvalidType(r_field, "OR".to_string())), + }, + Field::Boolean(false) | Field::Null => match right.evaluate(record, schema)? { + Field::Boolean(false) => Ok(Field::Boolean(false)), + Field::Boolean(true) => Ok(Field::Boolean(true)), + Field::Null => Ok(Field::Boolean(false)), + Field::UInt(_) + | Field::U128(_) + | Field::Int(_) + | Field::I128(_) + | Field::Float(_) + | Field::String(_) + | Field::Text(_) + | Field::Binary(_) + | Field::Decimal(_) + | Field::Timestamp(_) + | Field::Date(_) + | Field::Json(_) + | Field::Point(_) + | Field::Duration(_) => Err(Error::InvalidType(r_field, "OR".to_string())), + }, + Field::UInt(_) + | Field::U128(_) + | Field::Int(_) + | Field::I128(_) + | Field::Float(_) + | Field::String(_) + | Field::Text(_) + | Field::Binary(_) + | Field::Decimal(_) + | Field::Timestamp(_) + | Field::Date(_) + | Field::Json(_) + | Field::Point(_) + | Field::Duration(_) => Err(Error::InvalidType(l_field, "OR".to_string())), + } +} + +pub fn evaluate_not(schema: &Schema, value: &Expression, record: &Record) -> Result { + let value_p = value.evaluate(record, schema)?; + + match value_p { + Field::Boolean(value_v) => Ok(Field::Boolean(!value_v)), + Field::Null => Ok(Field::Null), + Field::UInt(_) + | Field::U128(_) + | Field::Int(_) + | Field::I128(_) + | Field::Float(_) + | Field::String(_) + | Field::Text(_) + | Field::Binary(_) + | Field::Decimal(_) + | Field::Timestamp(_) + | Field::Date(_) + | Field::Json(_) + | Field::Point(_) + | Field::Duration(_) => Err(Error::InvalidType(value_p, "NOT".to_string())), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use dozer_types::types::Record; + use dozer_types::types::{Field, Schema}; + use dozer_types::{ordered_float::OrderedFloat, rust_decimal::Decimal}; + use proptest::prelude::*; + use Expression::Literal; + + #[test] + fn test_logical() { + proptest!( + ProptestConfig::with_cases(1000), + move |(bool1: bool, bool2: bool, u_num: u64, i_num: i64, f_num: f64, str in ".*")| { + _test_bool_bool_and(bool1, bool2); + _test_bool_null_and(Field::Boolean(bool1), Field::Null); + _test_bool_null_and(Field::Null, Field::Boolean(bool1)); + + _test_bool_bool_or(bool1, bool2); + _test_bool_null_or(bool1); + _test_null_bool_or(bool2); + + _test_bool_not(bool2); + + _test_bool_non_bool_and(Field::UInt(u_num), Field::Boolean(bool1)); + _test_bool_non_bool_and(Field::Int(i_num), Field::Boolean(bool1)); + _test_bool_non_bool_and(Field::Float(OrderedFloat(f_num)), Field::Boolean(bool1)); + _test_bool_non_bool_and(Field::Decimal(Decimal::from(u_num)), Field::Boolean(bool1)); + _test_bool_non_bool_and(Field::String(str.clone()), Field::Boolean(bool1)); + _test_bool_non_bool_and(Field::Text(str.clone()), Field::Boolean(bool1)); + + _test_bool_non_bool_and(Field::Boolean(bool2), Field::UInt(u_num)); + _test_bool_non_bool_and(Field::Boolean(bool2), Field::Int(i_num)); + _test_bool_non_bool_and(Field::Boolean(bool2), Field::Float(OrderedFloat(f_num))); + _test_bool_non_bool_and(Field::Boolean(bool2), Field::Decimal(Decimal::from(u_num))); + _test_bool_non_bool_and(Field::Boolean(bool2), Field::String(str.clone())); + _test_bool_non_bool_and(Field::Boolean(bool2), Field::Text(str.clone())); + + _test_bool_non_bool_or(Field::UInt(u_num), Field::Boolean(bool1)); + _test_bool_non_bool_or(Field::Int(i_num), Field::Boolean(bool1)); + _test_bool_non_bool_or(Field::Float(OrderedFloat(f_num)), Field::Boolean(bool1)); + _test_bool_non_bool_or(Field::Decimal(Decimal::from(u_num)), Field::Boolean(bool1)); + _test_bool_non_bool_or(Field::String(str.clone()), Field::Boolean(bool1)); + _test_bool_non_bool_or(Field::Text(str.clone()), Field::Boolean(bool1)); + + _test_bool_non_bool_or(Field::Boolean(bool2), Field::UInt(u_num)); + _test_bool_non_bool_or(Field::Boolean(bool2), Field::Int(i_num)); + _test_bool_non_bool_or(Field::Boolean(bool2), Field::Float(OrderedFloat(f_num))); + _test_bool_non_bool_or(Field::Boolean(bool2), Field::Decimal(Decimal::from(u_num))); + _test_bool_non_bool_or(Field::Boolean(bool2), Field::String(str.clone())); + _test_bool_non_bool_or(Field::Boolean(bool2), Field::Text(str)); + }); + } + + fn _test_bool_bool_and(bool1: bool, bool2: bool) { + let row = Record::new(vec![]); + let l = Box::new(Literal(Field::Boolean(bool1))); + let r = Box::new(Literal(Field::Boolean(bool2))); + assert!(matches!( + evaluate_and(&Schema::default(), &l, &r, &row) + .unwrap_or_else(|e| panic!("{}", e.to_string())), + Field::Boolean(_ans) + )); + } + + fn _test_bool_null_and(f1: Field, f2: Field) { + let row = Record::new(vec![]); + let l = Box::new(Literal(f1)); + let r = Box::new(Literal(f2)); + assert!(matches!( + evaluate_and(&Schema::default(), &l, &r, &row) + .unwrap_or_else(|e| panic!("{}", e.to_string())), + Field::Boolean(false) + )); + } + + fn _test_bool_bool_or(bool1: bool, bool2: bool) { + let row = Record::new(vec![]); + let l = Box::new(Literal(Field::Boolean(bool1))); + let r = Box::new(Literal(Field::Boolean(bool2))); + assert!(matches!( + evaluate_or(&Schema::default(), &l, &r, &row) + .unwrap_or_else(|e| panic!("{}", e.to_string())), + Field::Boolean(_ans) + )); + } + + fn _test_bool_null_or(_bool: bool) { + let row = Record::new(vec![]); + let l = Box::new(Literal(Field::Boolean(_bool))); + let r = Box::new(Literal(Field::Null)); + assert!(matches!( + evaluate_or(&Schema::default(), &l, &r, &row) + .unwrap_or_else(|e| panic!("{}", e.to_string())), + Field::Boolean(_bool) + )); + } + + fn _test_null_bool_or(_bool: bool) { + let row = Record::new(vec![]); + let l = Box::new(Literal(Field::Null)); + let r = Box::new(Literal(Field::Boolean(_bool))); + assert!(matches!( + evaluate_or(&Schema::default(), &l, &r, &row) + .unwrap_or_else(|e| panic!("{}", e.to_string())), + Field::Boolean(_bool) + )); + } + + fn _test_bool_not(bool: bool) { + let row = Record::new(vec![]); + let v = Box::new(Literal(Field::Boolean(bool))); + assert!(matches!( + evaluate_not(&Schema::default(), &v, &row) + .unwrap_or_else(|e| panic!("{}", e.to_string())), + Field::Boolean(_ans) + )); + } + + fn _test_bool_non_bool_and(f1: Field, f2: Field) { + let row = Record::new(vec![]); + let l = Box::new(Literal(f1)); + let r = Box::new(Literal(f2)); + assert!(evaluate_and(&Schema::default(), &l, &r, &row).is_err()); + } + + fn _test_bool_non_bool_or(f1: Field, f2: Field) { + let row = Record::new(vec![]); + let l = Box::new(Literal(f1)); + let r = Box::new(Literal(f2)); + assert!(evaluate_or(&Schema::default(), &l, &r, &row).is_err()); + } +} diff --git a/dozer-sql/src/pipeline/expression/mathematical.rs b/dozer-sql/expression/src/mathematical/mod.rs similarity index 93% rename from dozer-sql/src/pipeline/expression/mathematical.rs rename to dozer-sql/expression/src/mathematical/mod.rs index f52da53b67..a120bcc4f3 100644 --- a/dozer-sql/src/pipeline/expression/mathematical.rs +++ b/dozer-sql/expression/src/mathematical/mod.rs @@ -1,7 +1,3 @@ -use crate::pipeline::errors::OperationError; -use crate::pipeline::errors::PipelineError; -use crate::pipeline::errors::SqlError::Operation; -use crate::pipeline::expression::execution::Expression; use dozer_types::rust_decimal::Decimal; use dozer_types::types::Record; use dozer_types::types::Schema; @@ -11,6 +7,10 @@ use num_traits::{FromPrimitive, Zero}; use std::num::Wrapping; use std::ops::Neg; +use crate::execution::Expression; + +use crate::error::{Error as PipelineError, OperationError}; + macro_rules! define_math_operator { ($id:ident, $op:expr, $fct:expr, $t: expr) => { pub fn $id( @@ -23,22 +23,18 @@ macro_rules! define_math_operator { let right_p = right.evaluate(&record, schema)?; match left_p { - Field::Duration(left_v) => { - match right_p { - Field::Duration(right_v) => match $op { + Field::Duration(left_v) => match right_p { + Field::Duration(right_v) => { + match $op { "-" => { let duration = left_v.0.checked_sub(right_v.0).ok_or( - PipelineError::SqlError(Operation( - OperationError::AdditionOverflow, - )), + PipelineError::SqlError(OperationError::AdditionOverflow), )?; Ok(Field::from(DozerDuration(duration, TimeUnit::Nanoseconds))) } "+" => { let duration = left_v.0.checked_add(right_v.0).ok_or( - PipelineError::SqlError(Operation( - OperationError::SubtractionOverflow, - )), + PipelineError::SqlError(OperationError::SubtractionOverflow), )?; Ok(Field::from(DozerDuration(duration, TimeUnit::Nanoseconds))) } @@ -52,49 +48,47 @@ macro_rules! define_math_operator { right_p, $op.to_string(), )), - }, - Field::Timestamp(right_v) => match $op { - "+" => { - let duration = right_v - .checked_add_signed(chrono::Duration::nanoseconds( - left_v.0.as_nanos() as i64, - )) - .ok_or(PipelineError::SqlError(Operation( - OperationError::AdditionOverflow, - )))?; - Ok(Field::Timestamp(duration)) - } - "-" | "*" | "/" | "%" => Err(PipelineError::InvalidTypeComparison( - left_p, - right_p, - $op.to_string(), - )), - &_ => Err(PipelineError::InvalidTypeComparison( - left_p, - right_p, - $op.to_string(), - )), - }, - Field::UInt(_) - | Field::U128(_) - | Field::Int(_) - | Field::I128(_) - | Field::Float(_) - | Field::Boolean(_) - | Field::String(_) - | Field::Text(_) - | Field::Binary(_) - | Field::Decimal(_) - | Field::Date(_) - | Field::Json(_) - | Field::Point(_) - | Field::Null => Err(PipelineError::InvalidTypeComparison( + } + } + Field::Timestamp(right_v) => match $op { + "+" => { + let duration = right_v + .checked_add_signed(chrono::Duration::nanoseconds( + left_v.0.as_nanos() as i64, + )) + .ok_or(PipelineError::SqlError(OperationError::AdditionOverflow))?; + Ok(Field::Timestamp(duration)) + } + "-" | "*" | "/" | "%" => Err(PipelineError::InvalidTypeComparison( left_p, right_p, $op.to_string(), )), - } - } + &_ => Err(PipelineError::InvalidTypeComparison( + left_p, + right_p, + $op.to_string(), + )), + }, + Field::UInt(_) + | Field::U128(_) + | Field::Int(_) + | Field::I128(_) + | Field::Float(_) + | Field::Boolean(_) + | Field::String(_) + | Field::Text(_) + | Field::Binary(_) + | Field::Decimal(_) + | Field::Date(_) + | Field::Json(_) + | Field::Point(_) + | Field::Null => Err(PipelineError::InvalidTypeComparison( + left_p, + right_p, + $op.to_string(), + )), + }, Field::Timestamp(left_v) => match right_p { Field::Duration(right_v) => match $op { "-" => { @@ -102,9 +96,7 @@ macro_rules! define_math_operator { .checked_sub_signed(chrono::Duration::nanoseconds( right_v.0.as_nanos() as i64, )) - .ok_or(PipelineError::SqlError(Operation( - OperationError::AdditionOverflow, - )))?; + .ok_or(PipelineError::SqlError(OperationError::AdditionOverflow))?; Ok(Field::Timestamp(duration)) } "+" => { @@ -112,9 +104,9 @@ macro_rules! define_math_operator { .checked_add_signed(chrono::Duration::nanoseconds( right_v.0.as_nanos() as i64, )) - .ok_or(PipelineError::SqlError(Operation( + .ok_or(PipelineError::SqlError( OperationError::SubtractionOverflow, - )))?; + ))?; Ok(Field::Timestamp(duration)) } "*" | "/" | "%" => Err(PipelineError::InvalidTypeComparison( @@ -185,9 +177,9 @@ macro_rules! define_math_operator { return match $op { "/" | "%" => { if right_v == 0_i64 { - Err(PipelineError::SqlError(Operation( + Err(PipelineError::SqlError( OperationError::DivisionByZeroOrOverflow, - ))) + )) } else { Ok(Field::Float($fct( left_v, @@ -216,9 +208,9 @@ macro_rules! define_math_operator { return match $op { "/" | "%" => { if right_v == 0_i128 { - Err(PipelineError::SqlError(Operation( + Err(PipelineError::SqlError( OperationError::DivisionByZeroOrOverflow, - ))) + )) } else { Ok(Field::Float($fct( left_v, @@ -247,9 +239,9 @@ macro_rules! define_math_operator { return match $op { "/" | "%" => { if right_v == 0_u64 { - Err(PipelineError::SqlError(Operation( + Err(PipelineError::SqlError( OperationError::DivisionByZeroOrOverflow, - ))) + )) } else { Ok(Field::Float($fct( left_v, @@ -278,9 +270,9 @@ macro_rules! define_math_operator { return match $op { "/" | "%" => { if right_v == 0_u128 { - Err(PipelineError::SqlError(Operation( + Err(PipelineError::SqlError( OperationError::DivisionByZeroOrOverflow, - ))) + )) } else { Ok(Field::Float($fct( left_v, @@ -309,9 +301,9 @@ macro_rules! define_math_operator { return match $op { "/" | "%" => { if right_v == 0_f64 { - Err(PipelineError::SqlError(Operation( + Err(PipelineError::SqlError( OperationError::DivisionByZeroOrOverflow, - ))) + )) } else { Ok(Field::Float($fct(left_v, right_v))) } @@ -329,9 +321,9 @@ macro_rules! define_math_operator { "Decimal".to_string(), ))? .checked_div(right_v) - .ok_or(PipelineError::SqlError(Operation( + .ok_or(PipelineError::SqlError( OperationError::DivisionByZeroOrOverflow, - )))?, + ))?, )), "%" => Ok(Field::Decimal( Decimal::from_f64(*left_v) @@ -340,9 +332,9 @@ macro_rules! define_math_operator { "Decimal".to_string(), ))? .checked_rem(right_v) - .ok_or(PipelineError::SqlError(Operation( + .ok_or(PipelineError::SqlError( OperationError::ModuloByZeroOrOverflow, - )))?, + ))?, )), "*" => Ok(Field::Decimal( Decimal::from_f64(*left_v) @@ -351,9 +343,9 @@ macro_rules! define_math_operator { "Decimal".to_string(), ))? .checked_mul(right_v) - .ok_or(PipelineError::SqlError(Operation( + .ok_or(PipelineError::SqlError( OperationError::MultiplicationOverflow, - )))?, + ))?, )), "+" | "-" => Ok(Field::Decimal($fct( Decimal::from_f64(*left_v).ok_or(PipelineError::UnableToCast( @@ -391,9 +383,9 @@ macro_rules! define_math_operator { // When Int / Int division happens "/" => { if right_v == 0_i64 { - Err(PipelineError::SqlError(Operation( + Err(PipelineError::SqlError( OperationError::DivisionByZeroOrOverflow, - ))) + )) } else { Ok(Field::Float($fct( OrderedFloat::::from_i64(left_v).ok_or( @@ -428,9 +420,9 @@ macro_rules! define_math_operator { // When Int / I128 division happens "/" => { if right_v == 0_i128 { - Err(PipelineError::SqlError(Operation( + Err(PipelineError::SqlError( OperationError::DivisionByZeroOrOverflow, - ))) + )) } else { Ok(Field::Float($fct( OrderedFloat::::from_i64(left_v).ok_or( @@ -465,9 +457,9 @@ macro_rules! define_math_operator { // When Int / UInt division happens "/" => { if right_v == 0_u64 { - Err(PipelineError::SqlError(Operation( + Err(PipelineError::SqlError( OperationError::DivisionByZeroOrOverflow, - ))) + )) } else { Ok(Field::Float($fct( OrderedFloat::::from_i64(left_v).ok_or( @@ -502,9 +494,9 @@ macro_rules! define_math_operator { // When Int / U128 division happens "/" => { if right_v == 0_u128 { - Err(PipelineError::SqlError(Operation( + Err(PipelineError::SqlError( OperationError::DivisionByZeroOrOverflow, - ))) + )) } else { Ok(Field::Float($fct( OrderedFloat::::from_i64(left_v).ok_or( @@ -538,9 +530,9 @@ macro_rules! define_math_operator { return match $op { "/" | "%" => { if right_v == 0_f64 { - Err(PipelineError::SqlError(Operation( + Err(PipelineError::SqlError( OperationError::DivisionByZeroOrOverflow, - ))) + )) } else { Ok(Field::Float($fct( OrderedFloat::::from_i64(left_v).ok_or( @@ -575,9 +567,9 @@ macro_rules! define_math_operator { "Decimal".to_string(), ))? .checked_div(right_v) - .ok_or(PipelineError::SqlError(Operation( + .ok_or(PipelineError::SqlError( OperationError::DivisionByZeroOrOverflow, - )))?, + ))?, )), "%" => Ok(Field::Decimal( Decimal::from_i64(left_v) @@ -586,9 +578,9 @@ macro_rules! define_math_operator { "Decimal".to_string(), ))? .checked_rem(right_v) - .ok_or(PipelineError::SqlError(Operation( + .ok_or(PipelineError::SqlError( OperationError::ModuloByZeroOrOverflow, - )))?, + ))?, )), "*" => Ok(Field::Decimal( Decimal::from_i64(left_v) @@ -597,9 +589,9 @@ macro_rules! define_math_operator { "Decimal".to_string(), ))? .checked_mul(right_v) - .ok_or(PipelineError::SqlError(Operation( + .ok_or(PipelineError::SqlError( OperationError::MultiplicationOverflow, - )))?, + ))?, )), "+" | "-" => Ok(Field::Decimal($fct( Decimal::from_i64(left_v).ok_or(PipelineError::UnableToCast( @@ -638,9 +630,9 @@ macro_rules! define_math_operator { // When I128 / Int division happens "/" => { if right_v == 0_i64 { - Err(PipelineError::SqlError(Operation( + Err(PipelineError::SqlError( OperationError::DivisionByZeroOrOverflow, - ))) + )) } else { Ok(Field::Float($fct( OrderedFloat::::from_i128(left_v).ok_or( @@ -675,9 +667,9 @@ macro_rules! define_math_operator { // When I128 / I128 division happens "/" => { if right_v == 0_i128 { - Err(PipelineError::SqlError(Operation( + Err(PipelineError::SqlError( OperationError::DivisionByZeroOrOverflow, - ))) + )) } else { Ok(Field::Float($fct( OrderedFloat::::from_i128(left_v).ok_or( @@ -712,9 +704,9 @@ macro_rules! define_math_operator { // When I128 / UInt division happens "/" => { if right_v == 0_u64 { - Err(PipelineError::SqlError(Operation( + Err(PipelineError::SqlError( OperationError::DivisionByZeroOrOverflow, - ))) + )) } else { Ok(Field::Float($fct( OrderedFloat::::from_i128(left_v).ok_or( @@ -749,9 +741,9 @@ macro_rules! define_math_operator { // When Int / U128 division happens "/" => { if right_v == 0_u128 { - Err(PipelineError::SqlError(Operation( + Err(PipelineError::SqlError( OperationError::DivisionByZeroOrOverflow, - ))) + )) } else { Ok(Field::Float($fct( OrderedFloat::::from_i128(left_v).ok_or( @@ -785,9 +777,9 @@ macro_rules! define_math_operator { return match $op { "/" | "%" => { if right_v == 0_f64 { - Err(PipelineError::SqlError(Operation( + Err(PipelineError::SqlError( OperationError::DivisionByZeroOrOverflow, - ))) + )) } else { Ok(Field::Float($fct( OrderedFloat::::from_i128(left_v).ok_or( @@ -816,9 +808,9 @@ macro_rules! define_math_operator { return match $op { "/" | "%" => { if right_v == dozer_types::rust_decimal::Decimal::zero() { - Err(PipelineError::SqlError(Operation( + Err(PipelineError::SqlError( OperationError::DivisionByZeroOrOverflow, - ))) + )) } else { Ok(Field::Decimal($fct( Decimal::from_i128(left_v).ok_or( @@ -863,9 +855,9 @@ macro_rules! define_math_operator { // When UInt / Int division happens "/" => { if right_v == 0_i64 { - Err(PipelineError::SqlError(Operation( + Err(PipelineError::SqlError( OperationError::DivisionByZeroOrOverflow, - ))) + )) } else { Ok(Field::Float(OrderedFloat($fct( f64::from_u64(left_v).ok_or( @@ -900,9 +892,9 @@ macro_rules! define_math_operator { // When UInt / I128 division happens "/" => { if right_v == 0_i128 { - Err(PipelineError::SqlError(Operation( + Err(PipelineError::SqlError( OperationError::DivisionByZeroOrOverflow, - ))) + )) } else { Ok(Field::Float(OrderedFloat($fct( f64::from_u64(left_v).ok_or( @@ -937,9 +929,9 @@ macro_rules! define_math_operator { // When UInt / UInt division happens "/" => { if right_v == 0_u64 { - Err(PipelineError::SqlError(Operation( + Err(PipelineError::SqlError( OperationError::DivisionByZeroOrOverflow, - ))) + )) } else { Ok(Field::Float( OrderedFloat::::from_f64($fct( @@ -982,9 +974,9 @@ macro_rules! define_math_operator { // When UInt / UInt division happens "/" => { if right_v == 0_u128 { - Err(PipelineError::SqlError(Operation( + Err(PipelineError::SqlError( OperationError::DivisionByZeroOrOverflow, - ))) + )) } else { Ok(Field::Float( OrderedFloat::::from_f64($fct( @@ -1026,9 +1018,9 @@ macro_rules! define_math_operator { return match $op { "/" | "%" => { if right_v == 0_f64 { - Err(PipelineError::SqlError(Operation( + Err(PipelineError::SqlError( OperationError::DivisionByZeroOrOverflow, - ))) + )) } else { Ok(Field::Float($fct( OrderedFloat::::from_u64(left_v).ok_or( @@ -1057,9 +1049,9 @@ macro_rules! define_math_operator { return match $op { "/" => { if right_v == dozer_types::rust_decimal::Decimal::zero() { - Err(PipelineError::SqlError(Operation( + Err(PipelineError::SqlError( OperationError::DivisionByZeroOrOverflow, - ))) + )) } else { Ok(Field::Decimal( Decimal::from_u64(left_v) @@ -1068,9 +1060,9 @@ macro_rules! define_math_operator { "Decimal".to_string(), ))? .checked_div(right_v) - .ok_or(PipelineError::SqlError(Operation( + .ok_or(PipelineError::SqlError( OperationError::DivisionByZeroOrOverflow, - )))?, + ))?, )) } } @@ -1081,9 +1073,9 @@ macro_rules! define_math_operator { "Decimal".to_string(), ))? .checked_rem(right_v) - .ok_or(PipelineError::SqlError(Operation( + .ok_or(PipelineError::SqlError( OperationError::ModuloByZeroOrOverflow, - )))?, + ))?, )), "*" => Ok(Field::Decimal( Decimal::from_u64(left_v) @@ -1092,9 +1084,9 @@ macro_rules! define_math_operator { "Decimal".to_string(), ))? .checked_mul(right_v) - .ok_or(PipelineError::SqlError(Operation( + .ok_or(PipelineError::SqlError( OperationError::MultiplicationOverflow, - )))?, + ))?, )), "+" | "-" => Ok(Field::Decimal($fct( Decimal::from_u64(left_v).ok_or(PipelineError::UnableToCast( @@ -1133,9 +1125,9 @@ macro_rules! define_math_operator { // When U128 / Int division happens "/" => { if right_v == 0_i64 { - Err(PipelineError::SqlError(Operation( + Err(PipelineError::SqlError( OperationError::DivisionByZeroOrOverflow, - ))) + )) } else { Ok(Field::Float(OrderedFloat($fct( f64::from_u128(left_v).ok_or( @@ -1170,9 +1162,9 @@ macro_rules! define_math_operator { // When U128 / I128 division happens "/" => { if right_v == 0_i128 { - Err(PipelineError::SqlError(Operation( + Err(PipelineError::SqlError( OperationError::DivisionByZeroOrOverflow, - ))) + )) } else { Ok(Field::Float(OrderedFloat($fct( f64::from_u128(left_v).ok_or( @@ -1207,9 +1199,9 @@ macro_rules! define_math_operator { // When U128 / UInt division happens "/" => { if right_v == 0_u64 { - Err(PipelineError::SqlError(Operation( + Err(PipelineError::SqlError( OperationError::DivisionByZeroOrOverflow, - ))) + )) } else { Ok(Field::Float( OrderedFloat::::from_f64($fct( @@ -1252,9 +1244,9 @@ macro_rules! define_math_operator { // When U128 / U128 division happens "/" => { if right_v == 0_u128 { - Err(PipelineError::SqlError(Operation( + Err(PipelineError::SqlError( OperationError::DivisionByZeroOrOverflow, - ))) + )) } else { Ok(Field::Float( OrderedFloat::::from_f64($fct( @@ -1296,9 +1288,9 @@ macro_rules! define_math_operator { return match $op { "/" | "%" => { if right_v == 0_f64 { - Err(PipelineError::SqlError(Operation( + Err(PipelineError::SqlError( OperationError::DivisionByZeroOrOverflow, - ))) + )) } else { Ok(Field::Float($fct( OrderedFloat::::from_u128(left_v).ok_or( @@ -1327,9 +1319,9 @@ macro_rules! define_math_operator { return match $op { "/" | "%" => { if right_v == dozer_types::rust_decimal::Decimal::zero() { - Err(PipelineError::SqlError(Operation( + Err(PipelineError::SqlError( OperationError::DivisionByZeroOrOverflow, - ))) + )) } else { Ok(Field::Decimal($fct( Decimal::from_u128(left_v).ok_or( @@ -1374,9 +1366,9 @@ macro_rules! define_math_operator { // left: Decimal, right: Int Field::Int(right_v) => { if right_v == 0_i64 { - Err(PipelineError::SqlError(Operation( + Err(PipelineError::SqlError( OperationError::DivisionByZeroOrOverflow, - ))) + )) } else { Ok(Field::Decimal( left_v @@ -1386,18 +1378,18 @@ macro_rules! define_math_operator { "Decimal".to_string(), ), )?) - .ok_or(PipelineError::SqlError(Operation( + .ok_or(PipelineError::SqlError( OperationError::DivisionByZeroOrOverflow, - )))?, + ))?, )) } } // left: Decimal, right: I128 Field::I128(right_v) => { if right_v == 0_i128 { - Err(PipelineError::SqlError(Operation( + Err(PipelineError::SqlError( OperationError::DivisionByZeroOrOverflow, - ))) + )) } else { Ok(Field::Decimal( left_v @@ -1407,18 +1399,18 @@ macro_rules! define_math_operator { "Decimal".to_string(), ), )?) - .ok_or(PipelineError::SqlError(Operation( + .ok_or(PipelineError::SqlError( OperationError::DivisionByZeroOrOverflow, - )))?, + ))?, )) } } // left: Decimal, right: UInt Field::UInt(right_v) => { if right_v == 0_u64 { - Err(PipelineError::SqlError(Operation( + Err(PipelineError::SqlError( OperationError::DivisionByZeroOrOverflow, - ))) + )) } else { Ok(Field::Decimal( left_v @@ -1428,18 +1420,18 @@ macro_rules! define_math_operator { "Decimal".to_string(), ), )?) - .ok_or(PipelineError::SqlError(Operation( + .ok_or(PipelineError::SqlError( OperationError::DivisionByZeroOrOverflow, - )))?, + ))?, )) } } // left: Decimal, right: U128 Field::U128(right_v) => { if right_v == 0_u128 { - Err(PipelineError::SqlError(Operation( + Err(PipelineError::SqlError( OperationError::DivisionByZeroOrOverflow, - ))) + )) } else { Ok(Field::Decimal( left_v @@ -1449,18 +1441,18 @@ macro_rules! define_math_operator { "Decimal".to_string(), ), )?) - .ok_or(PipelineError::SqlError(Operation( + .ok_or(PipelineError::SqlError( OperationError::DivisionByZeroOrOverflow, - )))?, + ))?, )) } } // left: Decimal, right: Float Field::Float(right_v) => { if right_v == 0_f64 { - Err(PipelineError::SqlError(Operation( + Err(PipelineError::SqlError( OperationError::DivisionByZeroOrOverflow, - ))) + )) } else { Ok(Field::Decimal( left_v @@ -1470,9 +1462,9 @@ macro_rules! define_math_operator { "Decimal".to_string(), ), )?) - .ok_or(PipelineError::SqlError(Operation( + .ok_or(PipelineError::SqlError( OperationError::DivisionByZeroOrOverflow, - )))?, + ))?, )) } } @@ -1481,7 +1473,7 @@ macro_rules! define_math_operator { // left: Decimal, right: Decimal Field::Decimal(right_v) => Ok(Field::Decimal( left_v.checked_div(right_v).ok_or(PipelineError::SqlError( - Operation(OperationError::DivisionByZeroOrOverflow), + OperationError::DivisionByZeroOrOverflow, ))?, )), Field::Boolean(_) @@ -1510,9 +1502,9 @@ macro_rules! define_math_operator { "Decimal".to_string(), ), )?) - .ok_or(PipelineError::SqlError(Operation( + .ok_or(PipelineError::SqlError( OperationError::ModuloByZeroOrOverflow, - )))?, + ))?, )), // left: Decimal, right: I128 Field::I128(right_v) => Ok(Field::Decimal( @@ -1523,9 +1515,9 @@ macro_rules! define_math_operator { "Decimal".to_string(), ), )?) - .ok_or(PipelineError::SqlError(Operation( + .ok_or(PipelineError::SqlError( OperationError::ModuloByZeroOrOverflow, - )))?, + ))?, )), // left: Decimal, right: UInt Field::UInt(right_v) => Ok(Field::Decimal( @@ -1536,9 +1528,9 @@ macro_rules! define_math_operator { "Decimal".to_string(), ), )?) - .ok_or(PipelineError::SqlError(Operation( + .ok_or(PipelineError::SqlError( OperationError::ModuloByZeroOrOverflow, - )))?, + ))?, )), // left: Decimal, right: U128 Field::U128(right_v) => Ok(Field::Decimal( @@ -1549,9 +1541,9 @@ macro_rules! define_math_operator { "Decimal".to_string(), ), )?) - .ok_or(PipelineError::SqlError(Operation( + .ok_or(PipelineError::SqlError( OperationError::ModuloByZeroOrOverflow, - )))?, + ))?, )), // left: Decimal, right: Float Field::Float(right_v) => Ok(Field::Decimal( @@ -1562,16 +1554,16 @@ macro_rules! define_math_operator { "Decimal".to_string(), ), )?) - .ok_or(PipelineError::SqlError(Operation( + .ok_or(PipelineError::SqlError( OperationError::ModuloByZeroOrOverflow, - )))?, + ))?, )), // left: Decimal, right: Null Field::Null => Ok(Field::Null), // left: Decimal, right: Decimal Field::Decimal(right_v) => Ok(Field::Decimal( left_v.checked_rem(right_v).ok_or(PipelineError::SqlError( - Operation(OperationError::ModuloByZeroOrOverflow), + OperationError::ModuloByZeroOrOverflow, ))?, )), Field::Boolean(_) @@ -1600,9 +1592,9 @@ macro_rules! define_math_operator { "Decimal".to_string(), ), )?) - .ok_or(PipelineError::SqlError(Operation( + .ok_or(PipelineError::SqlError( OperationError::MultiplicationOverflow, - )))?, + ))?, )), // left: Decimal, right: I128 Field::I128(right_v) => Ok(Field::Decimal( @@ -1613,9 +1605,9 @@ macro_rules! define_math_operator { "Decimal".to_string(), ), )?) - .ok_or(PipelineError::SqlError(Operation( + .ok_or(PipelineError::SqlError( OperationError::MultiplicationOverflow, - )))?, + ))?, )), // left: Decimal, right: UInt Field::UInt(right_v) => Ok(Field::Decimal( @@ -1626,9 +1618,9 @@ macro_rules! define_math_operator { "Decimal".to_string(), ), )?) - .ok_or(PipelineError::SqlError(Operation( + .ok_or(PipelineError::SqlError( OperationError::MultiplicationOverflow, - )))?, + ))?, )), // left: Decimal, right: U128 Field::U128(right_v) => Ok(Field::Decimal( @@ -1639,9 +1631,9 @@ macro_rules! define_math_operator { "Decimal".to_string(), ), )?) - .ok_or(PipelineError::SqlError(Operation( + .ok_or(PipelineError::SqlError( OperationError::MultiplicationOverflow, - )))?, + ))?, )), // left: Decimal, right: Float Field::Float(right_v) => Ok(Field::Decimal( @@ -1652,16 +1644,16 @@ macro_rules! define_math_operator { "Decimal".to_string(), ), )?) - .ok_or(PipelineError::SqlError(Operation( + .ok_or(PipelineError::SqlError( OperationError::MultiplicationOverflow, - )))?, + ))?, )), // left: Decimal, right: Null Field::Null => Ok(Field::Null), // left: Decimal, right: Decimal Field::Decimal(right_v) => Ok(Field::Decimal( left_v.checked_mul(right_v).ok_or(PipelineError::SqlError( - Operation(OperationError::MultiplicationOverflow), + OperationError::MultiplicationOverflow, ))?, )), Field::Boolean(_) @@ -1841,3 +1833,6 @@ pub fn evaluate_minus( )), } } + +#[cfg(test)] +mod tests; diff --git a/dozer-sql/src/pipeline/expression/tests/mathematical.rs b/dozer-sql/expression/src/mathematical/tests.rs similarity index 87% rename from dozer-sql/src/pipeline/expression/tests/mathematical.rs rename to dozer-sql/expression/src/mathematical/tests.rs index 86a205c852..80b4d72f93 100644 --- a/dozer-sql/src/pipeline/expression/tests/mathematical.rs +++ b/dozer-sql/expression/src/mathematical/tests.rs @@ -1,19 +1,17 @@ -use crate::pipeline::errors::SqlError::Operation; -use crate::pipeline::errors::{OperationError, PipelineError}; -use crate::pipeline::expression::execution::Expression::Literal; -use crate::pipeline::expression::mathematical::{ - evaluate_add, evaluate_div, evaluate_mod, evaluate_mul, evaluate_sub, -}; -use crate::pipeline::expression::tests::test_common::*; -use dozer_types::types::Record; +use crate::tests::{ArbitraryDateTime, ArbitraryDecimal}; + +use super::*; + +use dozer_types::chrono::DateTime; +use dozer_types::types::{FieldDefinition, FieldType, Record, SourceDefinition}; use dozer_types::{ ordered_float::OrderedFloat, rust_decimal::Decimal, types::{Field, Schema}, }; -use num_traits::FromPrimitive; use proptest::prelude::*; use std::num::Wrapping; +use Expression::Literal; #[test] fn test_uint_math() { @@ -222,7 +220,7 @@ fn test_uint_math() { assert!(res.is_err()); assert!(matches!( res, - Err(PipelineError::SqlError(Operation(OperationError::MultiplicationOverflow))) + Err(PipelineError::SqlError(OperationError::MultiplicationOverflow)) )); } // UInt / Decimal = Decimal @@ -231,7 +229,7 @@ fn test_uint_math() { assert!(res.is_err()); assert!(matches!( res, - Err(PipelineError::SqlError(Operation(OperationError::DivisionByZeroOrOverflow))) + Err(PipelineError::SqlError(OperationError::DivisionByZeroOrOverflow)) )); } else if res.is_ok() { @@ -242,7 +240,7 @@ fn test_uint_math() { assert!(res.is_err()); assert!(matches!( res, - Err(PipelineError::SqlError(Operation(OperationError::DivisionByZeroOrOverflow))) + Err(PipelineError::SqlError(OperationError::DivisionByZeroOrOverflow)) )); } // UInt % Decimal = Decimal @@ -251,7 +249,7 @@ fn test_uint_math() { assert!(res.is_err()); assert!(matches!( res, - Err(PipelineError::SqlError(Operation(OperationError::ModuloByZeroOrOverflow))) + Err(PipelineError::SqlError(OperationError::ModuloByZeroOrOverflow)) )); } else if res.is_ok() { @@ -262,7 +260,7 @@ fn test_uint_math() { assert!(res.is_err()); assert!(matches!( res, - Err(PipelineError::SqlError(Operation(OperationError::ModuloByZeroOrOverflow))) + Err(PipelineError::SqlError(OperationError::ModuloByZeroOrOverflow)) )); } @@ -525,7 +523,7 @@ fn test_u128_math() { if !matches!(res, Err(PipelineError::UnableToCast(_, _))) { assert!(matches!( res, - Err(PipelineError::SqlError(Operation(OperationError::MultiplicationOverflow))) + Err(PipelineError::SqlError(OperationError::MultiplicationOverflow)) )); } } @@ -536,7 +534,7 @@ fn test_u128_math() { if !matches!(res, Err(PipelineError::UnableToCast(_, _))) { assert!(matches!( res, - Err(PipelineError::SqlError(Operation(OperationError::DivisionByZeroOrOverflow))) + Err(PipelineError::SqlError(OperationError::DivisionByZeroOrOverflow)) )); } } @@ -549,7 +547,7 @@ fn test_u128_math() { if !matches!(res, Err(PipelineError::UnableToCast(_, _))) { assert!(matches!( res, - Err(PipelineError::SqlError(Operation(OperationError::DivisionByZeroOrOverflow))) + Err(PipelineError::SqlError(OperationError::DivisionByZeroOrOverflow)) )); } } @@ -560,7 +558,7 @@ fn test_u128_math() { if !matches!(res, Err(PipelineError::UnableToCast(_, _))) { assert!(matches!( res, - Err(PipelineError::SqlError(Operation(OperationError::ModuloByZeroOrOverflow))) + Err(PipelineError::SqlError(OperationError::ModuloByZeroOrOverflow)) )); } } @@ -573,7 +571,7 @@ fn test_u128_math() { if !matches!(res, Err(PipelineError::UnableToCast(_, _))) { assert!(matches!( res, - Err(PipelineError::SqlError(Operation(OperationError::ModuloByZeroOrOverflow))) + Err(PipelineError::SqlError(OperationError::ModuloByZeroOrOverflow)) )); } } @@ -824,7 +822,7 @@ fn test_int_math() { assert!(res.is_err()); assert!(matches!( res, - Err(PipelineError::SqlError(Operation(OperationError::MultiplicationOverflow))) + Err(PipelineError::SqlError(OperationError::MultiplicationOverflow)) )); } // Int / Decimal = Decimal @@ -833,7 +831,7 @@ fn test_int_math() { assert!(res.is_err()); assert!(matches!( res, - Err(PipelineError::SqlError(Operation(OperationError::DivisionByZeroOrOverflow))) + Err(PipelineError::SqlError(OperationError::DivisionByZeroOrOverflow)) )); } else if res.is_ok() { @@ -844,7 +842,7 @@ fn test_int_math() { assert!(res.is_err()); assert!(matches!( res, - Err(PipelineError::SqlError(Operation(OperationError::DivisionByZeroOrOverflow))) + Err(PipelineError::SqlError(OperationError::DivisionByZeroOrOverflow)) )); } // Int % Decimal = Decimal @@ -853,7 +851,7 @@ fn test_int_math() { assert!(res.is_err()); assert!(matches!( res, - Err(PipelineError::SqlError(Operation(OperationError::ModuloByZeroOrOverflow))) + Err(PipelineError::SqlError(OperationError::ModuloByZeroOrOverflow)) )); } else if res.is_ok() { @@ -864,7 +862,7 @@ fn test_int_math() { assert!(res.is_err()); assert!(matches!( res, - Err(PipelineError::SqlError(Operation(OperationError::ModuloByZeroOrOverflow))) + Err(PipelineError::SqlError(OperationError::ModuloByZeroOrOverflow)) )); } @@ -1138,7 +1136,7 @@ fn test_i128_math() { if !matches!(res, Err(PipelineError::UnableToCast(_, _))) { assert!(matches!( res, - Err(PipelineError::SqlError(Operation(OperationError::MultiplicationOverflow))) + Err(PipelineError::SqlError(OperationError::MultiplicationOverflow)) )); } } @@ -1149,7 +1147,7 @@ fn test_i128_math() { if !matches!(res, Err(PipelineError::UnableToCast(_, _))) { assert!(matches!( res, - Err(PipelineError::SqlError(Operation(OperationError::DivisionByZeroOrOverflow))) + Err(PipelineError::SqlError(OperationError::DivisionByZeroOrOverflow)) )); } } @@ -1162,7 +1160,7 @@ fn test_i128_math() { if !matches!(res, Err(PipelineError::UnableToCast(_, _))) { assert!(matches!( res, - Err(PipelineError::SqlError(Operation(OperationError::DivisionByZeroOrOverflow))) + Err(PipelineError::SqlError(OperationError::DivisionByZeroOrOverflow)) )); } } @@ -1173,7 +1171,7 @@ fn test_i128_math() { if !matches!(res, Err(PipelineError::UnableToCast(_, _))) { assert!(matches!( res, - Err(PipelineError::SqlError(Operation(OperationError::ModuloByZeroOrOverflow))) + Err(PipelineError::SqlError(OperationError::ModuloByZeroOrOverflow)) )); } } @@ -1186,7 +1184,7 @@ fn test_i128_math() { if !matches!(res, Err(PipelineError::UnableToCast(_, _))) { assert!(matches!( res, - Err(PipelineError::SqlError(Operation(OperationError::ModuloByZeroOrOverflow))) + Err(PipelineError::SqlError(OperationError::ModuloByZeroOrOverflow)) )); } } @@ -1465,7 +1463,7 @@ fn test_float_math() { assert!(res.is_err()); assert!(matches!( res, - Err(PipelineError::SqlError(Operation(OperationError::MultiplicationOverflow))) + Err(PipelineError::SqlError(OperationError::MultiplicationOverflow)) )); } // Float / Decimal = Decimal @@ -1474,7 +1472,7 @@ fn test_float_math() { assert!(res.is_err()); assert!(matches!( res, - Err(PipelineError::SqlError(Operation(OperationError::DivisionByZeroOrOverflow))) + Err(PipelineError::SqlError(OperationError::DivisionByZeroOrOverflow)) )); } else if res.is_ok() { @@ -1485,7 +1483,7 @@ fn test_float_math() { assert!(res.is_err()); assert!(matches!( res, - Err(PipelineError::SqlError(Operation(OperationError::DivisionByZeroOrOverflow))) + Err(PipelineError::SqlError(OperationError::DivisionByZeroOrOverflow)) )); } // Float % Decimal = Decimal @@ -1494,7 +1492,7 @@ fn test_float_math() { assert!(res.is_err()); assert!(matches!( res, - Err(PipelineError::SqlError(Operation(OperationError::ModuloByZeroOrOverflow))) + Err(PipelineError::SqlError(OperationError::ModuloByZeroOrOverflow)) )); } else if res.is_ok() { @@ -1505,7 +1503,7 @@ fn test_float_math() { assert!(res.is_err()); assert!(matches!( res, - Err(PipelineError::SqlError(Operation(OperationError::ModuloByZeroOrOverflow))) + Err(PipelineError::SqlError(OperationError::ModuloByZeroOrOverflow)) )); } } @@ -1587,7 +1585,7 @@ fn test_decimal_math() { assert!(res.is_err()); assert!(matches!( res, - Err(PipelineError::SqlError(Operation(OperationError::MultiplicationOverflow))) + Err(PipelineError::SqlError(OperationError::MultiplicationOverflow)) )); } // Decimal / UInt = Decimal @@ -1596,7 +1594,7 @@ fn test_decimal_math() { assert!(res.is_err()); assert!(matches!( res, - Err(PipelineError::SqlError(Operation(OperationError::DivisionByZeroOrOverflow))) + Err(PipelineError::SqlError(OperationError::DivisionByZeroOrOverflow)) )); } else if res.is_ok() { @@ -1607,7 +1605,7 @@ fn test_decimal_math() { assert!(res.is_err()); assert!(matches!( res, - Err(PipelineError::SqlError(Operation(OperationError::DivisionByZeroOrOverflow))) + Err(PipelineError::SqlError(OperationError::DivisionByZeroOrOverflow)) )); } // Decimal % UInt = Decimal @@ -1616,7 +1614,7 @@ fn test_decimal_math() { assert!(res.is_err()); assert!(matches!( res, - Err(PipelineError::SqlError(Operation(OperationError::ModuloByZeroOrOverflow))) + Err(PipelineError::SqlError(OperationError::ModuloByZeroOrOverflow)) )); } else if res.is_ok() { @@ -1627,7 +1625,7 @@ fn test_decimal_math() { assert!(res.is_err()); assert!(matches!( res, - Err(PipelineError::SqlError(Operation(OperationError::ModuloByZeroOrOverflow))) + Err(PipelineError::SqlError(OperationError::ModuloByZeroOrOverflow)) )); } @@ -1661,7 +1659,7 @@ fn test_decimal_math() { if !matches!(res, Err(PipelineError::UnableToCast(_, _))) { assert!(matches!( res, - Err(PipelineError::SqlError(Operation(OperationError::MultiplicationOverflow))) + Err(PipelineError::SqlError(OperationError::MultiplicationOverflow)) )); } } @@ -1672,7 +1670,7 @@ fn test_decimal_math() { if !matches!(res, Err(PipelineError::UnableToCast(_, _))) { assert!(matches!( res, - Err(PipelineError::SqlError(Operation(OperationError::DivisionByZeroOrOverflow))) + Err(PipelineError::SqlError(OperationError::DivisionByZeroOrOverflow)) )); } } @@ -1685,7 +1683,7 @@ fn test_decimal_math() { if !matches!(res, Err(PipelineError::UnableToCast(_, _))) { assert!(matches!( res, - Err(PipelineError::SqlError(Operation(OperationError::DivisionByZeroOrOverflow))) + Err(PipelineError::SqlError(OperationError::DivisionByZeroOrOverflow)) )); } } @@ -1696,7 +1694,7 @@ fn test_decimal_math() { if !matches!(res, Err(PipelineError::UnableToCast(_, _))) { assert!(matches!( res, - Err(PipelineError::SqlError(Operation(OperationError::ModuloByZeroOrOverflow))) + Err(PipelineError::SqlError(OperationError::ModuloByZeroOrOverflow)) )); } } @@ -1709,7 +1707,7 @@ fn test_decimal_math() { if !matches!(res, Err(PipelineError::UnableToCast(_, _))) { assert!(matches!( res, - Err(PipelineError::SqlError(Operation(OperationError::ModuloByZeroOrOverflow))) + Err(PipelineError::SqlError(OperationError::ModuloByZeroOrOverflow)) )); } } @@ -1822,7 +1820,7 @@ fn test_decimal_math() { assert!(res.is_err()); assert!(matches!( res, - Err(PipelineError::SqlError(Operation(OperationError::MultiplicationOverflow))) + Err(PipelineError::SqlError(OperationError::MultiplicationOverflow)) )); } // Decimal / Float = Decimal @@ -1831,7 +1829,7 @@ fn test_decimal_math() { assert!(res.is_err()); assert!(matches!( res, - Err(PipelineError::SqlError(Operation(OperationError::DivisionByZeroOrOverflow))) + Err(PipelineError::SqlError(OperationError::DivisionByZeroOrOverflow)) )); } else if res.is_ok() { @@ -1842,7 +1840,7 @@ fn test_decimal_math() { assert!(res.is_err()); assert!(matches!( res, - Err(PipelineError::SqlError(Operation(OperationError::DivisionByZeroOrOverflow))) + Err(PipelineError::SqlError(OperationError::DivisionByZeroOrOverflow)) )); } // Decimal % Float = Decimal @@ -1851,7 +1849,7 @@ fn test_decimal_math() { assert!(res.is_err()); assert!(matches!( res, - Err(PipelineError::SqlError(Operation(OperationError::ModuloByZeroOrOverflow))) + Err(PipelineError::SqlError(OperationError::ModuloByZeroOrOverflow)) )); } else if res.is_ok() { @@ -1862,7 +1860,7 @@ fn test_decimal_math() { assert!(res.is_err()); assert!(matches!( res, - Err(PipelineError::SqlError(Operation(OperationError::ModuloByZeroOrOverflow))) + Err(PipelineError::SqlError(OperationError::ModuloByZeroOrOverflow)) )); } } @@ -1891,7 +1889,7 @@ fn test_decimal_math() { assert!(res.is_err()); assert!(matches!( res, - Err(PipelineError::SqlError(Operation(OperationError::MultiplicationOverflow))) + Err(PipelineError::SqlError(OperationError::MultiplicationOverflow)) )); } // Decimal / Decimal = Decimal @@ -1900,7 +1898,7 @@ fn test_decimal_math() { assert!(res.is_err()); assert!(matches!( res, - Err(PipelineError::SqlError(Operation(OperationError::DivisionByZeroOrOverflow))) + Err(PipelineError::SqlError(OperationError::DivisionByZeroOrOverflow)) )); } else if res.is_ok() { @@ -1911,7 +1909,7 @@ fn test_decimal_math() { assert!(res.is_err()); assert!(matches!( res, - Err(PipelineError::SqlError(Operation(OperationError::DivisionByZeroOrOverflow))) + Err(PipelineError::SqlError(OperationError::DivisionByZeroOrOverflow)) )); } // Decimal % Decimal = Decimal @@ -1920,7 +1918,7 @@ fn test_decimal_math() { assert!(res.is_err()); assert!(matches!( res, - Err(PipelineError::SqlError(Operation(OperationError::ModuloByZeroOrOverflow))) + Err(PipelineError::SqlError(OperationError::ModuloByZeroOrOverflow)) )); } else if res.is_ok() { @@ -1931,7 +1929,7 @@ fn test_decimal_math() { assert!(res.is_err()); assert!(matches!( res, - Err(PipelineError::SqlError(Operation(OperationError::ModuloByZeroOrOverflow))) + Err(PipelineError::SqlError(OperationError::ModuloByZeroOrOverflow)) )); } @@ -2214,3 +2212,247 @@ fn test_null_math() { ); }) } + +#[test] +fn test_timestamp_difference() { + let schema = Schema::default() + .field( + FieldDefinition::new( + String::from("a"), + FieldType::Timestamp, + false, + SourceDefinition::Dynamic, + ), + true, + ) + .field( + FieldDefinition::new( + String::from("b"), + FieldType::Timestamp, + false, + SourceDefinition::Dynamic, + ), + false, + ) + .clone(); + + let record = Record::new(vec![ + Field::Timestamp(DateTime::parse_from_rfc3339("2020-01-01T00:13:00Z").unwrap()), + Field::Timestamp(DateTime::parse_from_rfc3339("2020-01-01T00:12:10Z").unwrap()), + ]); + + let result = evaluate_sub( + &schema, + &Expression::Column { index: 0 }, + &Expression::Column { index: 1 }, + &record, + ) + .unwrap(); + assert_eq!( + result, + Field::Duration(DozerDuration( + std::time::Duration::from_nanos(50000 * 1000 * 1000), + TimeUnit::Nanoseconds + )) + ); + + let result = evaluate_sub( + &schema, + &Expression::Column { index: 1 }, + &Expression::Column { index: 0 }, + &record, + ); + assert!(result.is_err()); +} + +#[test] +fn test_duration() { + proptest!( + ProptestConfig::with_cases(1000), + move |(d1: u64, d2: u64, dt1: ArbitraryDateTime)| { + test_duration_math(d1, d2, dt1) + }); +} + +fn test_duration_math(d1: u64, d2: u64, dt1: ArbitraryDateTime) { + let row = Record::new(vec![]); + + let v = Expression::Literal(Field::Date(dt1.0.date_naive())); + let dur1 = Expression::Literal(Field::Duration(DozerDuration( + std::time::Duration::from_nanos(d1), + TimeUnit::Nanoseconds, + ))); + let dur2 = Expression::Literal(Field::Duration(DozerDuration( + std::time::Duration::from_nanos(d2), + TimeUnit::Nanoseconds, + ))); + + // Duration + Duration = Duration + let result = evaluate_add(&Schema::default(), &dur1, &dur2, &row); + let sum = std::time::Duration::from_nanos(d1).checked_add(std::time::Duration::from_nanos(d2)); + if result.is_ok() && sum.is_some() { + assert_eq!( + result.unwrap(), + Field::Duration(DozerDuration(sum.unwrap(), TimeUnit::Nanoseconds)) + ); + } + // Duration - Duration = Duration + let result = evaluate_sub(&Schema::default(), &dur1, &dur2, &row); + let diff = std::time::Duration::from_nanos(d1).checked_sub(std::time::Duration::from_nanos(d2)); + if result.is_ok() && diff.is_some() { + assert_eq!( + result.unwrap(), + Field::Duration(DozerDuration(diff.unwrap(), TimeUnit::Nanoseconds)) + ); + } + // Duration * Duration = Error + let result = evaluate_mul(&Schema::default(), &dur1, &dur2, &row); + assert!(result.is_err()); + // Duration / Duration = Error + let result = evaluate_div(&Schema::default(), &dur1, &dur2, &row); + assert!(result.is_err()); + // Duration % Duration = Error + let result = evaluate_mod(&Schema::default(), &dur1, &dur2, &row); + assert!(result.is_err()); + + // Duration + Timestamp = Error + let result = evaluate_add(&Schema::default(), &dur1, &v, &row); + assert!(result.is_err()); + // Duration - Timestamp = Error + let result = evaluate_sub(&Schema::default(), &dur1, &v, &row); + assert!(result.is_err()); + // Duration * Timestamp = Error + let result = evaluate_mul(&Schema::default(), &dur1, &v, &row); + assert!(result.is_err()); + // Duration / Timestamp = Error + let result = evaluate_div(&Schema::default(), &dur1, &v, &row); + assert!(result.is_err()); + // Duration % Timestamp = Error + let result = evaluate_mod(&Schema::default(), &dur1, &v, &row); + assert!(result.is_err()); + + // Timestamp + Duration = Timestamp + let result = evaluate_add(&Schema::default(), &v, &dur1, &row); + let sum = dt1 + .0 + .checked_add_signed(chrono::Duration::nanoseconds(d1 as i64)); + if result.is_ok() && sum.is_some() { + assert_eq!(result.unwrap(), Field::Timestamp(sum.unwrap())); + } + // Timestamp - Duration = Timestamp + let result = evaluate_sub(&Schema::default(), &v, &dur2, &row); + let diff = dt1 + .0 + .checked_sub_signed(chrono::Duration::nanoseconds(d2 as i64)); + if result.is_ok() && diff.is_some() { + assert_eq!(result.unwrap(), Field::Timestamp(diff.unwrap())); + } + // Timestamp * Duration = Error + let result = evaluate_mul(&Schema::default(), &v, &dur1, &row); + assert!(result.is_err()); + // Timestamp / Duration = Error + let result = evaluate_div(&Schema::default(), &v, &dur1, &row); + assert!(result.is_err()); + // Timestamp % Duration = Error + let result = evaluate_mod(&Schema::default(), &v, &dur1, &row); + assert!(result.is_err()); +} + +#[test] +fn test_decimal() { + let dec1 = Box::new(Literal(Field::Decimal(Decimal::from_i64(1_i64).unwrap()))); + let dec2 = Box::new(Literal(Field::Decimal(Decimal::from_i64(2_i64).unwrap()))); + let float1 = Box::new(Literal(Field::Float( + OrderedFloat::::from_i64(1_i64).unwrap(), + ))); + let float2 = Box::new(Literal(Field::Float( + OrderedFloat::::from_i64(2_i64).unwrap(), + ))); + let int1 = Box::new(Literal(Field::Int(1_i64))); + let int2 = Box::new(Literal(Field::Int(2_i64))); + let uint1 = Box::new(Literal(Field::UInt(1_u64))); + let uint2 = Box::new(Literal(Field::UInt(2_u64))); + + let row = Record::new(vec![]); + + // left: Int, right: Decimal + assert_eq!( + evaluate_add(&Schema::default(), &int1, dec1.as_ref(), &row) + .unwrap_or_else(|e| panic!("{}", e.to_string())), + Field::Decimal(Decimal::from_i64(2_i64).unwrap()) + ); + assert_eq!( + evaluate_sub(&Schema::default(), &int1, dec1.as_ref(), &row) + .unwrap_or_else(|e| panic!("{}", e.to_string())), + Field::Decimal(Decimal::from_i64(0_i64).unwrap()) + ); + assert_eq!( + evaluate_mul(&Schema::default(), &int2, dec1.as_ref(), &row) + .unwrap_or_else(|e| panic!("{}", e.to_string())), + Field::Decimal(Decimal::from_i64(2_i64).unwrap()) + ); + assert_eq!( + evaluate_div(&Schema::default(), &int1, dec2.as_ref(), &row) + .unwrap_or_else(|e| panic!("{}", e.to_string())), + Field::Decimal(Decimal::from_f64(0.5).unwrap()) + ); + assert_eq!( + evaluate_mod(&Schema::default(), &int1, dec1.as_ref(), &row) + .unwrap_or_else(|e| panic!("{}", e.to_string())), + Field::Decimal(Decimal::from_i64(0_i64).unwrap()) + ); + + // left: UInt, right: Decimal + assert_eq!( + evaluate_add(&Schema::default(), &uint1, dec1.as_ref(), &row) + .unwrap_or_else(|e| panic!("{}", e.to_string())), + Field::Decimal(Decimal::from_i64(2_i64).unwrap()) + ); + assert_eq!( + evaluate_sub(&Schema::default(), &uint1, dec1.as_ref(), &row) + .unwrap_or_else(|e| panic!("{}", e.to_string())), + Field::Decimal(Decimal::from_i64(0_i64).unwrap()) + ); + assert_eq!( + evaluate_mul(&Schema::default(), &uint2, dec1.as_ref(), &row) + .unwrap_or_else(|e| panic!("{}", e.to_string())), + Field::Decimal(Decimal::from_i64(2_i64).unwrap()) + ); + assert_eq!( + evaluate_div(&Schema::default(), &uint1, dec2.as_ref(), &row) + .unwrap_or_else(|e| panic!("{}", e.to_string())), + Field::Decimal(Decimal::from_f64(0.5).unwrap()) + ); + assert_eq!( + evaluate_mod(&Schema::default(), &uint1, dec1.as_ref(), &row) + .unwrap_or_else(|e| panic!("{}", e.to_string())), + Field::Decimal(Decimal::from_i64(0_i64).unwrap()) + ); + + // left: Float, right: Decimal + assert_eq!( + evaluate_add(&Schema::default(), &float1, dec1.as_ref(), &row) + .unwrap_or_else(|e| panic!("{}", e.to_string())), + Field::Decimal(Decimal::from_i64(2_i64).unwrap()) + ); + assert_eq!( + evaluate_sub(&Schema::default(), &float1, dec1.as_ref(), &row) + .unwrap_or_else(|e| panic!("{}", e.to_string())), + Field::Decimal(Decimal::from_i64(0_i64).unwrap()) + ); + assert_eq!( + evaluate_mul(&Schema::default(), &float2, dec1.as_ref(), &row) + .unwrap_or_else(|e| panic!("{}", e.to_string())), + Field::Decimal(Decimal::from_i64(2_i64).unwrap()) + ); + assert_eq!( + evaluate_div(&Schema::default(), &float1, dec2.as_ref(), &row) + .unwrap_or_else(|e| panic!("{}", e.to_string())), + Field::Decimal(Decimal::from_f64(0.5).unwrap()) + ); + assert_eq!( + evaluate_mod(&Schema::default(), &float1, dec1.as_ref(), &row) + .unwrap_or_else(|e| panic!("{}", e.to_string())), + Field::Decimal(Decimal::from_i64(0_i64).unwrap()) + ); +} diff --git a/dozer-sql/src/pipeline/onnx.rs b/dozer-sql/expression/src/onnx/error.rs similarity index 65% rename from dozer-sql/src/pipeline/onnx.rs rename to dozer-sql/expression/src/onnx/error.rs index ed90ad972b..2765f25fcb 100644 --- a/dozer-sql/src/pipeline/onnx.rs +++ b/dozer-sql/expression/src/onnx/error.rs @@ -1,27 +1,18 @@ -use crate::pipeline::expression::execution::Expression; -use dozer_types::thiserror; -use dozer_types::thiserror::Error; -use dozer_types::types::{Field, FieldType}; +use dozer_types::{ + thiserror::{self, Error}, + types::{Field, FieldType}, +}; use ndarray::ShapeError; -use ort::tensor::TensorElementDataType; -use ort::OrtError; +use ort::{tensor::TensorElementDataType, OrtError}; -#[derive(Clone, Debug)] -pub struct DozerSession(pub std::sync::Arc); - -#[cfg(feature = "onnx")] -impl PartialEq for DozerSession { - fn eq(&self, other: &Self) -> bool { - std::ptr::eq(self as *const _, other as *const _) - } -} +use crate::execution::Expression; #[derive(Error, Debug)] -pub enum OnnxError { +pub enum Error { #[error("Onnx Ndarray Shape Error: {0}")] - OnnxShapeErr(ShapeError), + OnnxShapeErr(#[from] ShapeError), #[error("Onnx Runtime Error: {0}")] - OnnxOrtErr(OrtError), + OnnxOrtErr(#[from] OrtError), #[error("Dozer expect onnx model to ingest single 1d input tensor: size of input {0}")] OnnxInputSizeErr(usize), #[error("Expected model input shape {0} doesn't match with actual input shape {1}")] @@ -37,7 +28,9 @@ pub enum OnnxError { #[error("Dozer doesn't support following output datatype {0:?}")] OnnxNotSupportedDataTypeErr(TensorElementDataType), #[error("Dozer can't find following column in the input schema {0:?}")] - ColumnNotFoundError(Expression), + ColumnNotFound(Expression), #[error("Dozer doesn't support non-column for onnx arguments {0:?}")] - NonColumnArgFoundError(Expression), + NonColumnArgFound(Expression), + #[error("Input argument overflow for {1:?}: {0}")] + InputArgumentOverflow(Field, TensorElementDataType), } diff --git a/dozer-sql/expression/src/onnx/mod.rs b/dozer-sql/expression/src/onnx/mod.rs new file mode 100644 index 0000000000..ce74b51382 --- /dev/null +++ b/dozer-sql/expression/src/onnx/mod.rs @@ -0,0 +1,12 @@ +pub mod error; +pub mod udf; +pub mod utils; + +#[derive(Clone, Debug)] +pub struct DozerSession(pub std::sync::Arc); + +impl PartialEq for DozerSession { + fn eq(&self, other: &Self) -> bool { + std::ptr::eq(self as *const _, other as *const _) + } +} diff --git a/dozer-sql/src/pipeline/expression/onnx/onnx_udf.rs b/dozer-sql/expression/src/onnx/udf.rs similarity index 70% rename from dozer-sql/src/pipeline/expression/onnx/onnx_udf.rs rename to dozer-sql/expression/src/onnx/udf.rs index 6d9d978c75..5a01b4ace0 100644 --- a/dozer-sql/src/pipeline/expression/onnx/onnx_udf.rs +++ b/dozer-sql/expression/src/onnx/udf.rs @@ -1,10 +1,9 @@ -use crate::pipeline::errors::PipelineError; -use crate::pipeline::errors::PipelineError::{InvalidType, OnnxError}; -use crate::pipeline::expression::execution::Expression; -use crate::pipeline::onnx::OnnxError::{ - OnnxInputDataMismatchErr, OnnxInvalidInputShapeErr, OnnxNotSupportedDataTypeErr, OnnxOrtErr, - OnnxShapeErr, +use super::error::Error::{ + InputArgumentOverflow, OnnxInputDataMismatchErr, OnnxInvalidInputShapeErr, + OnnxNotSupportedDataTypeErr, OnnxOrtErr, OnnxShapeErr, }; +use crate::error::Error::{self, Onnx}; +use crate::execution::Expression; use dozer_types::log::warn; use dozer_types::ordered_float::OrderedFloat; use dozer_types::types::{Field, Record, Schema}; @@ -21,24 +20,24 @@ pub fn evaluate_onnx_udf( session: &Session, args: &[Expression], record: &Record, -) -> Result { +) -> Result { let input_values = args .iter() .map(|arg| arg.evaluate(record, schema)) - .collect::, PipelineError>>()?; + .collect::, Error>>()?; let mut input_shape = vec![]; for d in session.inputs[0].dimensions() { match d { Some(v) => input_shape.push(v), - None => return Err(OnnxError(OnnxInvalidInputShapeErr)), + None => return Err(Onnx(OnnxInvalidInputShapeErr)), } } let mut output_shape = vec![]; for d in session.outputs[0].dimensions() { match d { Some(v) => output_shape.push(v), - None => return Err(OnnxError(OnnxInvalidInputShapeErr)), + None => return Err(Onnx(OnnxInvalidInputShapeErr)), } } let input_type = session.inputs[0].input_type; @@ -53,24 +52,25 @@ pub fn evaluate_onnx_udf( let num = match f32::from_f64(*v) { Some(val) => val, None => { - return Err(InvalidType(field.clone(), format!("{:?}", return_type))) + return Err(Onnx(InputArgumentOverflow(field.clone(), return_type))) } }; input_array.push(num); } else { - return Err(OnnxError(OnnxInputDataMismatchErr(input_type, field))); + return Err(Onnx(OnnxInputDataMismatchErr(input_type, field))); } } let array = ndarray::CowArray::from( Array::from_shape_vec(input_shape.clone(), input_array) - .map_err(|e| OnnxError(OnnxShapeErr(e)))? + .map_err(|e| Onnx(OnnxShapeErr(e)))? .into_dyn(), ); - let input_tensor_values = vec![Value::from_array(session.allocator(), &array) - .map_err(|e| OnnxError(OnnxOrtErr(e)))?]; + let input_tensor_values = + vec![Value::from_array(session.allocator(), &array) + .map_err(|e| Onnx(OnnxOrtErr(e)))?]; let outputs: Vec = session .run(input_tensor_values) - .map_err(|e| OnnxError(OnnxOrtErr(e)))?; + .map_err(|e| Onnx(OnnxOrtErr(e)))?; let output = outputs[0].borrow(); // number of output validation @@ -83,19 +83,20 @@ pub fn evaluate_onnx_udf( if let Field::Float(v) = field { input_array.push(*v); } else { - return Err(OnnxError(OnnxInputDataMismatchErr(input_type, field))); + return Err(Onnx(OnnxInputDataMismatchErr(input_type, field))); } } let array = ndarray::CowArray::from( Array::from_shape_vec(input_shape.clone(), input_array) - .map_err(|e| OnnxError(OnnxShapeErr(e)))? + .map_err(|e| Onnx(OnnxShapeErr(e)))? .into_dyn(), ); - let input_tensor_values = vec![Value::from_array(session.allocator(), &array) - .map_err(|e| OnnxError(OnnxOrtErr(e)))?]; + let input_tensor_values = + vec![Value::from_array(session.allocator(), &array) + .map_err(|e| Onnx(OnnxOrtErr(e)))?]; let outputs: Vec = session .run(input_tensor_values) - .map_err(|e| OnnxError(OnnxOrtErr(e)))?; + .map_err(|e| Onnx(OnnxOrtErr(e)))?; let output = outputs[0].borrow(); // number of output validation @@ -110,7 +111,7 @@ pub fn evaluate_onnx_udf( let num = match u8::from_u64(v) { Some(val) => val, None => { - return Err(InvalidType(field.clone(), format!("{:?}", return_type))) + return Err(Onnx(InputArgumentOverflow(field.clone(), return_type))) } }; input_array.push(num); @@ -119,24 +120,25 @@ pub fn evaluate_onnx_udf( let num = match u8::from_u128(v) { Some(val) => val, None => { - return Err(InvalidType(field.clone(), format!("{:?}", return_type))) + return Err(Onnx(InputArgumentOverflow(field.clone(), return_type))) } }; input_array.push(num); } else { - return Err(OnnxError(OnnxInputDataMismatchErr(input_type, field))); + return Err(Onnx(OnnxInputDataMismatchErr(input_type, field))); } } let array = ndarray::CowArray::from( Array::from_shape_vec(input_shape.clone(), input_array) - .map_err(|e| OnnxError(OnnxShapeErr(e)))? + .map_err(|e| Onnx(OnnxShapeErr(e)))? .into_dyn(), ); - let input_tensor_values = vec![Value::from_array(session.allocator(), &array) - .map_err(|e| OnnxError(OnnxOrtErr(e)))?]; + let input_tensor_values = + vec![Value::from_array(session.allocator(), &array) + .map_err(|e| Onnx(OnnxOrtErr(e)))?]; let outputs: Vec = session .run(input_tensor_values) - .map_err(|e| OnnxError(OnnxOrtErr(e)))?; + .map_err(|e| Onnx(OnnxOrtErr(e)))?; let output = outputs[0].borrow(); // number of output validation @@ -151,7 +153,7 @@ pub fn evaluate_onnx_udf( let num = match u16::from_u64(v) { Some(val) => val, None => { - return Err(InvalidType(field.clone(), format!("{:?}", return_type))) + return Err(Onnx(InputArgumentOverflow(field.clone(), return_type))) } }; input_array.push(num); @@ -160,24 +162,25 @@ pub fn evaluate_onnx_udf( let num = match u16::from_u128(v) { Some(val) => val, None => { - return Err(InvalidType(field.clone(), format!("{:?}", return_type))) + return Err(Onnx(InputArgumentOverflow(field.clone(), return_type))) } }; input_array.push(num); } else { - return Err(OnnxError(OnnxInputDataMismatchErr(input_type, field))); + return Err(Onnx(OnnxInputDataMismatchErr(input_type, field))); } } let array = ndarray::CowArray::from( Array::from_shape_vec(input_shape.clone(), input_array) - .map_err(|e| OnnxError(OnnxShapeErr(e)))? + .map_err(|e| Onnx(OnnxShapeErr(e)))? .into_dyn(), ); - let input_tensor_values = vec![Value::from_array(session.allocator(), &array) - .map_err(|e| OnnxError(OnnxOrtErr(e)))?]; + let input_tensor_values = + vec![Value::from_array(session.allocator(), &array) + .map_err(|e| Onnx(OnnxOrtErr(e)))?]; let outputs: Vec = session .run(input_tensor_values) - .map_err(|e| OnnxError(OnnxOrtErr(e)))?; + .map_err(|e| Onnx(OnnxOrtErr(e)))?; let output = outputs[0].borrow(); // number of output validation @@ -192,7 +195,7 @@ pub fn evaluate_onnx_udf( let num = match u32::from_u64(v) { Some(val) => val, None => { - return Err(InvalidType(field.clone(), format!("{:?}", return_type))) + return Err(Onnx(InputArgumentOverflow(field.clone(), return_type))) } }; input_array.push(num); @@ -201,24 +204,25 @@ pub fn evaluate_onnx_udf( let num = match u32::from_u128(v) { Some(val) => val, None => { - return Err(InvalidType(field.clone(), format!("{:?}", return_type))) + return Err(Onnx(InputArgumentOverflow(field.clone(), return_type))) } }; input_array.push(num); } else { - return Err(OnnxError(OnnxInputDataMismatchErr(input_type, field))); + return Err(Onnx(OnnxInputDataMismatchErr(input_type, field))); } } let array = ndarray::CowArray::from( Array::from_shape_vec(input_shape.clone(), input_array) - .map_err(|e| OnnxError(OnnxShapeErr(e)))? + .map_err(|e| Onnx(OnnxShapeErr(e)))? .into_dyn(), ); - let input_tensor_values = vec![Value::from_array(session.allocator(), &array) - .map_err(|e| OnnxError(OnnxOrtErr(e)))?]; + let input_tensor_values = + vec![Value::from_array(session.allocator(), &array) + .map_err(|e| Onnx(OnnxOrtErr(e)))?]; let outputs: Vec = session .run(input_tensor_values) - .map_err(|e| OnnxError(OnnxOrtErr(e)))?; + .map_err(|e| Onnx(OnnxOrtErr(e)))?; let output = outputs[0].borrow(); // number of output validation @@ -235,24 +239,25 @@ pub fn evaluate_onnx_udf( let num = match u64::from_u128(v) { Some(val) => val, None => { - return Err(InvalidType(field.clone(), format!("{:?}", return_type))) + return Err(Onnx(InputArgumentOverflow(field.clone(), return_type))) } }; input_array.push(num); } else { - return Err(OnnxError(OnnxInputDataMismatchErr(input_type, field))); + return Err(Onnx(OnnxInputDataMismatchErr(input_type, field))); } } let array = ndarray::CowArray::from( Array::from_shape_vec(input_shape.clone(), input_array) - .map_err(|e| OnnxError(OnnxShapeErr(e)))? + .map_err(|e| Onnx(OnnxShapeErr(e)))? .into_dyn(), ); - let input_tensor_values = vec![Value::from_array(session.allocator(), &array) - .map_err(|e| OnnxError(OnnxOrtErr(e)))?]; + let input_tensor_values = + vec![Value::from_array(session.allocator(), &array) + .map_err(|e| Onnx(OnnxOrtErr(e)))?]; let outputs: Vec = session .run(input_tensor_values) - .map_err(|e| OnnxError(OnnxOrtErr(e)))?; + .map_err(|e| Onnx(OnnxOrtErr(e)))?; let output = outputs[0].borrow(); // number of output validation @@ -267,7 +272,7 @@ pub fn evaluate_onnx_udf( let num = match i8::from_i64(v) { Some(val) => val, None => { - return Err(InvalidType(field.clone(), format!("{:?}", return_type))) + return Err(Onnx(InputArgumentOverflow(field.clone(), return_type))) } }; input_array.push(num); @@ -276,24 +281,25 @@ pub fn evaluate_onnx_udf( let num = match i8::from_i128(v) { Some(val) => val, None => { - return Err(InvalidType(field.clone(), format!("{:?}", return_type))) + return Err(Onnx(InputArgumentOverflow(field.clone(), return_type))) } }; input_array.push(num); } else { - return Err(OnnxError(OnnxInputDataMismatchErr(input_type, field))); + return Err(Onnx(OnnxInputDataMismatchErr(input_type, field))); } } let array = ndarray::CowArray::from( Array::from_shape_vec(input_shape.clone(), input_array) - .map_err(|e| OnnxError(OnnxShapeErr(e)))? + .map_err(|e| Onnx(OnnxShapeErr(e)))? .into_dyn(), ); - let input_tensor_values = vec![Value::from_array(session.allocator(), &array) - .map_err(|e| OnnxError(OnnxOrtErr(e)))?]; + let input_tensor_values = + vec![Value::from_array(session.allocator(), &array) + .map_err(|e| Onnx(OnnxOrtErr(e)))?]; let outputs: Vec = session .run(input_tensor_values) - .map_err(|e| OnnxError(OnnxOrtErr(e)))?; + .map_err(|e| Onnx(OnnxOrtErr(e)))?; let output = outputs[0].borrow(); // number of output validation @@ -308,7 +314,7 @@ pub fn evaluate_onnx_udf( let num = match i16::from_i64(v) { Some(val) => val, None => { - return Err(InvalidType(field.clone(), format!("{:?}", return_type))) + return Err(Onnx(InputArgumentOverflow(field.clone(), return_type))) } }; input_array.push(num); @@ -317,24 +323,25 @@ pub fn evaluate_onnx_udf( let num = match i16::from_i128(v) { Some(val) => val, None => { - return Err(InvalidType(field.clone(), format!("{:?}", return_type))) + return Err(Onnx(InputArgumentOverflow(field.clone(), return_type))) } }; input_array.push(num); } else { - return Err(OnnxError(OnnxInputDataMismatchErr(input_type, field))); + return Err(Onnx(OnnxInputDataMismatchErr(input_type, field))); } } let array = ndarray::CowArray::from( Array::from_shape_vec(input_shape.clone(), input_array) - .map_err(|e| OnnxError(OnnxShapeErr(e)))? + .map_err(|e| Onnx(OnnxShapeErr(e)))? .into_dyn(), ); - let input_tensor_values = vec![Value::from_array(session.allocator(), &array) - .map_err(|e| OnnxError(OnnxOrtErr(e)))?]; + let input_tensor_values = + vec![Value::from_array(session.allocator(), &array) + .map_err(|e| Onnx(OnnxOrtErr(e)))?]; let outputs: Vec = session .run(input_tensor_values) - .map_err(|e| OnnxError(OnnxOrtErr(e)))?; + .map_err(|e| Onnx(OnnxOrtErr(e)))?; let output = outputs[0].borrow(); // number of output validation @@ -349,7 +356,7 @@ pub fn evaluate_onnx_udf( let num = match i32::from_i64(v) { Some(val) => val, None => { - return Err(InvalidType(field.clone(), format!("{:?}", return_type))) + return Err(Onnx(InputArgumentOverflow(field.clone(), return_type))) } }; input_array.push(num); @@ -358,24 +365,25 @@ pub fn evaluate_onnx_udf( let num = match i32::from_i128(v) { Some(val) => val, None => { - return Err(InvalidType(field.clone(), format!("{:?}", return_type))) + return Err(Onnx(InputArgumentOverflow(field.clone(), return_type))) } }; input_array.push(num); } else { - return Err(OnnxError(OnnxInputDataMismatchErr(input_type, field))); + return Err(Onnx(OnnxInputDataMismatchErr(input_type, field))); } } let array = ndarray::CowArray::from( Array::from_shape_vec(input_shape.clone(), input_array) - .map_err(|e| OnnxError(OnnxShapeErr(e)))? + .map_err(|e| Onnx(OnnxShapeErr(e)))? .into_dyn(), ); - let input_tensor_values = vec![Value::from_array(session.allocator(), &array) - .map_err(|e| OnnxError(OnnxOrtErr(e)))?]; + let input_tensor_values = + vec![Value::from_array(session.allocator(), &array) + .map_err(|e| Onnx(OnnxOrtErr(e)))?]; let outputs: Vec = session .run(input_tensor_values) - .map_err(|e| OnnxError(OnnxOrtErr(e)))?; + .map_err(|e| Onnx(OnnxOrtErr(e)))?; let output = outputs[0].borrow(); // number of output validation @@ -392,24 +400,25 @@ pub fn evaluate_onnx_udf( let num = match i64::from_i128(v) { Some(val) => val, None => { - return Err(InvalidType(field.clone(), format!("{:?}", return_type))) + return Err(Onnx(InputArgumentOverflow(field.clone(), return_type))) } }; input_array.push(num); } else { - return Err(OnnxError(OnnxInputDataMismatchErr(input_type, field))); + return Err(Onnx(OnnxInputDataMismatchErr(input_type, field))); } } let array = ndarray::CowArray::from( Array::from_shape_vec(input_shape.clone(), input_array) - .map_err(|e| OnnxError(OnnxShapeErr(e)))? + .map_err(|e| Onnx(OnnxShapeErr(e)))? .into_dyn(), ); - let input_tensor_values = vec![Value::from_array(session.allocator(), &array) - .map_err(|e| OnnxError(OnnxOrtErr(e)))?]; + let input_tensor_values = + vec![Value::from_array(session.allocator(), &array) + .map_err(|e| Onnx(OnnxOrtErr(e)))?]; let outputs: Vec = session .run(input_tensor_values) - .map_err(|e| OnnxError(OnnxOrtErr(e)))?; + .map_err(|e| Onnx(OnnxOrtErr(e)))?; let output = outputs[0].borrow(); // number of output validation @@ -424,19 +433,20 @@ pub fn evaluate_onnx_udf( } else if let Field::Text(v) = field { input_array.push(v); } else { - return Err(OnnxError(OnnxInputDataMismatchErr(input_type, field))); + return Err(Onnx(OnnxInputDataMismatchErr(input_type, field))); } } let array = ndarray::CowArray::from( Array::from_shape_vec(input_shape.clone(), input_array) - .map_err(|e| OnnxError(OnnxShapeErr(e)))? + .map_err(|e| Onnx(OnnxShapeErr(e)))? .into_dyn(), ); - let input_tensor_values = vec![Value::from_array(session.allocator(), &array) - .map_err(|e| OnnxError(OnnxOrtErr(e)))?]; + let input_tensor_values = + vec![Value::from_array(session.allocator(), &array) + .map_err(|e| Onnx(OnnxOrtErr(e)))?]; let outputs: Vec = session .run(input_tensor_values) - .map_err(|e| OnnxError(OnnxOrtErr(e)))?; + .map_err(|e| Onnx(OnnxOrtErr(e)))?; let output = outputs[0].borrow(); // number of output validation @@ -449,26 +459,27 @@ pub fn evaluate_onnx_udf( if let Field::Boolean(v) = field { input_array.push(v); } else { - return Err(OnnxError(OnnxInputDataMismatchErr(input_type, field))); + return Err(Onnx(OnnxInputDataMismatchErr(input_type, field))); } } let array = ndarray::CowArray::from( Array::from_shape_vec(input_shape.clone(), input_array) - .map_err(|e| OnnxError(OnnxShapeErr(e)))? + .map_err(|e| Onnx(OnnxShapeErr(e)))? .into_dyn(), ); - let input_tensor_values = vec![Value::from_array(session.allocator(), &array) - .map_err(|e| OnnxError(OnnxOrtErr(e)))?]; + let input_tensor_values = + vec![Value::from_array(session.allocator(), &array) + .map_err(|e| Onnx(OnnxOrtErr(e)))?]; let outputs: Vec = session .run(input_tensor_values) - .map_err(|e| OnnxError(OnnxOrtErr(e)))?; + .map_err(|e| Onnx(OnnxOrtErr(e)))?; let output = outputs[0].borrow(); // number of output validation assert_eq!(outputs.len(), 1); onnx_output_to_dozer(return_type, output, output_shape) } - _ => Err(OnnxError(OnnxNotSupportedDataTypeErr(input_type))), + _ => Err(Onnx(OnnxNotSupportedDataTypeErr(input_type))), } } @@ -476,12 +487,12 @@ fn onnx_output_to_dozer( return_type: TensorElementDataType, output: &Value, output_shape: Vec, -) -> Result { +) -> Result { match return_type { TensorElementDataType::Float16 => { let output_array_view = output .try_extract::() - .map_err(|e| OnnxError(OnnxOrtErr(e)))?; + .map_err(|e| Onnx(OnnxOrtErr(e)))?; assert_eq!(output_array_view.view().shape(), output_shape); Ok(Field::Float(OrderedFloat( output_array_view.view().deref()[0].into(), @@ -490,7 +501,7 @@ fn onnx_output_to_dozer( TensorElementDataType::Float32 => { let output_array_view = output .try_extract::() - .map_err(|e| OnnxError(OnnxOrtErr(e)))?; + .map_err(|e| Onnx(OnnxOrtErr(e)))?; assert_eq!(output_array_view.view().shape(), output_shape); let view = output_array_view.view(); let result = view.deref()[0].into(); @@ -499,12 +510,12 @@ fn onnx_output_to_dozer( TensorElementDataType::Float64 => { let output_array_view = output .try_extract::() - .map_err(|e| OnnxError(OnnxOrtErr(e)))?; + .map_err(|e| Onnx(OnnxOrtErr(e)))?; assert_eq!(output_array_view.view().shape(), output_shape); Ok(Field::Float(OrderedFloat( output_array_view.view().deref()[0], ))) } - _ => Err(OnnxError(OnnxNotSupportedDataTypeErr(return_type))), + _ => Err(Onnx(OnnxNotSupportedDataTypeErr(return_type))), } } diff --git a/dozer-sql/src/pipeline/expression/onnx/onnx_utils.rs b/dozer-sql/expression/src/onnx/utils.rs similarity index 75% rename from dozer-sql/src/pipeline/expression/onnx/onnx_utils.rs rename to dozer-sql/expression/src/onnx/utils.rs index 705768d25a..d87f13dd30 100644 --- a/dozer-sql/src/pipeline/expression/onnx/onnx_utils.rs +++ b/dozer-sql/expression/src/onnx/utils.rs @@ -1,10 +1,9 @@ -use crate::pipeline::errors::PipelineError; -use crate::pipeline::errors::PipelineError::OnnxError; -use crate::pipeline::expression::execution::Expression; -use crate::pipeline::onnx::OnnxError::{ - ColumnNotFoundError, NonColumnArgFoundError, OnnxInputDataTypeMismatchErr, OnnxInputShapeErr, +use super::error::Error::{ + ColumnNotFound, NonColumnArgFound, OnnxInputDataTypeMismatchErr, OnnxInputShapeErr, OnnxInputSizeErr, OnnxNotSupportedDataTypeErr, OnnxOutputShapeErr, }; +use crate::error::Error::{self, Onnx}; +use crate::execution::Expression; use dozer_types::arrow::datatypes::ArrowNativeTypeOp; use dozer_types::types::{FieldType, Schema}; use ort::session::{Input, Output}; @@ -14,10 +13,10 @@ pub fn onnx_input_validation( schema: &Schema, args: &Vec, inputs: &Vec, -) -> Result<(), PipelineError> { +) -> Result<(), Error> { // 1. number of input & input shape check if inputs.len() != 1 { - return Err(OnnxError(OnnxInputSizeErr(inputs.len()))); + return Err(Onnx(OnnxInputSizeErr(inputs.len()))); } let mut flattened = 1_u32; let dim = inputs[0].dimensions.clone(); @@ -30,7 +29,7 @@ pub fn onnx_input_validation( } } if flattened as usize != args.len() || inputs.len() != 1 { - return Err(OnnxError(OnnxInputShapeErr(flattened as usize, args.len()))); + return Err(Onnx(OnnxInputShapeErr(flattened as usize, args.len()))); } // 2. input datatype check for (input, arg) in inputs.iter().zip(args) { @@ -39,7 +38,7 @@ pub fn onnx_input_validation( Some(def) => match input.input_type { TensorElementDataType::Float32 | TensorElementDataType::Float64 => { if def.typ != FieldType::Float { - return Err(OnnxError(OnnxInputDataTypeMismatchErr( + return Err(Onnx(OnnxInputDataTypeMismatchErr( input.input_type, def.typ, ))); @@ -50,7 +49,7 @@ pub fn onnx_input_validation( | TensorElementDataType::Uint32 | TensorElementDataType::Uint64 => { if def.typ != FieldType::UInt && def.typ != FieldType::U128 { - return Err(OnnxError(OnnxInputDataTypeMismatchErr( + return Err(Onnx(OnnxInputDataTypeMismatchErr( input.input_type, def.typ, ))); @@ -61,7 +60,7 @@ pub fn onnx_input_validation( | TensorElementDataType::Int32 | TensorElementDataType::Int64 => { if def.typ != FieldType::Int && def.typ != FieldType::I128 { - return Err(OnnxError(OnnxInputDataTypeMismatchErr( + return Err(Onnx(OnnxInputDataTypeMismatchErr( input.input_type, def.typ, ))); @@ -69,7 +68,7 @@ pub fn onnx_input_validation( } TensorElementDataType::String => { if def.typ != FieldType::String && def.typ != FieldType::Text { - return Err(OnnxError(OnnxInputDataTypeMismatchErr( + return Err(Onnx(OnnxInputDataTypeMismatchErr( input.input_type, def.typ, ))); @@ -77,23 +76,23 @@ pub fn onnx_input_validation( } TensorElementDataType::Bool => { if def.typ != FieldType::Boolean { - return Err(OnnxError(OnnxInputDataTypeMismatchErr( + return Err(Onnx(OnnxInputDataTypeMismatchErr( input.input_type, def.typ, ))); } } - _ => return Err(OnnxError(OnnxNotSupportedDataTypeErr(input.input_type))), + _ => return Err(Onnx(OnnxNotSupportedDataTypeErr(input.input_type))), }, - None => return Err(OnnxError(ColumnNotFoundError(arg.clone()))), + None => return Err(Onnx(ColumnNotFound(arg.clone()))), }, - _ => return Err(OnnxError(NonColumnArgFoundError(arg.clone()))), + _ => return Err(Onnx(NonColumnArgFound(arg.clone()))), } } Ok(()) } -pub fn onnx_output_validation(outputs: &Vec) -> Result<(), PipelineError> { +pub fn onnx_output_validation(outputs: &Vec) -> Result<(), Error> { // 1. number of output & output shape check let mut flattened = 1_u32; for output_shape in outputs { @@ -109,7 +108,7 @@ pub fn onnx_output_validation(outputs: &Vec) -> Result<(), PipelineError } // output needs to be 1d single dim tensor if flattened as usize != 1_usize { - return Err(OnnxError(OnnxOutputShapeErr(flattened as usize, 1_usize))); + return Err(Onnx(OnnxOutputShapeErr(flattened as usize, 1_usize))); } // 2. output datatype check for output in outputs { @@ -126,7 +125,7 @@ pub fn onnx_output_validation(outputs: &Vec) -> Result<(), PipelineError | TensorElementDataType::Int64 | TensorElementDataType::String | TensorElementDataType::Bool => continue, - _ => return Err(OnnxError(OnnxNotSupportedDataTypeErr(output.output_type))), + _ => return Err(Onnx(OnnxNotSupportedDataTypeErr(output.output_type))), } } Ok(()) diff --git a/dozer-sql/src/pipeline/expression/operator.rs b/dozer-sql/expression/src/operator.rs similarity index 91% rename from dozer-sql/src/pipeline/expression/operator.rs rename to dozer-sql/expression/src/operator.rs index cc76d4cb02..55a422acd3 100644 --- a/dozer-sql/src/pipeline/expression/operator.rs +++ b/dozer-sql/expression/src/operator.rs @@ -1,8 +1,8 @@ -use crate::pipeline::errors::PipelineError; -use crate::pipeline::expression::comparison::*; -use crate::pipeline::expression::execution::Expression; -use crate::pipeline::expression::logical::*; -use crate::pipeline::expression::mathematical::*; +use crate::comparison::*; +use crate::error::Error; +use crate::execution::Expression; +use crate::logical::*; +use crate::mathematical::*; use dozer_types::types::Record; use dozer_types::types::{Field, Schema}; use std::fmt::{Display, Formatter}; @@ -30,7 +30,7 @@ impl UnaryOperatorType { schema: &Schema, value: &Expression, record: &Record, - ) -> Result { + ) -> Result { match self { UnaryOperatorType::Not => evaluate_not(schema, value, record), UnaryOperatorType::Plus => evaluate_plus(schema, value, record), @@ -88,7 +88,7 @@ impl BinaryOperatorType { left: &Expression, right: &Expression, record: &Record, - ) -> Result { + ) -> Result { match self { BinaryOperatorType::Eq => evaluate_eq(schema, left, right, record), BinaryOperatorType::Ne => evaluate_ne(schema, left, right, record), diff --git a/dozer-sql/src/pipeline/expression/python_udf.rs b/dozer-sql/expression/src/python_udf.rs similarity index 69% rename from dozer-sql/src/pipeline/expression/python_udf.rs rename to dozer-sql/expression/src/python_udf.rs index e93e40963a..63785beb3d 100644 --- a/dozer-sql/src/pipeline/expression/python_udf.rs +++ b/dozer-sql/expression/src/python_udf.rs @@ -1,10 +1,8 @@ -use crate::pipeline::errors::PipelineError; -use crate::pipeline::errors::PipelineError::UnsupportedSqlError; -use crate::pipeline::errors::UnsupportedSqlError::GenericError; -use crate::pipeline::expression::execution::Expression; +use crate::execution::Expression; use dozer_types::ordered_float::OrderedFloat; use dozer_types::pyo3::types::PyTuple; use dozer_types::pyo3::Python; +use dozer_types::thiserror::{self, Error}; use dozer_types::types::Record; use dozer_types::types::{Field, FieldType, Schema}; use std::env; @@ -12,27 +10,41 @@ use std::path::PathBuf; const MODULE_NAME: &str = "python_udf"; +#[derive(Debug, Error)] +pub enum Error { + #[error( + "Python UDF must have a return type. The syntax is: function_name(arguments)" + )] + MissingReturnType, + #[error("Missing 'VIRTUAL_ENV' environment var")] + MissingVirtualEnv, + #[error("PyO3 error: {0}")] + PyO3(#[from] dozer_types::pyo3::PyErr), + #[error("Unsupported return type: {0}")] + UnsupportedReturnType(FieldType), + #[error("Failed to parse return type: {0}")] + FailedToParseReturnType(String), +} + pub fn evaluate_py_udf( schema: &Schema, name: &str, args: &[Expression], return_type: &FieldType, record: &Record, -) -> Result { +) -> Result { let values = args .iter() .map(|arg| arg.evaluate(record, schema)) - .collect::, PipelineError>>()?; + .collect::, crate::error::Error>>()?; // Get the path of the Python interpreter in your virtual environment - let env_path = env::var("VIRTUAL_ENV").map_err(|_| { - PipelineError::InvalidFunction("Missing 'VIRTUAL_ENV' environment var".to_string()) - })?; + let env_path = env::var("VIRTUAL_ENV").map_err(|_| Error::MissingVirtualEnv)?; let py_path = format!("{env_path}/bin/python"); // Set the `PYTHON_SYS_EXECUTABLE` environment variable env::set_var("PYTHON_SYS_EXECUTABLE", py_path); - Python::with_gil(|py| -> Result { + Python::with_gil(|py| -> Result { // Get the directory containing the module let module_dir = PathBuf::from(env_path); // Import the `sys` module and append the module directory to the system path @@ -61,11 +73,8 @@ pub fn evaluate_py_udf( | FieldType::Timestamp | FieldType::Point | FieldType::Duration - | FieldType::Json => { - return Err(UnsupportedSqlError(GenericError( - "Unsupported return type for python udf".to_string(), - ))) - } + | FieldType::Json => return Err(Error::UnsupportedReturnType(*return_type)), }) }) + .map_err(Into::into) } diff --git a/dozer-sql/expression/src/scalar/common.rs b/dozer-sql/expression/src/scalar/common.rs new file mode 100644 index 0000000000..2a256557c9 --- /dev/null +++ b/dozer-sql/expression/src/scalar/common.rs @@ -0,0 +1,122 @@ +use crate::arg_utils::{validate_num_arguments, validate_one_argument, validate_two_arguments}; +use crate::error::Error; +use crate::execution::{Expression, ExpressionType}; +use crate::scalar::number::{evaluate_abs, evaluate_round}; +use crate::scalar::string::{ + evaluate_concat, evaluate_length, evaluate_to_char, evaluate_ucase, validate_concat, + validate_ucase, +}; +use dozer_types::types::Record; +use dozer_types::types::{Field, FieldType, Schema}; +use std::fmt::{Display, Formatter}; + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] +pub enum ScalarFunctionType { + Abs, + Round, + Ucase, + Concat, + Length, + ToChar, +} + +impl Display for ScalarFunctionType { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + ScalarFunctionType::Abs => f.write_str("ABS"), + ScalarFunctionType::Round => f.write_str("ROUND"), + ScalarFunctionType::Ucase => f.write_str("UCASE"), + ScalarFunctionType::Concat => f.write_str("CONCAT"), + ScalarFunctionType::Length => f.write_str("LENGTH"), + ScalarFunctionType::ToChar => f.write_str("TO_CHAR"), + } + } +} + +pub(crate) fn get_scalar_function_type( + function: &ScalarFunctionType, + args: &[Expression], + schema: &Schema, +) -> Result { + match function { + ScalarFunctionType::Abs => validate_one_argument(args, schema, ScalarFunctionType::Abs), + ScalarFunctionType::Round => { + let return_type = if args.len() == 1 { + validate_one_argument(args, schema, ScalarFunctionType::Round)?.return_type + } else { + validate_two_arguments(args, schema, ScalarFunctionType::Round)? + .0 + .return_type + }; + Ok(ExpressionType::new( + return_type, + true, + dozer_types::types::SourceDefinition::Dynamic, + false, + )) + } + ScalarFunctionType::Ucase => { + validate_num_arguments(1..2, args.len(), ScalarFunctionType::Ucase)?; + validate_ucase(&args[0], schema) + } + ScalarFunctionType::Concat => validate_concat(args, schema), + ScalarFunctionType::Length => Ok(ExpressionType::new( + FieldType::UInt, + false, + dozer_types::types::SourceDefinition::Dynamic, + false, + )), + ScalarFunctionType::ToChar => { + if args.len() == 1 { + validate_one_argument(args, schema, ScalarFunctionType::ToChar) + } else { + Ok(validate_two_arguments(args, schema, ScalarFunctionType::ToChar)?.0) + } + } + } +} + +impl ScalarFunctionType { + pub fn new(name: &str) -> Option { + match name { + "abs" => Some(ScalarFunctionType::Abs), + "round" => Some(ScalarFunctionType::Round), + "ucase" => Some(ScalarFunctionType::Ucase), + "concat" => Some(ScalarFunctionType::Concat), + "length" => Some(ScalarFunctionType::Length), + "to_char" => Some(ScalarFunctionType::ToChar), + _ => None, + } + } + + pub(crate) fn evaluate( + &self, + schema: &Schema, + args: &[Expression], + record: &Record, + ) -> Result { + match self { + ScalarFunctionType::Abs => { + validate_num_arguments(1..2, args.len(), ScalarFunctionType::Abs)?; + evaluate_abs(schema, &args[0], record) + } + ScalarFunctionType::Round => { + validate_num_arguments(1..3, args.len(), ScalarFunctionType::Round)?; + evaluate_round(schema, &args[0], args.get(1), record) + } + ScalarFunctionType::Ucase => { + validate_num_arguments(1..2, args.len(), ScalarFunctionType::Ucase)?; + evaluate_ucase(schema, &args[0], record) + } + ScalarFunctionType::Concat => evaluate_concat(schema, args, record), + ScalarFunctionType::Length => { + validate_num_arguments(1..2, args.len(), ScalarFunctionType::Length)?; + evaluate_length(schema, &args[0], record) + } + ScalarFunctionType::ToChar => { + validate_num_arguments(2..3, args.len(), ScalarFunctionType::ToChar)?; + evaluate_to_char(schema, &args[0], &args[1], record) + } + } + } +} diff --git a/dozer-sql/src/pipeline/expression/scalar/mod.rs b/dozer-sql/expression/src/scalar/mod.rs similarity index 100% rename from dozer-sql/src/pipeline/expression/scalar/mod.rs rename to dozer-sql/expression/src/scalar/mod.rs diff --git a/dozer-sql/expression/src/scalar/number.rs b/dozer-sql/expression/src/scalar/number.rs new file mode 100644 index 0000000000..2e4c204211 --- /dev/null +++ b/dozer-sql/expression/src/scalar/number.rs @@ -0,0 +1,192 @@ +use crate::error::Error; +use crate::execution::Expression; +use crate::scalar::common::ScalarFunctionType; +use dozer_types::ordered_float::OrderedFloat; +use dozer_types::types::Record; +use dozer_types::types::{Field, FieldType, Schema}; +use num_traits::{Float, ToPrimitive}; + +pub(crate) fn evaluate_abs( + schema: &Schema, + arg: &Expression, + record: &Record, +) -> Result { + let value = arg.evaluate(record, schema)?; + match value { + Field::UInt(u) => Ok(Field::UInt(u)), + Field::U128(u) => Ok(Field::U128(u)), + Field::Int(i) => Ok(Field::Int(i.abs())), + Field::I128(i) => Ok(Field::I128(i.abs())), + Field::Float(f) => Ok(Field::Float(f.abs())), + Field::Decimal(d) => Ok(Field::Decimal(d.abs())), + Field::Boolean(_) + | Field::String(_) + | Field::Text(_) + | Field::Date(_) + | Field::Timestamp(_) + | Field::Binary(_) + | Field::Json(_) + | Field::Point(_) + | Field::Duration(_) + | Field::Null => Err(Error::InvalidFunctionArgument { + function_name: ScalarFunctionType::Abs.to_string(), + argument_index: 0, + argument: value, + }), + } +} + +pub(crate) fn evaluate_round( + schema: &Schema, + arg: &Expression, + decimals: Option<&Expression>, + record: &Record, +) -> Result { + let value = arg.evaluate(record, schema)?; + let mut places = 0; + if let Some(expression) = decimals { + let field = expression.evaluate(record, schema)?; + match field { + Field::UInt(u) => places = u as i32, + Field::U128(u) => places = u as i32, + Field::Int(i) => places = i as i32, + Field::I128(i) => places = i as i32, + Field::Float(f) => places = f.round().0 as i32, + Field::Decimal(d) => { + places = d + .to_i32() + .ok_or(Error::InvalidCast { + from: field, + to: FieldType::Decimal, + }) + .unwrap() + } + Field::Boolean(_) + | Field::String(_) + | Field::Text(_) + | Field::Date(_) + | Field::Timestamp(_) + | Field::Binary(_) + | Field::Json(_) + | Field::Point(_) + | Field::Duration(_) + | Field::Null => {} // Truncate value to 0 decimals + } + } + let order = OrderedFloat(10.0_f64.powi(places)); + + match value { + Field::UInt(u) => Ok(Field::UInt(u)), + Field::U128(u) => Ok(Field::U128(u)), + Field::Int(i) => Ok(Field::Int(i)), + Field::I128(i) => Ok(Field::I128(i)), + Field::Float(f) => Ok(Field::Float((f * order).round() / order)), + Field::Decimal(d) => Ok(Field::Decimal(d.round_dp(places as u32))), + Field::Null => Ok(Field::Null), + Field::Boolean(_) + | Field::String(_) + | Field::Text(_) + | Field::Date(_) + | Field::Timestamp(_) + | Field::Binary(_) + | Field::Json(_) + | Field::Point(_) + | Field::Duration(_) => Err(Error::InvalidFunctionArgument { + function_name: ScalarFunctionType::Round.to_string(), + argument_index: 0, + argument: value, + }), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use dozer_types::ordered_float::OrderedFloat; + use dozer_types::types::Record; + use dozer_types::types::{Field, Schema}; + use proptest::prelude::*; + use std::ops::Neg; + use Expression::Literal; + + #[test] + fn test_abs() { + proptest!(ProptestConfig::with_cases(1000), |(i_num in 0i64..100000000i64, f_num in 0f64..100000000f64)| { + let row = Record::new(vec![]); + + let v = Box::new(Literal(Field::Int(i_num.neg()))); + assert_eq!( + evaluate_abs(&Schema::default(), &v, &row) + .unwrap_or_else(|e| panic!("{}", e.to_string())), + Field::Int(i_num) + ); + + let row = Record::new(vec![]); + + let v = Box::new(Literal(Field::Float(OrderedFloat(f_num.neg())))); + assert_eq!( + evaluate_abs(&Schema::default(), &v, &row) + .unwrap_or_else(|e| panic!("{}", e.to_string())), + Field::Float(OrderedFloat(f_num)) + ); + }); + } + + #[test] + fn test_round() { + proptest!(ProptestConfig::with_cases(1000), |(i_num: i64, f_num: f64, i_pow: i32, f_pow: f32)| { + let row = Record::new(vec![]); + + let v = Box::new(Literal(Field::Int(i_num))); + let d = &Box::new(Literal(Field::Int(0))); + assert_eq!( + evaluate_round(&Schema::default(), &v, Some(d), &row) + .unwrap_or_else(|e| panic!("{}", e.to_string())), + Field::Int(i_num) + ); + + let v = Box::new(Literal(Field::Float(OrderedFloat(f_num)))); + let d = &Box::new(Literal(Field::Int(0))); + assert_eq!( + evaluate_round(&Schema::default(), &v, Some(d), &row) + .unwrap_or_else(|e| panic!("{}", e.to_string())), + Field::Float(OrderedFloat(f_num.round())) + ); + + let v = Box::new(Literal(Field::Float(OrderedFloat(f_num)))); + let d = &Box::new(Literal(Field::Int(i_pow as i64))); + let order = 10.0_f64.powi(i_pow); + assert_eq!( + evaluate_round(&Schema::default(), &v, Some(d), &row) + .unwrap_or_else(|e| panic!("{}", e.to_string())), + Field::Float(OrderedFloat((f_num * order).round() / order)) + ); + + let v = Box::new(Literal(Field::Float(OrderedFloat(f_num)))); + let d = &Box::new(Literal(Field::Float(OrderedFloat(f_pow as f64)))); + let order = 10.0_f64.powi(f_pow.round() as i32); + assert_eq!( + evaluate_round(&Schema::default(), &v, Some(d), &row) + .unwrap_or_else(|e| panic!("{}", e.to_string())), + Field::Float(OrderedFloat((f_num * order).round() / order)) + ); + + let v = Box::new(Literal(Field::Float(OrderedFloat(f_num)))); + let d = &Box::new(Literal(Field::String(f_pow.to_string()))); + assert_eq!( + evaluate_round(&Schema::default(), &v, Some(d), &row) + .unwrap_or_else(|e| panic!("{}", e.to_string())), + Field::Float(OrderedFloat(f_num.round())) + ); + + let v = Box::new(Literal(Field::Null)); + let d = &Box::new(Literal(Field::String(i_pow.to_string()))); + assert_eq!( + evaluate_round(&Schema::default(), &v, Some(d), &row) + .unwrap_or_else(|e| panic!("{}", e.to_string())), + Field::Null + ); + }); + } +} diff --git a/dozer-sql/expression/src/scalar/string.rs b/dozer-sql/expression/src/scalar/string.rs new file mode 100644 index 0000000000..06cca69f51 --- /dev/null +++ b/dozer-sql/expression/src/scalar/string.rs @@ -0,0 +1,588 @@ +use crate::error::Error; +use std::fmt::Write; +use std::fmt::{Display, Formatter}; + +use crate::execution::{Expression, ExpressionType}; + +use crate::arg_utils::validate_arg_type; +use crate::scalar::common::ScalarFunctionType; + +use dozer_types::types::Record; +use dozer_types::types::{Field, FieldType, Schema}; +use like::{Escape, Like}; + +pub(crate) fn validate_ucase(arg: &Expression, schema: &Schema) -> Result { + validate_arg_type( + arg, + vec![FieldType::String, FieldType::Text], + schema, + ScalarFunctionType::Ucase, + 0, + ) +} + +pub fn evaluate_ucase(schema: &Schema, arg: &Expression, record: &Record) -> Result { + let f = arg.evaluate(record, schema)?; + let v = f.to_string(); + let ret = v.to_uppercase(); + + Ok(match arg.get_type(schema)?.return_type { + FieldType::String => Field::String(ret), + FieldType::UInt + | FieldType::U128 + | FieldType::Int + | FieldType::I128 + | FieldType::Float + | FieldType::Decimal + | FieldType::Boolean + | FieldType::Text + | FieldType::Date + | FieldType::Timestamp + | FieldType::Binary + | FieldType::Json + | FieldType::Point + | FieldType::Duration => Field::Text(ret), + }) +} + +pub fn validate_concat(args: &[Expression], schema: &Schema) -> Result { + let mut ret_type = FieldType::String; + for exp in args { + let r = validate_arg_type( + exp, + vec![FieldType::String, FieldType::Text], + schema, + ScalarFunctionType::Concat, + 0, + )?; + if matches!(r.return_type, FieldType::Text) { + ret_type = FieldType::Text; + } + } + Ok(ExpressionType::new( + ret_type, + false, + dozer_types::types::SourceDefinition::Dynamic, + false, + )) +} + +pub fn evaluate_concat( + schema: &Schema, + args: &[Expression], + record: &Record, +) -> Result { + let mut res_type = FieldType::String; + let mut res_vec: Vec = Vec::with_capacity(args.len()); + + for e in args { + if matches!(e.get_type(schema)?.return_type, FieldType::Text) { + res_type = FieldType::Text; + } + let f = e.evaluate(record, schema)?; + let val = f.to_string(); + res_vec.push(val); + } + + let res_str = res_vec.iter().fold(String::new(), |a, b| a + b.as_str()); + Ok(match res_type { + FieldType::Text => Field::Text(res_str), + FieldType::UInt + | FieldType::U128 + | FieldType::Int + | FieldType::I128 + | FieldType::Float + | FieldType::Decimal + | FieldType::Boolean + | FieldType::String + | FieldType::Date + | FieldType::Timestamp + | FieldType::Binary + | FieldType::Json + | FieldType::Point + | FieldType::Duration => Field::String(res_str), + }) +} + +pub(crate) fn evaluate_length( + schema: &Schema, + arg0: &Expression, + record: &Record, +) -> Result { + let f0 = arg0.evaluate(record, schema)?; + let v0 = f0.to_string(); + Ok(Field::UInt(v0.len() as u64)) +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum TrimType { + Trailing, + Leading, + Both, +} + +impl Display for TrimType { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + TrimType::Trailing => f.write_str("TRAILING "), + TrimType::Leading => f.write_str("LEADING "), + TrimType::Both => f.write_str("BOTH "), + } + } +} + +pub fn validate_trim(arg: &Expression, schema: &Schema) -> Result { + validate_arg_type( + arg, + vec![FieldType::String, FieldType::Text], + schema, + ScalarFunctionType::Concat, + 0, + ) +} + +pub fn evaluate_trim( + schema: &Schema, + arg: &Expression, + what: &Option>, + typ: &Option, + record: &Record, +) -> Result { + let arg_field = arg.evaluate(record, schema)?; + let arg_value = arg_field.to_string(); + + let v1: Vec<_> = match what { + Some(e) => { + let f = e.evaluate(record, schema)?; + f.to_string().chars().collect() + } + _ => vec![' '], + }; + + let retval = match typ { + Some(TrimType::Both) => arg_value.trim_matches::<&[char]>(&v1).to_string(), + Some(TrimType::Leading) => arg_value.trim_start_matches::<&[char]>(&v1).to_string(), + Some(TrimType::Trailing) => arg_value.trim_end_matches::<&[char]>(&v1).to_string(), + None => arg_value.trim_matches::<&[char]>(&v1).to_string(), + }; + + Ok(match arg.get_type(schema)?.return_type { + FieldType::String => Field::String(retval), + FieldType::UInt + | FieldType::U128 + | FieldType::Int + | FieldType::I128 + | FieldType::Float + | FieldType::Decimal + | FieldType::Boolean + | FieldType::Text + | FieldType::Date + | FieldType::Timestamp + | FieldType::Binary + | FieldType::Json + | FieldType::Point + | FieldType::Duration => Field::Text(retval), + }) +} + +pub(crate) fn get_like_operator_type( + arg: &Expression, + pattern: &Expression, + schema: &Schema, +) -> Result { + validate_arg_type( + pattern, + vec![FieldType::String, FieldType::Text], + schema, + ScalarFunctionType::Concat, + 0, + )?; + + validate_arg_type( + arg, + vec![FieldType::String, FieldType::Text], + schema, + ScalarFunctionType::Concat, + 0, + ) +} + +pub fn evaluate_like( + schema: &Schema, + arg: &Expression, + pattern: &Expression, + escape: Option, + record: &Record, +) -> Result { + let arg_field = arg.evaluate(record, schema)?; + let arg_value = arg_field.to_string(); + let arg_string = arg_value.as_str(); + + let pattern_field = pattern.evaluate(record, schema)?; + let pattern_value = pattern_field.to_string(); + let pattern_string = pattern_value.as_str(); + + if let Some(escape_char) = escape { + let arg_escape = &arg_string.escape(&escape_char.to_string())?; + let result = + Like::::like(arg_escape.as_str(), pattern_string).map(Field::Boolean)?; + return Ok(result); + } + + let result = Like::::like(arg_string, pattern_string).map(Field::Boolean)?; + Ok(result) +} + +pub(crate) fn evaluate_to_char( + schema: &Schema, + arg: &Expression, + pattern: &Expression, + record: &Record, +) -> Result { + let arg_field = arg.evaluate(record, schema)?; + + let pattern_field = pattern.evaluate(record, schema)?; + let pattern_value = pattern_field.to_string(); + + let output = match arg_field { + Field::Timestamp(value) => value.format(pattern_value.as_str()).to_string(), + Field::Date(value) => { + let mut formatted = String::new(); + let format_result = write!(formatted, "{}", value.format(pattern_value.as_str())); + if format_result.is_ok() { + formatted + } else { + pattern_value + } + } + Field::Null => return Ok(Field::Null), + _ => { + return Err(Error::InvalidFunctionArgument { + function_name: "TO_CHAR".to_string(), + argument_index: 0, + argument: arg_field, + }); + } + }; + + Ok(Field::String(output)) +} + +#[cfg(test)] +mod tests { + use super::*; + use Expression::Literal; + + use proptest::prelude::*; + + #[test] + fn test_string() { + proptest!( + ProptestConfig::with_cases(1000), + move |(s_val in ".+", s_val1 in ".*", s_val2 in ".*", c_val: char)| { + test_like(&s_val, c_val); + test_ucase(&s_val, c_val); + test_concat(&s_val1, &s_val2, c_val); + test_trim(&s_val, c_val); + }); + } + + fn test_like(s_val: &str, c_val: char) { + let row = Record::new(vec![]); + + // Field::String + let value = Box::new(Literal(Field::String(format!("Hello{}", s_val)))); + let pattern = Box::new(Literal(Field::String("Hello%".to_owned()))); + + assert_eq!( + evaluate_like(&Schema::default(), &value, &pattern, None, &row).unwrap(), + Field::Boolean(true) + ); + + let value = Box::new(Literal(Field::String(format!("Hello, {}orld!", c_val)))); + let pattern = Box::new(Literal(Field::String("Hello, _orld!".to_owned()))); + + assert_eq!( + evaluate_like(&Schema::default(), &value, &pattern, None, &row).unwrap(), + Field::Boolean(true) + ); + + let value = Box::new(Literal(Field::String(s_val.to_string()))); + let pattern = Box::new(Literal(Field::String("Hello%".to_owned()))); + + assert_eq!( + evaluate_like(&Schema::default(), &value, &pattern, None, &row).unwrap(), + Field::Boolean(false) + ); + + let c_value = &s_val[0..0]; + let value = Box::new(Literal(Field::String(format!("Hello, {}!", c_value)))); + let pattern = Box::new(Literal(Field::String("Hello, _!".to_owned()))); + + assert_eq!( + evaluate_like(&Schema::default(), &value, &pattern, None, &row).unwrap(), + Field::Boolean(false) + ); + + // todo: should find the way to generate escape character using proptest + // let value = Box::new(Literal(Field::String(format!("Hello, {}%", c_val)))); + // let pattern = Box::new(Literal(Field::String("Hello, %".to_owned()))); + // let escape = Some(c_val); + // + // assert_eq!( + // evaluate_like(&Schema::default(), &value, &pattern, escape, &row).unwrap(), + // Field::Boolean(true) + // ); + + // Field::Text + let value = Box::new(Literal(Field::Text(format!("Hello{}", s_val)))); + let pattern = Box::new(Literal(Field::Text("Hello%".to_owned()))); + + assert_eq!( + evaluate_like(&Schema::default(), &value, &pattern, None, &row).unwrap(), + Field::Boolean(true) + ); + + let value = Box::new(Literal(Field::Text(format!("Hello, {}orld!", c_val)))); + let pattern = Box::new(Literal(Field::Text("Hello, _orld!".to_owned()))); + + assert_eq!( + evaluate_like(&Schema::default(), &value, &pattern, None, &row).unwrap(), + Field::Boolean(true) + ); + + let value = Box::new(Literal(Field::Text(s_val.to_string()))); + let pattern = Box::new(Literal(Field::Text("Hello%".to_owned()))); + + assert_eq!( + evaluate_like(&Schema::default(), &value, &pattern, None, &row).unwrap(), + Field::Boolean(false) + ); + + let c_value = &s_val[0..0]; + let value = Box::new(Literal(Field::Text(format!("Hello, {}!", c_value)))); + let pattern = Box::new(Literal(Field::Text("Hello, _!".to_owned()))); + + assert_eq!( + evaluate_like(&Schema::default(), &value, &pattern, None, &row).unwrap(), + Field::Boolean(false) + ); + + // todo: should find the way to generate escape character using proptest + // let value = Box::new(Literal(Field::Text(format!("Hello, {}%", c_val)))); + // let pattern = Box::new(Literal(Field::Text("Hello, %".to_owned()))); + // let escape = Some(c_val); + // + // assert_eq!( + // evaluate_like(&Schema::default(), &value, &pattern, escape, &row).unwrap(), + // Field::Boolean(true) + // ); + } + + fn test_ucase(s_val: &str, c_val: char) { + let row = Record::new(vec![]); + + // Field::String + let value = Box::new(Literal(Field::String(s_val.to_string()))); + assert_eq!( + evaluate_ucase(&Schema::default(), &value, &row).unwrap(), + Field::String(s_val.to_uppercase()) + ); + + let value = Box::new(Literal(Field::String(c_val.to_string()))); + assert_eq!( + evaluate_ucase(&Schema::default(), &value, &row).unwrap(), + Field::String(c_val.to_uppercase().to_string()) + ); + + // Field::Text + let value = Box::new(Literal(Field::Text(s_val.to_string()))); + assert_eq!( + evaluate_ucase(&Schema::default(), &value, &row).unwrap(), + Field::Text(s_val.to_uppercase()) + ); + + let value = Box::new(Literal(Field::Text(c_val.to_string()))); + assert_eq!( + evaluate_ucase(&Schema::default(), &value, &row).unwrap(), + Field::Text(c_val.to_uppercase().to_string()) + ); + } + + fn test_concat(s_val1: &str, s_val2: &str, c_val: char) { + let row = Record::new(vec![]); + + // Field::String + let val1 = Literal(Field::String(s_val1.to_string())); + let val2 = Literal(Field::String(s_val2.to_string())); + + if validate_concat(&[val1.clone(), val2.clone()], &Schema::default()).is_ok() { + assert_eq!( + evaluate_concat(&Schema::default(), &[val1, val2], &row).unwrap(), + Field::String(s_val1.to_string() + s_val2) + ); + } + + let val1 = Literal(Field::String(s_val2.to_string())); + let val2 = Literal(Field::String(s_val1.to_string())); + + if validate_concat(&[val1.clone(), val2.clone()], &Schema::default()).is_ok() { + assert_eq!( + evaluate_concat(&Schema::default(), &[val1, val2], &row).unwrap(), + Field::String(s_val2.to_string() + s_val1) + ); + } + + let val1 = Literal(Field::String(s_val1.to_string())); + let val2 = Literal(Field::String(c_val.to_string())); + + if validate_concat(&[val1.clone(), val2.clone()], &Schema::default()).is_ok() { + assert_eq!( + evaluate_concat(&Schema::default(), &[val1, val2], &row).unwrap(), + Field::String(s_val1.to_string() + c_val.to_string().as_str()) + ); + } + + let val1 = Literal(Field::String(c_val.to_string())); + let val2 = Literal(Field::String(s_val1.to_string())); + + if validate_concat(&[val1.clone(), val2.clone()], &Schema::default()).is_ok() { + assert_eq!( + evaluate_concat(&Schema::default(), &[val1, val2], &row).unwrap(), + Field::String(c_val.to_string() + s_val1) + ); + } + + // Field::Text + let val1 = Literal(Field::Text(s_val1.to_string())); + let val2 = Literal(Field::Text(s_val2.to_string())); + + if validate_concat(&[val1.clone(), val2.clone()], &Schema::default()).is_ok() { + assert_eq!( + evaluate_concat(&Schema::default(), &[val1, val2], &row).unwrap(), + Field::Text(s_val1.to_string() + s_val2) + ); + } + + let val1 = Literal(Field::Text(s_val2.to_string())); + let val2 = Literal(Field::Text(s_val1.to_string())); + + if validate_concat(&[val1.clone(), val2.clone()], &Schema::default()).is_ok() { + assert_eq!( + evaluate_concat(&Schema::default(), &[val1, val2], &row).unwrap(), + Field::Text(s_val2.to_string() + s_val1) + ); + } + + let val1 = Literal(Field::Text(s_val1.to_string())); + let val2 = Literal(Field::Text(c_val.to_string())); + + if validate_concat(&[val1.clone(), val2.clone()], &Schema::default()).is_ok() { + assert_eq!( + evaluate_concat(&Schema::default(), &[val1, val2], &row).unwrap(), + Field::Text(s_val1.to_string() + c_val.to_string().as_str()) + ); + } + + let val1 = Literal(Field::Text(c_val.to_string())); + let val2 = Literal(Field::Text(s_val1.to_string())); + + if validate_concat(&[val1.clone(), val2.clone()], &Schema::default()).is_ok() { + assert_eq!( + evaluate_concat(&Schema::default(), &[val1, val2], &row).unwrap(), + Field::Text(c_val.to_string() + s_val1) + ); + } + } + + fn test_trim(s_val1: &str, c_val: char) { + let row = Record::new(vec![]); + + // Field::String + let value = Literal(Field::String(s_val1.to_string())); + let what = ' '; + + if validate_trim(&value, &Schema::default()).is_ok() { + assert_eq!( + evaluate_trim(&Schema::default(), &value, &None, &None, &row).unwrap(), + Field::String(s_val1.trim_matches(what).to_string()) + ); + assert_eq!( + evaluate_trim( + &Schema::default(), + &value, + &None, + &Some(TrimType::Trailing), + &row + ) + .unwrap(), + Field::String(s_val1.trim_end_matches(what).to_string()) + ); + assert_eq!( + evaluate_trim( + &Schema::default(), + &value, + &None, + &Some(TrimType::Leading), + &row + ) + .unwrap(), + Field::String(s_val1.trim_start_matches(what).to_string()) + ); + assert_eq!( + evaluate_trim( + &Schema::default(), + &value, + &None, + &Some(TrimType::Both), + &row + ) + .unwrap(), + Field::String(s_val1.trim_matches(what).to_string()) + ); + } + + let value = Literal(Field::String(s_val1.to_string())); + let what = Some(Box::new(Literal(Field::String(c_val.to_string())))); + + if validate_trim(&value, &Schema::default()).is_ok() { + assert_eq!( + evaluate_trim(&Schema::default(), &value, &what, &None, &row).unwrap(), + Field::String(s_val1.trim_matches(c_val).to_string()) + ); + assert_eq!( + evaluate_trim( + &Schema::default(), + &value, + &what, + &Some(TrimType::Trailing), + &row + ) + .unwrap(), + Field::String(s_val1.trim_end_matches(c_val).to_string()) + ); + assert_eq!( + evaluate_trim( + &Schema::default(), + &value, + &what, + &Some(TrimType::Leading), + &row + ) + .unwrap(), + Field::String(s_val1.trim_start_matches(c_val).to_string()) + ); + assert_eq!( + evaluate_trim( + &Schema::default(), + &value, + &what, + &Some(TrimType::Both), + &row + ) + .unwrap(), + Field::String(s_val1.trim_matches(c_val).to_string()) + ); + } + } +} diff --git a/dozer-sql/src/pipeline/aggregation/aggregator.rs b/dozer-sql/src/aggregation/aggregator.rs similarity index 93% rename from dozer-sql/src/pipeline/aggregation/aggregator.rs rename to dozer-sql/src/aggregation/aggregator.rs index d4928195d9..14fef33f84 100644 --- a/dozer-sql/src/pipeline/aggregation/aggregator.rs +++ b/dozer-sql/src/aggregation/aggregator.rs @@ -1,23 +1,23 @@ #![allow(clippy::enum_variant_names)] -use crate::pipeline::aggregation::avg::AvgAggregator; -use crate::pipeline::aggregation::count::CountAggregator; -use crate::pipeline::aggregation::max::MaxAggregator; -use crate::pipeline::aggregation::min::MinAggregator; -use crate::pipeline::aggregation::sum::SumAggregator; -use crate::pipeline::errors::PipelineError; +use crate::aggregation::avg::AvgAggregator; +use crate::aggregation::count::CountAggregator; +use crate::aggregation::max::MaxAggregator; +use crate::aggregation::min::MinAggregator; +use crate::aggregation::sum::SumAggregator; +use crate::errors::PipelineError; use dozer_types::serde::de::DeserializeOwned; use dozer_types::serde::{Deserialize, Serialize}; use enum_dispatch::enum_dispatch; use std::collections::BTreeMap; -use crate::pipeline::expression::aggregate::AggregateFunctionType; -use crate::pipeline::expression::execution::Expression; +use dozer_sql_expression::aggregate::AggregateFunctionType; +use dozer_sql_expression::execution::Expression; -use crate::pipeline::aggregation::max_value::MaxValueAggregator; -use crate::pipeline::aggregation::min_value::MinValueAggregator; -use crate::pipeline::errors::PipelineError::{InvalidFunctionArgument, InvalidValue}; -use crate::pipeline::expression::aggregate::AggregateFunctionType::MaxValue; +use crate::aggregation::max_value::MaxValueAggregator; +use crate::aggregation::min_value::MinValueAggregator; +use crate::errors::PipelineError::{InvalidFunctionArgument, InvalidValue}; +use dozer_sql_expression::aggregate::AggregateFunctionType::MaxValue; use dozer_types::types::{Field, FieldType, Schema}; use std::fmt::{Debug, Display, Formatter}; diff --git a/dozer-sql/src/pipeline/aggregation/avg.rs b/dozer-sql/src/aggregation/avg.rs similarity index 61% rename from dozer-sql/src/pipeline/aggregation/avg.rs rename to dozer-sql/src/aggregation/avg.rs index 12d6ba95bf..eaf5dc0627 100644 --- a/dozer-sql/src/pipeline/aggregation/avg.rs +++ b/dozer-sql/src/aggregation/avg.rs @@ -1,64 +1,17 @@ -use crate::argv; -use crate::pipeline::aggregation::aggregator::Aggregator; -use crate::pipeline::aggregation::sum::{get_sum, SumState}; -use crate::pipeline::errors::PipelineError::InvalidValue; -use crate::pipeline::errors::{FieldTypes, PipelineError}; -use crate::pipeline::expression::aggregate::AggregateFunctionType; -use crate::pipeline::expression::aggregate::AggregateFunctionType::Avg; -use crate::pipeline::expression::execution::{Expression, ExpressionType}; +use crate::aggregation::aggregator::Aggregator; +use crate::aggregation::sum::{get_sum, SumState}; +use crate::errors::PipelineError; +use crate::errors::PipelineError::InvalidValue; +use dozer_sql_expression::aggregate::AggregateFunctionType::Avg; +use dozer_sql_expression::num_traits::FromPrimitive; use dozer_types::arrow::datatypes::ArrowNativeTypeOp; use dozer_types::ordered_float::OrderedFloat; use dozer_types::rust_decimal::Decimal; use dozer_types::serde::{Deserialize, Serialize}; -use dozer_types::types::{DozerDuration, Field, FieldType, Schema, SourceDefinition, TimeUnit}; -use num_traits::FromPrimitive; +use dozer_types::types::{DozerDuration, Field, FieldType, TimeUnit}; use std::ops::Div; -pub fn validate_avg(args: &[Expression], schema: &Schema) -> Result { - let arg = &argv!(args, 0, AggregateFunctionType::Avg)?.get_type(schema)?; - - let ret_type = match arg.return_type { - FieldType::UInt => FieldType::Decimal, - FieldType::U128 => FieldType::Decimal, - FieldType::Int => FieldType::Decimal, - FieldType::I128 => FieldType::Decimal, - FieldType::Float => FieldType::Float, - FieldType::Decimal => FieldType::Decimal, - FieldType::Duration => FieldType::Duration, - FieldType::Boolean - | FieldType::String - | FieldType::Text - | FieldType::Date - | FieldType::Timestamp - | FieldType::Binary - | FieldType::Json - | FieldType::Point => { - return Err(PipelineError::InvalidFunctionArgumentType( - Avg.to_string(), - arg.return_type, - FieldTypes::new(vec![ - FieldType::UInt, - FieldType::U128, - FieldType::Int, - FieldType::I128, - FieldType::Float, - FieldType::Decimal, - FieldType::Duration, - ]), - 0, - )); - } - }; - - Ok(ExpressionType::new( - ret_type, - true, - SourceDefinition::Dynamic, - false, - )) -} - #[derive(Debug, Serialize, Deserialize)] #[serde(crate = "dozer_types::serde")] pub struct AvgAggregator { @@ -133,50 +86,35 @@ fn get_average( if *current_count == 0 { return Ok(Field::Null); } - let u_sum = sum - .to_uint() - .ok_or(InvalidValue(sum.to_string().unwrap())) - .unwrap(); + let u_sum = sum.to_uint().ok_or(InvalidValue(sum.to_string())).unwrap(); Ok(Field::UInt(u_sum.div_wrapping(*current_count))) } FieldType::U128 => { if *current_count == 0 { return Ok(Field::Null); } - let u_sum = sum - .to_u128() - .ok_or(InvalidValue(sum.to_string().unwrap())) - .unwrap(); + let u_sum = sum.to_u128().ok_or(InvalidValue(sum.to_string())).unwrap(); Ok(Field::U128(u_sum.wrapping_div(*current_count as u128))) } FieldType::Int => { if *current_count == 0 { return Ok(Field::Null); } - let i_sum = sum - .to_int() - .ok_or(InvalidValue(sum.to_string().unwrap())) - .unwrap(); + let i_sum = sum.to_int().ok_or(InvalidValue(sum.to_string())).unwrap(); Ok(Field::Int(i_sum.div_wrapping(*current_count as i64))) } FieldType::I128 => { if *current_count == 0 { return Ok(Field::Null); } - let i_sum = sum - .to_i128() - .ok_or(InvalidValue(sum.to_string().unwrap())) - .unwrap(); + let i_sum = sum.to_i128().ok_or(InvalidValue(sum.to_string())).unwrap(); Ok(Field::I128(i_sum.div_wrapping(*current_count as i128))) } FieldType::Float => { if *current_count == 0 { return Ok(Field::Null); } - let f_sum = sum - .to_float() - .ok_or(InvalidValue(sum.to_string().unwrap())) - .unwrap(); + let f_sum = sum.to_float().ok_or(InvalidValue(sum.to_string())).unwrap(); Ok(Field::Float(OrderedFloat( f_sum.div_wrapping(*current_count as f64), ))) @@ -187,7 +125,7 @@ fn get_average( } let d_sum = sum .to_decimal() - .ok_or(InvalidValue(sum.to_string().unwrap())) + .ok_or(InvalidValue(sum.to_string())) .unwrap(); Ok(Field::Decimal(d_sum.div(Decimal::from(*current_count)))) } @@ -195,9 +133,9 @@ fn get_average( if *current_count == 0 { return Ok(Field::Null); } - let str_dur = sum.to_duration()?.unwrap().to_string(); + let str_dur = format!("{:?}", sum.to_duration().unwrap().0); let d_sum = sum - .to_duration()? + .to_duration() .ok_or(InvalidValue(str_dur.clone())) .unwrap(); diff --git a/dozer-sql/src/pipeline/aggregation/count.rs b/dozer-sql/src/aggregation/count.rs similarity index 80% rename from dozer-sql/src/pipeline/aggregation/count.rs rename to dozer-sql/src/aggregation/count.rs index ab5b01de1e..022b1a3abf 100644 --- a/dozer-sql/src/pipeline/aggregation/count.rs +++ b/dozer-sql/src/aggregation/count.rs @@ -1,25 +1,12 @@ +use crate::aggregation::aggregator::Aggregator; use crate::calculate_err_type; -use crate::pipeline::aggregation::aggregator::Aggregator; -use crate::pipeline::errors::PipelineError; -use crate::pipeline::expression::aggregate::AggregateFunctionType::Count; -use crate::pipeline::expression::execution::{Expression, ExpressionType}; +use crate::errors::PipelineError; +use dozer_sql_expression::aggregate::AggregateFunctionType::Count; +use dozer_sql_expression::num_traits::FromPrimitive; use dozer_types::ordered_float::OrderedFloat; use dozer_types::rust_decimal::Decimal; use dozer_types::serde::{Deserialize, Serialize}; -use dozer_types::types::{Field, FieldType, Schema, SourceDefinition}; -use num_traits::FromPrimitive; - -pub fn validate_count( - _args: &[Expression], - _schema: &Schema, -) -> Result { - Ok(ExpressionType::new( - FieldType::Int, - false, - SourceDefinition::Dynamic, - false, - )) -} +use dozer_types::types::{Field, FieldType}; #[derive(Debug, Serialize, Deserialize)] #[serde(crate = "dozer_types::serde")] diff --git a/dozer-sql/src/pipeline/aggregation/factory.rs b/dozer-sql/src/aggregation/factory.rs similarity index 94% rename from dozer-sql/src/pipeline/aggregation/factory.rs rename to dozer-sql/src/aggregation/factory.rs index c489ecc03a..1f09bf804e 100644 --- a/dozer-sql/src/pipeline/aggregation/factory.rs +++ b/dozer-sql/src/aggregation/factory.rs @@ -1,16 +1,16 @@ -use crate::pipeline::planner::projection::CommonPlanner; -use crate::pipeline::projection::processor::ProjectionProcessor; -use crate::pipeline::{aggregation::processor::AggregationProcessor, errors::PipelineError}; +use crate::planner::projection::CommonPlanner; +use crate::projection::processor::ProjectionProcessor; +use crate::{aggregation::processor::AggregationProcessor, errors::PipelineError}; use dozer_core::processor_record::ProcessorRecordStore; use dozer_core::{ node::{OutputPortDef, OutputPortType, PortHandle, Processor, ProcessorFactory}, DEFAULT_PORT_HANDLE, }; +use dozer_sql_expression::sqlparser::ast::Select; use dozer_types::errors::internal::BoxedError; use dozer_types::models::udf_config::UdfConfig; use dozer_types::parking_lot::Mutex; use dozer_types::types::Schema; -use sqlparser::ast::Select; use std::collections::HashMap; #[derive(Debug)] diff --git a/dozer-sql/src/pipeline/aggregation/max.rs b/dozer-sql/src/aggregation/max.rs similarity index 59% rename from dozer-sql/src/pipeline/aggregation/max.rs rename to dozer-sql/src/aggregation/max.rs index 07cde3b9da..a3aade5dd4 100644 --- a/dozer-sql/src/pipeline/aggregation/max.rs +++ b/dozer-sql/src/aggregation/max.rs @@ -1,59 +1,12 @@ -use crate::pipeline::aggregation::aggregator::{update_map, Aggregator}; -use crate::pipeline::errors::{FieldTypes, PipelineError}; -use crate::pipeline::expression::aggregate::AggregateFunctionType; -use crate::pipeline::expression::aggregate::AggregateFunctionType::Max; -use crate::pipeline::expression::execution::{Expression, ExpressionType}; -use crate::{argv, calculate_err, calculate_err_field}; +use crate::aggregation::aggregator::{update_map, Aggregator}; +use crate::errors::PipelineError; +use crate::{calculate_err, calculate_err_field}; +use dozer_sql_expression::aggregate::AggregateFunctionType::Max; use dozer_types::ordered_float::OrderedFloat; use dozer_types::serde::{Deserialize, Serialize}; -use dozer_types::types::{Field, FieldType, Schema, SourceDefinition}; +use dozer_types::types::{Field, FieldType}; use std::collections::BTreeMap; -pub fn validate_max(args: &[Expression], schema: &Schema) -> Result { - let arg = &argv!(args, 0, AggregateFunctionType::Max)?.get_type(schema)?; - - let ret_type = match arg.return_type { - FieldType::UInt => FieldType::UInt, - FieldType::U128 => FieldType::U128, - FieldType::Int => FieldType::Int, - FieldType::I128 => FieldType::I128, - FieldType::Float => FieldType::Float, - FieldType::Decimal => FieldType::Decimal, - FieldType::Timestamp => FieldType::Timestamp, - FieldType::Date => FieldType::Date, - FieldType::Duration => FieldType::Duration, - FieldType::Boolean - | FieldType::String - | FieldType::Text - | FieldType::Binary - | FieldType::Json - | FieldType::Point => { - return Err(PipelineError::InvalidFunctionArgumentType( - Max.to_string(), - arg.return_type, - FieldTypes::new(vec![ - FieldType::Decimal, - FieldType::UInt, - FieldType::U128, - FieldType::Int, - FieldType::I128, - FieldType::Float, - FieldType::Timestamp, - FieldType::Date, - FieldType::Duration, - ]), - 0, - )); - } - }; - Ok(ExpressionType::new( - ret_type, - true, - SourceDefinition::Dynamic, - false, - )) -} - #[derive(Debug, Serialize, Deserialize)] #[serde(crate = "dozer_types::serde")] pub struct MaxAggregator { @@ -116,13 +69,13 @@ fn get_max( val ))), FieldType::Timestamp => Ok(Field::Timestamp(calculate_err_field!( - val.to_timestamp()?, + val.to_timestamp(), Max, val ))), - FieldType::Date => Ok(Field::Date(calculate_err_field!(val.to_date()?, Max, val))), + FieldType::Date => Ok(Field::Date(calculate_err_field!(val.to_date(), Max, val))), FieldType::Duration => Ok(Field::Duration(calculate_err_field!( - val.to_duration()?, + val.to_duration(), Max, val ))), diff --git a/dozer-sql/src/pipeline/aggregation/max_value.rs b/dozer-sql/src/aggregation/max_value.rs similarity index 51% rename from dozer-sql/src/pipeline/aggregation/max_value.rs rename to dozer-sql/src/aggregation/max_value.rs index d9350599e8..05b7ce75f8 100644 --- a/dozer-sql/src/pipeline/aggregation/max_value.rs +++ b/dozer-sql/src/aggregation/max_value.rs @@ -1,63 +1,12 @@ -use crate::pipeline::aggregation::aggregator::{update_val_map, Aggregator}; -use crate::pipeline::errors::PipelineError::InvalidReturnType; -use crate::pipeline::errors::{FieldTypes, PipelineError}; -use crate::pipeline::expression::aggregate::AggregateFunctionType::MaxValue; -use crate::pipeline::expression::execution::{Expression, ExpressionType}; -use crate::{argv, calculate_err}; +use crate::aggregation::aggregator::{update_val_map, Aggregator}; +use crate::calculate_err; +use crate::errors::PipelineError; +use crate::errors::PipelineError::InvalidReturnType; +use dozer_sql_expression::aggregate::AggregateFunctionType::MaxValue; use dozer_types::serde::{Deserialize, Serialize}; -use dozer_types::types::{Field, FieldType, Schema, SourceDefinition}; +use dozer_types::types::{Field, FieldType}; use std::collections::BTreeMap; -pub fn validate_max_value( - args: &[Expression], - schema: &Schema, -) -> Result { - let base_arg = &argv!(args, 0, MaxValue)?.get_type(schema)?; - let arg = &argv!(args, 1, MaxValue)?.get_type(schema)?; - - match base_arg.return_type { - FieldType::UInt => FieldType::UInt, - FieldType::U128 => FieldType::U128, - FieldType::Int => FieldType::Int, - FieldType::I128 => FieldType::I128, - FieldType::Float => FieldType::Float, - FieldType::Decimal => FieldType::Decimal, - FieldType::Timestamp => FieldType::Timestamp, - FieldType::Date => FieldType::Date, - FieldType::Duration => FieldType::Duration, - FieldType::Boolean - | FieldType::String - | FieldType::Text - | FieldType::Binary - | FieldType::Json - | FieldType::Point => { - return Err(PipelineError::InvalidFunctionArgumentType( - MaxValue.to_string(), - arg.return_type, - FieldTypes::new(vec![ - FieldType::Decimal, - FieldType::UInt, - FieldType::U128, - FieldType::Int, - FieldType::I128, - FieldType::Float, - FieldType::Timestamp, - FieldType::Date, - FieldType::Duration, - ]), - 0, - )); - } - }; - - Ok(ExpressionType::new( - arg.return_type, - true, - SourceDefinition::Dynamic, - false, - )) -} - #[derive(Debug, Serialize, Deserialize)] #[serde(crate = "dozer_types::serde")] pub struct MaxValueAggregator { diff --git a/dozer-sql/src/pipeline/aggregation/min.rs b/dozer-sql/src/aggregation/min.rs similarity index 59% rename from dozer-sql/src/pipeline/aggregation/min.rs rename to dozer-sql/src/aggregation/min.rs index 9039626c1f..294486614e 100644 --- a/dozer-sql/src/pipeline/aggregation/min.rs +++ b/dozer-sql/src/aggregation/min.rs @@ -1,59 +1,12 @@ -use crate::pipeline::aggregation::aggregator::{update_map, Aggregator}; -use crate::pipeline::errors::{FieldTypes, PipelineError}; -use crate::pipeline::expression::aggregate::AggregateFunctionType; -use crate::pipeline::expression::aggregate::AggregateFunctionType::Min; -use crate::pipeline::expression::execution::{Expression, ExpressionType}; -use crate::{argv, calculate_err, calculate_err_field}; +use crate::aggregation::aggregator::{update_map, Aggregator}; +use crate::errors::PipelineError; +use crate::{calculate_err, calculate_err_field}; +use dozer_sql_expression::aggregate::AggregateFunctionType::Min; use dozer_types::ordered_float::OrderedFloat; use dozer_types::serde::{Deserialize, Serialize}; -use dozer_types::types::{Field, FieldType, Schema, SourceDefinition}; +use dozer_types::types::{Field, FieldType}; use std::collections::BTreeMap; -pub fn validate_min(args: &[Expression], schema: &Schema) -> Result { - let arg = &argv!(args, 0, AggregateFunctionType::Min)?.get_type(schema)?; - - let ret_type = match arg.return_type { - FieldType::UInt => FieldType::UInt, - FieldType::U128 => FieldType::U128, - FieldType::Int => FieldType::Int, - FieldType::I128 => FieldType::I128, - FieldType::Float => FieldType::Float, - FieldType::Decimal => FieldType::Decimal, - FieldType::Timestamp => FieldType::Timestamp, - FieldType::Date => FieldType::Date, - FieldType::Duration => FieldType::Duration, - FieldType::Boolean - | FieldType::String - | FieldType::Text - | FieldType::Binary - | FieldType::Json - | FieldType::Point => { - return Err(PipelineError::InvalidFunctionArgumentType( - Min.to_string(), - arg.return_type, - FieldTypes::new(vec![ - FieldType::Decimal, - FieldType::UInt, - FieldType::U128, - FieldType::Int, - FieldType::I128, - FieldType::Float, - FieldType::Timestamp, - FieldType::Date, - FieldType::Duration, - ]), - 0, - )); - } - }; - Ok(ExpressionType::new( - ret_type, - true, - SourceDefinition::Dynamic, - false, - )) -} - #[derive(Debug, Serialize, Deserialize)] #[serde(crate = "dozer_types::serde")] pub struct MinAggregator { @@ -116,13 +69,13 @@ fn get_min( val ))), FieldType::Timestamp => Ok(Field::Timestamp(calculate_err_field!( - val.to_timestamp()?, + val.to_timestamp(), Min, val ))), - FieldType::Date => Ok(Field::Date(calculate_err_field!(val.to_date()?, Min, val))), + FieldType::Date => Ok(Field::Date(calculate_err_field!(val.to_date(), Min, val))), FieldType::Duration => Ok(Field::Duration(calculate_err_field!( - val.to_duration()?, + val.to_duration(), Min, val ))), diff --git a/dozer-sql/src/pipeline/aggregation/min_value.rs b/dozer-sql/src/aggregation/min_value.rs similarity index 51% rename from dozer-sql/src/pipeline/aggregation/min_value.rs rename to dozer-sql/src/aggregation/min_value.rs index 59e26f8d70..88bd850a1a 100644 --- a/dozer-sql/src/pipeline/aggregation/min_value.rs +++ b/dozer-sql/src/aggregation/min_value.rs @@ -1,63 +1,12 @@ -use crate::pipeline::aggregation::aggregator::{update_val_map, Aggregator}; -use crate::pipeline::errors::PipelineError::InvalidReturnType; -use crate::pipeline::errors::{FieldTypes, PipelineError}; -use crate::pipeline::expression::aggregate::AggregateFunctionType::MinValue; -use crate::pipeline::expression::execution::{Expression, ExpressionType}; -use crate::{argv, calculate_err}; +use crate::aggregation::aggregator::{update_val_map, Aggregator}; +use crate::calculate_err; +use crate::errors::PipelineError; +use crate::errors::PipelineError::InvalidReturnType; +use dozer_sql_expression::aggregate::AggregateFunctionType::MinValue; use dozer_types::serde::{Deserialize, Serialize}; -use dozer_types::types::{Field, FieldType, Schema, SourceDefinition}; +use dozer_types::types::{Field, FieldType}; use std::collections::BTreeMap; -pub fn validate_min_value( - args: &[Expression], - schema: &Schema, -) -> Result { - let base_arg = &argv!(args, 0, MinValue)?.get_type(schema)?; - let arg = &argv!(args, 1, MinValue)?.get_type(schema)?; - - match base_arg.return_type { - FieldType::UInt => FieldType::UInt, - FieldType::U128 => FieldType::U128, - FieldType::Int => FieldType::Int, - FieldType::I128 => FieldType::I128, - FieldType::Float => FieldType::Float, - FieldType::Decimal => FieldType::Decimal, - FieldType::Timestamp => FieldType::Timestamp, - FieldType::Date => FieldType::Date, - FieldType::Duration => FieldType::Duration, - FieldType::Boolean - | FieldType::String - | FieldType::Text - | FieldType::Binary - | FieldType::Json - | FieldType::Point => { - return Err(PipelineError::InvalidFunctionArgumentType( - MinValue.to_string(), - arg.return_type, - FieldTypes::new(vec![ - FieldType::Decimal, - FieldType::UInt, - FieldType::U128, - FieldType::Int, - FieldType::I128, - FieldType::Float, - FieldType::Timestamp, - FieldType::Date, - FieldType::Duration, - ]), - 0, - )); - } - }; - - Ok(ExpressionType::new( - arg.return_type, - true, - SourceDefinition::Dynamic, - false, - )) -} - #[derive(Debug, Serialize, Deserialize)] #[serde(crate = "dozer_types::serde")] pub struct MinValueAggregator { diff --git a/dozer-sql/src/pipeline/aggregation/mod.rs b/dozer-sql/src/aggregation/mod.rs similarity index 100% rename from dozer-sql/src/pipeline/aggregation/mod.rs rename to dozer-sql/src/aggregation/mod.rs diff --git a/dozer-sql/src/pipeline/aggregation/processor.rs b/dozer-sql/src/aggregation/processor.rs similarity index 98% rename from dozer-sql/src/pipeline/aggregation/processor.rs rename to dozer-sql/src/aggregation/processor.rs index f222151239..4bd8772e0b 100644 --- a/dozer-sql/src/pipeline/aggregation/processor.rs +++ b/dozer-sql/src/aggregation/processor.rs @@ -1,21 +1,22 @@ #![allow(clippy::too_many_arguments)] -use crate::pipeline::errors::PipelineError; -use crate::pipeline::utils::record_hashtable_key::{get_record_hash, RecordKey}; -use crate::pipeline::{aggregation::aggregator::Aggregator, expression::execution::Expression}; +use crate::aggregation::aggregator::Aggregator; +use crate::errors::PipelineError; +use crate::utils::record_hashtable_key::{get_record_hash, RecordKey}; use dozer_core::channels::ProcessorChannelForwarder; use dozer_core::dozer_log::storage::Object; use dozer_core::executor_operation::ProcessorOperation; use dozer_core::node::{PortHandle, Processor}; use dozer_core::processor_record::ProcessorRecordStore; use dozer_core::DEFAULT_PORT_HANDLE; +use dozer_sql_expression::execution::Expression; use dozer_types::bincode; use dozer_types::errors::internal::BoxedError; use dozer_types::serde::{Deserialize, Serialize}; use dozer_types::types::{Field, FieldType, Operation, Record, Schema}; use std::collections::HashMap; -use crate::pipeline::aggregation::aggregator::{ +use crate::aggregation::aggregator::{ get_aggregator_from_aggregator_type, get_aggregator_type_from_aggregation_expression, AggregatorEnum, AggregatorType, }; diff --git a/dozer-sql/src/pipeline/aggregation/sum.rs b/dozer-sql/src/aggregation/sum.rs similarity index 77% rename from dozer-sql/src/pipeline/aggregation/sum.rs rename to dozer-sql/src/aggregation/sum.rs index a7797d2893..641c528af4 100644 --- a/dozer-sql/src/pipeline/aggregation/sum.rs +++ b/dozer-sql/src/aggregation/sum.rs @@ -1,57 +1,12 @@ -use crate::pipeline::aggregation::aggregator::Aggregator; -use crate::pipeline::errors::{FieldTypes, PipelineError}; -use crate::pipeline::expression::aggregate::AggregateFunctionType; -use crate::pipeline::expression::aggregate::AggregateFunctionType::Sum; -use crate::pipeline::expression::execution::{Expression, ExpressionType}; -use crate::{argv, calculate_err_field}; +use crate::aggregation::aggregator::Aggregator; +use crate::calculate_err_field; +use crate::errors::PipelineError; +use dozer_sql_expression::aggregate::AggregateFunctionType::Sum; +use dozer_sql_expression::num_traits::FromPrimitive; use dozer_types::ordered_float::OrderedFloat; use dozer_types::rust_decimal::Decimal; use dozer_types::serde::{Deserialize, Serialize}; -use dozer_types::types::{DozerDuration, Field, FieldType, Schema, SourceDefinition, TimeUnit}; -use num_traits::FromPrimitive; - -pub fn validate_sum(args: &[Expression], schema: &Schema) -> Result { - let arg = &argv!(args, 0, AggregateFunctionType::Sum)?.get_type(schema)?; - - let ret_type = match arg.return_type { - FieldType::UInt => FieldType::UInt, - FieldType::U128 => FieldType::U128, - FieldType::Int => FieldType::Int, - FieldType::I128 => FieldType::I128, - FieldType::Float => FieldType::Float, - FieldType::Decimal => FieldType::Decimal, - FieldType::Duration => FieldType::Duration, - FieldType::Boolean - | FieldType::String - | FieldType::Text - | FieldType::Date - | FieldType::Timestamp - | FieldType::Binary - | FieldType::Json - | FieldType::Point => { - return Err(PipelineError::InvalidFunctionArgumentType( - Sum.to_string(), - arg.return_type, - FieldTypes::new(vec![ - FieldType::UInt, - FieldType::U128, - FieldType::Int, - FieldType::I128, - FieldType::Float, - FieldType::Decimal, - FieldType::Duration, - ]), - 0, - )); - } - }; - Ok(ExpressionType::new( - ret_type, - true, - SourceDefinition::Dynamic, - false, - )) -} +use dozer_types::types::{DozerDuration, Field, FieldType, TimeUnit}; #[derive(Debug, Serialize, Deserialize)] #[serde(crate = "dozer_types::serde")] @@ -203,12 +158,12 @@ pub fn get_sum( FieldType::Duration => { if decr { for field in fields { - let val = calculate_err_field!(field.to_duration()?, Sum, field); + let val = calculate_err_field!(field.to_duration(), Sum, field); current_state.duration_state -= val.0; } } else { for field in fields { - let val = calculate_err_field!(field.to_duration()?, Sum, field); + let val = calculate_err_field!(field.to_duration(), Sum, field); current_state.duration_state += val.0; } } diff --git a/dozer-sql/src/pipeline/aggregation/tests/aggregation_avg_tests.rs b/dozer-sql/src/aggregation/tests/aggregation_avg_tests.rs similarity index 99% rename from dozer-sql/src/pipeline/aggregation/tests/aggregation_avg_tests.rs rename to dozer-sql/src/aggregation/tests/aggregation_avg_tests.rs index a000bffc45..2bfad81230 100644 --- a/dozer-sql/src/pipeline/aggregation/tests/aggregation_avg_tests.rs +++ b/dozer-sql/src/aggregation/tests/aggregation_avg_tests.rs @@ -1,11 +1,11 @@ -use crate::output; -use crate::pipeline::aggregation::tests::aggregation_tests_utils::{ +use crate::aggregation::tests::aggregation_tests_utils::{ delete_exp, delete_field, get_decimal_div_field, get_decimal_field, get_duration_div_field, get_duration_field, init_input_schema, init_processor, insert_exp, insert_field, update_exp, update_field, FIELD_0_FLOAT, FIELD_100_FLOAT, FIELD_100_INT, FIELD_100_UINT, FIELD_200_FLOAT, FIELD_200_INT, FIELD_200_UINT, FIELD_250_DIV_3_FLOAT, FIELD_350_DIV_3_FLOAT, FIELD_50_FLOAT, FIELD_50_INT, FIELD_50_UINT, FIELD_75_FLOAT, FIELD_NULL, ITALY, SINGAPORE, }; +use crate::output; use dozer_core::DEFAULT_PORT_HANDLE; use dozer_types::types::FieldType::{Decimal, Duration, Float, Int, UInt}; use std::collections::HashMap; diff --git a/dozer-sql/src/pipeline/aggregation/tests/aggregation_count_tests.rs b/dozer-sql/src/aggregation/tests/aggregation_count_tests.rs similarity index 99% rename from dozer-sql/src/pipeline/aggregation/tests/aggregation_count_tests.rs rename to dozer-sql/src/aggregation/tests/aggregation_count_tests.rs index ff18f0c334..41a57f50e9 100644 --- a/dozer-sql/src/pipeline/aggregation/tests/aggregation_count_tests.rs +++ b/dozer-sql/src/aggregation/tests/aggregation_count_tests.rs @@ -1,10 +1,10 @@ -use crate::output; -use crate::pipeline::aggregation::tests::aggregation_tests_utils::{ +use crate::aggregation::tests::aggregation_tests_utils::{ delete_exp, delete_field, get_date_field, get_decimal_field, get_duration_field, get_ts_field, init_input_schema, init_processor, insert_exp, insert_field, update_exp, update_field, DATE8, FIELD_100_FLOAT, FIELD_100_INT, FIELD_1_INT, FIELD_200_FLOAT, FIELD_200_INT, FIELD_2_INT, FIELD_3_INT, FIELD_50_FLOAT, FIELD_50_INT, FIELD_NULL, ITALY, SINGAPORE, }; +use crate::output; use dozer_core::DEFAULT_PORT_HANDLE; use dozer_types::types::FieldType::{Date, Decimal, Duration, Float, Int, Timestamp}; use dozer_types::types::{Operation, Record}; diff --git a/dozer-sql/src/pipeline/aggregation/tests/aggregation_having_tests.rs b/dozer-sql/src/aggregation/tests/aggregation_having_tests.rs similarity index 99% rename from dozer-sql/src/pipeline/aggregation/tests/aggregation_having_tests.rs rename to dozer-sql/src/aggregation/tests/aggregation_having_tests.rs index 868790fe5f..20d5ea801c 100644 --- a/dozer-sql/src/pipeline/aggregation/tests/aggregation_having_tests.rs +++ b/dozer-sql/src/aggregation/tests/aggregation_having_tests.rs @@ -1,9 +1,9 @@ -use crate::output; -use crate::pipeline::aggregation::tests::aggregation_tests_utils::{ +use crate::aggregation::tests::aggregation_tests_utils::{ delete_exp, delete_field, init_input_schema, init_processor, insert_exp, insert_field, update_exp, update_field, FIELD_100_INT, FIELD_150_INT, FIELD_200_INT, FIELD_300_INT, FIELD_400_INT, FIELD_500_INT, FIELD_50_INT, FIELD_600_INT, ITALY, SINGAPORE, }; +use crate::output; use dozer_core::DEFAULT_PORT_HANDLE; use dozer_types::types::FieldType::Int; use std::collections::HashMap; diff --git a/dozer-sql/src/pipeline/aggregation/tests/aggregation_max_tests.rs b/dozer-sql/src/aggregation/tests/aggregation_max_tests.rs similarity index 99% rename from dozer-sql/src/pipeline/aggregation/tests/aggregation_max_tests.rs rename to dozer-sql/src/aggregation/tests/aggregation_max_tests.rs index 31831c3d77..daccb788e3 100644 --- a/dozer-sql/src/pipeline/aggregation/tests/aggregation_max_tests.rs +++ b/dozer-sql/src/aggregation/tests/aggregation_max_tests.rs @@ -1,10 +1,10 @@ -use crate::output; -use crate::pipeline::aggregation::tests::aggregation_tests_utils::{ +use crate::aggregation::tests::aggregation_tests_utils::{ delete_exp, delete_field, get_date_field, get_decimal_field, get_duration_field, get_ts_field, init_input_schema, init_processor, insert_exp, insert_field, update_exp, update_field, DATE16, DATE4, DATE8, FIELD_100_FLOAT, FIELD_100_INT, FIELD_100_UINT, FIELD_200_FLOAT, FIELD_200_INT, FIELD_200_UINT, FIELD_50_FLOAT, FIELD_50_INT, FIELD_50_UINT, FIELD_NULL, ITALY, SINGAPORE, }; +use crate::output; use dozer_core::DEFAULT_PORT_HANDLE; use dozer_types::types::FieldType::{Date, Decimal, Duration, Float, Int, Timestamp, UInt}; diff --git a/dozer-sql/src/pipeline/aggregation/tests/aggregation_max_value_tests.rs b/dozer-sql/src/aggregation/tests/aggregation_max_value_tests.rs similarity index 99% rename from dozer-sql/src/pipeline/aggregation/tests/aggregation_max_value_tests.rs rename to dozer-sql/src/aggregation/tests/aggregation_max_value_tests.rs index 3b596e7c9a..ea6f900865 100644 --- a/dozer-sql/src/pipeline/aggregation/tests/aggregation_max_value_tests.rs +++ b/dozer-sql/src/aggregation/tests/aggregation_max_value_tests.rs @@ -1,11 +1,11 @@ -use crate::output; -use crate::pipeline::aggregation::tests::aggregation_tests_utils::{ +use crate::aggregation::tests::aggregation_tests_utils::{ delete_field, delete_val_exp, get_date_field, get_decimal_field, get_duration_field, get_ts_field, init_input_schema, init_processor, init_val_input_schema, insert_field, insert_val_exp, update_field, update_val_exp, DATE16, DATE4, DATE8, FIELD_100_FLOAT, FIELD_100_INT, FIELD_100_UINT, FIELD_150_FLOAT, FIELD_150_INT, FIELD_150_UINT, FIELD_200_FLOAT, FIELD_200_INT, FIELD_200_UINT, FIELD_NULL, ITALY, SINGAPORE, }; +use crate::output; use dozer_core::DEFAULT_PORT_HANDLE; use dozer_types::types::Field; diff --git a/dozer-sql/src/pipeline/aggregation/tests/aggregation_min_tests.rs b/dozer-sql/src/aggregation/tests/aggregation_min_tests.rs similarity index 99% rename from dozer-sql/src/pipeline/aggregation/tests/aggregation_min_tests.rs rename to dozer-sql/src/aggregation/tests/aggregation_min_tests.rs index 8bd932db25..25121ae07f 100644 --- a/dozer-sql/src/pipeline/aggregation/tests/aggregation_min_tests.rs +++ b/dozer-sql/src/aggregation/tests/aggregation_min_tests.rs @@ -1,10 +1,10 @@ -use crate::output; -use crate::pipeline::aggregation::tests::aggregation_tests_utils::{ +use crate::aggregation::tests::aggregation_tests_utils::{ delete_exp, delete_field, get_date_field, get_decimal_field, get_duration_field, get_ts_field, init_input_schema, init_processor, insert_exp, insert_field, update_exp, update_field, DATE16, DATE4, DATE8, FIELD_100_FLOAT, FIELD_100_INT, FIELD_100_UINT, FIELD_200_FLOAT, FIELD_200_INT, FIELD_200_UINT, FIELD_50_FLOAT, FIELD_50_INT, FIELD_50_UINT, FIELD_NULL, ITALY, SINGAPORE, }; +use crate::output; use dozer_core::DEFAULT_PORT_HANDLE; use dozer_types::types::FieldType::{Date, Decimal, Duration, Float, Int, Timestamp, UInt}; diff --git a/dozer-sql/src/pipeline/aggregation/tests/aggregation_min_value_tests.rs b/dozer-sql/src/aggregation/tests/aggregation_min_value_tests.rs similarity index 99% rename from dozer-sql/src/pipeline/aggregation/tests/aggregation_min_value_tests.rs rename to dozer-sql/src/aggregation/tests/aggregation_min_value_tests.rs index 6c7e9213e1..0ec34b08c0 100644 --- a/dozer-sql/src/pipeline/aggregation/tests/aggregation_min_value_tests.rs +++ b/dozer-sql/src/aggregation/tests/aggregation_min_value_tests.rs @@ -1,11 +1,11 @@ -use crate::output; -use crate::pipeline::aggregation::tests::aggregation_tests_utils::{ +use crate::aggregation::tests::aggregation_tests_utils::{ delete_field, delete_val_exp, get_date_field, get_decimal_field, get_duration_field, get_ts_field, init_input_schema, init_processor, init_val_input_schema, insert_field, insert_val_exp, update_field, update_val_exp, DATE16, DATE4, DATE8, FIELD_100_FLOAT, FIELD_100_INT, FIELD_100_UINT, FIELD_50_FLOAT, FIELD_50_INT, FIELD_50_UINT, FIELD_75_FLOAT, FIELD_75_INT, FIELD_75_UINT, FIELD_NULL, ITALY, SINGAPORE, }; +use crate::output; use dozer_core::DEFAULT_PORT_HANDLE; use dozer_types::types::Field; diff --git a/dozer-sql/src/pipeline/aggregation/tests/aggregation_null.rs b/dozer-sql/src/aggregation/tests/aggregation_null.rs similarity index 97% rename from dozer-sql/src/pipeline/aggregation/tests/aggregation_null.rs rename to dozer-sql/src/aggregation/tests/aggregation_null.rs index 8924a4da57..4965cdc125 100644 --- a/dozer-sql/src/pipeline/aggregation/tests/aggregation_null.rs +++ b/dozer-sql/src/aggregation/tests/aggregation_null.rs @@ -1,6 +1,6 @@ use crate::output; -use crate::pipeline::aggregation::tests::aggregation_tests_utils::{ +use crate::aggregation::tests::aggregation_tests_utils::{ delete_exp, delete_field, init_input_schema, init_processor, insert_exp, insert_field, FIELD_100_INT, FIELD_1_INT, ITALY, }; diff --git a/dozer-sql/src/pipeline/aggregation/tests/aggregation_sum_tests.rs b/dozer-sql/src/aggregation/tests/aggregation_sum_tests.rs similarity index 99% rename from dozer-sql/src/pipeline/aggregation/tests/aggregation_sum_tests.rs rename to dozer-sql/src/aggregation/tests/aggregation_sum_tests.rs index 73344ef374..7bc07b01a4 100644 --- a/dozer-sql/src/pipeline/aggregation/tests/aggregation_sum_tests.rs +++ b/dozer-sql/src/aggregation/tests/aggregation_sum_tests.rs @@ -1,5 +1,4 @@ -use crate::output; -use crate::pipeline::aggregation::tests::aggregation_tests_utils::{ +use crate::aggregation::tests::aggregation_tests_utils::{ delete_exp, delete_field, get_decimal_field, get_duration_field, init_input_schema, init_processor, insert_exp, insert_field, update_exp, update_field, FIELD_0_FLOAT, FIELD_0_INT, FIELD_100_FLOAT, FIELD_100_INT, FIELD_100_UINT, FIELD_150_FLOAT, FIELD_150_INT, FIELD_150_UINT, @@ -7,6 +6,7 @@ use crate::pipeline::aggregation::tests::aggregation_tests_utils::{ FIELD_350_FLOAT, FIELD_350_INT, FIELD_350_UINT, FIELD_50_FLOAT, FIELD_50_INT, FIELD_50_UINT, FIELD_NULL, ITALY, SINGAPORE, }; +use crate::output; use dozer_core::DEFAULT_PORT_HANDLE; use dozer_types::types::FieldType::{Decimal, Duration, Float, Int, UInt}; use std::collections::HashMap; diff --git a/dozer-sql/src/pipeline/aggregation/tests/aggregation_test_planner.rs b/dozer-sql/src/aggregation/tests/aggregation_test_planner.rs similarity index 95% rename from dozer-sql/src/pipeline/aggregation/tests/aggregation_test_planner.rs rename to dozer-sql/src/aggregation/tests/aggregation_test_planner.rs index 0540061624..0ad53a17e8 100644 --- a/dozer-sql/src/pipeline/aggregation/tests/aggregation_test_planner.rs +++ b/dozer-sql/src/aggregation/tests/aggregation_test_planner.rs @@ -1,6 +1,6 @@ -use crate::pipeline::aggregation::processor::AggregationProcessor; -use crate::pipeline::planner::projection::CommonPlanner; -use crate::pipeline::tests::utils::get_select; +use crate::aggregation::processor::AggregationProcessor; +use crate::planner::projection::CommonPlanner; +use crate::tests::utils::get_select; use dozer_types::types::{ Field, FieldDefinition, FieldType, Operation, Record, Schema, SourceDefinition, }; diff --git a/dozer-sql/src/pipeline/aggregation/tests/aggregation_tests_utils.rs b/dozer-sql/src/aggregation/tests/aggregation_tests_utils.rs similarity index 97% rename from dozer-sql/src/pipeline/aggregation/tests/aggregation_tests_utils.rs rename to dozer-sql/src/aggregation/tests/aggregation_tests_utils.rs index 871a8871ee..b2f368ca13 100644 --- a/dozer-sql/src/pipeline/aggregation/tests/aggregation_tests_utils.rs +++ b/dozer-sql/src/aggregation/tests/aggregation_tests_utils.rs @@ -5,10 +5,10 @@ use dozer_types::types::{ }; use std::collections::HashMap; -use crate::pipeline::aggregation::processor::AggregationProcessor; -use crate::pipeline::errors::PipelineError; -use crate::pipeline::planner::projection::CommonPlanner; -use crate::pipeline::tests::utils::get_select; +use crate::aggregation::processor::AggregationProcessor; +use crate::errors::PipelineError; +use crate::planner::projection::CommonPlanner; +use crate::tests::utils::get_select; use dozer_types::arrow::datatypes::ArrowNativeTypeOp; use dozer_types::chrono::{DateTime, NaiveDate, TimeZone, Utc}; use dozer_types::ordered_float::OrderedFloat; diff --git a/dozer-sql/src/pipeline/aggregation/tests/mod.rs b/dozer-sql/src/aggregation/tests/mod.rs similarity index 100% rename from dozer-sql/src/pipeline/aggregation/tests/mod.rs rename to dozer-sql/src/aggregation/tests/mod.rs diff --git a/dozer-sql/src/pipeline/builder.rs b/dozer-sql/src/builder.rs similarity index 98% rename from dozer-sql/src/pipeline/builder.rs rename to dozer-sql/src/builder.rs index a452486bd3..f6fe18806a 100644 --- a/dozer-sql/src/pipeline/builder.rs +++ b/dozer-sql/src/builder.rs @@ -1,16 +1,18 @@ -use crate::pipeline::aggregation::factory::AggregationProcessorFactory; -use crate::pipeline::builder::PipelineError::InvalidQuery; -use crate::pipeline::errors::PipelineError; -use crate::pipeline::expression::builder::{ExpressionBuilder, NameOrAlias}; -use crate::pipeline::selection::factory::SelectionProcessorFactory; +use crate::aggregation::factory::AggregationProcessorFactory; +use crate::builder::PipelineError::InvalidQuery; +use crate::errors::PipelineError; +use crate::selection::factory::SelectionProcessorFactory; use dozer_core::app::AppPipeline; use dozer_core::app::PipelineEntryPoint; use dozer_core::node::PortHandle; use dozer_core::DEFAULT_PORT_HANDLE; +use dozer_sql_expression::builder::{ExpressionBuilder, NameOrAlias}; +use dozer_sql_expression::sqlparser::ast::{ + Join, SetOperator, SetQuantifier, TableFactor, TableWithJoins, +}; use dozer_types::models::udf_config::UdfConfig; -use sqlparser::ast::{Join, SetOperator, SetQuantifier, TableFactor, TableWithJoins}; -use sqlparser::{ +use dozer_sql_expression::sqlparser::{ ast::{Query, Select, SetExpr, Statement}, dialect::DozerDialect, parser::Parser, @@ -652,7 +654,7 @@ pub fn get_from_source( #[cfg(test)] mod tests { use super::statement_to_pipeline; - use crate::pipeline::errors::PipelineError; + use crate::errors::PipelineError; use dozer_core::app::AppPipeline; #[test] #[should_panic] diff --git a/dozer-sql/src/pipeline/errors.rs b/dozer-sql/src/errors.rs similarity index 75% rename from dozer-sql/src/pipeline/errors.rs rename to dozer-sql/src/errors.rs index 69b253f996..2f56c676b7 100644 --- a/dozer-sql/src/pipeline/errors.rs +++ b/dozer-sql/src/errors.rs @@ -11,9 +11,6 @@ use dozer_types::thiserror::Error; use dozer_types::types::{Field, FieldType}; use std::fmt::{Display, Formatter}; -#[cfg(feature = "onnx")] -use crate::pipeline::onnx::OnnxError; - use super::utils::serialize::DeserializationError; #[derive(Debug, Clone)] @@ -38,72 +35,24 @@ impl Display for FieldTypes { pub enum PipelineError { #[error("Invalid operand type for function: {0}()")] InvalidOperandType(String), - #[error("Invalid input type. Reason: {0}")] - InvalidInputType(String), #[error("Invalid return type: {0}")] InvalidReturnType(String), #[error("Invalid function: {0}")] InvalidFunction(String), #[error("Invalid operator: {0}")] InvalidOperator(String), - #[error("Invalid expression: {0}")] - InvalidExpression(String), - #[error("Invalid argument: {0}")] - InvalidArgument(String), - #[error("Invalid types on {0} and {1} for {2} operand")] - InvalidTypeComparison(Field, Field, String), - #[error("Invalid types on {0} for {1} operand")] - InvalidType(Field, String), #[error("Invalid value: {0}")] InvalidValue(String), #[error("Invalid query: {0}")] InvalidQuery(String), - #[error("Invalid relation")] - InvalidRelation, - #[error("Invalid relation")] - DataTypeMismatch, #[error("Invalid argument for function {0}(): argument: {1}, index: {2}")] InvalidFunctionArgument(String, Field, usize), - #[error("Too many arguments for function {0}()")] - TooManyArguments(String), #[error("Not enough arguments for function {0}()")] NotEnoughArguments(String), - #[error( - "Invalid argument type for function {0}(): type: {1}, expected types: {2}, index: {3}" - )] - InvalidFunctionArgumentType(String, FieldType, FieldTypes, usize), - #[error("Mismatching argument types for {0}(): {1}, consider using CAST function")] - InvalidConditionalExpression(String, FieldTypes), - #[error("Invalid cast: from: {from}, to: {to}")] - InvalidCast { from: Field, to: FieldType }, - #[error("{0}() cannot be called from here. Aggregations can only be used in SELECT and HAVING and cannot be nested within other aggregations.")] - InvalidNestedAggregationFunction(String), - #[error("Field {0} is not present in the source schema")] - UnknownFieldIdentifier(String), - #[error( - "Field {0} is ambiguous. Specify a fully qualified name such as [connection.]source.field" - )] - AmbiguousFieldIdentifier(String), - #[error("The field identifier {0} is invalid. Correct format is: [[connection.]source.]field")] - IllegalFieldIdentifier(String), - #[error("Unable to cast {0} to {1}")] - UnableToCast(String, String), #[error("Missing INTO clause for top-level SELECT statement")] MissingIntoClause, #[error("Duplicate INTO table name found: {0:?}")] DuplicateIntoClause(String), - #[cfg(feature = "python")] - #[error("Python Error: {0}")] - PythonErr(dozer_types::pyo3::PyErr), - #[cfg(feature = "onnx")] - #[error("Onnx Error: {0}")] - OnnxError(OnnxError), - #[cfg(not(feature = "onnx"))] - #[error("Onnx feature is not enabled")] - OnnxNotEnabled, - - #[error("Udf is defined but missing with config: {0}")] - UdfConfigMissing(String), // Error forwarding #[error("Internal type error: {0}")] @@ -111,6 +60,9 @@ pub enum PipelineError { #[error("Internal error: {0}")] InternalError(#[from] BoxedError), + #[error("Expression error: {0}")] + Expression(#[from] dozer_sql_expression::error::Error), + #[error("Unsupported sql: {0}")] UnsupportedSqlError(#[from] UnsupportedSqlError), @@ -123,9 +75,6 @@ pub enum PipelineError { #[error("Set: {0}")] SetError(#[from] SetError), - #[error("Sql: {0}")] - SqlError(#[from] SqlError), - #[error("Window: {0}")] WindowError(#[from] WindowError), @@ -172,26 +121,11 @@ pub enum PipelineError { #[error("Currently JOIN supports two level of namespacing. For example, `source.field_name` is valid, but `connection.source.field_name` is not.")] NameSpaceTooLong(String), - #[error("Error building the JOIN on the {0} source of the Processor")] - JoinBuild(String), - #[error("Window: {0}")] TableOperatorError(#[from] TableOperatorError), #[error("Invalid port handle: {0}")] InvalidPortHandle(PortHandle), - #[error("JOIN processor received a Record from a wrong input: {0}")] - InvalidPort(u16), - - #[error("Unknown function: {0}")] - UnknownFunction(String), -} - -#[cfg(feature = "python")] -impl From for PipelineError { - fn from(py_err: dozer_types::pyo3::PyErr) -> Self { - PipelineError::PythonErr(py_err) - } } #[derive(Error, Debug)] @@ -218,30 +152,6 @@ pub enum UnsupportedSqlError { GenericError(String), } -#[derive(Error, Debug)] -pub enum SqlError { - #[error("SQL Error: The first argument of the {0} function must be a source name.")] - WindowError(String), - #[error("SQL Error: Invalid column name {0}.")] - InvalidColumn(String), - #[error(transparent)] - Operation(#[from] OperationError), -} - -#[derive(Error, Debug)] -pub enum OperationError { - #[error("SQL Error: Addition operation cannot be done due to overflow.")] - AdditionOverflow, - #[error("SQL Error: Subtraction operation cannot be done due to overflow.")] - SubtractionOverflow, - #[error("SQL Error: Multiplication operation cannot be done due to overflow.")] - MultiplicationOverflow, - #[error("SQL Error: Division operation cannot be done.")] - DivisionByZeroOrOverflow, - #[error("SQL Error: Modulo operation cannot be done.")] - ModuloByZeroOrOverflow, -} - #[derive(Error, Debug)] pub enum SetError { #[error("Invalid input schemas have been populated")] diff --git a/dozer-sql/src/expression/mod.rs b/dozer-sql/src/expression/mod.rs new file mode 100644 index 0000000000..87c2771955 --- /dev/null +++ b/dozer-sql/src/expression/mod.rs @@ -0,0 +1,2 @@ +#[cfg(test)] +mod tests; diff --git a/dozer-sql/src/pipeline/expression/tests/case.rs b/dozer-sql/src/expression/tests/case.rs similarity index 98% rename from dozer-sql/src/pipeline/expression/tests/case.rs rename to dozer-sql/src/expression/tests/case.rs index 07f2498d8d..e5e07c788b 100644 --- a/dozer-sql/src/pipeline/expression/tests/case.rs +++ b/dozer-sql/src/expression/tests/case.rs @@ -1,4 +1,4 @@ -use crate::pipeline::expression::tests::test_common::run_fct; +use crate::expression::tests::test_common::run_fct; use dozer_types::types::{Field, FieldDefinition, FieldType, Schema, SourceDefinition}; #[test] diff --git a/dozer-sql/src/pipeline/expression/tests/cast.rs b/dozer-sql/src/expression/tests/cast.rs similarity index 83% rename from dozer-sql/src/pipeline/expression/tests/cast.rs rename to dozer-sql/src/expression/tests/cast.rs index f45119c34e..0b2f62efb0 100644 --- a/dozer-sql/src/pipeline/expression/tests/cast.rs +++ b/dozer-sql/src/expression/tests/cast.rs @@ -1,9 +1,4 @@ -use crate::pipeline::expression::execution::Expression::Literal; -use crate::pipeline::expression::mathematical::{ - evaluate_add, evaluate_div, evaluate_mod, evaluate_mul, evaluate_sub, -}; -use crate::pipeline::expression::tests::test_common::*; -use dozer_types::types::Record; +use crate::expression::tests::test_common::*; use dozer_types::types::SourceDefinition; use dozer_types::{ chrono::{DateTime, NaiveDate, TimeZone, Utc}, @@ -11,7 +6,6 @@ use dozer_types::{ rust_decimal::Decimal, types::{Field, FieldDefinition, FieldType, Schema}, }; -use num_traits::FromPrimitive; #[test] fn test_uint() { @@ -764,102 +758,3 @@ fn test_text() { ); assert_eq!(f, Field::Text("42".to_string())); } - -#[test] -fn test_decimal() { - let dec1 = Box::new(Literal(Field::Decimal(Decimal::from_i64(1_i64).unwrap()))); - let dec2 = Box::new(Literal(Field::Decimal(Decimal::from_i64(2_i64).unwrap()))); - let float1 = Box::new(Literal(Field::Float( - OrderedFloat::::from_i64(1_i64).unwrap(), - ))); - let float2 = Box::new(Literal(Field::Float( - OrderedFloat::::from_i64(2_i64).unwrap(), - ))); - let int1 = Box::new(Literal(Field::Int(1_i64))); - let int2 = Box::new(Literal(Field::Int(2_i64))); - let uint1 = Box::new(Literal(Field::UInt(1_u64))); - let uint2 = Box::new(Literal(Field::UInt(2_u64))); - - let row = Record::new(vec![]); - - // left: Int, right: Decimal - assert_eq!( - evaluate_add(&Schema::default(), &int1, dec1.as_ref(), &row) - .unwrap_or_else(|e| panic!("{}", e.to_string())), - Field::Decimal(Decimal::from_i64(2_i64).unwrap()) - ); - assert_eq!( - evaluate_sub(&Schema::default(), &int1, dec1.as_ref(), &row) - .unwrap_or_else(|e| panic!("{}", e.to_string())), - Field::Decimal(Decimal::from_i64(0_i64).unwrap()) - ); - assert_eq!( - evaluate_mul(&Schema::default(), &int2, dec1.as_ref(), &row) - .unwrap_or_else(|e| panic!("{}", e.to_string())), - Field::Decimal(Decimal::from_i64(2_i64).unwrap()) - ); - assert_eq!( - evaluate_div(&Schema::default(), &int1, dec2.as_ref(), &row) - .unwrap_or_else(|e| panic!("{}", e.to_string())), - Field::Decimal(Decimal::from_f64(0.5).unwrap()) - ); - assert_eq!( - evaluate_mod(&Schema::default(), &int1, dec1.as_ref(), &row) - .unwrap_or_else(|e| panic!("{}", e.to_string())), - Field::Decimal(Decimal::from_i64(0_i64).unwrap()) - ); - - // left: UInt, right: Decimal - assert_eq!( - evaluate_add(&Schema::default(), &uint1, dec1.as_ref(), &row) - .unwrap_or_else(|e| panic!("{}", e.to_string())), - Field::Decimal(Decimal::from_i64(2_i64).unwrap()) - ); - assert_eq!( - evaluate_sub(&Schema::default(), &uint1, dec1.as_ref(), &row) - .unwrap_or_else(|e| panic!("{}", e.to_string())), - Field::Decimal(Decimal::from_i64(0_i64).unwrap()) - ); - assert_eq!( - evaluate_mul(&Schema::default(), &uint2, dec1.as_ref(), &row) - .unwrap_or_else(|e| panic!("{}", e.to_string())), - Field::Decimal(Decimal::from_i64(2_i64).unwrap()) - ); - assert_eq!( - evaluate_div(&Schema::default(), &uint1, dec2.as_ref(), &row) - .unwrap_or_else(|e| panic!("{}", e.to_string())), - Field::Decimal(Decimal::from_f64(0.5).unwrap()) - ); - assert_eq!( - evaluate_mod(&Schema::default(), &uint1, dec1.as_ref(), &row) - .unwrap_or_else(|e| panic!("{}", e.to_string())), - Field::Decimal(Decimal::from_i64(0_i64).unwrap()) - ); - - // left: Float, right: Decimal - assert_eq!( - evaluate_add(&Schema::default(), &float1, dec1.as_ref(), &row) - .unwrap_or_else(|e| panic!("{}", e.to_string())), - Field::Decimal(Decimal::from_i64(2_i64).unwrap()) - ); - assert_eq!( - evaluate_sub(&Schema::default(), &float1, dec1.as_ref(), &row) - .unwrap_or_else(|e| panic!("{}", e.to_string())), - Field::Decimal(Decimal::from_i64(0_i64).unwrap()) - ); - assert_eq!( - evaluate_mul(&Schema::default(), &float2, dec1.as_ref(), &row) - .unwrap_or_else(|e| panic!("{}", e.to_string())), - Field::Decimal(Decimal::from_i64(2_i64).unwrap()) - ); - assert_eq!( - evaluate_div(&Schema::default(), &float1, dec2.as_ref(), &row) - .unwrap_or_else(|e| panic!("{}", e.to_string())), - Field::Decimal(Decimal::from_f64(0.5).unwrap()) - ); - assert_eq!( - evaluate_mod(&Schema::default(), &float1, dec1.as_ref(), &row) - .unwrap_or_else(|e| panic!("{}", e.to_string())), - Field::Decimal(Decimal::from_i64(0_i64).unwrap()) - ); -} diff --git a/dozer-sql/src/expression/tests/comparison.rs b/dozer-sql/src/expression/tests/comparison.rs new file mode 100644 index 0000000000..235fef0a3e --- /dev/null +++ b/dozer-sql/src/expression/tests/comparison.rs @@ -0,0 +1,175 @@ +use dozer_types::chrono::{DateTime, NaiveDate}; +use dozer_types::types::DATE_FORMAT; +use dozer_types::types::{Field, Schema}; +use dozer_types::types::{FieldDefinition, FieldType, SourceDefinition}; + +use crate::expression::tests::test_common::run_fct; + +#[test] +fn test_comparison_logical_int() { + let record = vec![Field::Int(124)]; + let schema = Schema::default() + .field( + FieldDefinition::new( + String::from("id"), + FieldType::Int, + false, + SourceDefinition::Dynamic, + ), + false, + ) + .clone(); + + let f = run_fct( + "SELECT id FROM users WHERE id = '124'", + schema.clone(), + record.clone(), + ); + assert_eq!(f, Field::Int(124)); + + let f = run_fct( + "SELECT id FROM users WHERE id <= '124'", + schema.clone(), + record.clone(), + ); + assert_eq!(f, Field::Int(124)); + + let f = run_fct( + "SELECT id FROM users WHERE id >= '124'", + schema.clone(), + record.clone(), + ); + assert_eq!(f, Field::Int(124)); + + let f = run_fct( + "SELECT id = '124' FROM users", + schema.clone(), + record.clone(), + ); + assert_eq!(f, Field::Boolean(true)); + + let f = run_fct( + "SELECT id < '124' FROM users", + schema.clone(), + record.clone(), + ); + assert_eq!(f, Field::Boolean(false)); + + let f = run_fct( + "SELECT id > '124' FROM users", + schema.clone(), + record.clone(), + ); + assert_eq!(f, Field::Boolean(false)); + + let f = run_fct( + "SELECT id <= '124' FROM users", + schema.clone(), + record.clone(), + ); + assert_eq!(f, Field::Boolean(true)); + + let f = run_fct("SELECT id >= '124' FROM users", schema, record); + assert_eq!(f, Field::Boolean(true)); +} + +#[test] +fn test_comparison_logical_timestamp() { + let f = run_fct( + "SELECT time = '2020-01-01T00:00:00Z' FROM users", + Schema::default() + .field( + FieldDefinition::new( + String::from("time"), + FieldType::Timestamp, + false, + SourceDefinition::Dynamic, + ), + false, + ) + .clone(), + vec![Field::Timestamp( + DateTime::parse_from_rfc3339("2020-01-01T00:00:00Z").unwrap(), + )], + ); + assert_eq!(f, Field::Boolean(true)); + + let f = run_fct( + "SELECT time < '2020-01-01T00:00:01Z' FROM users", + Schema::default() + .field( + FieldDefinition::new( + String::from("time"), + FieldType::Timestamp, + false, + SourceDefinition::Dynamic, + ), + false, + ) + .clone(), + vec![Field::Timestamp( + DateTime::parse_from_rfc3339("2020-01-01T00:00:00Z").unwrap(), + )], + ); + assert_eq!(f, Field::Boolean(true)); +} + +#[test] +fn test_comparison_logical_date() { + let f = run_fct( + "SELECT date = '2020-01-01' FROM users", + Schema::default() + .field( + FieldDefinition::new( + String::from("date"), + FieldType::Int, + false, + SourceDefinition::Dynamic, + ), + false, + ) + .clone(), + vec![Field::Date( + NaiveDate::parse_from_str("2020-01-01", DATE_FORMAT).unwrap(), + )], + ); + assert_eq!(f, Field::Boolean(true)); + + let f = run_fct( + "SELECT date != '2020-01-01' FROM users", + Schema::default() + .field( + FieldDefinition::new( + String::from("date"), + FieldType::Int, + false, + SourceDefinition::Dynamic, + ), + false, + ) + .clone(), + vec![Field::Date( + NaiveDate::parse_from_str("2020-01-01", DATE_FORMAT).unwrap(), + )], + ); + assert_eq!(f, Field::Boolean(false)); + + let f = run_fct( + "SELECT date > '2020-01-01' FROM users", + Schema::default() + .field( + FieldDefinition::new( + String::from("date"), + FieldType::Int, + false, + SourceDefinition::Dynamic, + ), + false, + ) + .clone(), + vec![Field::Date( + NaiveDate::parse_from_str("2020-01-02", DATE_FORMAT).unwrap(), + )], + ); + assert_eq!(f, Field::Boolean(true)); +} diff --git a/dozer-sql/src/expression/tests/conditional.rs b/dozer-sql/src/expression/tests/conditional.rs new file mode 100644 index 0000000000..8148e70735 --- /dev/null +++ b/dozer-sql/src/expression/tests/conditional.rs @@ -0,0 +1,96 @@ +use crate::expression::tests::test_common::*; +use dozer_types::{ + ordered_float::OrderedFloat, + types::{Field, FieldDefinition, FieldType, Schema, SourceDefinition}, +}; + +#[test] +fn test_coalesce_logic() { + let f = run_fct( + "SELECT COALESCE(field, 2) FROM users", + Schema::default() + .field( + FieldDefinition::new( + String::from("field"), + FieldType::Int, + false, + SourceDefinition::Dynamic, + ), + false, + ) + .clone(), + vec![Field::Null], + ); + assert_eq!(f, Field::Int(2)); + + let f = run_fct( + "SELECT COALESCE(field, CAST(2 AS FLOAT)) FROM users", + Schema::default() + .field( + FieldDefinition::new( + String::from("field"), + FieldType::Float, + false, + SourceDefinition::Dynamic, + ), + false, + ) + .clone(), + vec![Field::Null], + ); + assert_eq!(f, Field::Float(OrderedFloat(2.0))); + + let f = run_fct( + "SELECT COALESCE(field, 'X') FROM users", + Schema::default() + .field( + FieldDefinition::new( + String::from("field"), + FieldType::String, + false, + SourceDefinition::Dynamic, + ), + false, + ) + .clone(), + vec![Field::Null], + ); + assert_eq!(f, Field::String("X".to_string())); + + let f = run_fct( + "SELECT COALESCE(field, 'X') FROM users", + Schema::default() + .field( + FieldDefinition::new( + String::from("field"), + FieldType::String, + false, + SourceDefinition::Dynamic, + ), + false, + ) + .clone(), + vec![Field::Null], + ); + assert_eq!(f, Field::String("X".to_string())); +} + +#[test] +fn test_coalesce_logic_null() { + let f = run_fct( + "SELECT COALESCE(field) FROM users", + Schema::default() + .field( + FieldDefinition::new( + String::from("field"), + FieldType::Int, + false, + SourceDefinition::Dynamic, + ), + false, + ) + .clone(), + vec![Field::Null], + ); + assert_eq!(f, Field::Null); +} diff --git a/dozer-sql/src/expression/tests/datetime.rs b/dozer-sql/src/expression/tests/datetime.rs new file mode 100644 index 0000000000..24d04d856b --- /dev/null +++ b/dozer-sql/src/expression/tests/datetime.rs @@ -0,0 +1,182 @@ +use crate::expression::tests::test_common::*; +use dozer_types::chrono::{DateTime, NaiveDate}; +use dozer_types::types::{ + DozerDuration, Field, FieldDefinition, FieldType, Schema, SourceDefinition, TimeUnit, +}; + +#[test] +fn test_extract_date() { + let date_fns: Vec<(&str, i64, i64)> = vec![ + ("dow", 6, 0), + ("day", 1, 2), + ("month", 1, 1), + ("year", 2023, 2023), + ("hour", 0, 0), + ("minute", 0, 12), + ("second", 0, 10), + ("millisecond", 1672531200000, 1672618330000), + ("microsecond", 1672531200000000, 1672618330000000), + ("nanoseconds", 1672531200000000000, 1672618330000000000), + ("quarter", 1, 1), + ("epoch", 1672531200, 1672618330), + ("week", 52, 1), + ("century", 21, 21), + ("decade", 203, 203), + ("doy", 1, 2), + ]; + let inputs = vec![ + Field::Date(NaiveDate::from_ymd_opt(2023, 1, 1).unwrap()), + Field::Timestamp(DateTime::parse_from_rfc3339("2023-01-02T00:12:10Z").unwrap()), + ]; + + for (part, val1, val2) in date_fns { + let mut results = vec![]; + for i in inputs.clone() { + let f = run_fct( + &format!("select extract({part} from date) from users"), + Schema::default() + .field( + FieldDefinition::new( + String::from("date"), + FieldType::Date, + false, + SourceDefinition::Dynamic, + ), + false, + ) + .clone(), + vec![i.clone()], + ); + results.push(f.to_int().unwrap()); + } + assert_eq!(val1, results[0]); + assert_eq!(val2, results[1]); + } +} + +#[test] +fn test_timestamp_diff() { + let f = run_fct( + "SELECT ts1 - ts2 FROM users", + Schema::default() + .field( + FieldDefinition::new( + String::from("ts1"), + FieldType::Timestamp, + false, + SourceDefinition::Dynamic, + ), + false, + ) + .field( + FieldDefinition::new( + String::from("ts2"), + FieldType::Timestamp, + false, + SourceDefinition::Dynamic, + ), + false, + ) + .clone(), + vec![ + Field::Timestamp(DateTime::parse_from_rfc3339("2023-01-02T00:12:11Z").unwrap()), + Field::Timestamp(DateTime::parse_from_rfc3339("2023-01-02T00:12:10Z").unwrap()), + ], + ); + assert_eq!( + f, + Field::Duration(DozerDuration( + std::time::Duration::from_secs(1), + TimeUnit::Nanoseconds + )) + ); +} + +#[test] +fn test_interval() { + let f = run_fct( + "SELECT ts1 - INTERVAL '1' SECOND FROM users", + Schema::default() + .field( + FieldDefinition::new( + String::from("ts1"), + FieldType::Timestamp, + false, + SourceDefinition::Dynamic, + ), + false, + ) + .clone(), + vec![Field::Timestamp( + DateTime::parse_from_rfc3339("2023-01-02T00:12:11Z").unwrap(), + )], + ); + assert_eq!( + f, + Field::Timestamp(DateTime::parse_from_rfc3339("2023-01-02T00:12:10Z").unwrap()) + ); + + let f = run_fct( + "SELECT ts1 + INTERVAL '1' SECOND FROM users", + Schema::default() + .field( + FieldDefinition::new( + String::from("ts1"), + FieldType::Timestamp, + false, + SourceDefinition::Dynamic, + ), + false, + ) + .clone(), + vec![Field::Timestamp( + DateTime::parse_from_rfc3339("2023-01-02T00:12:11Z").unwrap(), + )], + ); + assert_eq!( + f, + Field::Timestamp(DateTime::parse_from_rfc3339("2023-01-02T00:12:12Z").unwrap()) + ); + + let f = run_fct( + "SELECT INTERVAL '1' SECOND + ts1 FROM users", + Schema::default() + .field( + FieldDefinition::new( + String::from("ts1"), + FieldType::Timestamp, + false, + SourceDefinition::Dynamic, + ), + false, + ) + .clone(), + vec![Field::Timestamp( + DateTime::parse_from_rfc3339("2023-01-02T00:12:11Z").unwrap(), + )], + ); + assert_eq!( + f, + Field::Timestamp(DateTime::parse_from_rfc3339("2023-01-02T00:12:12Z").unwrap()) + ); +} + +#[test] +fn test_now() { + let f = run_fct( + "SELECT NOW() FROM users", + Schema::default() + .field( + FieldDefinition::new( + String::from("ts1"), + FieldType::Timestamp, + false, + SourceDefinition::Dynamic, + ), + false, + ) + .clone(), + vec![], + ); + assert!(f.to_timestamp().is_some()) +} diff --git a/dozer-sql/src/expression/tests/distance.rs b/dozer-sql/src/expression/tests/distance.rs new file mode 100644 index 0000000000..fa1a04ff49 --- /dev/null +++ b/dozer-sql/src/expression/tests/distance.rs @@ -0,0 +1,82 @@ +use crate::expression::tests::test_common::*; +use dozer_types::ordered_float::OrderedFloat; +use dozer_types::types::{DozerPoint, Field, FieldDefinition, FieldType, Schema, SourceDefinition}; + +#[test] +fn test_distance_logical() { + let tests = vec![ + ("", 1113.0264976969), + ("GEODESIC", 1113.0264976969), + ("HAVERSINE", 1111.7814468418496), + ("VINCENTY", 1113.0264975564357), + ]; + + let schema = Schema::default() + .field( + FieldDefinition::new( + String::from("from"), + FieldType::Point, + false, + SourceDefinition::Dynamic, + ), + false, + ) + .field( + FieldDefinition::new( + String::from("to"), + FieldType::Point, + false, + SourceDefinition::Dynamic, + ), + false, + ) + .clone(); + + let input = vec![ + Field::Point(DozerPoint::from((1.0, 1.0))), + Field::Point(DozerPoint::from((1.01, 1.0))), + ]; + + for (calculation_type, expected_result) in tests { + let sql = if calculation_type.is_empty() { + "SELECT DISTANCE(from, to) FROM LOCATIONS".to_string() + } else { + format!("SELECT DISTANCE(from, to, '{calculation_type}') FROM LOCATIONS") + }; + if let Field::Float(OrderedFloat(result)) = run_fct(&sql, schema.clone(), input.clone()) { + assert!((result - expected_result) < 0.000000001); + } else { + panic!("Expected float"); + } + } +} + +#[test] +fn test_distance_with_nullable_parameter() { + let f = run_fct( + "SELECT DISTANCE(from, to) FROM LOCATION", + Schema::default() + .field( + FieldDefinition::new( + String::from("from"), + FieldType::Point, + false, + SourceDefinition::Dynamic, + ), + false, + ) + .field( + FieldDefinition::new( + String::from("to"), + FieldType::Point, + true, + SourceDefinition::Dynamic, + ), + false, + ) + .clone(), + vec![Field::Point(DozerPoint::from((0.0, 1.0))), Field::Null], + ); + + assert_eq!(f, Field::Null); +} diff --git a/dozer-sql/src/pipeline/expression/tests/execution.rs b/dozer-sql/src/expression/tests/execution.rs similarity index 76% rename from dozer-sql/src/pipeline/expression/tests/execution.rs rename to dozer-sql/src/expression/tests/execution.rs index 548348f58d..badd877ab8 100644 --- a/dozer-sql/src/pipeline/expression/tests/execution.rs +++ b/dozer-sql/src/expression/tests/execution.rs @@ -1,16 +1,12 @@ -use crate::pipeline::expression::execution::Expression; -use crate::pipeline::expression::mathematical::evaluate_sub; -use crate::pipeline::expression::operator::{BinaryOperatorType, UnaryOperatorType}; -use crate::pipeline::expression::scalar::common::ScalarFunctionType; -use crate::pipeline::projection::factory::ProjectionProcessorFactory; -use crate::pipeline::tests::utils::get_select; +use crate::projection::factory::ProjectionProcessorFactory; +use crate::tests::utils::get_select; use dozer_core::node::ProcessorFactory; use dozer_core::DEFAULT_PORT_HANDLE; -use dozer_types::chrono::DateTime; +use dozer_sql_expression::execution::Expression; +use dozer_sql_expression::operator::{BinaryOperatorType, UnaryOperatorType}; +use dozer_sql_expression::scalar::common::ScalarFunctionType; use dozer_types::types::Record; -use dozer_types::types::{ - DozerDuration, Field, FieldDefinition, FieldType, Schema, SourceDefinition, TimeUnit, -}; +use dozer_types::types::{Field, FieldDefinition, FieldType, Schema, SourceDefinition}; #[test] fn test_column_execution() { @@ -232,55 +228,3 @@ fn test_wildcard() { .clone() ); } - -#[test] -fn test_timestamp_difference() { - let schema = Schema::default() - .field( - FieldDefinition::new( - String::from("a"), - FieldType::Timestamp, - false, - SourceDefinition::Dynamic, - ), - true, - ) - .field( - FieldDefinition::new( - String::from("b"), - FieldType::Timestamp, - false, - SourceDefinition::Dynamic, - ), - false, - ) - .clone(); - - let record = Record::new(vec![ - Field::Timestamp(DateTime::parse_from_rfc3339("2020-01-01T00:13:00Z").unwrap()), - Field::Timestamp(DateTime::parse_from_rfc3339("2020-01-01T00:12:10Z").unwrap()), - ]); - - let result = evaluate_sub( - &schema, - &Expression::Column { index: 0 }, - &Expression::Column { index: 1 }, - &record, - ) - .unwrap(); - assert_eq!( - result, - Field::Duration(DozerDuration( - std::time::Duration::from_nanos(50000 * 1000 * 1000), - TimeUnit::Nanoseconds - )) - ); - - let result = evaluate_sub( - &schema, - &Expression::Column { index: 1 }, - &Expression::Column { index: 0 }, - &record, - ); - assert!(result.is_err()); -} diff --git a/dozer-sql/src/pipeline/expression/tests/expression_builder_test.rs b/dozer-sql/src/expression/tests/expression_builder_test.rs similarity index 97% rename from dozer-sql/src/pipeline/expression/tests/expression_builder_test.rs rename to dozer-sql/src/expression/tests/expression_builder_test.rs index 59ed2ebc8c..5c44840762 100644 --- a/dozer-sql/src/pipeline/expression/tests/expression_builder_test.rs +++ b/dozer-sql/src/expression/tests/expression_builder_test.rs @@ -1,12 +1,11 @@ -use crate::pipeline::expression::builder::ExpressionBuilder; -use crate::pipeline::expression::execution::Expression; -use crate::pipeline::expression::operator::BinaryOperatorType; -use crate::pipeline::expression::scalar::common::ScalarFunctionType; -use crate::pipeline::tests::utils::get_select; +use crate::tests::utils::get_select; +use dozer_sql_expression::execution::Expression; +use dozer_sql_expression::operator::BinaryOperatorType; +use dozer_sql_expression::scalar::common::ScalarFunctionType; +use dozer_sql_expression::{builder::ExpressionBuilder, sqlparser::ast::SelectItem}; -use crate::pipeline::expression::aggregate::AggregateFunctionType; +use dozer_sql_expression::aggregate::AggregateFunctionType; use dozer_types::types::{Field, FieldDefinition, FieldType, Schema, SourceDefinition}; -use sqlparser::ast::SelectItem; #[test] fn test_simple_function() { diff --git a/dozer-sql/src/pipeline/expression/tests/in_list.rs b/dozer-sql/src/expression/tests/in_list.rs similarity index 97% rename from dozer-sql/src/pipeline/expression/tests/in_list.rs rename to dozer-sql/src/expression/tests/in_list.rs index 520c64b777..0296f2210c 100644 --- a/dozer-sql/src/pipeline/expression/tests/in_list.rs +++ b/dozer-sql/src/expression/tests/in_list.rs @@ -1,4 +1,4 @@ -use crate::pipeline::expression::tests::test_common::run_fct; +use crate::expression::tests::test_common::run_fct; use dozer_types::types::{Field, FieldDefinition, FieldType, Schema, SourceDefinition}; #[test] diff --git a/dozer-sql/src/pipeline/expression/tests/json_functions.rs b/dozer-sql/src/expression/tests/json_functions.rs similarity index 99% rename from dozer-sql/src/pipeline/expression/tests/json_functions.rs rename to dozer-sql/src/expression/tests/json_functions.rs index 1a166e81b9..e4ead464b8 100644 --- a/dozer-sql/src/pipeline/expression/tests/json_functions.rs +++ b/dozer-sql/src/expression/tests/json_functions.rs @@ -1,4 +1,4 @@ -use crate::pipeline::expression::tests::test_common::run_fct; +use crate::expression::tests::test_common::run_fct; use dozer_types::json_types::{serde_json_to_json_value, JsonValue}; use dozer_types::ordered_float::OrderedFloat; use dozer_types::serde_json::json; diff --git a/dozer-sql/src/expression/tests/mod.rs b/dozer-sql/src/expression/tests/mod.rs new file mode 100644 index 0000000000..3fdb76ff53 --- /dev/null +++ b/dozer-sql/src/expression/tests/mod.rs @@ -0,0 +1,14 @@ +mod case; +mod cast; +mod comparison; +mod conditional; +mod datetime; +mod distance; +mod execution; +mod expression_builder_test; +mod in_list; +mod json_functions; +mod number; +mod point; +mod string; +mod test_common; diff --git a/dozer-sql/src/pipeline/expression/tests/models/onnx_modeling.py b/dozer-sql/src/expression/tests/models/onnx_modeling.py similarity index 100% rename from dozer-sql/src/pipeline/expression/tests/models/onnx_modeling.py rename to dozer-sql/src/expression/tests/models/onnx_modeling.py diff --git a/dozer-sql/src/expression/tests/number.rs b/dozer-sql/src/expression/tests/number.rs new file mode 100644 index 0000000000..abfeca822b --- /dev/null +++ b/dozer-sql/src/expression/tests/number.rs @@ -0,0 +1,26 @@ +use crate::expression::tests::test_common::*; +use dozer_types::types::{Field, FieldDefinition, FieldType, Schema, SourceDefinition}; +use proptest::prelude::*; +use std::ops::Neg; + +#[test] +fn test_abs_logic() { + proptest!(ProptestConfig::with_cases(1000), |(i_num in 0i64..100000000i64)| { + let f = run_fct( + "SELECT ABS(c) FROM USERS", + Schema::default() + .field( + FieldDefinition::new( + String::from("c"), + FieldType::Int, + false, + SourceDefinition::Dynamic, + ), + false, + ) + .clone(), + vec![Field::Int(i_num.neg())], + ); + assert_eq!(f, Field::Int(i_num)); + }); +} diff --git a/dozer-sql/src/expression/tests/point.rs b/dozer-sql/src/expression/tests/point.rs new file mode 100644 index 0000000000..407b3900b8 --- /dev/null +++ b/dozer-sql/src/expression/tests/point.rs @@ -0,0 +1,64 @@ +use crate::expression::tests::test_common::*; +use dozer_types::ordered_float::OrderedFloat; +use dozer_types::types::{DozerPoint, Field, FieldDefinition, FieldType, Schema, SourceDefinition}; + +#[test] +fn test_point_logical() { + let f = run_fct( + "SELECT POINT(x, y) FROM LOCATION", + Schema::default() + .field( + FieldDefinition::new( + String::from("x"), + FieldType::Float, + false, + SourceDefinition::Dynamic, + ), + false, + ) + .field( + FieldDefinition::new( + String::from("y"), + FieldType::Float, + false, + SourceDefinition::Dynamic, + ), + false, + ) + .clone(), + vec![ + Field::Float(OrderedFloat(1.0)), + Field::Float(OrderedFloat(2.0)), + ], + ); + assert_eq!(f, Field::Point(DozerPoint::from((1.0, 2.0)))); +} + +#[test] +fn test_point_with_nullable_parameter() { + let f = run_fct( + "SELECT POINT(x, y) FROM LOCATION", + Schema::default() + .field( + FieldDefinition::new( + String::from("x"), + FieldType::Float, + false, + SourceDefinition::Dynamic, + ), + false, + ) + .field( + FieldDefinition::new( + String::from("y"), + FieldType::Float, + true, + SourceDefinition::Dynamic, + ), + false, + ) + .clone(), + vec![Field::Float(OrderedFloat(1.0)), Field::Null], + ); + assert_eq!(f, Field::Null); +} diff --git a/dozer-sql/src/pipeline/expression/tests/string.rs b/dozer-sql/src/expression/tests/string.rs similarity index 55% rename from dozer-sql/src/pipeline/expression/tests/string.rs rename to dozer-sql/src/expression/tests/string.rs index 78dc36e3d8..80f374a20b 100644 --- a/dozer-sql/src/pipeline/expression/tests/string.rs +++ b/dozer-sql/src/expression/tests/string.rs @@ -1,324 +1,6 @@ -use crate::pipeline::expression::execution::Expression::Literal; -use crate::pipeline::expression::scalar::string::{ - evaluate_concat, evaluate_like, evaluate_trim, evaluate_ucase, validate_concat, validate_trim, - TrimType, -}; -use crate::pipeline::expression::tests::test_common::*; +use crate::expression::tests::test_common::*; use dozer_types::chrono::{DateTime, NaiveDate, TimeZone, Utc}; -use dozer_types::types::Record; use dozer_types::types::{Field, FieldDefinition, FieldType, Schema, SourceDefinition}; -use proptest::prelude::*; - -#[test] -fn test_string() { - proptest!( - ProptestConfig::with_cases(1000), - move |(s_val in ".+", s_val1 in ".*", s_val2 in ".*", c_val: char)| { - test_like(&s_val, c_val); - test_ucase(&s_val, c_val); - test_concat(&s_val1, &s_val2, c_val); - test_trim(&s_val, c_val); - }); -} - -fn test_like(s_val: &str, c_val: char) { - let row = Record::new(vec![]); - - // Field::String - let value = Box::new(Literal(Field::String(format!("Hello{}", s_val)))); - let pattern = Box::new(Literal(Field::String("Hello%".to_owned()))); - - assert_eq!( - evaluate_like(&Schema::default(), &value, &pattern, None, &row).unwrap(), - Field::Boolean(true) - ); - - let value = Box::new(Literal(Field::String(format!("Hello, {}orld!", c_val)))); - let pattern = Box::new(Literal(Field::String("Hello, _orld!".to_owned()))); - - assert_eq!( - evaluate_like(&Schema::default(), &value, &pattern, None, &row).unwrap(), - Field::Boolean(true) - ); - - let value = Box::new(Literal(Field::String(s_val.to_string()))); - let pattern = Box::new(Literal(Field::String("Hello%".to_owned()))); - - assert_eq!( - evaluate_like(&Schema::default(), &value, &pattern, None, &row).unwrap(), - Field::Boolean(false) - ); - - let c_value = &s_val[0..0]; - let value = Box::new(Literal(Field::String(format!("Hello, {}!", c_value)))); - let pattern = Box::new(Literal(Field::String("Hello, _!".to_owned()))); - - assert_eq!( - evaluate_like(&Schema::default(), &value, &pattern, None, &row).unwrap(), - Field::Boolean(false) - ); - - // todo: should find the way to generate escape character using proptest - // let value = Box::new(Literal(Field::String(format!("Hello, {}%", c_val)))); - // let pattern = Box::new(Literal(Field::String("Hello, %".to_owned()))); - // let escape = Some(c_val); - // - // assert_eq!( - // evaluate_like(&Schema::default(), &value, &pattern, escape, &row).unwrap(), - // Field::Boolean(true) - // ); - - // Field::Text - let value = Box::new(Literal(Field::Text(format!("Hello{}", s_val)))); - let pattern = Box::new(Literal(Field::Text("Hello%".to_owned()))); - - assert_eq!( - evaluate_like(&Schema::default(), &value, &pattern, None, &row).unwrap(), - Field::Boolean(true) - ); - - let value = Box::new(Literal(Field::Text(format!("Hello, {}orld!", c_val)))); - let pattern = Box::new(Literal(Field::Text("Hello, _orld!".to_owned()))); - - assert_eq!( - evaluate_like(&Schema::default(), &value, &pattern, None, &row).unwrap(), - Field::Boolean(true) - ); - - let value = Box::new(Literal(Field::Text(s_val.to_string()))); - let pattern = Box::new(Literal(Field::Text("Hello%".to_owned()))); - - assert_eq!( - evaluate_like(&Schema::default(), &value, &pattern, None, &row).unwrap(), - Field::Boolean(false) - ); - - let c_value = &s_val[0..0]; - let value = Box::new(Literal(Field::Text(format!("Hello, {}!", c_value)))); - let pattern = Box::new(Literal(Field::Text("Hello, _!".to_owned()))); - - assert_eq!( - evaluate_like(&Schema::default(), &value, &pattern, None, &row).unwrap(), - Field::Boolean(false) - ); - - // todo: should find the way to generate escape character using proptest - // let value = Box::new(Literal(Field::Text(format!("Hello, {}%", c_val)))); - // let pattern = Box::new(Literal(Field::Text("Hello, %".to_owned()))); - // let escape = Some(c_val); - // - // assert_eq!( - // evaluate_like(&Schema::default(), &value, &pattern, escape, &row).unwrap(), - // Field::Boolean(true) - // ); -} - -fn test_ucase(s_val: &str, c_val: char) { - let row = Record::new(vec![]); - - // Field::String - let value = Box::new(Literal(Field::String(s_val.to_string()))); - assert_eq!( - evaluate_ucase(&Schema::default(), &value, &row).unwrap(), - Field::String(s_val.to_uppercase()) - ); - - let value = Box::new(Literal(Field::String(c_val.to_string()))); - assert_eq!( - evaluate_ucase(&Schema::default(), &value, &row).unwrap(), - Field::String(c_val.to_uppercase().to_string()) - ); - - // Field::Text - let value = Box::new(Literal(Field::Text(s_val.to_string()))); - assert_eq!( - evaluate_ucase(&Schema::default(), &value, &row).unwrap(), - Field::Text(s_val.to_uppercase()) - ); - - let value = Box::new(Literal(Field::Text(c_val.to_string()))); - assert_eq!( - evaluate_ucase(&Schema::default(), &value, &row).unwrap(), - Field::Text(c_val.to_uppercase().to_string()) - ); -} - -fn test_concat(s_val1: &str, s_val2: &str, c_val: char) { - let row = Record::new(vec![]); - - // Field::String - let val1 = Literal(Field::String(s_val1.to_string())); - let val2 = Literal(Field::String(s_val2.to_string())); - - if validate_concat(&[val1.clone(), val2.clone()], &Schema::default()).is_ok() { - assert_eq!( - evaluate_concat(&Schema::default(), &[val1, val2], &row).unwrap(), - Field::String(s_val1.to_string() + s_val2) - ); - } - - let val1 = Literal(Field::String(s_val2.to_string())); - let val2 = Literal(Field::String(s_val1.to_string())); - - if validate_concat(&[val1.clone(), val2.clone()], &Schema::default()).is_ok() { - assert_eq!( - evaluate_concat(&Schema::default(), &[val1, val2], &row).unwrap(), - Field::String(s_val2.to_string() + s_val1) - ); - } - - let val1 = Literal(Field::String(s_val1.to_string())); - let val2 = Literal(Field::String(c_val.to_string())); - - if validate_concat(&[val1.clone(), val2.clone()], &Schema::default()).is_ok() { - assert_eq!( - evaluate_concat(&Schema::default(), &[val1, val2], &row).unwrap(), - Field::String(s_val1.to_string() + c_val.to_string().as_str()) - ); - } - - let val1 = Literal(Field::String(c_val.to_string())); - let val2 = Literal(Field::String(s_val1.to_string())); - - if validate_concat(&[val1.clone(), val2.clone()], &Schema::default()).is_ok() { - assert_eq!( - evaluate_concat(&Schema::default(), &[val1, val2], &row).unwrap(), - Field::String(c_val.to_string() + s_val1) - ); - } - - // Field::Text - let val1 = Literal(Field::Text(s_val1.to_string())); - let val2 = Literal(Field::Text(s_val2.to_string())); - - if validate_concat(&[val1.clone(), val2.clone()], &Schema::default()).is_ok() { - assert_eq!( - evaluate_concat(&Schema::default(), &[val1, val2], &row).unwrap(), - Field::Text(s_val1.to_string() + s_val2) - ); - } - - let val1 = Literal(Field::Text(s_val2.to_string())); - let val2 = Literal(Field::Text(s_val1.to_string())); - - if validate_concat(&[val1.clone(), val2.clone()], &Schema::default()).is_ok() { - assert_eq!( - evaluate_concat(&Schema::default(), &[val1, val2], &row).unwrap(), - Field::Text(s_val2.to_string() + s_val1) - ); - } - - let val1 = Literal(Field::Text(s_val1.to_string())); - let val2 = Literal(Field::Text(c_val.to_string())); - - if validate_concat(&[val1.clone(), val2.clone()], &Schema::default()).is_ok() { - assert_eq!( - evaluate_concat(&Schema::default(), &[val1, val2], &row).unwrap(), - Field::Text(s_val1.to_string() + c_val.to_string().as_str()) - ); - } - - let val1 = Literal(Field::Text(c_val.to_string())); - let val2 = Literal(Field::Text(s_val1.to_string())); - - if validate_concat(&[val1.clone(), val2.clone()], &Schema::default()).is_ok() { - assert_eq!( - evaluate_concat(&Schema::default(), &[val1, val2], &row).unwrap(), - Field::Text(c_val.to_string() + s_val1) - ); - } -} - -fn test_trim(s_val1: &str, c_val: char) { - let row = Record::new(vec![]); - - // Field::String - let value = Literal(Field::String(s_val1.to_string())); - let what = ' '; - - if validate_trim(&value, &Schema::default()).is_ok() { - assert_eq!( - evaluate_trim(&Schema::default(), &value, &None, &None, &row).unwrap(), - Field::String(s_val1.trim_matches(what).to_string()) - ); - assert_eq!( - evaluate_trim( - &Schema::default(), - &value, - &None, - &Some(TrimType::Trailing), - &row - ) - .unwrap(), - Field::String(s_val1.trim_end_matches(what).to_string()) - ); - assert_eq!( - evaluate_trim( - &Schema::default(), - &value, - &None, - &Some(TrimType::Leading), - &row - ) - .unwrap(), - Field::String(s_val1.trim_start_matches(what).to_string()) - ); - assert_eq!( - evaluate_trim( - &Schema::default(), - &value, - &None, - &Some(TrimType::Both), - &row - ) - .unwrap(), - Field::String(s_val1.trim_matches(what).to_string()) - ); - } - - let value = Literal(Field::String(s_val1.to_string())); - let what = Some(Box::new(Literal(Field::String(c_val.to_string())))); - - if validate_trim(&value, &Schema::default()).is_ok() { - assert_eq!( - evaluate_trim(&Schema::default(), &value, &what, &None, &row).unwrap(), - Field::String(s_val1.trim_matches(c_val).to_string()) - ); - assert_eq!( - evaluate_trim( - &Schema::default(), - &value, - &what, - &Some(TrimType::Trailing), - &row - ) - .unwrap(), - Field::String(s_val1.trim_end_matches(c_val).to_string()) - ); - assert_eq!( - evaluate_trim( - &Schema::default(), - &value, - &what, - &Some(TrimType::Leading), - &row - ) - .unwrap(), - Field::String(s_val1.trim_start_matches(c_val).to_string()) - ); - assert_eq!( - evaluate_trim( - &Schema::default(), - &value, - &what, - &Some(TrimType::Both), - &row - ) - .unwrap(), - Field::String(s_val1.trim_matches(c_val).to_string()) - ); - } -} #[test] fn test_concat_string() { diff --git a/dozer-sql/src/expression/tests/test_common.rs b/dozer-sql/src/expression/tests/test_common.rs new file mode 100644 index 0000000000..c456c7949b --- /dev/null +++ b/dozer-sql/src/expression/tests/test_common.rs @@ -0,0 +1,62 @@ +use crate::{projection::factory::ProjectionProcessorFactory, tests::utils::get_select}; +use dozer_core::channels::ProcessorChannelForwarder; +use dozer_core::executor_operation::ProcessorOperation; +use dozer_core::node::ProcessorFactory; +use dozer_core::processor_record::ProcessorRecordStore; +use dozer_core::DEFAULT_PORT_HANDLE; +use dozer_types::types::Record; +use dozer_types::types::{Field, Schema}; +use std::collections::HashMap; + +struct TestChannelForwarder { + operations: Vec, +} + +impl ProcessorChannelForwarder for TestChannelForwarder { + fn send(&mut self, op: ProcessorOperation, _port: dozer_core::node::PortHandle) { + self.operations.push(op); + } +} + +pub(crate) fn run_fct(sql: &str, schema: Schema, input: Vec) -> Field { + let record_store = ProcessorRecordStore::new().unwrap(); + + let select = get_select(sql).unwrap(); + let processor_factory = + ProjectionProcessorFactory::_new("projection_id".to_owned(), select.projection, vec![]); + processor_factory + .get_output_schema( + &DEFAULT_PORT_HANDLE, + &[(DEFAULT_PORT_HANDLE, schema.clone())] + .into_iter() + .collect(), + ) + .unwrap(); + + let mut processor = processor_factory + .build( + HashMap::from([(DEFAULT_PORT_HANDLE, schema)]), + HashMap::new(), + &record_store, + None, + ) + .unwrap(); + + let mut fw = TestChannelForwarder { operations: vec![] }; + let rec = Record::new(input); + let rec = record_store.create_record(&rec).unwrap(); + + let op = ProcessorOperation::Insert { new: rec }; + + processor + .process(DEFAULT_PORT_HANDLE, &record_store, op, &mut fw) + .unwrap(); + + match &fw.operations[0] { + ProcessorOperation::Insert { new } => { + let mut new = record_store.load_record(new).unwrap(); + new.values.remove(0) + } + _ => panic!("Unable to find result value"), + } +} diff --git a/dozer-sql/src/lib.rs b/dozer-sql/src/lib.rs index a4f10a8581..8f2e2087c7 100644 --- a/dozer-sql/src/lib.rs +++ b/dozer-sql/src/lib.rs @@ -1,4 +1,17 @@ -// Re-export sqlparser -pub use sqlparser; +mod aggregation; +pub mod builder; +pub mod errors; +mod expression; +mod pipeline_builder; +mod planner; +mod product; +mod projection; +mod selection; +mod table_operator; +mod utils; +mod window; -pub mod pipeline; +pub use dozer_sql_expression::sqlparser; + +#[cfg(test)] +mod tests; diff --git a/dozer-sql/src/pipeline/expression/arg_utils.rs b/dozer-sql/src/pipeline/expression/arg_utils.rs deleted file mode 100644 index 4b16f2f763..0000000000 --- a/dozer-sql/src/pipeline/expression/arg_utils.rs +++ /dev/null @@ -1,161 +0,0 @@ -use crate::pipeline::errors::PipelineError::InvalidFunctionArgumentType; -use crate::pipeline::errors::{FieldTypes, PipelineError}; -use crate::pipeline::expression::execution::{Expression, ExpressionType}; -use crate::pipeline::expression::scalar::common::ScalarFunctionType; -use dozer_types::types::{FieldType, Schema}; - -pub(crate) fn validate_arg_type( - arg: &Expression, - expected: Vec, - schema: &Schema, - fct: ScalarFunctionType, - idx: usize, -) -> Result { - let arg_t = arg.get_type(schema)?; - if !expected.contains(&arg_t.return_type) { - Err(InvalidFunctionArgumentType( - fct.to_string(), - arg_t.return_type, - FieldTypes::new(expected), - idx, - )) - } else { - Ok(arg_t) - } -} - -#[macro_export] -macro_rules! argv { - ($arr: expr, $idx: expr, $fct: expr) => { - match $arr.get($idx) { - Some(v) => Ok(v), - _ => Err(PipelineError::NotEnoughArguments($fct.to_string())), - } - }; -} - -#[macro_export] -macro_rules! arg_str { - ($field: expr, $fct: expr, $idx: expr) => { - match $field.to_string() { - Some(e) => Ok(e), - _ => Err(PipelineError::InvalidFunctionArgument( - $fct.to_string(), - $field, - $idx, - )), - } - }; -} - -#[macro_export] -macro_rules! arg_uint { - ($field: expr, $fct: expr, $idx: expr) => { - match $field.to_uint() { - Some(e) => Ok(e), - _ => Err(PipelineError::InvalidFunctionArgument( - $fct.to_string(), - $field, - $idx, - )), - } - }; -} - -#[macro_export] -macro_rules! arg_int { - ($field: expr, $fct: expr, $idx: expr) => { - match $field.to_int() { - Some(e) => Ok(e), - _ => Err(PipelineError::InvalidFunctionArgument( - $fct.to_string(), - $field, - $idx, - )), - } - }; -} - -#[macro_export] -macro_rules! arg_float { - ($field: expr, $fct: expr, $idx: expr) => { - match $field.to_float() { - Some(e) => Ok(e), - _ => Err(PipelineError::InvalidFunctionArgument( - $fct.to_string(), - $field, - $idx, - )), - } - }; -} - -#[macro_export] -macro_rules! arg_binary { - ($field: expr, $fct: expr, $idx: expr) => { - match $field.to_binary() { - Some(e) => Ok(e), - _ => Err(PipelineError::InvalidFunctionArgument( - $fct.to_string(), - $field, - $idx, - )), - } - }; -} - -#[macro_export] -macro_rules! arg_decimal { - ($field: expr, $fct: expr, $idx: expr) => { - match $field.to_decimal() { - Some(e) => Ok(e), - _ => Err(PipelineError::InvalidFunctionArgument( - $fct.to_string(), - $field, - $idx, - )), - } - }; -} - -#[macro_export] -macro_rules! arg_timestamp { - ($field: expr, $fct: expr, $idx: expr) => { - match $field.to_timestamp() { - Some(e) => Ok(e), - _ => Err(PipelineError::InvalidFunctionArgument( - $fct.to_string(), - $field, - $idx, - )), - } - }; -} - -#[macro_export] -macro_rules! arg_date { - ($field: expr, $fct: expr, $idx: expr) => { - match $field.to_date() { - Some(e) => Ok(e), - _ => Err(PipelineError::InvalidFunctionArgument( - $fct.to_string(), - $field, - $idx, - )), - } - }; -} - -#[macro_export] -macro_rules! arg_point { - ($field: expr, $fct: expr, $idx: expr) => { - match $field.to_point() { - Some(e) => Ok(e), - _ => Err(PipelineError::InvalidFunctionArgument( - $fct.to_string(), - $field, - $idx, - )), - } - }; -} diff --git a/dozer-sql/src/pipeline/expression/cast.rs b/dozer-sql/src/pipeline/expression/cast.rs deleted file mode 100644 index e03203f0c1..0000000000 --- a/dozer-sql/src/pipeline/expression/cast.rs +++ /dev/null @@ -1,367 +0,0 @@ -use std::fmt::{Display, Formatter}; - -use dozer_types::types::Record; -use dozer_types::{ - ordered_float::OrderedFloat, - types::{Field, FieldType, Schema}, -}; - -use crate::pipeline::errors::{FieldTypes, PipelineError}; - -use super::execution::{Expression, ExpressionType}; - -#[allow(dead_code)] -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] -pub enum CastOperatorType { - UInt, - U128, - Int, - I128, - Float, - Boolean, - String, - Text, - Binary, - Decimal, - Timestamp, - Date, - Json, -} - -impl Display for CastOperatorType { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - CastOperatorType::UInt => f.write_str("CAST AS UINT"), - CastOperatorType::U128 => f.write_str("CAST AS U128"), - CastOperatorType::Int => f.write_str("CAST AS INT"), - CastOperatorType::I128 => f.write_str("CAST AS I128"), - CastOperatorType::Float => f.write_str("CAST AS FLOAT"), - CastOperatorType::Boolean => f.write_str("CAST AS BOOLEAN"), - CastOperatorType::String => f.write_str("CAST AS STRING"), - CastOperatorType::Text => f.write_str("CAST AS TEXT"), - CastOperatorType::Binary => f.write_str("CAST AS BINARY"), - CastOperatorType::Decimal => f.write_str("CAST AS DECIMAL"), - CastOperatorType::Timestamp => f.write_str("CAST AS TIMESTAMP"), - CastOperatorType::Date => f.write_str("CAST AS DATE"), - CastOperatorType::Json => f.write_str("CAST AS JSON"), - } - } -} - -impl CastOperatorType { - pub(crate) fn evaluate( - &self, - schema: &Schema, - arg: &Expression, - record: &Record, - ) -> Result { - let field = arg.evaluate(record, schema)?; - match self { - CastOperatorType::UInt => { - if let Some(value) = field.to_uint() { - Ok(Field::UInt(value)) - } else { - Err(PipelineError::InvalidCast { - from: field, - to: FieldType::UInt, - }) - } - } - CastOperatorType::U128 => { - if let Some(value) = field.to_u128() { - Ok(Field::U128(value)) - } else { - Err(PipelineError::InvalidCast { - from: field, - to: FieldType::U128, - }) - } - } - CastOperatorType::Int => { - if let Some(value) = field.to_int() { - Ok(Field::Int(value)) - } else { - Err(PipelineError::InvalidCast { - from: field, - to: FieldType::Int, - }) - } - } - CastOperatorType::I128 => { - if let Some(value) = field.to_i128() { - Ok(Field::I128(value)) - } else { - Err(PipelineError::InvalidCast { - from: field, - to: FieldType::I128, - }) - } - } - CastOperatorType::Float => { - if let Some(value) = field.to_float() { - Ok(Field::Float(OrderedFloat(value))) - } else { - Err(PipelineError::InvalidCast { - from: field, - to: FieldType::Float, - }) - } - } - CastOperatorType::Boolean => { - if let Some(value) = field.to_boolean() { - Ok(Field::Boolean(value)) - } else { - Err(PipelineError::InvalidCast { - from: field, - to: FieldType::Boolean, - }) - } - } - CastOperatorType::String => { - if let Some(value) = field.to_string() { - Ok(Field::String(value)) - } else { - Err(PipelineError::InvalidCast { - from: field, - to: FieldType::String, - }) - } - } - CastOperatorType::Text => { - if let Some(value) = field.to_text() { - Ok(Field::Text(value)) - } else { - Err(PipelineError::InvalidCast { - from: field, - to: FieldType::Text, - }) - } - } - CastOperatorType::Binary => { - if let Some(value) = field.to_binary() { - Ok(Field::Binary(value.to_vec())) - } else { - Err(PipelineError::InvalidCast { - from: field, - to: FieldType::Binary, - }) - } - } - CastOperatorType::Decimal => { - if let Some(value) = field.to_decimal() { - Ok(Field::Decimal(value)) - } else { - Err(PipelineError::InvalidCast { - from: field, - to: FieldType::Decimal, - }) - } - } - CastOperatorType::Timestamp => { - if let Some(value) = field.to_timestamp()? { - Ok(Field::Timestamp(value)) - } else { - Err(PipelineError::InvalidCast { - from: field, - to: FieldType::Timestamp, - }) - } - } - CastOperatorType::Date => { - if let Some(value) = field.to_date()? { - Ok(Field::Date(value)) - } else { - Err(PipelineError::InvalidCast { - from: field, - to: FieldType::Date, - }) - } - } - CastOperatorType::Json => { - if let Some(value) = field.to_json() { - Ok(Field::Json(value)) - } else { - Err(PipelineError::InvalidCast { - from: field, - to: FieldType::Json, - }) - } - } - } - } - - pub(crate) fn get_return_type( - &self, - schema: &Schema, - arg: &Expression, - ) -> Result { - let (expected_input_type, return_type) = match self { - CastOperatorType::UInt => ( - vec![ - FieldType::Int, - FieldType::String, - FieldType::UInt, - FieldType::I128, - FieldType::U128, - FieldType::Json, - ], - FieldType::UInt, - ), - CastOperatorType::U128 => ( - vec![ - FieldType::Int, - FieldType::String, - FieldType::UInt, - FieldType::I128, - FieldType::U128, - FieldType::Json, - ], - FieldType::U128, - ), - CastOperatorType::Int => ( - vec![ - FieldType::Int, - FieldType::String, - FieldType::UInt, - FieldType::I128, - FieldType::U128, - FieldType::Json, - ], - FieldType::Int, - ), - CastOperatorType::I128 => ( - vec![ - FieldType::Int, - FieldType::String, - FieldType::UInt, - FieldType::I128, - FieldType::U128, - FieldType::Json, - ], - FieldType::I128, - ), - CastOperatorType::Float => ( - vec![ - FieldType::Decimal, - FieldType::Float, - FieldType::Int, - FieldType::I128, - FieldType::String, - FieldType::UInt, - FieldType::U128, - FieldType::Json, - ], - FieldType::Float, - ), - CastOperatorType::Boolean => ( - vec![ - FieldType::Boolean, - FieldType::Decimal, - FieldType::Float, - FieldType::Int, - FieldType::I128, - FieldType::UInt, - FieldType::U128, - FieldType::Json, - ], - FieldType::Boolean, - ), - CastOperatorType::String => ( - vec![ - FieldType::Binary, - FieldType::Boolean, - FieldType::Date, - FieldType::Decimal, - FieldType::Float, - FieldType::Int, - FieldType::I128, - FieldType::String, - FieldType::Text, - FieldType::Timestamp, - FieldType::UInt, - FieldType::U128, - FieldType::Json, - ], - FieldType::String, - ), - CastOperatorType::Text => ( - vec![ - FieldType::Binary, - FieldType::Boolean, - FieldType::Date, - FieldType::Decimal, - FieldType::Float, - FieldType::Int, - FieldType::I128, - FieldType::String, - FieldType::Text, - FieldType::Timestamp, - FieldType::UInt, - FieldType::U128, - FieldType::Json, - ], - FieldType::Text, - ), - CastOperatorType::Binary => (vec![FieldType::Binary], FieldType::Binary), - CastOperatorType::Decimal => ( - vec![ - FieldType::Decimal, - FieldType::Float, - FieldType::Int, - FieldType::I128, - FieldType::String, - FieldType::UInt, - FieldType::U128, - ], - FieldType::Decimal, - ), - CastOperatorType::Timestamp => ( - vec![FieldType::String, FieldType::Timestamp], - FieldType::Timestamp, - ), - CastOperatorType::Date => (vec![FieldType::Date, FieldType::String], FieldType::Date), - CastOperatorType::Json => ( - vec![ - FieldType::Boolean, - FieldType::Float, - FieldType::Int, - FieldType::I128, - FieldType::String, - FieldType::Text, - FieldType::UInt, - FieldType::U128, - FieldType::Json, - ], - FieldType::Json, - ), - }; - - let expression_type = validate_arg_type(arg, expected_input_type, schema, self, 0)?; - Ok(ExpressionType { - return_type, - nullable: expression_type.nullable, - source: expression_type.source, - is_primary_key: expression_type.is_primary_key, - }) - } -} - -pub(crate) fn validate_arg_type( - arg: &Expression, - expected: Vec, - schema: &Schema, - fct: &CastOperatorType, - idx: usize, -) -> Result { - let arg_t = arg.get_type(schema)?; - if !expected.contains(&arg_t.return_type) { - Err(PipelineError::InvalidFunctionArgumentType( - fct.to_string(), - arg_t.return_type, - FieldTypes::new(expected), - idx, - )) - } else { - Ok(arg_t) - } -} diff --git a/dozer-sql/src/pipeline/expression/conditional.rs b/dozer-sql/src/pipeline/expression/conditional.rs deleted file mode 100644 index f435fb6866..0000000000 --- a/dozer-sql/src/pipeline/expression/conditional.rs +++ /dev/null @@ -1,94 +0,0 @@ -use crate::pipeline::errors::PipelineError; -use crate::pipeline::errors::PipelineError::{InvalidFunction, NotEnoughArguments}; -use crate::pipeline::expression::execution::{Expression, ExpressionType}; -use dozer_types::types::Record; -use dozer_types::types::{Field, FieldType, Schema}; -use std::fmt::{Display, Formatter}; - -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] -pub enum ConditionalExpressionType { - Coalesce, - NullIf, -} - -pub(crate) fn get_conditional_expr_type( - function: &ConditionalExpressionType, - args: &[Expression], - schema: &Schema, -) -> Result { - match function { - ConditionalExpressionType::Coalesce => validate_coalesce(args, schema), - ConditionalExpressionType::NullIf => todo!(), - } -} - -impl ConditionalExpressionType { - pub(crate) fn new(name: &str) -> Result { - match name { - "coalesce" => Ok(ConditionalExpressionType::Coalesce), - "nullif" => Ok(ConditionalExpressionType::NullIf), - _ => Err(InvalidFunction(name.to_string())), - } - } - - pub(crate) fn evaluate( - &self, - schema: &Schema, - args: &[Expression], - record: &Record, - ) -> Result { - match self { - ConditionalExpressionType::Coalesce => evaluate_coalesce(schema, args, record), - ConditionalExpressionType::NullIf => todo!(), - } - } -} - -pub(crate) fn validate_coalesce( - args: &[Expression], - schema: &Schema, -) -> Result { - if args.is_empty() { - return Err(NotEnoughArguments( - ConditionalExpressionType::Coalesce.to_string(), - )); - } - - let return_types = args - .iter() - .map(|expr| expr.get_type(schema).unwrap().return_type) - .collect::>(); - let return_type = return_types[0]; - - Ok(ExpressionType::new( - return_type, - false, - dozer_types::types::SourceDefinition::Dynamic, - false, - )) -} - -pub(crate) fn evaluate_coalesce( - schema: &Schema, - args: &[Expression], - record: &Record, -) -> Result { - // The COALESCE function returns the first of its arguments that is not null. - for expr in args { - let field = expr.evaluate(record, schema)?; - if field != Field::Null { - return Ok(field); - } - } - // Null is returned only if all arguments are null. - Ok(Field::Null) -} - -impl Display for ConditionalExpressionType { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - ConditionalExpressionType::Coalesce => f.write_str("COALESCE"), - ConditionalExpressionType::NullIf => f.write_str("NULLIF"), - } - } -} diff --git a/dozer-sql/src/pipeline/expression/geo/distance.rs b/dozer-sql/src/pipeline/expression/geo/distance.rs deleted file mode 100644 index b5c7e2b1a9..0000000000 --- a/dozer-sql/src/pipeline/expression/geo/distance.rs +++ /dev/null @@ -1,119 +0,0 @@ -use dozer_types::errors::types::TypeError::DistanceCalculationError; -use std::str::FromStr; - -use crate::pipeline::errors::PipelineError::{ - InvalidFunctionArgumentType, InvalidValue, NotEnoughArguments, TooManyArguments, -}; -use crate::pipeline::errors::{FieldTypes, PipelineError}; -use crate::{arg_point, arg_str}; -use dozer_types::types::Record; -use dozer_types::types::{Field, FieldType, Schema}; - -use crate::pipeline::expression::execution::{Expression, ExpressionType}; -use crate::pipeline::expression::geo::common::GeoFunctionType; -use dozer_types::geo::GeodesicDistance; -use dozer_types::geo::HaversineDistance; -use dozer_types::geo::VincentyDistance; - -use dozer_types::ordered_float::OrderedFloat; - -const EXPECTED_ARGS_TYPES: &[FieldType] = &[FieldType::Point, FieldType::Point, FieldType::String]; - -pub enum Algorithm { - Geodesic, - Haversine, - Vincenty, -} - -impl FromStr for Algorithm { - type Err = PipelineError; - - fn from_str(s: &str) -> Result { - match s { - "GEODESIC" => Ok(Algorithm::Geodesic), - "HAVERSINE" => Ok(Algorithm::Haversine), - "VINCENTY" => Ok(Algorithm::Vincenty), - &_ => Err(InvalidValue(s.to_string())), - } - } -} - -const DEFAULT_ALGORITHM: Algorithm = Algorithm::Geodesic; - -pub(crate) fn validate_distance( - args: &[Expression], - schema: &Schema, -) -> Result { - let ret_type = FieldType::Float; - if args.len() < 2 { - return Err(NotEnoughArguments(GeoFunctionType::Distance.to_string())); - } - - if args.len() > 3 { - return Err(TooManyArguments(GeoFunctionType::Distance.to_string())); - } - - for (idx, exp) in args.iter().enumerate() { - let return_type = exp.get_type(schema)?.return_type; - let expected_arg_type_option = EXPECTED_ARGS_TYPES.get(idx); - if let Some(expected_arg_type) = expected_arg_type_option { - if &return_type != expected_arg_type { - return Err(InvalidFunctionArgumentType( - GeoFunctionType::Distance.to_string(), - return_type, - FieldTypes::new(vec![*expected_arg_type]), - idx, - )); - } - } - } - - Ok(ExpressionType::new( - ret_type, - false, - dozer_types::types::SourceDefinition::Dynamic, - false, - )) -} - -pub(crate) fn evaluate_distance( - schema: &Schema, - args: &[Expression], - record: &Record, -) -> Result { - let f_from = args - .get(0) - .ok_or(InvalidValue(String::from("from")))? - .evaluate(record, schema)?; - - let f_to = args - .get(1) - .ok_or(InvalidValue(String::from("to")))? - .evaluate(record, schema)?; - - if f_from == Field::Null || f_to == Field::Null { - Ok(Field::Null) - } else { - let from = arg_point!(f_from, GeoFunctionType::Distance, 0)?; - let to = arg_point!(f_to, GeoFunctionType::Distance, 0)?; - let calculation_type = args.get(2).map_or_else( - || Ok(DEFAULT_ALGORITHM), - |arg| { - let f = arg.evaluate(record, schema)?; - let t = arg_str!(f, GeoFunctionType::Distance, 0)?; - Algorithm::from_str(&t) - }, - )?; - - let distance: OrderedFloat = match calculation_type { - Algorithm::Geodesic => Ok(from.geodesic_distance(to)), - Algorithm::Haversine => Ok(from.0.haversine_distance(&to.0)), - Algorithm::Vincenty => from - .0 - .vincenty_distance(&to.0) - .map_err(DistanceCalculationError), - }?; - - Ok(Field::Float(distance)) - } -} diff --git a/dozer-sql/src/pipeline/expression/geo/point.rs b/dozer-sql/src/pipeline/expression/geo/point.rs deleted file mode 100644 index 474af86c29..0000000000 --- a/dozer-sql/src/pipeline/expression/geo/point.rs +++ /dev/null @@ -1,70 +0,0 @@ -use crate::arg_float; -use crate::pipeline::errors::PipelineError::{ - InvalidArgument, InvalidFunctionArgumentType, NotEnoughArguments, TooManyArguments, -}; -use crate::pipeline::errors::{FieldTypes, PipelineError}; -use dozer_types::types::Record; -use dozer_types::types::{DozerPoint, Field, FieldType, Schema}; - -use crate::pipeline::expression::execution::{Expression, ExpressionType}; -use crate::pipeline::expression::geo::common::GeoFunctionType; - -pub(crate) fn validate_point( - args: &[Expression], - schema: &Schema, -) -> Result { - let ret_type = FieldType::Point; - let expected_arg_type = FieldType::Float; - - if args.len() < 2 { - return Err(NotEnoughArguments(GeoFunctionType::Point.to_string())); - } - - if args.len() > 2 { - return Err(TooManyArguments(GeoFunctionType::Point.to_string())); - } - - for (idx, exp) in args.iter().enumerate() { - let return_type = exp.get_type(schema)?.return_type; - if return_type != expected_arg_type { - return Err(InvalidFunctionArgumentType( - GeoFunctionType::Point.to_string(), - return_type, - FieldTypes::new(vec![expected_arg_type]), - idx, - )); - } - } - - Ok(ExpressionType::new( - ret_type, - false, - dozer_types::types::SourceDefinition::Dynamic, - false, - )) -} - -pub(crate) fn evaluate_point( - schema: &Schema, - args: &[Expression], - record: &Record, -) -> Result { - let _res_type = FieldType::Point; - let f_x = args - .get(0) - .ok_or(InvalidArgument("x".to_string()))? - .evaluate(record, schema)?; - let f_y = args - .get(1) - .ok_or(InvalidArgument("y".to_string()))? - .evaluate(record, schema)?; - - if f_x == Field::Null || f_y == Field::Null { - Ok(Field::Null) - } else { - let x = arg_float!(f_x, GeoFunctionType::Point, 0)?; - let y = arg_float!(f_y, GeoFunctionType::Point, 0)?; - - Ok(Field::Point(DozerPoint::from((x, y)))) - } -} diff --git a/dozer-sql/src/pipeline/expression/logical.rs b/dozer-sql/src/pipeline/expression/logical.rs deleted file mode 100644 index 8db409397a..0000000000 --- a/dozer-sql/src/pipeline/expression/logical.rs +++ /dev/null @@ -1,160 +0,0 @@ -use crate::pipeline::errors::PipelineError; -use crate::pipeline::expression::execution::Expression; -use dozer_types::types::Record; -use dozer_types::types::{Field, Schema}; - -pub fn evaluate_and( - schema: &Schema, - left: &Expression, - right: &Expression, - record: &Record, -) -> Result { - let l_field = left.evaluate(record, schema)?; - let r_field = right.evaluate(record, schema)?; - match l_field { - Field::Boolean(true) => match r_field { - Field::Boolean(true) => Ok(Field::Boolean(true)), - Field::Boolean(false) => Ok(Field::Boolean(false)), - Field::Null => Ok(Field::Boolean(false)), - Field::UInt(_) - | Field::U128(_) - | Field::Int(_) - | Field::I128(_) - | Field::Float(_) - | Field::String(_) - | Field::Text(_) - | Field::Binary(_) - | Field::Decimal(_) - | Field::Timestamp(_) - | Field::Date(_) - | Field::Json(_) - | Field::Point(_) - | Field::Duration(_) => Err(PipelineError::InvalidType(r_field, "AND".to_string())), - }, - Field::Boolean(false) => match r_field { - Field::Boolean(true) => Ok(Field::Boolean(false)), - Field::Boolean(false) => Ok(Field::Boolean(false)), - Field::Null => Ok(Field::Boolean(false)), - Field::UInt(_) - | Field::U128(_) - | Field::Int(_) - | Field::I128(_) - | Field::Float(_) - | Field::String(_) - | Field::Text(_) - | Field::Binary(_) - | Field::Decimal(_) - | Field::Timestamp(_) - | Field::Date(_) - | Field::Json(_) - | Field::Point(_) - | Field::Duration(_) => Err(PipelineError::InvalidType(r_field, "AND".to_string())), - }, - Field::Null => Ok(Field::Boolean(false)), - Field::UInt(_) - | Field::U128(_) - | Field::Int(_) - | Field::I128(_) - | Field::Float(_) - | Field::String(_) - | Field::Text(_) - | Field::Binary(_) - | Field::Decimal(_) - | Field::Timestamp(_) - | Field::Date(_) - | Field::Json(_) - | Field::Point(_) - | Field::Duration(_) => Err(PipelineError::InvalidType(l_field, "AND".to_string())), - } -} - -pub fn evaluate_or( - schema: &Schema, - left: &Expression, - right: &Expression, - record: &Record, -) -> Result { - let l_field = left.evaluate(record, schema)?; - let r_field = right.evaluate(record, schema)?; - match l_field { - Field::Boolean(true) => match r_field { - Field::Boolean(false) => Ok(Field::Boolean(true)), - Field::Boolean(true) => Ok(Field::Boolean(true)), - Field::Null => Ok(Field::Boolean(true)), - Field::UInt(_) - | Field::U128(_) - | Field::Int(_) - | Field::I128(_) - | Field::Float(_) - | Field::String(_) - | Field::Text(_) - | Field::Binary(_) - | Field::Decimal(_) - | Field::Timestamp(_) - | Field::Date(_) - | Field::Json(_) - | Field::Point(_) - | Field::Duration(_) => Err(PipelineError::InvalidType(r_field, "OR".to_string())), - }, - Field::Boolean(false) | Field::Null => match right.evaluate(record, schema)? { - Field::Boolean(false) => Ok(Field::Boolean(false)), - Field::Boolean(true) => Ok(Field::Boolean(true)), - Field::Null => Ok(Field::Boolean(false)), - Field::UInt(_) - | Field::U128(_) - | Field::Int(_) - | Field::I128(_) - | Field::Float(_) - | Field::String(_) - | Field::Text(_) - | Field::Binary(_) - | Field::Decimal(_) - | Field::Timestamp(_) - | Field::Date(_) - | Field::Json(_) - | Field::Point(_) - | Field::Duration(_) => Err(PipelineError::InvalidType(r_field, "OR".to_string())), - }, - Field::UInt(_) - | Field::U128(_) - | Field::Int(_) - | Field::I128(_) - | Field::Float(_) - | Field::String(_) - | Field::Text(_) - | Field::Binary(_) - | Field::Decimal(_) - | Field::Timestamp(_) - | Field::Date(_) - | Field::Json(_) - | Field::Point(_) - | Field::Duration(_) => Err(PipelineError::InvalidType(l_field, "OR".to_string())), - } -} - -pub fn evaluate_not( - schema: &Schema, - value: &Expression, - record: &Record, -) -> Result { - let value_p = value.evaluate(record, schema)?; - - match value_p { - Field::Boolean(value_v) => Ok(Field::Boolean(!value_v)), - Field::Null => Ok(Field::Null), - Field::UInt(_) - | Field::U128(_) - | Field::Int(_) - | Field::I128(_) - | Field::Float(_) - | Field::String(_) - | Field::Text(_) - | Field::Binary(_) - | Field::Decimal(_) - | Field::Timestamp(_) - | Field::Date(_) - | Field::Json(_) - | Field::Point(_) - | Field::Duration(_) => Err(PipelineError::InvalidType(value_p, "NOT".to_string())), - } -} diff --git a/dozer-sql/src/pipeline/expression/mod.rs b/dozer-sql/src/pipeline/expression/mod.rs deleted file mode 100644 index 4ffb862642..0000000000 --- a/dozer-sql/src/pipeline/expression/mod.rs +++ /dev/null @@ -1,23 +0,0 @@ -pub mod aggregate; -mod arg_utils; -pub mod builder; -pub mod case; -pub mod cast; -pub mod comparison; -pub mod conditional; -mod datetime; -pub mod execution; -pub mod geo; -pub mod in_list; -mod json_functions; -pub mod logical; -pub mod mathematical; -pub mod operator; -pub mod scalar; - -#[cfg(feature = "onnx")] -pub mod onnx; -#[cfg(feature = "python")] -pub mod python_udf; -#[cfg(test)] -mod tests; diff --git a/dozer-sql/src/pipeline/expression/onnx/mod.rs b/dozer-sql/src/pipeline/expression/onnx/mod.rs deleted file mode 100644 index 282fbac486..0000000000 --- a/dozer-sql/src/pipeline/expression/onnx/mod.rs +++ /dev/null @@ -1,4 +0,0 @@ -#[cfg(feature = "onnx")] -pub mod onnx_udf; -#[cfg(feature = "onnx")] -pub mod onnx_utils; diff --git a/dozer-sql/src/pipeline/expression/scalar/common.rs b/dozer-sql/src/pipeline/expression/scalar/common.rs deleted file mode 100644 index 53457116bc..0000000000 --- a/dozer-sql/src/pipeline/expression/scalar/common.rs +++ /dev/null @@ -1,112 +0,0 @@ -use crate::argv; -use crate::pipeline::errors::PipelineError; -use crate::pipeline::expression::execution::{Expression, ExpressionType}; -use crate::pipeline::expression::scalar::number::{evaluate_abs, evaluate_round}; -use crate::pipeline::expression::scalar::string::{ - evaluate_concat, evaluate_length, evaluate_to_char, evaluate_ucase, validate_concat, - validate_ucase, -}; -use dozer_types::types::Record; -use dozer_types::types::{Field, FieldType, Schema}; -use std::fmt::{Display, Formatter}; - -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] -pub enum ScalarFunctionType { - Abs, - Round, - Ucase, - Concat, - Length, - ToChar, -} - -impl Display for ScalarFunctionType { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - ScalarFunctionType::Abs => f.write_str("ABS"), - ScalarFunctionType::Round => f.write_str("ROUND"), - ScalarFunctionType::Ucase => f.write_str("UCASE"), - ScalarFunctionType::Concat => f.write_str("CONCAT"), - ScalarFunctionType::Length => f.write_str("LENGTH"), - ScalarFunctionType::ToChar => f.write_str("TO_CHAR"), - } - } -} - -pub(crate) fn get_scalar_function_type( - function: &ScalarFunctionType, - args: &[Expression], - schema: &Schema, -) -> Result { - match function { - ScalarFunctionType::Abs => argv!(args, 0, ScalarFunctionType::Abs)?.get_type(schema), - ScalarFunctionType::Round => { - let return_type = argv!(args, 0, ScalarFunctionType::Round)? - .get_type(schema)? - .return_type; - Ok(ExpressionType::new( - return_type, - true, - dozer_types::types::SourceDefinition::Dynamic, - false, - )) - } - ScalarFunctionType::Ucase => { - validate_ucase(argv!(args, 0, ScalarFunctionType::Ucase)?, schema) - } - ScalarFunctionType::Concat => validate_concat(args, schema), - ScalarFunctionType::Length => Ok(ExpressionType::new( - FieldType::UInt, - false, - dozer_types::types::SourceDefinition::Dynamic, - false, - )), - ScalarFunctionType::ToChar => argv!(args, 0, ScalarFunctionType::ToChar)?.get_type(schema), - } -} - -impl ScalarFunctionType { - pub fn new(name: &str) -> Result { - match name { - "abs" => Ok(ScalarFunctionType::Abs), - "round" => Ok(ScalarFunctionType::Round), - "ucase" => Ok(ScalarFunctionType::Ucase), - "concat" => Ok(ScalarFunctionType::Concat), - "length" => Ok(ScalarFunctionType::Length), - "to_char" => Ok(ScalarFunctionType::ToChar), - _ => Err(PipelineError::InvalidFunction(name.to_string())), - } - } - - pub(crate) fn evaluate( - &self, - schema: &Schema, - args: &[Expression], - record: &Record, - ) -> Result { - match self { - ScalarFunctionType::Abs => { - evaluate_abs(schema, argv!(args, 0, ScalarFunctionType::Abs)?, record) - } - ScalarFunctionType::Round => evaluate_round( - schema, - argv!(args, 0, ScalarFunctionType::Round)?, - args.get(1), - record, - ), - ScalarFunctionType::Ucase => { - evaluate_ucase(schema, argv!(args, 0, ScalarFunctionType::Ucase)?, record) - } - ScalarFunctionType::Concat => evaluate_concat(schema, args, record), - ScalarFunctionType::Length => { - evaluate_length(schema, argv!(args, 0, ScalarFunctionType::Length)?, record) - } - ScalarFunctionType::ToChar => evaluate_to_char( - schema, - argv!(args, 0, ScalarFunctionType::ToChar)?, - argv!(args, 1, ScalarFunctionType::ToChar)?, - record, - ), - } - } -} diff --git a/dozer-sql/src/pipeline/expression/scalar/number.rs b/dozer-sql/src/pipeline/expression/scalar/number.rs deleted file mode 100644 index 4fc511395b..0000000000 --- a/dozer-sql/src/pipeline/expression/scalar/number.rs +++ /dev/null @@ -1,101 +0,0 @@ -use crate::pipeline::errors::PipelineError; -use crate::pipeline::errors::PipelineError::InvalidFunctionArgument; -use crate::pipeline::expression::execution::Expression; -use crate::pipeline::expression::scalar::common::ScalarFunctionType; -use dozer_types::ordered_float::OrderedFloat; -use dozer_types::types::Record; -use dozer_types::types::{Field, FieldType, Schema}; -use num_traits::{Float, ToPrimitive}; - -pub(crate) fn evaluate_abs( - schema: &Schema, - arg: &Expression, - record: &Record, -) -> Result { - let value = arg.evaluate(record, schema)?; - match value { - Field::UInt(u) => Ok(Field::UInt(u)), - Field::U128(u) => Ok(Field::U128(u)), - Field::Int(i) => Ok(Field::Int(i.abs())), - Field::I128(i) => Ok(Field::I128(i.abs())), - Field::Float(f) => Ok(Field::Float(f.abs())), - Field::Decimal(d) => Ok(Field::Decimal(d.abs())), - Field::Boolean(_) - | Field::String(_) - | Field::Text(_) - | Field::Date(_) - | Field::Timestamp(_) - | Field::Binary(_) - | Field::Json(_) - | Field::Point(_) - | Field::Duration(_) - | Field::Null => Err(InvalidFunctionArgument( - ScalarFunctionType::Abs.to_string(), - value, - 0, - )), - } -} - -pub(crate) fn evaluate_round( - schema: &Schema, - arg: &Expression, - decimals: Option<&Expression>, - record: &Record, -) -> Result { - let value = arg.evaluate(record, schema)?; - let mut places = 0; - if let Some(expression) = decimals { - let field = expression.evaluate(record, schema)?; - match field { - Field::UInt(u) => places = u as i32, - Field::U128(u) => places = u as i32, - Field::Int(i) => places = i as i32, - Field::I128(i) => places = i as i32, - Field::Float(f) => places = f.round().0 as i32, - Field::Decimal(d) => { - places = d - .to_i32() - .ok_or(PipelineError::InvalidCast { - from: field, - to: FieldType::Decimal, - }) - .unwrap() - } - Field::Boolean(_) - | Field::String(_) - | Field::Text(_) - | Field::Date(_) - | Field::Timestamp(_) - | Field::Binary(_) - | Field::Json(_) - | Field::Point(_) - | Field::Duration(_) - | Field::Null => {} // Truncate value to 0 decimals - } - } - let order = OrderedFloat(10.0_f64.powi(places)); - - match value { - Field::UInt(u) => Ok(Field::UInt(u)), - Field::U128(u) => Ok(Field::U128(u)), - Field::Int(i) => Ok(Field::Int(i)), - Field::I128(i) => Ok(Field::I128(i)), - Field::Float(f) => Ok(Field::Float((f * order).round() / order)), - Field::Decimal(d) => Ok(Field::Decimal(d.round_dp(places as u32))), - Field::Null => Ok(Field::Null), - Field::Boolean(_) - | Field::String(_) - | Field::Text(_) - | Field::Date(_) - | Field::Timestamp(_) - | Field::Binary(_) - | Field::Json(_) - | Field::Point(_) - | Field::Duration(_) => Err(InvalidFunctionArgument( - ScalarFunctionType::Round.to_string(), - value, - 0, - )), - } -} diff --git a/dozer-sql/src/pipeline/expression/scalar/string.rs b/dozer-sql/src/pipeline/expression/scalar/string.rs deleted file mode 100644 index abf700480e..0000000000 --- a/dozer-sql/src/pipeline/expression/scalar/string.rs +++ /dev/null @@ -1,288 +0,0 @@ -use crate::arg_str; -use std::fmt::Write; -use std::fmt::{Display, Formatter}; - -use crate::pipeline::errors::PipelineError; - -use crate::pipeline::expression::execution::{Expression, ExpressionType}; - -use crate::pipeline::expression::arg_utils::validate_arg_type; -use crate::pipeline::expression::scalar::common::ScalarFunctionType; - -use dozer_types::types::Record; -use dozer_types::types::{Field, FieldType, Schema}; -use like::{Escape, Like}; - -pub(crate) fn validate_ucase( - arg: &Expression, - schema: &Schema, -) -> Result { - validate_arg_type( - arg, - vec![FieldType::String, FieldType::Text], - schema, - ScalarFunctionType::Ucase, - 0, - ) -} - -pub(crate) fn evaluate_ucase( - schema: &Schema, - arg: &Expression, - record: &Record, -) -> Result { - let f = arg.evaluate(record, schema)?; - let v = arg_str!(f, ScalarFunctionType::Ucase, 0)?; - let ret = v.to_uppercase(); - - Ok(match arg.get_type(schema)?.return_type { - FieldType::String => Field::String(ret), - FieldType::UInt - | FieldType::U128 - | FieldType::Int - | FieldType::I128 - | FieldType::Float - | FieldType::Decimal - | FieldType::Boolean - | FieldType::Text - | FieldType::Date - | FieldType::Timestamp - | FieldType::Binary - | FieldType::Json - | FieldType::Point - | FieldType::Duration => Field::Text(ret), - }) -} - -pub(crate) fn validate_concat( - args: &[Expression], - schema: &Schema, -) -> Result { - let mut ret_type = FieldType::String; - for exp in args { - let r = validate_arg_type( - exp, - vec![FieldType::String, FieldType::Text], - schema, - ScalarFunctionType::Concat, - 0, - )?; - if matches!(r.return_type, FieldType::Text) { - ret_type = FieldType::Text; - } - } - Ok(ExpressionType::new( - ret_type, - false, - dozer_types::types::SourceDefinition::Dynamic, - false, - )) -} - -pub(crate) fn evaluate_concat( - schema: &Schema, - args: &[Expression], - record: &Record, -) -> Result { - let mut res_type = FieldType::String; - let mut res_vec: Vec = Vec::with_capacity(args.len()); - - for e in args { - if matches!(e.get_type(schema)?.return_type, FieldType::Text) { - res_type = FieldType::Text; - } - let f = e.evaluate(record, schema)?; - let val = arg_str!(f, ScalarFunctionType::Concat, 0)?; - res_vec.push(val); - } - - let res_str = res_vec.iter().fold(String::new(), |a, b| a + b.as_str()); - Ok(match res_type { - FieldType::Text => Field::Text(res_str), - FieldType::UInt - | FieldType::U128 - | FieldType::Int - | FieldType::I128 - | FieldType::Float - | FieldType::Decimal - | FieldType::Boolean - | FieldType::String - | FieldType::Date - | FieldType::Timestamp - | FieldType::Binary - | FieldType::Json - | FieldType::Point - | FieldType::Duration => Field::String(res_str), - }) -} - -pub(crate) fn evaluate_length( - schema: &Schema, - arg0: &Expression, - record: &Record, -) -> Result { - let f0 = arg0.evaluate(record, schema)?; - let v0 = arg_str!(f0, ScalarFunctionType::Concat, 0)?; - Ok(Field::UInt(v0.len() as u64)) -} - -#[derive(Clone, Debug, PartialEq, Eq)] -pub enum TrimType { - Trailing, - Leading, - Both, -} - -impl Display for TrimType { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - TrimType::Trailing => f.write_str("TRAILING "), - TrimType::Leading => f.write_str("LEADING "), - TrimType::Both => f.write_str("BOTH "), - } - } -} - -pub(crate) fn validate_trim( - arg: &Expression, - schema: &Schema, -) -> Result { - validate_arg_type( - arg, - vec![FieldType::String, FieldType::Text], - schema, - ScalarFunctionType::Concat, - 0, - ) -} - -pub(crate) fn evaluate_trim( - schema: &Schema, - arg: &Expression, - what: &Option>, - typ: &Option, - record: &Record, -) -> Result { - let arg_field = arg.evaluate(record, schema)?; - let arg_value = arg_str!(arg_field, "TRIM", 0)?; - - let v1: Vec<_> = match what { - Some(e) => { - let f = e.evaluate(record, schema)?; - arg_str!(f, "TRIM", 1)?.chars().collect() - } - _ => vec![' '], - }; - - let retval = match typ { - Some(TrimType::Both) => arg_value.trim_matches::<&[char]>(&v1).to_string(), - Some(TrimType::Leading) => arg_value.trim_start_matches::<&[char]>(&v1).to_string(), - Some(TrimType::Trailing) => arg_value.trim_end_matches::<&[char]>(&v1).to_string(), - None => arg_value.trim_matches::<&[char]>(&v1).to_string(), - }; - - Ok(match arg.get_type(schema)?.return_type { - FieldType::String => Field::String(retval), - FieldType::UInt - | FieldType::U128 - | FieldType::Int - | FieldType::I128 - | FieldType::Float - | FieldType::Decimal - | FieldType::Boolean - | FieldType::Text - | FieldType::Date - | FieldType::Timestamp - | FieldType::Binary - | FieldType::Json - | FieldType::Point - | FieldType::Duration => Field::Text(retval), - }) -} - -pub(crate) fn get_like_operator_type( - arg: &Expression, - pattern: &Expression, - schema: &Schema, -) -> Result { - validate_arg_type( - pattern, - vec![FieldType::String, FieldType::Text], - schema, - ScalarFunctionType::Concat, - 0, - )?; - - validate_arg_type( - arg, - vec![FieldType::String, FieldType::Text], - schema, - ScalarFunctionType::Concat, - 0, - ) -} - -pub(crate) fn evaluate_like( - schema: &Schema, - arg: &Expression, - pattern: &Expression, - escape: Option, - record: &Record, -) -> Result { - let arg_field = arg.evaluate(record, schema)?; - let arg_value = arg_str!(arg_field, "LIKE", 0)?; - let arg_string = arg_value.as_str(); - - let pattern_field = pattern.evaluate(record, schema)?; - let pattern_value = arg_str!(pattern_field, "LIKE", 1)?; - let pattern_string = pattern_value.as_str(); - - if let Some(escape_char) = escape { - let arg_escape = &arg_string - .escape(&escape_char.to_string()) - .map_err(|e| PipelineError::InvalidArgument(e.to_string()))?; - let result = Like::::like(arg_escape.as_str(), pattern_string) - .map(Field::Boolean) - .map_err(|e| PipelineError::InvalidArgument(e.to_string()))?; - return Ok(result); - } - - let result = Like::::like(arg_string, pattern_string) - .map(Field::Boolean) - .map_err(|e| PipelineError::InvalidArgument(e.to_string()))?; - Ok(result) -} - -pub(crate) fn evaluate_to_char( - schema: &Schema, - arg: &Expression, - pattern: &Expression, - record: &Record, -) -> Result { - let arg_field = arg.evaluate(record, schema)?; - - let pattern_field = pattern.evaluate(record, schema)?; - let pattern_value = arg_str!(pattern_field, "TO_CHAR", 0)?; - - let output = match arg_field { - Field::Timestamp(value) => value.format(pattern_value.as_str()).to_string(), - Field::Date(value) => { - let mut formatted = String::new(); - let format_result = write!(formatted, "{}", value.format(pattern_value.as_str())); - if format_result.is_ok() { - formatted - } else { - pattern_value - } - } - Field::Null => return Ok(Field::Null), - _ => { - return Err(PipelineError::InvalidArgument(format!( - "TO_CHAR({}, ...)", - arg_field - ))) - } - }; - - Ok(Field::String(output)) -} diff --git a/dozer-sql/src/pipeline/expression/tests/conditional.rs b/dozer-sql/src/pipeline/expression/tests/conditional.rs deleted file mode 100644 index a693f87873..0000000000 --- a/dozer-sql/src/pipeline/expression/tests/conditional.rs +++ /dev/null @@ -1,328 +0,0 @@ -use crate::pipeline::expression::conditional::*; -use crate::pipeline::expression::execution::Expression; -use crate::pipeline::expression::tests::test_common::*; -use dozer_types::ordered_float::OrderedFloat; -use dozer_types::types::Record; -use dozer_types::types::{Field, FieldDefinition, FieldType, Schema, SourceDefinition}; -use proptest::prelude::*; - -#[test] -fn test_coalesce() { - proptest!(ProptestConfig::with_cases(1000), move |( - u_num1: u64, u_num2: u64, i_num1: i64, i_num2: i64, f_num1: f64, f_num2: f64, - d_num1: ArbitraryDecimal, d_num2: ArbitraryDecimal, - s_val1: String, s_val2: String, - dt_val1: ArbitraryDateTime, dt_val2: ArbitraryDateTime)| { - let uint1 = Expression::Literal(Field::UInt(u_num1)); - let uint2 = Expression::Literal(Field::UInt(u_num2)); - let int1 = Expression::Literal(Field::Int(i_num1)); - let int2 = Expression::Literal(Field::Int(i_num2)); - let float1 = Expression::Literal(Field::Float(OrderedFloat(f_num1))); - let float2 = Expression::Literal(Field::Float(OrderedFloat(f_num2))); - let dec1 = Expression::Literal(Field::Decimal(d_num1.0)); - let dec2 = Expression::Literal(Field::Decimal(d_num2.0)); - let str1 = Expression::Literal(Field::String(s_val1.clone())); - let str2 = Expression::Literal(Field::String(s_val2)); - let t1 = Expression::Literal(Field::Timestamp(dt_val1.0)); - let t2 = Expression::Literal(Field::Timestamp(dt_val1.0)); - let dt1 = Expression::Literal(Field::Date(dt_val1.0.date_naive())); - let dt2 = Expression::Literal(Field::Date(dt_val2.0.date_naive())); - let null = Expression::Column{ index: 0usize }; - - // UInt - let typ = FieldType::UInt; - let f = Field::UInt(u_num1); - let row = Record::new(vec![f.clone()]); - - let args = vec![null.clone(), uint1.clone()]; - test_validate_coalesce(&args, typ); - test_evaluate_coalesce(&args, &row, typ, f.clone()); - - let args = vec![null.clone(), null.clone(), uint1.clone()]; - test_validate_coalesce(&args, typ); - test_evaluate_coalesce(&args, &row, typ, f.clone()); - - let args = vec![null.clone(), null.clone(), uint1.clone(), null.clone()]; - test_validate_coalesce(&args, typ); - test_evaluate_coalesce(&args, &row, typ, f.clone()); - - let args = vec![null.clone(), null.clone(), uint1, uint2]; - test_validate_coalesce(&args, typ); - test_evaluate_coalesce(&args, &row, typ, f); - - // Int - let typ = FieldType::Int; - let f = Field::Int(i_num1); - let row = Record::new(vec![f.clone()]); - - let args = vec![null.clone(), int1.clone()]; - test_validate_coalesce(&args, typ); - test_evaluate_coalesce(&args, &row, typ, f.clone()); - - let args = vec![null.clone(), null.clone(), int1.clone()]; - test_validate_coalesce(&args, typ); - test_evaluate_coalesce(&args, &row, typ, f.clone()); - - let args = vec![null.clone(), null.clone(), int1.clone(), null.clone()]; - test_validate_coalesce(&args, typ); - test_evaluate_coalesce(&args, &row, typ, f.clone()); - - let args = vec![null.clone(), null.clone(), int1, int2]; - test_validate_coalesce(&args, typ); - test_evaluate_coalesce(&args, &row, typ, f); - - // Float - let typ = FieldType::Float; - let f = Field::Float(OrderedFloat(f_num1)); - let row = Record::new(vec![f.clone()]); - - let args = vec![null.clone(), float1.clone()]; - test_validate_coalesce(&args, typ); - test_evaluate_coalesce(&args, &row, typ, f.clone()); - - let args = vec![null.clone(), null.clone(), float1.clone()]; - test_validate_coalesce(&args, typ); - test_evaluate_coalesce(&args, &row, typ, f.clone()); - - let args = vec![null.clone(), null.clone(), float1.clone(), null.clone()]; - test_validate_coalesce(&args, typ); - test_evaluate_coalesce(&args, &row, typ, f.clone()); - - let args = vec![null.clone(), null.clone(), float1, float2]; - test_validate_coalesce(&args, typ); - test_evaluate_coalesce(&args, &row, typ, f); - - // Decimal - let typ = FieldType::Decimal; - let f = Field::Decimal(d_num1.0); - let row = Record::new(vec![f.clone()]); - - let args = vec![null.clone(), dec1.clone()]; - test_validate_coalesce(&args, typ); - test_evaluate_coalesce(&args, &row, typ, f.clone()); - - let args = vec![null.clone(), null.clone(), dec1.clone()]; - test_validate_coalesce(&args, typ); - test_evaluate_coalesce(&args, &row, typ, f.clone()); - - let args = vec![null.clone(), null.clone(), dec1.clone(), null.clone()]; - test_validate_coalesce(&args, typ); - test_evaluate_coalesce(&args, &row, typ, f.clone()); - - let args = vec![null.clone(), null.clone(), dec1, dec2]; - test_validate_coalesce(&args, typ); - test_evaluate_coalesce(&args, &row, typ, f); - - // String - let typ = FieldType::String; - let f = Field::String(s_val1.clone()); - let row = Record::new(vec![f.clone()]); - - let args = vec![null.clone(), str1.clone()]; - test_validate_coalesce(&args, typ); - test_evaluate_coalesce(&args, &row, typ, f.clone()); - - let args = vec![null.clone(), null.clone(), str1.clone()]; - test_validate_coalesce(&args, typ); - test_evaluate_coalesce(&args, &row, typ, f.clone()); - - let args = vec![null.clone(), null.clone(), str1.clone(), null.clone()]; - test_validate_coalesce(&args, typ); - test_evaluate_coalesce(&args, &row, typ, f.clone()); - - let args = vec![null.clone(), null.clone(), str1.clone(), str2.clone()]; - test_validate_coalesce(&args, typ); - test_evaluate_coalesce(&args, &row, typ, f); - - // String - let typ = FieldType::String; - let f = Field::String(s_val1); - let row = Record::new(vec![f.clone()]); - - let args = vec![null.clone(), str1.clone()]; - test_validate_coalesce(&args, typ); - test_evaluate_coalesce(&args, &row, typ, f.clone()); - - let args = vec![null.clone(), null.clone(), str1.clone()]; - test_validate_coalesce(&args, typ); - test_evaluate_coalesce(&args, &row, typ, f.clone()); - - let args = vec![null.clone(), null.clone(), str1.clone(), null.clone()]; - test_validate_coalesce(&args, typ); - test_evaluate_coalesce(&args, &row, typ, f.clone()); - - let args = vec![null.clone(), null.clone(), str1, str2]; - test_validate_coalesce(&args, typ); - test_evaluate_coalesce(&args, &row, typ, f); - - // Timestamp - let typ = FieldType::Timestamp; - let f = Field::Timestamp(dt_val1.0); - let row = Record::new(vec![f.clone()]); - - let args = vec![null.clone(), t1.clone()]; - test_validate_coalesce(&args, typ); - test_evaluate_coalesce(&args, &row, typ, f.clone()); - - let args = vec![null.clone(), null.clone(), t1.clone()]; - test_validate_coalesce(&args, typ); - test_evaluate_coalesce(&args, &row, typ, f.clone()); - - let args = vec![null.clone(), null.clone(), t1.clone(), null.clone()]; - test_validate_coalesce(&args, typ); - test_evaluate_coalesce(&args, &row, typ, f.clone()); - - let args = vec![null.clone(), null.clone(), t1, t2]; - test_validate_coalesce(&args, typ); - test_evaluate_coalesce(&args, &row, typ, f); - - // Date - let typ = FieldType::Date; - let f = Field::Date(dt_val1.0.date_naive()); - let row = Record::new(vec![f.clone()]); - - let args = vec![null.clone(), dt1.clone()]; - test_validate_coalesce(&args, typ); - test_evaluate_coalesce(&args, &row, typ, f.clone()); - - let args = vec![null.clone(), null.clone(), dt1.clone()]; - test_validate_coalesce(&args, typ); - test_evaluate_coalesce(&args, &row, typ, f.clone()); - - let args = vec![null.clone(), null.clone(), dt1.clone(), null.clone()]; - test_validate_coalesce(&args, typ); - test_evaluate_coalesce(&args, &row, typ, f.clone()); - - let args = vec![null.clone(), null.clone(), dt1, dt2]; - test_validate_coalesce(&args, typ); - test_evaluate_coalesce(&args, &row, typ, f); - - // Null - let typ = FieldType::Date; - let f = Field::Null; - let row = Record::new(vec![f.clone()]); - - let args = vec![null.clone()]; - test_validate_coalesce(&args, typ); - test_evaluate_coalesce(&args, &row, typ, f.clone()); - - let args = vec![null.clone(), null]; - test_validate_coalesce(&args, typ); - test_evaluate_coalesce(&args, &row, typ, f); - }); -} - -fn test_validate_coalesce(args: &[Expression], typ: FieldType) { - let schema = Schema::default() - .field( - FieldDefinition::new(String::from("field"), typ, false, SourceDefinition::Dynamic), - false, - ) - .clone(); - - let result = validate_coalesce(args, &schema).unwrap().return_type; - assert_eq!(result, typ); -} - -fn test_evaluate_coalesce(args: &[Expression], row: &Record, typ: FieldType, _result: Field) { - let schema = Schema::default() - .field( - FieldDefinition::new(String::from("field"), typ, false, SourceDefinition::Dynamic), - false, - ) - .clone(); - - let res = evaluate_coalesce(&schema, args, row).unwrap(); - assert_eq!(res, _result); -} - -#[test] -fn test_coalesce_logic() { - let f = run_fct( - "SELECT COALESCE(field, 2) FROM users", - Schema::default() - .field( - FieldDefinition::new( - String::from("field"), - FieldType::Int, - false, - SourceDefinition::Dynamic, - ), - false, - ) - .clone(), - vec![Field::Null], - ); - assert_eq!(f, Field::Int(2)); - - let f = run_fct( - "SELECT COALESCE(field, CAST(2 AS FLOAT)) FROM users", - Schema::default() - .field( - FieldDefinition::new( - String::from("field"), - FieldType::Float, - false, - SourceDefinition::Dynamic, - ), - false, - ) - .clone(), - vec![Field::Null], - ); - assert_eq!(f, Field::Float(OrderedFloat(2.0))); - - let f = run_fct( - "SELECT COALESCE(field, 'X') FROM users", - Schema::default() - .field( - FieldDefinition::new( - String::from("field"), - FieldType::String, - false, - SourceDefinition::Dynamic, - ), - false, - ) - .clone(), - vec![Field::Null], - ); - assert_eq!(f, Field::String("X".to_string())); - - let f = run_fct( - "SELECT COALESCE(field, 'X') FROM users", - Schema::default() - .field( - FieldDefinition::new( - String::from("field"), - FieldType::String, - false, - SourceDefinition::Dynamic, - ), - false, - ) - .clone(), - vec![Field::Null], - ); - assert_eq!(f, Field::String("X".to_string())); -} - -#[test] -fn test_coalesce_logic_null() { - let f = run_fct( - "SELECT COALESCE(field) FROM users", - Schema::default() - .field( - FieldDefinition::new( - String::from("field"), - FieldType::Int, - false, - SourceDefinition::Dynamic, - ), - false, - ) - .clone(), - vec![Field::Null], - ); - assert_eq!(f, Field::Null); -} diff --git a/dozer-sql/src/pipeline/expression/tests/datetime.rs b/dozer-sql/src/pipeline/expression/tests/datetime.rs deleted file mode 100644 index a63c9d4961..0000000000 --- a/dozer-sql/src/pipeline/expression/tests/datetime.rs +++ /dev/null @@ -1,325 +0,0 @@ -use crate::pipeline::expression::datetime::evaluate_date_part; -use crate::pipeline::expression::execution::Expression; -use crate::pipeline::expression::mathematical::{ - evaluate_add, evaluate_div, evaluate_mod, evaluate_mul, evaluate_sub, -}; -use crate::pipeline::expression::tests::test_common::*; -use dozer_types::chrono; -use dozer_types::chrono::{DateTime, Datelike, NaiveDate}; -use dozer_types::types::Record; -use dozer_types::types::{ - DozerDuration, Field, FieldDefinition, FieldType, Schema, SourceDefinition, TimeUnit, -}; -use num_traits::ToPrimitive; -use proptest::prelude::*; -use sqlparser::ast::DateTimeField; - -#[test] -fn test_time() { - proptest!( - ProptestConfig::with_cases(1000), - move |(datetime: ArbitraryDateTime)| { - test_date_parts(datetime) - }); -} - -fn test_date_parts(datetime: ArbitraryDateTime) { - let row = Record::new(vec![]); - - let date_parts = vec![ - ( - DateTimeField::Dow, - datetime - .0 - .weekday() - .num_days_from_monday() - .to_i64() - .unwrap(), - ), - (DateTimeField::Year, datetime.0.year().to_i64().unwrap()), - (DateTimeField::Month, datetime.0.month().to_i64().unwrap()), - (DateTimeField::Hour, 0), - (DateTimeField::Second, 0), - ( - DateTimeField::Quarter, - datetime.0.month0().to_i64().map(|m| m / 3 + 1).unwrap(), - ), - ]; - - let v = Expression::Literal(Field::Date(datetime.0.date_naive())); - - for (part, value) in date_parts { - let result = evaluate_date_part(&Schema::default(), &part, &v, &row).unwrap(); - assert_eq!(result, Field::Int(value)); - } -} - -#[test] -fn test_extract_date() { - let date_fns: Vec<(&str, i64, i64)> = vec![ - ("dow", 6, 0), - ("day", 1, 2), - ("month", 1, 1), - ("year", 2023, 2023), - ("hour", 0, 0), - ("minute", 0, 12), - ("second", 0, 10), - ("millisecond", 1672531200000, 1672618330000), - ("microsecond", 1672531200000000, 1672618330000000), - ("nanoseconds", 1672531200000000000, 1672618330000000000), - ("quarter", 1, 1), - ("epoch", 1672531200, 1672618330), - ("week", 52, 1), - ("century", 21, 21), - ("decade", 203, 203), - ("doy", 1, 2), - ]; - let inputs = vec![ - Field::Date(NaiveDate::from_ymd_opt(2023, 1, 1).unwrap()), - Field::Timestamp(DateTime::parse_from_rfc3339("2023-01-02T00:12:10Z").unwrap()), - ]; - - for (part, val1, val2) in date_fns { - let mut results = vec![]; - for i in inputs.clone() { - let f = run_fct( - &format!("select extract({part} from date) from users"), - Schema::default() - .field( - FieldDefinition::new( - String::from("date"), - FieldType::Date, - false, - SourceDefinition::Dynamic, - ), - false, - ) - .clone(), - vec![i.clone()], - ); - results.push(f.to_int().unwrap()); - } - assert_eq!(val1, results[0]); - assert_eq!(val2, results[1]); - } -} - -#[test] -fn test_timestamp_diff() { - let f = run_fct( - "SELECT ts1 - ts2 FROM users", - Schema::default() - .field( - FieldDefinition::new( - String::from("ts1"), - FieldType::Timestamp, - false, - SourceDefinition::Dynamic, - ), - false, - ) - .field( - FieldDefinition::new( - String::from("ts2"), - FieldType::Timestamp, - false, - SourceDefinition::Dynamic, - ), - false, - ) - .clone(), - vec![ - Field::Timestamp(DateTime::parse_from_rfc3339("2023-01-02T00:12:11Z").unwrap()), - Field::Timestamp(DateTime::parse_from_rfc3339("2023-01-02T00:12:10Z").unwrap()), - ], - ); - assert_eq!( - f, - Field::Duration(DozerDuration( - std::time::Duration::from_secs(1), - TimeUnit::Nanoseconds - )) - ); -} - -#[test] -fn test_duration() { - proptest!( - ProptestConfig::with_cases(1000), - move |(d1: u64, d2: u64, dt1: ArbitraryDateTime)| { - test_duration_math(d1, d2, dt1) - }); -} - -fn test_duration_math(d1: u64, d2: u64, dt1: ArbitraryDateTime) { - let row = Record::new(vec![]); - - let v = Expression::Literal(Field::Date(dt1.0.date_naive())); - let dur1 = Expression::Literal(Field::Duration(DozerDuration( - std::time::Duration::from_nanos(d1), - TimeUnit::Nanoseconds, - ))); - let dur2 = Expression::Literal(Field::Duration(DozerDuration( - std::time::Duration::from_nanos(d2), - TimeUnit::Nanoseconds, - ))); - - // Duration + Duration = Duration - let result = evaluate_add(&Schema::default(), &dur1, &dur2, &row); - let sum = std::time::Duration::from_nanos(d1).checked_add(std::time::Duration::from_nanos(d2)); - if result.is_ok() && sum.is_some() { - assert_eq!( - result.unwrap(), - Field::Duration(DozerDuration(sum.unwrap(), TimeUnit::Nanoseconds)) - ); - } - // Duration - Duration = Duration - let result = evaluate_sub(&Schema::default(), &dur1, &dur2, &row); - let diff = std::time::Duration::from_nanos(d1).checked_sub(std::time::Duration::from_nanos(d2)); - if result.is_ok() && diff.is_some() { - assert_eq!( - result.unwrap(), - Field::Duration(DozerDuration(diff.unwrap(), TimeUnit::Nanoseconds)) - ); - } - // Duration * Duration = Error - let result = evaluate_mul(&Schema::default(), &dur1, &dur2, &row); - assert!(result.is_err()); - // Duration / Duration = Error - let result = evaluate_div(&Schema::default(), &dur1, &dur2, &row); - assert!(result.is_err()); - // Duration % Duration = Error - let result = evaluate_mod(&Schema::default(), &dur1, &dur2, &row); - assert!(result.is_err()); - - // Duration + Timestamp = Error - let result = evaluate_add(&Schema::default(), &dur1, &v, &row); - assert!(result.is_err()); - // Duration - Timestamp = Error - let result = evaluate_sub(&Schema::default(), &dur1, &v, &row); - assert!(result.is_err()); - // Duration * Timestamp = Error - let result = evaluate_mul(&Schema::default(), &dur1, &v, &row); - assert!(result.is_err()); - // Duration / Timestamp = Error - let result = evaluate_div(&Schema::default(), &dur1, &v, &row); - assert!(result.is_err()); - // Duration % Timestamp = Error - let result = evaluate_mod(&Schema::default(), &dur1, &v, &row); - assert!(result.is_err()); - - // Timestamp + Duration = Timestamp - let result = evaluate_add(&Schema::default(), &v, &dur1, &row); - let sum = dt1 - .0 - .checked_add_signed(chrono::Duration::nanoseconds(d1 as i64)); - if result.is_ok() && sum.is_some() { - assert_eq!(result.unwrap(), Field::Timestamp(sum.unwrap())); - } - // Timestamp - Duration = Timestamp - let result = evaluate_sub(&Schema::default(), &v, &dur2, &row); - let diff = dt1 - .0 - .checked_sub_signed(chrono::Duration::nanoseconds(d2 as i64)); - if result.is_ok() && diff.is_some() { - assert_eq!(result.unwrap(), Field::Timestamp(diff.unwrap())); - } - // Timestamp * Duration = Error - let result = evaluate_mul(&Schema::default(), &v, &dur1, &row); - assert!(result.is_err()); - // Timestamp / Duration = Error - let result = evaluate_div(&Schema::default(), &v, &dur1, &row); - assert!(result.is_err()); - // Timestamp % Duration = Error - let result = evaluate_mod(&Schema::default(), &v, &dur1, &row); - assert!(result.is_err()); -} - -#[test] -fn test_interval() { - let f = run_fct( - "SELECT ts1 - INTERVAL '1' SECOND FROM users", - Schema::default() - .field( - FieldDefinition::new( - String::from("ts1"), - FieldType::Timestamp, - false, - SourceDefinition::Dynamic, - ), - false, - ) - .clone(), - vec![Field::Timestamp( - DateTime::parse_from_rfc3339("2023-01-02T00:12:11Z").unwrap(), - )], - ); - assert_eq!( - f, - Field::Timestamp(DateTime::parse_from_rfc3339("2023-01-02T00:12:10Z").unwrap()) - ); - - let f = run_fct( - "SELECT ts1 + INTERVAL '1' SECOND FROM users", - Schema::default() - .field( - FieldDefinition::new( - String::from("ts1"), - FieldType::Timestamp, - false, - SourceDefinition::Dynamic, - ), - false, - ) - .clone(), - vec![Field::Timestamp( - DateTime::parse_from_rfc3339("2023-01-02T00:12:11Z").unwrap(), - )], - ); - assert_eq!( - f, - Field::Timestamp(DateTime::parse_from_rfc3339("2023-01-02T00:12:12Z").unwrap()) - ); - - let f = run_fct( - "SELECT INTERVAL '1' SECOND + ts1 FROM users", - Schema::default() - .field( - FieldDefinition::new( - String::from("ts1"), - FieldType::Timestamp, - false, - SourceDefinition::Dynamic, - ), - false, - ) - .clone(), - vec![Field::Timestamp( - DateTime::parse_from_rfc3339("2023-01-02T00:12:11Z").unwrap(), - )], - ); - assert_eq!( - f, - Field::Timestamp(DateTime::parse_from_rfc3339("2023-01-02T00:12:12Z").unwrap()) - ); -} - -#[test] -fn test_now() { - let f = run_fct( - "SELECT NOW() FROM users", - Schema::default() - .field( - FieldDefinition::new( - String::from("ts1"), - FieldType::Timestamp, - false, - SourceDefinition::Dynamic, - ), - false, - ) - .clone(), - vec![], - ); - assert!(f.to_timestamp().is_ok()) -} diff --git a/dozer-sql/src/pipeline/expression/tests/distance.rs b/dozer-sql/src/pipeline/expression/tests/distance.rs deleted file mode 100644 index a226c3db8e..0000000000 --- a/dozer-sql/src/pipeline/expression/tests/distance.rs +++ /dev/null @@ -1,246 +0,0 @@ -use crate::arg_point; -use crate::pipeline::errors::PipelineError; -use crate::pipeline::errors::PipelineError::{ - InvalidFunctionArgumentType, NotEnoughArguments, TooManyArguments, -}; -use crate::pipeline::expression::execution::Expression; -use crate::pipeline::expression::execution::Expression::Literal; -use crate::pipeline::expression::geo::common::GeoFunctionType; -use crate::pipeline::expression::geo::distance::{evaluate_distance, validate_distance, Algorithm}; -use crate::pipeline::expression::tests::test_common::*; -use dozer_types::geo::{GeodesicDistance, HaversineDistance}; -use dozer_types::ordered_float::OrderedFloat; -use dozer_types::types::Record; -use dozer_types::types::{DozerPoint, Field, FieldDefinition, FieldType, Schema, SourceDefinition}; -use proptest::prelude::*; - -#[test] -fn test_geo() { - proptest!(ProptestConfig::with_cases(1000), move |(x1: f64, x2: f64, y1: f64, y2: f64)| { - let row = Record::new(vec![]); - let from = Field::Point(DozerPoint::from((x1, y1))); - let to = Field::Point(DozerPoint::from((x2, y2))); - let null = Field::Null; - - test_distance(&from, &to, None, &row, None); - test_distance(&from, &null, None, &row, Some(Ok(Field::Null))); - test_distance(&null, &to, None, &row, Some(Ok(Field::Null))); - - test_distance(&from, &to, Some(Algorithm::Geodesic), &row, None); - test_distance(&from, &null, Some(Algorithm::Geodesic), &row, Some(Ok(Field::Null))); - test_distance(&null, &to, Some(Algorithm::Geodesic), &row, Some(Ok(Field::Null))); - - test_distance(&from, &to, Some(Algorithm::Haversine), &row, None); - test_distance(&from, &null, Some(Algorithm::Haversine), &row, Some(Ok(Field::Null))); - test_distance(&null, &to, Some(Algorithm::Haversine), &row, Some(Ok(Field::Null))); - - // test_distance(&from, &to, Some(Algorithm::Vincenty), &row, None); - // test_distance(&from, &null, Some(Algorithm::Vincenty), &row, Some(Ok(Field::Null))); - // test_distance(&null, &to, Some(Algorithm::Vincenty), &row, Some(Ok(Field::Null))); - }); -} - -fn test_distance( - from: &Field, - to: &Field, - typ: Option, - row: &Record, - result: Option>, -) { - let args = &vec![Literal(from.clone()), Literal(to.clone())]; - if validate_distance(args, &Schema::default()).is_ok() { - match result { - None => { - let from_f = from.to_owned(); - let to_f = to.to_owned(); - let f = arg_point!(from_f, GeoFunctionType::Distance, 0).unwrap(); - let t = arg_point!(to_f, GeoFunctionType::Distance, 0).unwrap(); - let _dist = match typ { - None => f.geodesic_distance(t), - Some(Algorithm::Geodesic) => f.geodesic_distance(t), - Some(Algorithm::Haversine) => f.0.haversine_distance(&t.0), - Some(Algorithm::Vincenty) => OrderedFloat(0.0), - // Some(Algorithm::Vincenty) => f.0.vincenty_distance(&t.0).unwrap(), - }; - assert!(matches!( - evaluate_distance(&Schema::default(), args, row), - Ok(Field::Float(_dist)), - )) - } - Some(_val) => { - assert!(matches!( - evaluate_distance(&Schema::default(), args, row), - _val, - )) - } - } - } -} - -#[test] -fn test_validate_distance() { - let schema = Schema::default() - .field( - FieldDefinition::new( - String::from("from"), - FieldType::Point, - false, - SourceDefinition::Dynamic, - ), - false, - ) - .field( - FieldDefinition::new( - String::from("to"), - FieldType::Point, - false, - SourceDefinition::Dynamic, - ), - false, - ) - .clone(); - let _fn_type = String::from("DISTANCE"); - - let result = validate_distance(&[], &schema); - assert!(result.is_err()); - assert!(matches!(result, Err(NotEnoughArguments(_fn_type)))); - - let result = validate_distance(&[Expression::Column { index: 0 }], &schema); - - assert!(result.is_err()); - assert!(matches!(result, Err(NotEnoughArguments(_fn_type)))); - - let result = validate_distance( - &[ - Expression::Column { index: 0 }, - Expression::Column { index: 1 }, - ], - &schema, - ); - - assert!(result.is_ok()); - - let result = validate_distance( - &[ - Expression::Column { index: 0 }, - Expression::Column { index: 1 }, - Expression::Literal(Field::String("GEODESIC".to_string())), - ], - &schema, - ); - - assert!(result.is_ok()); - - let result = validate_distance( - &[ - Expression::Column { index: 0 }, - Expression::Column { index: 1 }, - Expression::Literal(Field::String("GEODESIC".to_string())), - Expression::Column { index: 2 }, - ], - &schema, - ); - - assert!(result.is_err()); - assert!(matches!(result, Err(TooManyArguments(_fn_type)))); - - let result = validate_distance( - &[ - Expression::Column { index: 0 }, - Expression::Literal(Field::String("GEODESIC".to_string())), - Expression::Column { index: 2 }, - ], - &schema, - ); - - let _expected_types = [FieldType::Point]; - assert!(result.is_err()); - assert!(matches!( - result, - Err(InvalidFunctionArgumentType( - _fn_type, - FieldType::String, - _expected_types, - 1 - )) - )); -} - -#[test] -fn test_distance_logical() { - let tests = vec![ - ("", 1113.0264976969), - ("GEODESIC", 1113.0264976969), - ("HAVERSINE", 1111.7814468418496), - ("VINCENTY", 1113.0264975564357), - ]; - - let schema = Schema::default() - .field( - FieldDefinition::new( - String::from("from"), - FieldType::Point, - false, - SourceDefinition::Dynamic, - ), - false, - ) - .field( - FieldDefinition::new( - String::from("to"), - FieldType::Point, - false, - SourceDefinition::Dynamic, - ), - false, - ) - .clone(); - - let input = vec![ - Field::Point(DozerPoint::from((1.0, 1.0))), - Field::Point(DozerPoint::from((1.01, 1.0))), - ]; - - for (calculation_type, expected_result) in tests { - let sql = if calculation_type.is_empty() { - "SELECT DISTANCE(from, to) FROM LOCATIONS".to_string() - } else { - format!("SELECT DISTANCE(from, to, '{calculation_type}') FROM LOCATIONS") - }; - if let Field::Float(OrderedFloat(result)) = run_fct(&sql, schema.clone(), input.clone()) { - assert!((result - expected_result) < 0.000000001); - } else { - panic!("Expected float"); - } - } -} - -#[test] -fn test_distance_with_nullable_parameter() { - let f = run_fct( - "SELECT DISTANCE(from, to) FROM LOCATION", - Schema::default() - .field( - FieldDefinition::new( - String::from("from"), - FieldType::Point, - false, - SourceDefinition::Dynamic, - ), - false, - ) - .field( - FieldDefinition::new( - String::from("to"), - FieldType::Point, - true, - SourceDefinition::Dynamic, - ), - false, - ) - .clone(), - vec![Field::Point(DozerPoint::from((0.0, 1.0))), Field::Null], - ); - - assert_eq!(f, Field::Null); -} diff --git a/dozer-sql/src/pipeline/expression/tests/logical.rs b/dozer-sql/src/pipeline/expression/tests/logical.rs deleted file mode 100644 index a813ee5d2e..0000000000 --- a/dozer-sql/src/pipeline/expression/tests/logical.rs +++ /dev/null @@ -1,130 +0,0 @@ -use crate::pipeline::expression::execution::Expression::Literal; -use crate::pipeline::expression::logical::{evaluate_and, evaluate_not, evaluate_or}; -use dozer_types::types::Record; -use dozer_types::types::{Field, Schema}; -use dozer_types::{ordered_float::OrderedFloat, rust_decimal::Decimal}; -#[cfg(test)] -use proptest::prelude::*; - -#[test] -fn test_logical() { - proptest!( - ProptestConfig::with_cases(1000), - move |(bool1: bool, bool2: bool, u_num: u64, i_num: i64, f_num: f64, str in ".*")| { - _test_bool_bool_and(bool1, bool2); - _test_bool_null_and(Field::Boolean(bool1), Field::Null); - _test_bool_null_and(Field::Null, Field::Boolean(bool1)); - - _test_bool_bool_or(bool1, bool2); - _test_bool_null_or(bool1); - _test_null_bool_or(bool2); - - _test_bool_not(bool2); - - _test_bool_non_bool_and(Field::UInt(u_num), Field::Boolean(bool1)); - _test_bool_non_bool_and(Field::Int(i_num), Field::Boolean(bool1)); - _test_bool_non_bool_and(Field::Float(OrderedFloat(f_num)), Field::Boolean(bool1)); - _test_bool_non_bool_and(Field::Decimal(Decimal::from(u_num)), Field::Boolean(bool1)); - _test_bool_non_bool_and(Field::String(str.clone()), Field::Boolean(bool1)); - _test_bool_non_bool_and(Field::Text(str.clone()), Field::Boolean(bool1)); - - _test_bool_non_bool_and(Field::Boolean(bool2), Field::UInt(u_num)); - _test_bool_non_bool_and(Field::Boolean(bool2), Field::Int(i_num)); - _test_bool_non_bool_and(Field::Boolean(bool2), Field::Float(OrderedFloat(f_num))); - _test_bool_non_bool_and(Field::Boolean(bool2), Field::Decimal(Decimal::from(u_num))); - _test_bool_non_bool_and(Field::Boolean(bool2), Field::String(str.clone())); - _test_bool_non_bool_and(Field::Boolean(bool2), Field::Text(str.clone())); - - _test_bool_non_bool_or(Field::UInt(u_num), Field::Boolean(bool1)); - _test_bool_non_bool_or(Field::Int(i_num), Field::Boolean(bool1)); - _test_bool_non_bool_or(Field::Float(OrderedFloat(f_num)), Field::Boolean(bool1)); - _test_bool_non_bool_or(Field::Decimal(Decimal::from(u_num)), Field::Boolean(bool1)); - _test_bool_non_bool_or(Field::String(str.clone()), Field::Boolean(bool1)); - _test_bool_non_bool_or(Field::Text(str.clone()), Field::Boolean(bool1)); - - _test_bool_non_bool_or(Field::Boolean(bool2), Field::UInt(u_num)); - _test_bool_non_bool_or(Field::Boolean(bool2), Field::Int(i_num)); - _test_bool_non_bool_or(Field::Boolean(bool2), Field::Float(OrderedFloat(f_num))); - _test_bool_non_bool_or(Field::Boolean(bool2), Field::Decimal(Decimal::from(u_num))); - _test_bool_non_bool_or(Field::Boolean(bool2), Field::String(str.clone())); - _test_bool_non_bool_or(Field::Boolean(bool2), Field::Text(str)); - }); -} - -fn _test_bool_bool_and(bool1: bool, bool2: bool) { - let row = Record::new(vec![]); - let l = Box::new(Literal(Field::Boolean(bool1))); - let r = Box::new(Literal(Field::Boolean(bool2))); - assert!(matches!( - evaluate_and(&Schema::default(), &l, &r, &row) - .unwrap_or_else(|e| panic!("{}", e.to_string())), - Field::Boolean(_ans) - )); -} - -fn _test_bool_null_and(f1: Field, f2: Field) { - let row = Record::new(vec![]); - let l = Box::new(Literal(f1)); - let r = Box::new(Literal(f2)); - assert!(matches!( - evaluate_and(&Schema::default(), &l, &r, &row) - .unwrap_or_else(|e| panic!("{}", e.to_string())), - Field::Boolean(false) - )); -} - -fn _test_bool_bool_or(bool1: bool, bool2: bool) { - let row = Record::new(vec![]); - let l = Box::new(Literal(Field::Boolean(bool1))); - let r = Box::new(Literal(Field::Boolean(bool2))); - assert!(matches!( - evaluate_or(&Schema::default(), &l, &r, &row) - .unwrap_or_else(|e| panic!("{}", e.to_string())), - Field::Boolean(_ans) - )); -} - -fn _test_bool_null_or(_bool: bool) { - let row = Record::new(vec![]); - let l = Box::new(Literal(Field::Boolean(_bool))); - let r = Box::new(Literal(Field::Null)); - assert!(matches!( - evaluate_or(&Schema::default(), &l, &r, &row) - .unwrap_or_else(|e| panic!("{}", e.to_string())), - Field::Boolean(_bool) - )); -} - -fn _test_null_bool_or(_bool: bool) { - let row = Record::new(vec![]); - let l = Box::new(Literal(Field::Null)); - let r = Box::new(Literal(Field::Boolean(_bool))); - assert!(matches!( - evaluate_or(&Schema::default(), &l, &r, &row) - .unwrap_or_else(|e| panic!("{}", e.to_string())), - Field::Boolean(_bool) - )); -} - -fn _test_bool_not(bool: bool) { - let row = Record::new(vec![]); - let v = Box::new(Literal(Field::Boolean(bool))); - assert!(matches!( - evaluate_not(&Schema::default(), &v, &row).unwrap_or_else(|e| panic!("{}", e.to_string())), - Field::Boolean(_ans) - )); -} - -fn _test_bool_non_bool_and(f1: Field, f2: Field) { - let row = Record::new(vec![]); - let l = Box::new(Literal(f1)); - let r = Box::new(Literal(f2)); - assert!(evaluate_and(&Schema::default(), &l, &r, &row).is_err()); -} - -fn _test_bool_non_bool_or(f1: Field, f2: Field) { - let row = Record::new(vec![]); - let l = Box::new(Literal(f1)); - let r = Box::new(Literal(f2)); - assert!(evaluate_or(&Schema::default(), &l, &r, &row).is_err()); -} diff --git a/dozer-sql/src/pipeline/expression/tests/mod.rs b/dozer-sql/src/pipeline/expression/tests/mod.rs deleted file mode 100644 index fb3368fadf..0000000000 --- a/dozer-sql/src/pipeline/expression/tests/mod.rs +++ /dev/null @@ -1,31 +0,0 @@ -#[cfg(test)] -mod execution; -#[cfg(test)] -mod expression_builder_test; - -#[cfg(test)] -mod case; -#[cfg(test)] -mod cast; -#[cfg(test)] -mod comparison; -#[cfg(test)] -mod conditional; -#[cfg(test)] -mod datetime; -#[cfg(test)] -mod distance; -mod in_list; -#[cfg(test)] -mod json_functions; -#[cfg(test)] -mod logical; -#[cfg(test)] -mod mathematical; -#[cfg(test)] -mod number; -#[cfg(test)] -mod point; -#[cfg(test)] -mod string; -mod test_common; diff --git a/dozer-sql/src/pipeline/expression/tests/number.rs b/dozer-sql/src/pipeline/expression/tests/number.rs deleted file mode 100644 index d9ca167434..0000000000 --- a/dozer-sql/src/pipeline/expression/tests/number.rs +++ /dev/null @@ -1,110 +0,0 @@ -use crate::pipeline::expression::execution::Expression::Literal; -use crate::pipeline::expression::scalar::number::{evaluate_abs, evaluate_round}; -use crate::pipeline::expression::tests::test_common::*; -use dozer_types::ordered_float::OrderedFloat; -use dozer_types::types::Record; -use dozer_types::types::{Field, FieldDefinition, FieldType, Schema, SourceDefinition}; -use proptest::prelude::*; -use std::ops::Neg; - -#[test] -fn test_abs() { - proptest!(ProptestConfig::with_cases(1000), |(i_num in 0i64..100000000i64, f_num in 0f64..100000000f64)| { - let row = Record::new(vec![]); - - let v = Box::new(Literal(Field::Int(i_num.neg()))); - assert_eq!( - evaluate_abs(&Schema::default(), &v, &row) - .unwrap_or_else(|e| panic!("{}", e.to_string())), - Field::Int(i_num) - ); - - let row = Record::new(vec![]); - - let v = Box::new(Literal(Field::Float(OrderedFloat(f_num.neg())))); - assert_eq!( - evaluate_abs(&Schema::default(), &v, &row) - .unwrap_or_else(|e| panic!("{}", e.to_string())), - Field::Float(OrderedFloat(f_num)) - ); - }); -} - -#[test] -fn test_round() { - proptest!(ProptestConfig::with_cases(1000), |(i_num: i64, f_num: f64, i_pow: i32, f_pow: f32)| { - let row = Record::new(vec![]); - - let v = Box::new(Literal(Field::Int(i_num))); - let d = &Box::new(Literal(Field::Int(0))); - assert_eq!( - evaluate_round(&Schema::default(), &v, Some(d), &row) - .unwrap_or_else(|e| panic!("{}", e.to_string())), - Field::Int(i_num) - ); - - let v = Box::new(Literal(Field::Float(OrderedFloat(f_num)))); - let d = &Box::new(Literal(Field::Int(0))); - assert_eq!( - evaluate_round(&Schema::default(), &v, Some(d), &row) - .unwrap_or_else(|e| panic!("{}", e.to_string())), - Field::Float(OrderedFloat(f_num.round())) - ); - - let v = Box::new(Literal(Field::Float(OrderedFloat(f_num)))); - let d = &Box::new(Literal(Field::Int(i_pow as i64))); - let order = 10.0_f64.powi(i_pow); - assert_eq!( - evaluate_round(&Schema::default(), &v, Some(d), &row) - .unwrap_or_else(|e| panic!("{}", e.to_string())), - Field::Float(OrderedFloat((f_num * order).round() / order)) - ); - - let v = Box::new(Literal(Field::Float(OrderedFloat(f_num)))); - let d = &Box::new(Literal(Field::Float(OrderedFloat(f_pow as f64)))); - let order = 10.0_f64.powi(f_pow.round() as i32); - assert_eq!( - evaluate_round(&Schema::default(), &v, Some(d), &row) - .unwrap_or_else(|e| panic!("{}", e.to_string())), - Field::Float(OrderedFloat((f_num * order).round() / order)) - ); - - let v = Box::new(Literal(Field::Float(OrderedFloat(f_num)))); - let d = &Box::new(Literal(Field::String(f_pow.to_string()))); - assert_eq!( - evaluate_round(&Schema::default(), &v, Some(d), &row) - .unwrap_or_else(|e| panic!("{}", e.to_string())), - Field::Float(OrderedFloat(f_num.round())) - ); - - let v = Box::new(Literal(Field::Null)); - let d = &Box::new(Literal(Field::String(i_pow.to_string()))); - assert_eq!( - evaluate_round(&Schema::default(), &v, Some(d), &row) - .unwrap_or_else(|e| panic!("{}", e.to_string())), - Field::Null - ); - }); -} - -#[test] -fn test_abs_logic() { - proptest!(ProptestConfig::with_cases(1000), |(i_num in 0i64..100000000i64)| { - let f = run_fct( - "SELECT ABS(c) FROM USERS", - Schema::default() - .field( - FieldDefinition::new( - String::from("c"), - FieldType::Int, - false, - SourceDefinition::Dynamic, - ), - false, - ) - .clone(), - vec![Field::Int(i_num.neg())], - ); - assert_eq!(f, Field::Int(i_num)); - }); -} diff --git a/dozer-sql/src/pipeline/expression/tests/point.rs b/dozer-sql/src/pipeline/expression/tests/point.rs deleted file mode 100644 index b1177d3d18..0000000000 --- a/dozer-sql/src/pipeline/expression/tests/point.rs +++ /dev/null @@ -1,246 +0,0 @@ -use crate::pipeline::expression::geo::point::{evaluate_point, validate_point}; -use crate::pipeline::expression::tests::test_common::*; -use dozer_types::ordered_float::OrderedFloat; -use dozer_types::types::Record; -use dozer_types::types::{DozerPoint, Field, FieldDefinition, FieldType, Schema, SourceDefinition}; - -use crate::pipeline::errors::PipelineError::{ - InvalidArgument, InvalidFunctionArgumentType, NotEnoughArguments, TooManyArguments, -}; -use crate::pipeline::expression::execution::Expression; -use proptest::prelude::*; - -#[test] -fn test_point() { - proptest!( - ProptestConfig::with_cases(1000), move |(x: i64, y: i64)| { - test_validate_point(x, y); - test_evaluate_point(x, y); - }); -} - -fn test_validate_point(x: i64, y: i64) { - let schema = Schema::default() - .field( - FieldDefinition::new( - String::from("x"), - FieldType::Float, - false, - SourceDefinition::Dynamic, - ), - false, - ) - .field( - FieldDefinition::new( - String::from("y"), - FieldType::Float, - false, - SourceDefinition::Dynamic, - ), - false, - ) - .clone(); - let _fn_type = String::from("POINT"); - - let result = validate_point(&[], &schema); - assert!(result.is_err()); - assert!(matches!(result, Err(NotEnoughArguments(_fn_type)))); - - let result = validate_point(&[Expression::Column { index: 0 }], &schema); - - assert!(result.is_err()); - assert!(matches!(result, Err(NotEnoughArguments(_fn_type)))); - - let result = validate_point( - &[ - Expression::Column { index: 0 }, - Expression::Column { index: 1 }, - ], - &schema, - ); - - assert!(result.is_ok()); - - let result = validate_point( - &[ - Expression::Column { index: 0 }, - Expression::Column { index: 1 }, - Expression::Column { index: 2 }, - ], - &schema, - ); - - assert!(result.is_err()); - assert!(matches!(result, Err(TooManyArguments(_fn_type)))); - - let result = validate_point( - &[ - Expression::Column { index: 0 }, - Expression::Literal(Field::Int(y)), - ], - &schema, - ); - - let _expected_types = [FieldType::Float]; - assert!(result.is_err()); - assert!(matches!( - result, - Err(InvalidFunctionArgumentType( - _fn_type, - FieldType::Int, - _expected_types, - 1 - )) - )); - - let result = validate_point( - &[ - Expression::Literal(Field::Int(x)), - Expression::Column { index: 0 }, - ], - &schema, - ); - - assert!(result.is_err()); - assert!(matches!( - result, - Err(InvalidFunctionArgumentType( - _fn_type, - FieldType::Int, - _expected_types, - 0 - )) - )); -} - -fn test_evaluate_point(x: i64, y: i64) { - let row = Record::new(vec![]); - - let schema = Schema::default() - .field( - FieldDefinition::new( - String::from("x"), - FieldType::Float, - false, - SourceDefinition::Dynamic, - ), - false, - ) - .field( - FieldDefinition::new( - String::from("y"), - FieldType::Float, - false, - SourceDefinition::Dynamic, - ), - false, - ) - .clone(); - let _fn_type = String::from("x"); - - let result = evaluate_point(&schema, &[], &row); - assert!(result.is_err()); - assert!(matches!(result, Err(InvalidArgument(_fn_type)))); - - let _fn_type = String::from("y"); - - let result = evaluate_point(&schema, &[Expression::Literal(Field::Int(x))], &row); - assert!(result.is_err()); - assert!(matches!(result, Err(InvalidArgument(_fn_type)))); - - let result = evaluate_point( - &schema, - &[ - Expression::Literal(Field::Int(x)), - Expression::Literal(Field::Int(y)), - ], - &row, - ); - - assert!(result.is_ok()); - - let result = evaluate_point( - &schema, - &[ - Expression::Literal(Field::Int(x)), - Expression::Literal(Field::Null), - ], - &row, - ); - - assert!(result.is_ok()); - assert!(matches!(result, Ok(Field::Null))); - - let result = evaluate_point( - &schema, - &[ - Expression::Literal(Field::Null), - Expression::Literal(Field::Int(y)), - ], - &row, - ); - - assert!(result.is_ok()); - assert!(matches!(result, Ok(Field::Null))); -} - -#[test] -fn test_point_logical() { - let f = run_fct( - "SELECT POINT(x, y) FROM LOCATION", - Schema::default() - .field( - FieldDefinition::new( - String::from("x"), - FieldType::Float, - false, - SourceDefinition::Dynamic, - ), - false, - ) - .field( - FieldDefinition::new( - String::from("y"), - FieldType::Float, - false, - SourceDefinition::Dynamic, - ), - false, - ) - .clone(), - vec![ - Field::Float(OrderedFloat(1.0)), - Field::Float(OrderedFloat(2.0)), - ], - ); - assert_eq!(f, Field::Point(DozerPoint::from((1.0, 2.0)))); -} - -#[test] -fn test_point_with_nullable_parameter() { - let f = run_fct( - "SELECT POINT(x, y) FROM LOCATION", - Schema::default() - .field( - FieldDefinition::new( - String::from("x"), - FieldType::Float, - false, - SourceDefinition::Dynamic, - ), - false, - ) - .field( - FieldDefinition::new( - String::from("y"), - FieldType::Float, - true, - SourceDefinition::Dynamic, - ), - false, - ) - .clone(), - vec![Field::Float(OrderedFloat(1.0)), Field::Null], - ); - assert_eq!(f, Field::Null); -} diff --git a/dozer-sql/src/pipeline/expression/tests/test_common.rs b/dozer-sql/src/pipeline/expression/tests/test_common.rs deleted file mode 100644 index 5d90239ef5..0000000000 --- a/dozer-sql/src/pipeline/expression/tests/test_common.rs +++ /dev/null @@ -1,114 +0,0 @@ -use crate::pipeline::{projection::factory::ProjectionProcessorFactory, tests::utils::get_select}; -use dozer_core::channels::ProcessorChannelForwarder; -use dozer_core::executor_operation::ProcessorOperation; -use dozer_core::node::ProcessorFactory; -use dozer_core::processor_record::ProcessorRecordStore; -use dozer_core::DEFAULT_PORT_HANDLE; -use dozer_types::chrono::{ - DateTime, Datelike, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime, Timelike, -}; -use dozer_types::rust_decimal::Decimal; -use dozer_types::types::Record; -use dozer_types::types::{Field, Schema}; -use proptest::prelude::*; -use std::collections::HashMap; - -struct TestChannelForwarder { - operations: Vec, -} - -impl ProcessorChannelForwarder for TestChannelForwarder { - fn send(&mut self, op: ProcessorOperation, _port: dozer_core::node::PortHandle) { - self.operations.push(op); - } -} - -pub(crate) fn run_fct(sql: &str, schema: Schema, input: Vec) -> Field { - let record_store = ProcessorRecordStore::new().unwrap(); - - let select = get_select(sql).unwrap(); - let processor_factory = - ProjectionProcessorFactory::_new("projection_id".to_owned(), select.projection, vec![]); - processor_factory - .get_output_schema( - &DEFAULT_PORT_HANDLE, - &[(DEFAULT_PORT_HANDLE, schema.clone())] - .into_iter() - .collect(), - ) - .unwrap(); - - let mut processor = processor_factory - .build( - HashMap::from([(DEFAULT_PORT_HANDLE, schema)]), - HashMap::new(), - &record_store, - None, - ) - .unwrap(); - - let mut fw = TestChannelForwarder { operations: vec![] }; - let rec = Record::new(input); - let rec = record_store.create_record(&rec).unwrap(); - - let op = ProcessorOperation::Insert { new: rec }; - - processor - .process(DEFAULT_PORT_HANDLE, &record_store, op, &mut fw) - .unwrap(); - - match &fw.operations[0] { - ProcessorOperation::Insert { new } => { - let mut new = record_store.load_record(new).unwrap(); - new.values.remove(0) - } - _ => panic!("Unable to find result value"), - } -} - -#[derive(Debug)] -pub struct ArbitraryDecimal(pub Decimal); - -impl Arbitrary for ArbitraryDecimal { - type Parameters = (); - type Strategy = BoxedStrategy; - - fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy { - (i64::MIN..i64::MAX, u32::MIN..29u32) - .prop_map(|(num, scale)| ArbitraryDecimal(Decimal::new(num, scale))) - .boxed() - } -} - -#[derive(Debug)] -pub struct ArbitraryDateTime(pub DateTime); - -impl Arbitrary for ArbitraryDateTime { - type Parameters = (); - type Strategy = BoxedStrategy; - - fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy { - ( - NaiveDateTime::MIN.year()..NaiveDateTime::MAX.year(), - 1..13u32, - 1..32u32, - 0..NaiveDateTime::MAX.second(), - 0..NaiveDateTime::MAX.nanosecond(), - ) - .prop_map(|(year, month, day, secs, nano)| { - let timezone_east = FixedOffset::east_opt(8 * 60 * 60).unwrap(); - let date = NaiveDate::from_ymd_opt(year, month, day); - // Some dates are not able to created caused by leap in February with day larger than 28 or 29 - if date.is_none() { - return ArbitraryDateTime(DateTime::default()); - } - let time = NaiveTime::from_num_seconds_from_midnight_opt(secs, nano).unwrap(); - let datetime = DateTime::::from_local( - NaiveDateTime::new(date.unwrap(), time), - timezone_east, - ); - ArbitraryDateTime(datetime) - }) - .boxed() - } -} diff --git a/dozer-sql/src/pipeline/mod.rs b/dozer-sql/src/pipeline/mod.rs deleted file mode 100644 index c4a44ca1b2..0000000000 --- a/dozer-sql/src/pipeline/mod.rs +++ /dev/null @@ -1,18 +0,0 @@ -mod aggregation; -pub mod builder; -pub mod errors; -mod expression; -mod pipeline_builder; -mod planner; -mod product; -mod projection; -mod selection; -mod table_operator; -mod utils; -mod window; - -#[cfg(test)] -mod tests; - -#[cfg(feature = "onnx")] -pub mod onnx; diff --git a/dozer-sql/src/pipeline/pipeline_builder/from_builder.rs b/dozer-sql/src/pipeline_builder/from_builder.rs similarity index 98% rename from dozer-sql/src/pipeline/pipeline_builder/from_builder.rs rename to dozer-sql/src/pipeline_builder/from_builder.rs index 5e4b6666d6..c62ce995f5 100644 --- a/dozer-sql/src/pipeline/pipeline_builder/from_builder.rs +++ b/dozer-sql/src/pipeline_builder/from_builder.rs @@ -5,12 +5,14 @@ use dozer_core::{ node::PortHandle, DEFAULT_PORT_HANDLE, }; -use sqlparser::ast::{FunctionArg, ObjectName, TableFactor, TableWithJoins}; +use dozer_sql_expression::{ + builder::ExpressionBuilder, + sqlparser::ast::{FunctionArg, ObjectName, TableFactor, TableWithJoins}, +}; -use crate::pipeline::{ +use crate::{ builder::{get_from_source, OutputNodeInfo, QueryContext}, errors::PipelineError, - expression::builder::ExpressionBuilder, product::table::factory::TableProcessorFactory, table_operator::factory::TableOperatorProcessorFactory, window::factory::WindowProcessorFactory, diff --git a/dozer-sql/src/pipeline/pipeline_builder/join_builder.rs b/dozer-sql/src/pipeline_builder/join_builder.rs similarity index 96% rename from dozer-sql/src/pipeline/pipeline_builder/join_builder.rs rename to dozer-sql/src/pipeline_builder/join_builder.rs index 6d9a64fc8d..46d97868da 100644 --- a/dozer-sql/src/pipeline/pipeline_builder/join_builder.rs +++ b/dozer-sql/src/pipeline_builder/join_builder.rs @@ -2,9 +2,9 @@ use dozer_core::{ app::{AppPipeline, PipelineEntryPoint}, DEFAULT_PORT_HANDLE, }; -use sqlparser::ast::TableWithJoins; +use dozer_sql_expression::sqlparser::ast::TableWithJoins; -use crate::pipeline::{ +use crate::{ builder::{get_from_source, QueryContext}, errors::PipelineError, product::{ @@ -148,7 +148,7 @@ pub(crate) fn insert_join_to_pipeline( // TODO: refactor this fn insert_join_source_to_pipeline( - source: sqlparser::ast::TableFactor, + source: dozer_sql_expression::sqlparser::ast::TableFactor, pipeline: &mut AppPipeline, pipeline_idx: usize, query_context: &mut QueryContext, @@ -258,6 +258,9 @@ fn insert_table_operator_to_pipeline( } } -fn is_nested_join(left_table: &sqlparser::ast::TableFactor) -> bool { - matches!(left_table, sqlparser::ast::TableFactor::NestedJoin { .. }) +fn is_nested_join(left_table: &dozer_sql_expression::sqlparser::ast::TableFactor) -> bool { + matches!( + left_table, + dozer_sql_expression::sqlparser::ast::TableFactor::NestedJoin { .. } + ) } diff --git a/dozer-sql/src/pipeline/pipeline_builder/mod.rs b/dozer-sql/src/pipeline_builder/mod.rs similarity index 100% rename from dozer-sql/src/pipeline/pipeline_builder/mod.rs rename to dozer-sql/src/pipeline_builder/mod.rs diff --git a/dozer-sql/src/pipeline/planner/mod.rs b/dozer-sql/src/planner/mod.rs similarity index 100% rename from dozer-sql/src/pipeline/planner/mod.rs rename to dozer-sql/src/planner/mod.rs diff --git a/dozer-sql/src/pipeline/planner/projection.rs b/dozer-sql/src/planner/projection.rs similarity index 96% rename from dozer-sql/src/pipeline/planner/projection.rs rename to dozer-sql/src/planner/projection.rs index d56cf55259..3d62001ba7 100644 --- a/dozer-sql/src/pipeline/planner/projection.rs +++ b/dozer-sql/src/planner/projection.rs @@ -1,11 +1,11 @@ #![allow(dead_code)] -use crate::pipeline::errors::PipelineError; -use crate::pipeline::expression::builder::ExpressionBuilder; -use crate::pipeline::expression::execution::Expression; -use crate::pipeline::pipeline_builder::from_builder::string_from_sql_object_name; +use crate::errors::PipelineError; +use crate::pipeline_builder::from_builder::string_from_sql_object_name; +use dozer_sql_expression::builder::ExpressionBuilder; +use dozer_sql_expression::execution::Expression; +use dozer_sql_expression::sqlparser::ast::{Expr, Ident, Select, SelectItem}; use dozer_types::models::udf_config::UdfConfig; use dozer_types::types::{FieldDefinition, Schema}; -use sqlparser::ast::{Expr, Ident, Select, SelectItem}; #[derive(Clone, Copy)] pub enum PrimaryKeyAction { diff --git a/dozer-sql/src/pipeline/planner/tests/mod.rs b/dozer-sql/src/planner/tests/mod.rs similarity index 100% rename from dozer-sql/src/pipeline/planner/tests/mod.rs rename to dozer-sql/src/planner/tests/mod.rs diff --git a/dozer-sql/src/pipeline/planner/tests/projection_tests.rs b/dozer-sql/src/planner/tests/projection_tests.rs similarity index 90% rename from dozer-sql/src/pipeline/planner/tests/projection_tests.rs rename to dozer-sql/src/planner/tests/projection_tests.rs index f288998c20..525832aea8 100644 --- a/dozer-sql/src/pipeline/planner/tests/projection_tests.rs +++ b/dozer-sql/src/planner/tests/projection_tests.rs @@ -1,11 +1,11 @@ -use crate::pipeline::expression::aggregate::AggregateFunctionType; +use dozer_sql_expression::aggregate::AggregateFunctionType; -use crate::pipeline::expression::execution::Expression; -use crate::pipeline::expression::operator::BinaryOperatorType; -use crate::pipeline::expression::scalar::common::ScalarFunctionType; -use crate::pipeline::planner::projection::CommonPlanner; +use crate::planner::projection::CommonPlanner; +use dozer_sql_expression::execution::Expression; +use dozer_sql_expression::operator::BinaryOperatorType; +use dozer_sql_expression::scalar::common::ScalarFunctionType; -use crate::pipeline::tests::utils::get_select; +use crate::tests::utils::get_select; use dozer_types::types::{Field, FieldDefinition, FieldType, Schema, SourceDefinition}; #[test] diff --git a/dozer-sql/src/pipeline/planner/tests/schema_tests.rs b/dozer-sql/src/planner/tests/schema_tests.rs similarity index 96% rename from dozer-sql/src/pipeline/planner/tests/schema_tests.rs rename to dozer-sql/src/planner/tests/schema_tests.rs index faeac63a0e..fa5e20762d 100644 --- a/dozer-sql/src/pipeline/planner/tests/schema_tests.rs +++ b/dozer-sql/src/planner/tests/schema_tests.rs @@ -1,5 +1,5 @@ -use crate::pipeline::planner::projection::CommonPlanner; -use crate::pipeline::tests::utils::get_select; +use crate::planner::projection::CommonPlanner; +use crate::tests::utils::get_select; use dozer_types::types::{FieldDefinition, FieldType, Schema, SourceDefinition}; #[test] diff --git a/dozer-sql/src/pipeline/product/join/factory.rs b/dozer-sql/src/product/join/factory.rs similarity index 96% rename from dozer-sql/src/pipeline/product/join/factory.rs rename to dozer-sql/src/product/join/factory.rs index c5d4900db1..b5b6f367dc 100644 --- a/dozer-sql/src/pipeline/product/join/factory.rs +++ b/dozer-sql/src/product/join/factory.rs @@ -5,18 +5,21 @@ use dozer_core::{ processor_record::ProcessorRecordStore, DEFAULT_PORT_HANDLE, }; +use dozer_sql_expression::{ + builder::{ExpressionBuilder, NameOrAlias}, + sqlparser::ast::{ + BinaryOperator, Expr as SqlExpr, Ident, JoinConstraint as SqlJoinConstraint, + JoinOperator as SqlJoinOperator, + }, +}; use dozer_types::{ errors::internal::BoxedError, types::{FieldDefinition, Schema}, }; -use sqlparser::ast::{ - BinaryOperator, Expr as SqlExpr, Ident, JoinConstraint as SqlJoinConstraint, - JoinOperator as SqlJoinOperator, -}; -use crate::pipeline::expression::builder::extend_schema_source_def; -use crate::pipeline::{errors::JoinError, expression::builder::NameOrAlias}; -use crate::pipeline::{errors::PipelineError, expression::builder::ExpressionBuilder}; +use crate::errors::JoinError; +use crate::errors::PipelineError; +use dozer_sql_expression::builder::extend_schema_source_def; use super::{ operator::{JoinOperator, JoinType}, @@ -196,7 +199,7 @@ fn append_schema(left_schema: &Schema, right_schema: &Schema) -> Schema { } fn parse_join_constraint( - expression: &sqlparser::ast::Expr, + expression: &dozer_sql_expression::sqlparser::ast::Expr, left_join_table: &Schema, right_join_table: &Schema, ) -> Result<(Vec, Vec), JoinError> { diff --git a/dozer-sql/src/pipeline/product/join/mod.rs b/dozer-sql/src/product/join/mod.rs similarity index 72% rename from dozer-sql/src/pipeline/product/join/mod.rs rename to dozer-sql/src/product/join/mod.rs index 6437c53947..e3e84169aa 100644 --- a/dozer-sql/src/pipeline/product/join/mod.rs +++ b/dozer-sql/src/product/join/mod.rs @@ -1,4 +1,4 @@ -use crate::pipeline::errors::JoinError; +use crate::errors::JoinError; pub mod factory; diff --git a/dozer-sql/src/pipeline/product/join/operator/mod.rs b/dozer-sql/src/product/join/operator/mod.rs similarity index 99% rename from dozer-sql/src/pipeline/product/join/operator/mod.rs rename to dozer-sql/src/product/join/operator/mod.rs index 9e1b9e712c..2bd0b1f6cd 100644 --- a/dozer-sql/src/pipeline/product/join/operator/mod.rs +++ b/dozer-sql/src/product/join/operator/mod.rs @@ -4,7 +4,7 @@ use dozer_core::{ }; use dozer_types::types::{Record, Schema, Timestamp}; -use crate::pipeline::{ +use crate::{ errors::JoinError, utils::serialize::{Cursor, SerializationError}, }; diff --git a/dozer-sql/src/pipeline/product/join/operator/table.rs b/dozer-sql/src/product/join/operator/table.rs similarity index 99% rename from dozer-sql/src/pipeline/product/join/operator/table.rs rename to dozer-sql/src/product/join/operator/table.rs index b9604786a8..cbd2863b19 100644 --- a/dozer-sql/src/pipeline/product/join/operator/table.rs +++ b/dozer-sql/src/product/join/operator/table.rs @@ -16,7 +16,7 @@ use dozer_types::{ }; use linked_hash_map::LinkedHashMap; -use crate::pipeline::{ +use crate::{ errors::JoinError, utils::{ record_hashtable_key::{get_record_hash, RecordKey}, diff --git a/dozer-sql/src/pipeline/product/join/processor.rs b/dozer-sql/src/product/join/processor.rs similarity index 97% rename from dozer-sql/src/pipeline/product/join/processor.rs rename to dozer-sql/src/product/join/processor.rs index 219209ae02..23553f5213 100644 --- a/dozer-sql/src/pipeline/product/join/processor.rs +++ b/dozer-sql/src/product/join/processor.rs @@ -13,7 +13,7 @@ use metrics::{ increment_counter, }; -use crate::pipeline::errors::PipelineError; +use crate::errors::PipelineError; use super::operator::{JoinAction, JoinBranch, JoinOperator}; @@ -83,7 +83,7 @@ impl Processor for ProductProcessor { let from_branch = match from_port { 0 => JoinBranch::Left, 1 => JoinBranch::Right, - _ => return Err(PipelineError::InvalidPort(from_port).into()), + _ => return Err(PipelineError::InvalidPortHandle(from_port).into()), }; let now = std::time::Instant::now(); diff --git a/dozer-sql/src/pipeline/product/mod.rs b/dozer-sql/src/product/mod.rs similarity index 100% rename from dozer-sql/src/pipeline/product/mod.rs rename to dozer-sql/src/product/mod.rs diff --git a/dozer-sql/src/pipeline/product/set/mod.rs b/dozer-sql/src/product/set/mod.rs similarity index 100% rename from dozer-sql/src/pipeline/product/set/mod.rs rename to dozer-sql/src/product/set/mod.rs diff --git a/dozer-sql/src/pipeline/product/set/operator.rs b/dozer-sql/src/product/set/operator.rs similarity index 96% rename from dozer-sql/src/pipeline/product/set/operator.rs rename to dozer-sql/src/product/set/operator.rs index 861ed47dde..0eef1111b4 100644 --- a/dozer-sql/src/pipeline/product/set/operator.rs +++ b/dozer-sql/src/product/set/operator.rs @@ -1,7 +1,7 @@ use super::record_map::{CountingRecordMap, CountingRecordMapEnum}; -use crate::pipeline::errors::PipelineError; +use crate::errors::PipelineError; use dozer_core::processor_record::ProcessorRecord; -use sqlparser::ast::{SetOperator, SetQuantifier}; +use dozer_sql_expression::sqlparser::ast::{SetOperator, SetQuantifier}; #[derive(Clone, Debug, PartialEq, Eq, Copy)] pub enum SetAction { diff --git a/dozer-sql/src/pipeline/product/set/record_map/bloom.rs b/dozer-sql/src/product/set/record_map/bloom.rs similarity index 100% rename from dozer-sql/src/pipeline/product/set/record_map/bloom.rs rename to dozer-sql/src/product/set/record_map/bloom.rs diff --git a/dozer-sql/src/pipeline/product/set/record_map/mod.rs b/dozer-sql/src/product/set/record_map/mod.rs similarity index 99% rename from dozer-sql/src/pipeline/product/set/record_map/mod.rs rename to dozer-sql/src/product/set/record_map/mod.rs index 0660787ecc..eeab3e9ddb 100644 --- a/dozer-sql/src/pipeline/product/set/record_map/mod.rs +++ b/dozer-sql/src/product/set/record_map/mod.rs @@ -6,7 +6,7 @@ use dozer_types::serde::{Deserialize, Serialize}; use enum_dispatch::enum_dispatch; use std::collections::HashMap; -use crate::pipeline::utils::serialize::{ +use crate::utils::serialize::{ deserialize_bincode, deserialize_record, deserialize_u64, serialize_bincode, serialize_record, serialize_u64, Cursor, DeserializationError, SerializationError, }; diff --git a/dozer-sql/src/pipeline/product/set/set_factory.rs b/dozer-sql/src/product/set/set_factory.rs similarity index 96% rename from dozer-sql/src/pipeline/product/set/set_factory.rs rename to dozer-sql/src/product/set/set_factory.rs index 1de810ee4d..04910aa5b6 100644 --- a/dozer-sql/src/pipeline/product/set/set_factory.rs +++ b/dozer-sql/src/product/set/set_factory.rs @@ -1,16 +1,16 @@ use std::collections::HashMap; -use crate::pipeline::errors::PipelineError; -use crate::pipeline::errors::SetError; +use crate::errors::PipelineError; +use crate::errors::SetError; use dozer_core::processor_record::ProcessorRecordStore; use dozer_core::{ node::{OutputPortDef, OutputPortType, PortHandle, Processor, ProcessorFactory}, DEFAULT_PORT_HANDLE, }; +use dozer_sql_expression::sqlparser::ast::{SetOperator, SetQuantifier}; use dozer_types::errors::internal::BoxedError; use dozer_types::types::{FieldDefinition, Schema, SourceDefinition}; -use sqlparser::ast::{SetOperator, SetQuantifier}; use super::operator::SetOperation; use super::set_processor::SetProcessor; diff --git a/dozer-sql/src/pipeline/product/set/set_processor.rs b/dozer-sql/src/product/set/set_processor.rs similarity index 98% rename from dozer-sql/src/pipeline/product/set/set_processor.rs rename to dozer-sql/src/product/set/set_processor.rs index 5d213a7442..a9f74034fc 100644 --- a/dozer-sql/src/pipeline/product/set/set_processor.rs +++ b/dozer-sql/src/product/set/set_processor.rs @@ -3,8 +3,8 @@ use super::record_map::{ AccurateCountingRecordMap, CountingRecordMap, CountingRecordMapEnum, ProbabilisticCountingRecordMap, }; -use crate::pipeline::errors::{PipelineError, ProductError, SetError}; -use crate::pipeline::utils::serialize::Cursor; +use crate::errors::{PipelineError, ProductError, SetError}; +use crate::utils::serialize::Cursor; use dozer_core::channels::ProcessorChannelForwarder; use dozer_core::dozer_log::storage::Object; use dozer_core::epoch::Epoch; diff --git a/dozer-sql/src/pipeline/product/table/factory.rs b/dozer-sql/src/product/table/factory.rs similarity index 93% rename from dozer-sql/src/pipeline/product/table/factory.rs rename to dozer-sql/src/product/table/factory.rs index cc7240288b..a016f368f9 100644 --- a/dozer-sql/src/pipeline/product/table/factory.rs +++ b/dozer-sql/src/product/table/factory.rs @@ -5,16 +5,14 @@ use dozer_core::{ processor_record::ProcessorRecordStore, DEFAULT_PORT_HANDLE, }; +use dozer_sql_expression::{ + builder::{extend_schema_source_def, NameOrAlias}, + sqlparser::ast::TableFactor, +}; use dozer_types::{errors::internal::BoxedError, types::Schema}; -use sqlparser::ast::TableFactor; -use crate::pipeline::{ - errors::{PipelineError, ProductError}, - expression::builder::extend_schema_source_def, -}; -use crate::pipeline::{ - expression::builder::NameOrAlias, window::builder::string_from_sql_object_name, -}; +use crate::errors::{PipelineError, ProductError}; +use crate::window::builder::string_from_sql_object_name; use super::processor::TableProcessor; diff --git a/dozer-sql/src/pipeline/product/table/mod.rs b/dozer-sql/src/product/table/mod.rs similarity index 100% rename from dozer-sql/src/pipeline/product/table/mod.rs rename to dozer-sql/src/product/table/mod.rs diff --git a/dozer-sql/src/pipeline/product/table/processor.rs b/dozer-sql/src/product/table/processor.rs similarity index 100% rename from dozer-sql/src/pipeline/product/table/processor.rs rename to dozer-sql/src/product/table/processor.rs diff --git a/dozer-sql/src/pipeline/product/tests/left_join_test.rs b/dozer-sql/src/product/tests/left_join_test.rs similarity index 100% rename from dozer-sql/src/pipeline/product/tests/left_join_test.rs rename to dozer-sql/src/product/tests/left_join_test.rs diff --git a/dozer-sql/src/pipeline/product/tests/mod.rs b/dozer-sql/src/product/tests/mod.rs similarity index 100% rename from dozer-sql/src/pipeline/product/tests/mod.rs rename to dozer-sql/src/product/tests/mod.rs diff --git a/dozer-sql/src/pipeline/product/tests/processor_tests_utils.rs b/dozer-sql/src/product/tests/processor_tests_utils.rs similarity index 98% rename from dozer-sql/src/pipeline/product/tests/processor_tests_utils.rs rename to dozer-sql/src/product/tests/processor_tests_utils.rs index e1a3450204..2bba45f95e 100644 --- a/dozer-sql/src/pipeline/product/tests/processor_tests_utils.rs +++ b/dozer-sql/src/product/tests/processor_tests_utils.rs @@ -6,7 +6,7 @@ use dozer_core::{ }; use dozer_types::{parking_lot::RwLock, types::Schema}; -use crate::pipeline::{ +use crate::{ builder::get_select, errors::PipelineError, product::factory::ProductProcessorFactory, }; diff --git a/dozer-sql/src/pipeline/product/tests/set_operator_test.rs b/dozer-sql/src/product/tests/set_operator_test.rs similarity index 100% rename from dozer-sql/src/pipeline/product/tests/set_operator_test.rs rename to dozer-sql/src/product/tests/set_operator_test.rs diff --git a/dozer-sql/src/pipeline/projection/factory.rs b/dozer-sql/src/projection/factory.rs similarity index 88% rename from dozer-sql/src/pipeline/projection/factory.rs rename to dozer-sql/src/projection/factory.rs index e76feabc3f..54e7df80f0 100644 --- a/dozer-sql/src/pipeline/projection/factory.rs +++ b/dozer-sql/src/projection/factory.rs @@ -5,17 +5,18 @@ use dozer_core::{ processor_record::ProcessorRecordStore, DEFAULT_PORT_HANDLE, }; +use dozer_sql_expression::{ + builder::ExpressionBuilder, + execution::Expression, + sqlparser::ast::{Expr, Ident, SelectItem}, +}; use dozer_types::models::udf_config::UdfConfig; use dozer_types::{ errors::internal::BoxedError, types::{FieldDefinition, Schema}, }; -use sqlparser::ast::{Expr, Ident, SelectItem}; -use crate::pipeline::{ - errors::PipelineError, - expression::{builder::ExpressionBuilder, execution::Expression}, -}; +use crate::errors::PipelineError; use super::processor::ProjectionProcessor; @@ -138,16 +139,13 @@ pub(crate) fn parse_sql_select_item( ) -> Result<(String, Expression), PipelineError> { match sql { SelectItem::UnnamedExpr(sql_expr) => { - match ExpressionBuilder::new(0).parse_sql_expression(true, sql_expr, schema, udfs) { - Ok(expr) => Ok((sql_expr.to_string(), expr)), - Err(error) => Err(error), - } + let expr = + ExpressionBuilder::new(0).parse_sql_expression(true, sql_expr, schema, udfs)?; + Ok((sql_expr.to_string(), expr)) } SelectItem::ExprWithAlias { expr, alias } => { - match ExpressionBuilder::new(0).parse_sql_expression(true, expr, schema, udfs) { - Ok(expr) => Ok((alias.value.clone(), expr)), - Err(error) => Err(error), - } + let expr = ExpressionBuilder::new(0).parse_sql_expression(true, expr, schema, udfs)?; + Ok((alias.value.clone(), expr)) } SelectItem::Wildcard(_) => Err(PipelineError::InvalidOperator("*".to_string())), SelectItem::QualifiedWildcard(ref object_name, ..) => { diff --git a/dozer-sql/src/pipeline/projection/mod.rs b/dozer-sql/src/projection/mod.rs similarity index 100% rename from dozer-sql/src/pipeline/projection/mod.rs rename to dozer-sql/src/projection/mod.rs diff --git a/dozer-sql/src/pipeline/projection/processor.rs b/dozer-sql/src/projection/processor.rs similarity index 97% rename from dozer-sql/src/pipeline/projection/processor.rs rename to dozer-sql/src/projection/processor.rs index 79f99a75c8..ec8a03ec43 100644 --- a/dozer-sql/src/pipeline/projection/processor.rs +++ b/dozer-sql/src/projection/processor.rs @@ -1,5 +1,5 @@ -use crate::pipeline::errors::PipelineError; -use crate::pipeline::expression::execution::Expression; +use crate::errors::PipelineError; +use dozer_sql_expression::execution::Expression; use dozer_core::channels::ProcessorChannelForwarder; use dozer_core::dozer_log::storage::Object; diff --git a/dozer-sql/src/pipeline/selection/factory.rs b/dozer-sql/src/selection/factory.rs similarity index 94% rename from dozer-sql/src/pipeline/selection/factory.rs rename to dozer-sql/src/selection/factory.rs index 7caeb1b809..6bcf6de27f 100644 --- a/dozer-sql/src/pipeline/selection/factory.rs +++ b/dozer-sql/src/selection/factory.rs @@ -1,15 +1,15 @@ use std::collections::HashMap; -use crate::pipeline::errors::PipelineError; -use crate::pipeline::expression::builder::ExpressionBuilder; +use crate::errors::PipelineError; use dozer_core::processor_record::ProcessorRecordStore; use dozer_core::{ node::{OutputPortDef, OutputPortType, PortHandle, Processor, ProcessorFactory}, DEFAULT_PORT_HANDLE, }; +use dozer_sql_expression::builder::ExpressionBuilder; +use dozer_sql_expression::sqlparser::ast::Expr as SqlExpr; use dozer_types::models::udf_config::UdfConfig; use dozer_types::{errors::internal::BoxedError, types::Schema}; -use sqlparser::ast::Expr as SqlExpr; use super::processor::SelectionProcessor; diff --git a/dozer-sql/src/pipeline/selection/mod.rs b/dozer-sql/src/selection/mod.rs similarity index 100% rename from dozer-sql/src/pipeline/selection/mod.rs rename to dozer-sql/src/selection/mod.rs diff --git a/dozer-sql/src/pipeline/selection/processor.rs b/dozer-sql/src/selection/processor.rs similarity index 98% rename from dozer-sql/src/pipeline/selection/processor.rs rename to dozer-sql/src/selection/processor.rs index 772ffda3a3..e2677a0cc8 100644 --- a/dozer-sql/src/pipeline/selection/processor.rs +++ b/dozer-sql/src/selection/processor.rs @@ -1,4 +1,3 @@ -use crate::pipeline::expression::execution::Expression; use dozer_core::channels::ProcessorChannelForwarder; use dozer_core::dozer_log::storage::Object; use dozer_core::epoch::Epoch; @@ -6,6 +5,7 @@ use dozer_core::executor_operation::ProcessorOperation; use dozer_core::node::{PortHandle, Processor}; use dozer_core::processor_record::ProcessorRecordStore; use dozer_core::DEFAULT_PORT_HANDLE; +use dozer_sql_expression::execution::Expression; use dozer_types::errors::internal::BoxedError; use dozer_types::types::{Field, Schema}; diff --git a/dozer-sql/src/pipeline/table_operator/factory.rs b/dozer-sql/src/table_operator/factory.rs similarity index 98% rename from dozer-sql/src/pipeline/table_operator/factory.rs rename to dozer-sql/src/table_operator/factory.rs index 2953536600..ca3bcd1189 100644 --- a/dozer-sql/src/pipeline/table_operator/factory.rs +++ b/dozer-sql/src/table_operator/factory.rs @@ -5,13 +5,16 @@ use dozer_core::{ processor_record::ProcessorRecordStore, DEFAULT_PORT_HANDLE, }; +use dozer_sql_expression::{ + builder::ExpressionBuilder, + execution::Expression, + sqlparser::ast::{Expr, FunctionArg, FunctionArgExpr, Value}, +}; use dozer_types::models::udf_config::UdfConfig; use dozer_types::{errors::internal::BoxedError, types::Schema}; -use sqlparser::ast::{Expr, FunctionArg, FunctionArgExpr, Value}; -use crate::pipeline::{ +use crate::{ errors::{PipelineError, TableOperatorError}, - expression::{builder::ExpressionBuilder, execution::Expression}, pipeline_builder::from_builder::TableOperatorDescriptor, }; diff --git a/dozer-sql/src/pipeline/table_operator/lifetime.rs b/dozer-sql/src/table_operator/lifetime.rs similarity index 96% rename from dozer-sql/src/pipeline/table_operator/lifetime.rs rename to dozer-sql/src/table_operator/lifetime.rs index ae93f92fdf..975f200b55 100644 --- a/dozer-sql/src/pipeline/table_operator/lifetime.rs +++ b/dozer-sql/src/table_operator/lifetime.rs @@ -1,7 +1,8 @@ use dozer_core::processor_record::{ProcessorRecord, ProcessorRecordStore}; +use dozer_sql_expression::execution::Expression; use dozer_types::types::{Field, Lifetime, Schema}; -use crate::pipeline::{errors::TableOperatorError, expression::execution::Expression}; +use crate::errors::TableOperatorError; use super::operator::{TableOperator, TableOperatorType}; diff --git a/dozer-sql/src/pipeline/table_operator/mod.rs b/dozer-sql/src/table_operator/mod.rs similarity index 100% rename from dozer-sql/src/pipeline/table_operator/mod.rs rename to dozer-sql/src/table_operator/mod.rs diff --git a/dozer-sql/src/pipeline/table_operator/operator.rs b/dozer-sql/src/table_operator/operator.rs similarity index 84% rename from dozer-sql/src/pipeline/table_operator/operator.rs rename to dozer-sql/src/table_operator/operator.rs index 66c5edf166..387b2a9a60 100644 --- a/dozer-sql/src/pipeline/table_operator/operator.rs +++ b/dozer-sql/src/table_operator/operator.rs @@ -1,9 +1,9 @@ -use crate::pipeline::table_operator::lifetime::LifetimeTableOperator; +use crate::table_operator::lifetime::LifetimeTableOperator; use dozer_core::processor_record::{ProcessorRecord, ProcessorRecordStore}; use dozer_types::types::Schema; use enum_dispatch::enum_dispatch; -use crate::pipeline::errors::TableOperatorError; +use crate::errors::TableOperatorError; #[enum_dispatch] pub trait TableOperator: Send + Sync { diff --git a/dozer-sql/src/pipeline/table_operator/processor.rs b/dozer-sql/src/table_operator/processor.rs similarity index 98% rename from dozer-sql/src/pipeline/table_operator/processor.rs rename to dozer-sql/src/table_operator/processor.rs index 2a3d5fa9b8..af334f7a42 100644 --- a/dozer-sql/src/pipeline/table_operator/processor.rs +++ b/dozer-sql/src/table_operator/processor.rs @@ -8,7 +8,7 @@ use dozer_core::DEFAULT_PORT_HANDLE; use dozer_types::errors::internal::BoxedError; use dozer_types::types::Schema; -use crate::pipeline::errors::PipelineError; +use crate::errors::PipelineError; use super::operator::{TableOperator, TableOperatorType}; diff --git a/dozer-sql/src/pipeline/table_operator/tests/mod.rs b/dozer-sql/src/table_operator/tests/mod.rs similarity index 100% rename from dozer-sql/src/pipeline/table_operator/tests/mod.rs rename to dozer-sql/src/table_operator/tests/mod.rs diff --git a/dozer-sql/src/pipeline/table_operator/tests/operator_test.rs b/dozer-sql/src/table_operator/tests/operator_test.rs similarity index 94% rename from dozer-sql/src/pipeline/table_operator/tests/operator_test.rs rename to dozer-sql/src/table_operator/tests/operator_test.rs index 274bc86335..072e448818 100644 --- a/dozer-sql/src/pipeline/table_operator/tests/operator_test.rs +++ b/dozer-sql/src/table_operator/tests/operator_test.rs @@ -1,15 +1,13 @@ use std::time::Duration; use dozer_core::processor_record::ProcessorRecordStore; +use dozer_sql_expression::execution::Expression; use dozer_types::{ chrono::DateTime, types::{Field, FieldDefinition, FieldType, Lifetime, Record, Schema, SourceDefinition}, }; -use crate::pipeline::{ - expression::execution::Expression, - table_operator::{lifetime::LifetimeTableOperator, operator::TableOperator}, -}; +use crate::table_operator::{lifetime::LifetimeTableOperator, operator::TableOperator}; #[test] fn test_lifetime() { diff --git a/dozer-sql/src/pipeline/tests/builder_test.rs b/dozer-sql/src/tests/builder_test.rs similarity index 99% rename from dozer-sql/src/pipeline/tests/builder_test.rs rename to dozer-sql/src/tests/builder_test.rs index c614970575..6df1d62af3 100644 --- a/dozer-sql/src/pipeline/tests/builder_test.rs +++ b/dozer-sql/src/tests/builder_test.rs @@ -26,7 +26,7 @@ use std::collections::HashMap; use std::sync::atomic::AtomicBool; use std::sync::Arc; -use crate::pipeline::builder::statement_to_pipeline; +use crate::builder::statement_to_pipeline; /// Test Source #[derive(Debug)] diff --git a/dozer-sql/src/pipeline/tests/mod.rs b/dozer-sql/src/tests/mod.rs similarity index 55% rename from dozer-sql/src/pipeline/tests/mod.rs rename to dozer-sql/src/tests/mod.rs index 002ff1a173..0ed1750729 100644 --- a/dozer-sql/src/pipeline/tests/mod.rs +++ b/dozer-sql/src/tests/mod.rs @@ -1,5 +1,2 @@ -#[cfg(test)] mod builder_test; - -#[cfg(test)] pub mod utils; diff --git a/dozer-sql/src/pipeline/tests/utils.rs b/dozer-sql/src/tests/utils.rs similarity index 91% rename from dozer-sql/src/pipeline/tests/utils.rs rename to dozer-sql/src/tests/utils.rs index db7b2e4e6c..a123fbfce1 100644 --- a/dozer-sql/src/pipeline/tests/utils.rs +++ b/dozer-sql/src/tests/utils.rs @@ -1,5 +1,5 @@ -use crate::pipeline::errors::PipelineError; -use sqlparser::{ +use crate::errors::PipelineError; +use dozer_sql_expression::sqlparser::{ ast::{Query, Select, SetExpr, Statement}, dialect::DozerDialect, parser::Parser, diff --git a/dozer-sql/src/pipeline/utils/mod.rs b/dozer-sql/src/utils/mod.rs similarity index 100% rename from dozer-sql/src/pipeline/utils/mod.rs rename to dozer-sql/src/utils/mod.rs diff --git a/dozer-sql/src/pipeline/utils/record_hashtable_key.rs b/dozer-sql/src/utils/record_hashtable_key.rs similarity index 100% rename from dozer-sql/src/pipeline/utils/record_hashtable_key.rs rename to dozer-sql/src/utils/record_hashtable_key.rs diff --git a/dozer-sql/src/pipeline/utils/serialize.rs b/dozer-sql/src/utils/serialize.rs similarity index 100% rename from dozer-sql/src/pipeline/utils/serialize.rs rename to dozer-sql/src/utils/serialize.rs diff --git a/dozer-sql/src/pipeline/window/builder.rs b/dozer-sql/src/window/builder.rs similarity index 98% rename from dozer-sql/src/pipeline/window/builder.rs rename to dozer-sql/src/window/builder.rs index adfeadf130..53ef7a02d9 100644 --- a/dozer-sql/src/pipeline/window/builder.rs +++ b/dozer-sql/src/window/builder.rs @@ -1,12 +1,14 @@ +use dozer_sql_expression::{ + builder::ExpressionBuilder, + sqlparser::ast::{Expr, FunctionArg, FunctionArgExpr, Ident, ObjectName, Value}, +}; use dozer_types::{ chrono::Duration, types::{FieldDefinition, Schema}, }; -use sqlparser::ast::{Expr, FunctionArg, FunctionArgExpr, Ident, ObjectName, Value}; -use crate::pipeline::{ +use crate::{ errors::{JoinError, PipelineError, WindowError}, - expression::builder::ExpressionBuilder, pipeline_builder::from_builder::TableOperatorDescriptor, }; diff --git a/dozer-sql/src/pipeline/window/factory.rs b/dozer-sql/src/window/factory.rs similarity index 99% rename from dozer-sql/src/pipeline/window/factory.rs rename to dozer-sql/src/window/factory.rs index 7d631a9ab9..3fbef96600 100644 --- a/dozer-sql/src/pipeline/window/factory.rs +++ b/dozer-sql/src/window/factory.rs @@ -7,7 +7,7 @@ use dozer_core::{ }; use dozer_types::{errors::internal::BoxedError, types::Schema}; -use crate::pipeline::{ +use crate::{ errors::{PipelineError, WindowError}, pipeline_builder::from_builder::TableOperatorDescriptor, }; diff --git a/dozer-sql/src/pipeline/window/mod.rs b/dozer-sql/src/window/mod.rs similarity index 100% rename from dozer-sql/src/pipeline/window/mod.rs rename to dozer-sql/src/window/mod.rs diff --git a/dozer-sql/src/pipeline/window/operator.rs b/dozer-sql/src/window/operator.rs similarity index 99% rename from dozer-sql/src/pipeline/window/operator.rs rename to dozer-sql/src/window/operator.rs index 5a8ebfff96..9172f473f8 100644 --- a/dozer-sql/src/pipeline/window/operator.rs +++ b/dozer-sql/src/window/operator.rs @@ -4,7 +4,7 @@ use dozer_types::{ types::{Field, FieldDefinition, FieldType, Record, Schema, SourceDefinition}, }; -use crate::pipeline::errors::WindowError; +use crate::errors::WindowError; #[derive(Clone, Debug)] pub enum WindowType { diff --git a/dozer-sql/src/pipeline/window/processor.rs b/dozer-sql/src/window/processor.rs similarity index 98% rename from dozer-sql/src/pipeline/window/processor.rs rename to dozer-sql/src/window/processor.rs index 1de73a5f7e..4439fb88bf 100644 --- a/dozer-sql/src/pipeline/window/processor.rs +++ b/dozer-sql/src/window/processor.rs @@ -1,4 +1,4 @@ -use crate::pipeline::errors::PipelineError; +use crate::errors::PipelineError; use dozer_core::channels::ProcessorChannelForwarder; use dozer_core::dozer_log::storage::Object; use dozer_core::epoch::Epoch; diff --git a/dozer-sql/src/pipeline/window/tests/mod.rs b/dozer-sql/src/window/tests/mod.rs similarity index 100% rename from dozer-sql/src/pipeline/window/tests/mod.rs rename to dozer-sql/src/window/tests/mod.rs diff --git a/dozer-sql/src/pipeline/window/tests/operator_test.rs b/dozer-sql/src/window/tests/operator_test.rs similarity index 98% rename from dozer-sql/src/pipeline/window/tests/operator_test.rs rename to dozer-sql/src/window/tests/operator_test.rs index f513859b4a..21d8dd9e4d 100644 --- a/dozer-sql/src/pipeline/window/tests/operator_test.rs +++ b/dozer-sql/src/window/tests/operator_test.rs @@ -5,7 +5,7 @@ use dozer_types::{ types::{Field, FieldDefinition, FieldType, Schema, SourceDefinition}, }; -use crate::pipeline::window::operator::WindowType; +use crate::window::operator::WindowType; #[test] fn test_hop() { diff --git a/dozer-tests/src/sql_tests/helper/pipeline.rs b/dozer-tests/src/sql_tests/helper/pipeline.rs index fa6af0844e..cf07aa9283 100644 --- a/dozer-tests/src/sql_tests/helper/pipeline.rs +++ b/dozer-tests/src/sql_tests/helper/pipeline.rs @@ -18,7 +18,7 @@ use dozer_core::{Dag, DEFAULT_PORT_HANDLE}; use dozer_core::executor::{DagExecutor, ExecutorOptions}; use crossbeam::channel::{Receiver, Sender}; -use dozer_sql::pipeline::builder::statement_to_pipeline; +use dozer_sql::builder::statement_to_pipeline; use dozer_types::errors::internal::BoxedError; use dozer_types::ingestion_types::IngestionMessage; diff --git a/dozer-types/src/errors/types.rs b/dozer-types/src/errors/types.rs index 3e7dc2c1bf..b29af7d35e 100644 --- a/dozer-types/src/errors/types.rs +++ b/dozer-types/src/errors/types.rs @@ -1,6 +1,5 @@ use super::internal::BoxedError; use crate::types::FieldType; -use geo::vincenty_distance::FailedToConvergeError; use serde_json::Number; use std::num::ParseIntError; use thiserror::Error; @@ -20,16 +19,10 @@ pub enum TypeError { nullable: bool, value: String, }, - #[error("Invalid timestamp")] - InvalidTimestamp, - #[error("Ambiguous timestamp")] - AmbiguousTimestamp, #[error("Serialization failed: {0}")] SerializationError(#[source] SerializationError), #[error("Failed to parse the field: {0}")] DeserializationError(#[source] DeserializationError), - #[error("Failed to calculate distance: {0}")] - DistanceCalculationError(#[source] FailedToConvergeError), } #[derive(Error, Debug)] diff --git a/dozer-types/src/types/field.rs b/dozer-types/src/types/field.rs index b675b91197..8ebd525083 100644 --- a/dozer-types/src/types/field.rs +++ b/dozer-types/src/types/field.rs @@ -1,4 +1,4 @@ -use crate::errors::types::{DeserializationError, TypeError}; +use crate::errors::types::DeserializationError; use crate::json_types::JsonValue; use crate::types::{ DozerDuration, DozerPoint, FieldDefinition, Schema, SourceDefinition, TimeUnit, @@ -136,12 +136,12 @@ impl Field { match timestamp { LocalResult::Single(v) => Ok(Field::Timestamp(DateTime::from(v))), - LocalResult::Ambiguous(_, _) => Err(DeserializationError::Custom(Box::new( - TypeError::AmbiguousTimestamp, - ))), - LocalResult::None => Err(DeserializationError::Custom(Box::new( - TypeError::InvalidTimestamp, - ))), + LocalResult::Ambiguous(_, _) => Err(DeserializationError::Custom( + "Ambiguous timestamp".to_string().into(), + )), + LocalResult::None => Err(DeserializationError::Custom( + "Invalid timestamp".to_string().into(), + )), } } 11 => Ok(Field::Date(NaiveDate::parse_from_str( @@ -458,52 +458,38 @@ impl Field { } } - pub fn to_string(&self) -> Option { - match self { - Field::UInt(u) => Some(format!("{u}")), - Field::U128(u) => Some(format!("{u}")), - Field::Int(i) => Some(format!("{i}")), - Field::I128(i) => Some(format!("{i}")), - Field::Float(f) => Some(format!("{f}")), - Field::Decimal(d) => Some(format!("{d}")), - Field::Boolean(i) => Some(if *i { - "TRUE".to_string() - } else { - "FALSE".to_string() - }), - Field::String(s) => Some(s.to_owned()), - Field::Text(t) => Some(t.to_owned()), - Field::Date(d) => Some(d.format("%Y-%m-%d").to_string()), - Field::Timestamp(t) => Some(t.to_rfc3339()), - Field::Binary(b) => Some(format!("{b:X?}")), - Field::Json(j) => Some(j.to_string()), - Field::Null => Some("".to_string()), - _ => None, + pub fn to_string(&self) -> String { + match self { + Field::UInt(u) => format!("{u}"), + Field::U128(u) => format!("{u}"), + Field::Int(i) => format!("{i}"), + Field::I128(i) => format!("{i}"), + Field::Float(f) => format!("{f}"), + Field::Decimal(d) => format!("{d}"), + Field::Boolean(i) => { + if *i { + "TRUE".to_string() + } else { + "FALSE".to_string() + } + } + Field::String(s) => s.to_owned(), + Field::Text(t) => t.to_owned(), + Field::Date(d) => d.format(DATE_FORMAT).to_string(), + Field::Timestamp(t) => t.to_rfc3339(), + Field::Binary(b) => format!("{b:X?}"), + Field::Json(j) => j.to_string(), + Field::Point(p) => { + let (x, y) = p.0.x_y(); + format!("point({}, {})", x.0, y.0) + } + Field::Duration(d) => format!("{:?}", d.0), + Field::Null => "".to_string(), } } - pub fn to_text(&self) -> Option { - match self { - Field::UInt(u) => Some(format!("{u}")), - Field::U128(u) => Some(format!("{u}")), - Field::Int(i) => Some(format!("{i}")), - Field::I128(i) => Some(format!("{i}")), - Field::Float(f) => Some(format!("{f}")), - Field::Decimal(d) => Some(format!("{d}")), - Field::Boolean(i) => Some(if *i { - "TRUE".to_string() - } else { - "FALSE".to_string() - }), - Field::String(s) => Some(s.to_owned()), - Field::Text(t) => Some(t.to_owned()), - Field::Date(d) => Some(d.format("%Y-%m-%d").to_string()), - Field::Timestamp(t) => Some(t.to_rfc3339()), - Field::Binary(b) => Some(format!("{b:X?}")), - Field::Json(j) => Some(j.to_string()), - Field::Null => Some("".to_string()), - _ => None, - } + pub fn to_text(&self) -> String { + self.to_string() } pub fn to_binary(&self) -> Option<&[u8]> { @@ -527,29 +513,24 @@ impl Field { } } - pub fn to_timestamp(&self) -> Result>, TypeError> { + pub fn to_timestamp(&self) -> Option> { match self { - Field::String(s) => Ok(DateTime::parse_from_rfc3339(s.as_str()).ok()), - Field::Timestamp(t) => Ok(Some(*t)), - Field::Null => match Utc.timestamp_millis_opt(0) { - LocalResult::None => Err(TypeError::InvalidTimestamp), - LocalResult::Single(v) => Ok(Some(DateTime::from(v))), - LocalResult::Ambiguous(_, _) => Err(TypeError::AmbiguousTimestamp), + Field::String(s) => DateTime::parse_from_rfc3339(s.as_str()).ok(), + Field::Text(s) => DateTime::parse_from_rfc3339(s.as_str()).ok(), + Field::Timestamp(t) => Some(*t), + Field::Date(d) => match Utc.with_ymd_and_hms(d.year(), d.month(), d.day(), 0, 0, 0) { + LocalResult::Single(v) => Some(v.into()), + _ => unreachable!(), }, - _ => Ok(None), + _ => None, } } - pub fn to_date(&self) -> Result, TypeError> { + pub fn to_date(&self) -> Option { match self { - Field::String(s) => Ok(NaiveDate::parse_from_str(s, "%Y-%m-%d").ok()), - Field::Date(d) => Ok(Some(*d)), - Field::Null => match Utc.timestamp_millis_opt(0) { - LocalResult::None => Err(TypeError::InvalidTimestamp), - LocalResult::Single(v) => Ok(Some(v.naive_utc().date())), - LocalResult::Ambiguous(_, _) => Err(TypeError::AmbiguousTimestamp), - }, - _ => Ok(None), + Field::String(s) => NaiveDate::parse_from_str(s, DATE_FORMAT).ok(), + Field::Date(d) => Some(*d), + _ => None, } } @@ -572,45 +553,38 @@ impl Field { } } - pub fn to_point(&self) -> Option<&DozerPoint> { + pub fn to_point(&self) -> Option { match self { - Field::Point(p) => Some(p), + Field::Point(p) => Some(*p), _ => None, } } - pub fn to_duration(&self) -> Result, TypeError> { + pub fn to_duration(&self) -> Option { match self { - Field::UInt(d) => Ok(Some( - DozerDuration::from_str(d.to_string().as_str()).unwrap(), + Field::UInt(d) => Some(DozerDuration( + Duration::from_nanos(*d), + TimeUnit::Nanoseconds, )), - Field::U128(d) => Ok(Some( - DozerDuration::from_str(d.to_string().as_str()).unwrap(), + Field::U128(d) => Some(DozerDuration( + Duration::from_nanos(u64::try_from(*d).ok()?), + TimeUnit::Nanoseconds, )), - Field::Int(d) => Ok(Some( - DozerDuration::from_str(d.to_string().as_str()).unwrap(), + Field::Int(d) => Some(DozerDuration( + Duration::from_nanos(u64::try_from(*d).ok()?), + TimeUnit::Nanoseconds, )), - Field::I128(d) => Ok(Some( - DozerDuration::from_str(d.to_string().as_str()).unwrap(), + Field::I128(d) => Some(DozerDuration( + Duration::from_nanos(u64::try_from(*d).ok()?), + TimeUnit::Nanoseconds, )), - Field::Duration(d) => Ok(Some(*d)), - Field::String(d) | Field::Text(d) => { - if let Ok(dur) = DozerDuration::from_str(d.as_str()) { - Ok(Some(dur)) - } else { - Err(TypeError::InvalidFieldValue { - field_type: FieldType::Duration, - nullable: false, - value: format!("{:?}", self), - }) - } - } - Field::Null => Ok(Some(DozerDuration::from_str("0").unwrap())), - _ => Err(TypeError::InvalidFieldValue { - field_type: FieldType::Duration, - nullable: false, - value: format!("{:?}", self), - }), + Field::Duration(d) => Some(*d), + Field::String(d) | Field::Text(d) => DozerDuration::from_str(d.as_str()).ok(), + Field::Null => Some(DozerDuration( + Duration::from_nanos(0), + TimeUnit::Nanoseconds, + )), + _ => None, } } @@ -624,28 +598,11 @@ impl Field { impl Display for Field { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - Field::UInt(v) => f.write_str(&format!("{v} (64-bit unsigned int)")), - Field::U128(v) => f.write_str(&format!("{v} (128-bit unsigned int)")), - Field::Int(v) => f.write_str(&format!("{v} (64-bit signed int)")), - Field::I128(v) => f.write_str(&format!("{v} (128-bit signed int)")), - Field::Float(v) => f.write_str(&format!("{v} (Float)")), - Field::Decimal(v) => f.write_str(&format!("{v} (Decimal)")), - Field::Boolean(v) => f.write_str(&format!("{v}")), - Field::String(v) => f.write_str(&v.to_string()), - Field::Text(v) => f.write_str(&v.to_string()), - Field::Binary(v) => f.write_str(&format!("{v:x?}")), - Field::Timestamp(v) => f.write_str(&format!("{v}")), - Field::Date(v) => f.write_str(&format!("{v}")), - Field::Json(v) => f.write_str(&format!("{v}")), - Field::Point(v) => f.write_str(&format!("{v} (Point)")), - Field::Duration(d) => f.write_str(&format!("{:?} {:?} (Duration)", d.0, d.1)), - Field::Null => f.write_str("NULL"), - } + f.write_str(self.to_string().as_str()) } } -#[derive(Clone, Copy, Serialize, Deserialize, Debug, PartialEq, Eq, PartialOrd, Ord)] +#[derive(Clone, Copy, Serialize, Deserialize, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] /// All field types supported in Dozer. pub enum FieldType { /// Unsigned 64-bit integer. diff --git a/dozer-types/src/types/mod.rs b/dozer-types/src/types/mod.rs index 81b7878c29..63d27a8d08 100644 --- a/dozer-types/src/types/mod.rs +++ b/dozer-types/src/types/mod.rs @@ -240,7 +240,7 @@ impl Display for Record { let v = self .values .iter() - .map(|f| Cell::new(&f.to_string().unwrap_or_default())) + .map(|f| Cell::new(&f.to_string())) .collect::>(); let mut table = Table::new(); @@ -351,12 +351,6 @@ impl FromStr for DozerDuration { } } -impl Display for DozerDuration { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - f.write_str(&format!("{:?} {:?}", self.0, self.1)) - } -} - impl DozerDuration { pub fn to_bytes(&self) -> [u8; 17] { let mut result = [0_u8; 17]; diff --git a/dozer-types/src/types/tests.rs b/dozer-types/src/types/tests.rs index efc3726b64..480e307afa 100644 --- a/dozer-types/src/types/tests.rs +++ b/dozer-types/src/types/tests.rs @@ -315,15 +315,13 @@ fn test_to_conversion() { assert!(field.to_i128().is_some()); assert!(field.to_float().is_some()); assert!(field.to_boolean().is_some()); - assert!(field.to_string().is_some()); - assert!(field.to_text().is_some()); assert!(field.to_binary().is_none()); assert!(field.to_decimal().is_some()); - assert!(field.to_timestamp().unwrap().is_none()); - assert!(field.to_date().unwrap().is_none()); + assert!(field.to_timestamp().is_none()); + assert!(field.to_date().is_none()); assert!(field.to_json().is_some()); assert!(field.to_point().is_none()); - assert!(field.to_duration().is_ok()); + assert!(field.to_duration().is_some()); assert!(field.to_null().is_none()); let field = Field::Int(1); @@ -333,15 +331,13 @@ fn test_to_conversion() { assert!(field.to_i128().is_some()); assert!(field.to_float().is_some()); assert!(field.to_boolean().is_some()); - assert!(field.to_string().is_some()); - assert!(field.to_text().is_some()); assert!(field.to_binary().is_none()); assert!(field.to_decimal().is_some()); - assert!(field.to_timestamp().unwrap().is_none()); - assert!(field.to_date().unwrap().is_none()); + assert!(field.to_timestamp().is_none()); + assert!(field.to_date().is_none()); assert!(field.to_json().is_some()); assert!(field.to_point().is_none()); - assert!(field.to_duration().is_ok()); + assert!(field.to_duration().is_some()); assert!(field.to_null().is_none()); let field = Field::U128(1); @@ -351,15 +347,13 @@ fn test_to_conversion() { assert!(field.to_i128().is_some()); assert!(field.to_float().is_some()); assert!(field.to_boolean().is_some()); - assert!(field.to_string().is_some()); - assert!(field.to_text().is_some()); assert!(field.to_binary().is_none()); assert!(field.to_decimal().is_some()); - assert!(field.to_timestamp().unwrap().is_none()); - assert!(field.to_date().unwrap().is_none()); + assert!(field.to_timestamp().is_none()); + assert!(field.to_date().is_none()); assert!(field.to_json().is_some()); assert!(field.to_point().is_none()); - assert!(field.to_duration().is_ok()); + assert!(field.to_duration().is_some()); assert!(field.to_null().is_none()); let field = Field::I128(1); @@ -369,15 +363,13 @@ fn test_to_conversion() { assert!(field.to_i128().is_some()); assert!(field.to_float().is_some()); assert!(field.to_boolean().is_some()); - assert!(field.to_string().is_some()); - assert!(field.to_text().is_some()); assert!(field.to_binary().is_none()); assert!(field.to_decimal().is_some()); - assert!(field.to_timestamp().unwrap().is_none()); - assert!(field.to_date().unwrap().is_none()); + assert!(field.to_timestamp().is_none()); + assert!(field.to_date().is_none()); assert!(field.to_json().is_some()); assert!(field.to_point().is_none()); - assert!(field.to_duration().is_ok()); + assert!(field.to_duration().is_some()); assert!(field.to_null().is_none()); let field = Field::Float(OrderedFloat::from(1.0)); @@ -387,15 +379,13 @@ fn test_to_conversion() { assert!(field.to_i128().is_some()); assert!(field.to_float().is_some()); assert!(field.to_boolean().is_some()); - assert!(field.to_string().is_some()); - assert!(field.to_text().is_some()); assert!(field.to_binary().is_none()); assert!(field.to_decimal().is_some()); - assert!(field.to_timestamp().unwrap().is_none()); - assert!(field.to_date().unwrap().is_none()); + assert!(field.to_timestamp().is_none()); + assert!(field.to_date().is_none()); assert!(field.to_json().is_some()); assert!(field.to_point().is_none()); - assert!(field.to_duration().is_err()); + assert!(field.to_duration().is_none()); assert!(field.to_null().is_none()); let field = Field::Boolean(true); @@ -405,15 +395,13 @@ fn test_to_conversion() { assert!(field.to_i128().is_none()); assert!(field.to_float().is_none()); assert!(field.to_boolean().is_some()); - assert!(field.to_string().is_some()); - assert!(field.to_text().is_some()); assert!(field.to_binary().is_none()); assert!(field.to_decimal().is_none()); - assert!(field.to_timestamp().unwrap().is_none()); - assert!(field.to_date().unwrap().is_none()); + assert!(field.to_timestamp().is_none()); + assert!(field.to_date().is_none()); assert!(field.to_json().is_some()); assert!(field.to_point().is_none()); - assert!(field.to_duration().is_err()); + assert!(field.to_duration().is_none()); assert!(field.to_null().is_none()); let field = Field::String("".to_string()); @@ -423,15 +411,13 @@ fn test_to_conversion() { assert!(field.to_i128().is_none()); assert!(field.to_float().is_none()); assert!(field.to_boolean().is_none()); - assert!(field.to_string().is_some()); - assert!(field.to_text().is_some()); assert!(field.to_binary().is_none()); assert!(field.to_decimal().is_none()); - assert!(field.to_timestamp().unwrap().is_none()); - assert!(field.to_date().unwrap().is_none()); + assert!(field.to_timestamp().is_none()); + assert!(field.to_date().is_none()); assert!(field.to_json().is_some()); assert!(field.to_point().is_none()); - assert!(field.to_duration().is_err()); + assert!(field.to_duration().is_none()); assert!(field.to_null().is_none()); let field = Field::Text("".to_string()); @@ -441,15 +427,13 @@ fn test_to_conversion() { assert!(field.to_i128().is_none()); assert!(field.to_float().is_none()); assert!(field.to_boolean().is_none()); - assert!(field.to_string().is_some()); - assert!(field.to_text().is_some()); assert!(field.to_binary().is_none()); assert!(field.to_decimal().is_none()); - assert!(field.to_timestamp().unwrap().is_none()); - assert!(field.to_date().unwrap().is_none()); + assert!(field.to_timestamp().is_none()); + assert!(field.to_date().is_none()); assert!(field.to_json().is_some()); assert!(field.to_point().is_none()); - assert!(field.to_duration().is_err()); + assert!(field.to_duration().is_none()); assert!(field.to_null().is_none()); let field = Field::Binary(vec![]); @@ -459,15 +443,13 @@ fn test_to_conversion() { assert!(field.to_i128().is_none()); assert!(field.to_float().is_none()); assert!(field.to_boolean().is_none()); - assert!(field.to_string().is_some()); - assert!(field.to_text().is_some()); assert!(field.to_binary().is_some()); assert!(field.to_decimal().is_none()); - assert!(field.to_timestamp().unwrap().is_none()); - assert!(field.to_date().unwrap().is_none()); + assert!(field.to_timestamp().is_none()); + assert!(field.to_date().is_none()); assert!(field.to_json().is_none()); assert!(field.to_point().is_none()); - assert!(field.to_duration().is_err()); + assert!(field.to_duration().is_none()); assert!(field.to_null().is_none()); let field = Field::Decimal(Decimal::from(1)); @@ -477,15 +459,13 @@ fn test_to_conversion() { assert!(field.to_i128().is_some()); assert!(field.to_float().is_some()); assert!(field.to_boolean().is_some()); - assert!(field.to_string().is_some()); - assert!(field.to_text().is_some()); assert!(field.to_binary().is_none()); assert!(field.to_decimal().is_some()); - assert!(field.to_timestamp().unwrap().is_none()); - assert!(field.to_date().unwrap().is_none()); + assert!(field.to_timestamp().is_none()); + assert!(field.to_date().is_none()); assert!(field.to_json().is_none()); assert!(field.to_point().is_none()); - assert!(field.to_duration().is_err()); + assert!(field.to_duration().is_none()); assert!(field.to_null().is_none()); let field = Field::Timestamp(DateTime::from(Utc.timestamp_millis_opt(0).unwrap())); @@ -495,15 +475,13 @@ fn test_to_conversion() { assert!(field.to_i128().is_none()); assert!(field.to_float().is_none()); assert!(field.to_boolean().is_none()); - assert!(field.to_string().is_some()); - assert!(field.to_text().is_some()); assert!(field.to_binary().is_none()); assert!(field.to_decimal().is_none()); - assert!(field.to_timestamp().unwrap().is_some()); - assert!(field.to_date().unwrap().is_none()); + assert!(field.to_timestamp().is_some()); + assert!(field.to_date().is_none()); assert!(field.to_json().is_none()); assert!(field.to_point().is_none()); - assert!(field.to_duration().is_err()); + assert!(field.to_duration().is_none()); assert!(field.to_null().is_none()); let field = Field::Date(NaiveDate::from_ymd_opt(1970, 1, 1).unwrap()); @@ -513,15 +491,13 @@ fn test_to_conversion() { assert!(field.to_i128().is_none()); assert!(field.to_float().is_none()); assert!(field.to_boolean().is_none()); - assert!(field.to_string().is_some()); - assert!(field.to_text().is_some()); assert!(field.to_binary().is_none()); assert!(field.to_decimal().is_none()); - assert!(field.to_timestamp().unwrap().is_none()); - assert!(field.to_date().unwrap().is_some()); + assert!(field.to_timestamp().is_none()); + assert!(field.to_date().is_some()); assert!(field.to_json().is_none()); assert!(field.to_point().is_none()); - assert!(field.to_duration().is_err()); + assert!(field.to_duration().is_none()); assert!(field.to_null().is_none()); let field = Field::Json(JsonValue::Array(vec![])); @@ -531,15 +507,13 @@ fn test_to_conversion() { assert!(field.to_i128().is_none()); assert!(field.to_float().is_none()); assert!(field.to_boolean().is_none()); - assert!(field.to_string().is_some()); - assert!(field.to_text().is_some()); assert!(field.to_binary().is_none()); assert!(field.to_decimal().is_none()); - assert!(field.to_timestamp().unwrap().is_none()); - assert!(field.to_date().unwrap().is_none()); + assert!(field.to_timestamp().is_none()); + assert!(field.to_date().is_none()); assert!(field.to_json().is_some()); assert!(field.to_point().is_none()); - assert!(field.to_duration().is_err()); + assert!(field.to_duration().is_none()); assert!(field.to_null().is_none()); let field = Field::Point(DozerPoint::from((0.0, 0.0))); @@ -549,15 +523,13 @@ fn test_to_conversion() { assert!(field.to_i128().is_none()); assert!(field.to_float().is_none()); assert!(field.to_boolean().is_none()); - assert!(field.to_string().is_none()); - assert!(field.to_text().is_none()); assert!(field.to_binary().is_none()); assert!(field.to_decimal().is_none()); - assert!(field.to_timestamp().unwrap().is_none()); - assert!(field.to_date().unwrap().is_none()); + assert!(field.to_timestamp().is_none()); + assert!(field.to_date().is_none()); assert!(field.to_json().is_none()); assert!(field.to_point().is_some()); - assert!(field.to_duration().is_err()); + assert!(field.to_duration().is_none()); assert!(field.to_null().is_none()); let field = Field::Duration(DozerDuration( @@ -570,15 +542,13 @@ fn test_to_conversion() { assert!(field.to_i128().is_none()); assert!(field.to_float().is_none()); assert!(field.to_boolean().is_none()); - assert!(field.to_string().is_none()); - assert!(field.to_text().is_none()); assert!(field.to_binary().is_none()); assert!(field.to_decimal().is_none()); - assert!(field.to_timestamp().unwrap().is_none()); - assert!(field.to_date().unwrap().is_none()); + assert!(field.to_timestamp().is_none()); + assert!(field.to_date().is_none()); assert!(field.to_json().is_none()); assert!(field.to_point().is_none()); - assert!(field.to_duration().is_ok()); + assert!(field.to_duration().is_some()); assert!(field.to_null().is_none()); let field = Field::Null; @@ -588,14 +558,12 @@ fn test_to_conversion() { assert!(field.to_i128().is_some()); assert!(field.to_float().is_some()); assert!(field.to_boolean().is_some()); - assert!(field.to_string().is_some()); - assert!(field.to_text().is_some()); assert!(field.to_binary().is_none()); assert!(field.to_decimal().is_some()); - assert!(field.to_timestamp().unwrap().is_some()); - assert!(field.to_date().unwrap().is_some()); + assert!(field.to_timestamp().is_some()); + assert!(field.to_date().is_some()); assert!(field.to_json().is_some()); assert!(field.to_point().is_none()); - assert!(field.to_duration().is_ok()); + assert!(field.to_duration().is_some()); assert!(field.to_null().is_some()); }