diff --git a/Cargo.toml b/Cargo.toml index 3c2ef149..ef9de97f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -51,6 +51,7 @@ kip_db = { version = "0.1.2-alpha.25.fix2" } lazy_static = { version = "1.4.0" } log = { version = "0.4.21", optional = true } ordered-float = { version = "4.2.0" } +paste = { version = "1.0.14" } petgraph = { version = "0.6.4" } pgwire = { version = "0.19.2", optional = true } rand = { version = "0.9.0-alpha.0" } diff --git a/src/binder/expr.rs b/src/binder/expr.rs index 2ba8524a..0466c862 100644 --- a/src/binder/expr.rs +++ b/src/binder/expr.rs @@ -145,6 +145,7 @@ impl<'a, T: Transaction> Binder<'a, T> { op: expression::BinaryOperator::Eq, left_expr, right_expr: Box::new(alias_expr), + evaluator: None, ty: LogicalType::Boolean, }) } @@ -272,6 +273,7 @@ impl<'a, T: Transaction> Binder<'a, T> { op, left_expr, right_expr, + evaluator: None, ty: LogicalType::Boolean, }) } @@ -342,10 +344,19 @@ impl<'a, T: Transaction> Binder<'a, T> { BinaryOperator::Plus | BinaryOperator::Minus | BinaryOperator::Multiply - | BinaryOperator::Divide | BinaryOperator::Modulo => { LogicalType::max_logical_type(&left_expr.return_type(), &right_expr.return_type())? } + BinaryOperator::Divide => { + if let LogicalType::Decimal(precision, scale) = LogicalType::max_logical_type( + &left_expr.return_type(), + &right_expr.return_type(), + )? { + LogicalType::Decimal(precision, scale) + } else { + LogicalType::Double + } + } BinaryOperator::Gt | BinaryOperator::Lt | BinaryOperator::GtEq @@ -356,13 +367,14 @@ impl<'a, T: Transaction> Binder<'a, T> { | BinaryOperator::Or | BinaryOperator::Xor => LogicalType::Boolean, BinaryOperator::StringConcat => LogicalType::Varchar(None, CharLengthUnits::Characters), - _ => todo!(), + op => return Err(DatabaseError::UnsupportedStmt(format!("{}", op))), }; Ok(ScalarExpression::Binary { op: (op.clone()).into(), left_expr, right_expr, + evaluator: None, ty, }) } diff --git a/src/binder/select.rs b/src/binder/select.rs index 35e3cb3f..d794bb03 100644 --- a/src/binder/select.rs +++ b/src/binder/select.rs @@ -485,6 +485,7 @@ impl<'a, T: Transaction> Binder<'a, T> { op: BinaryOperator::And, left_expr: Box::new(acc), right_expr: Box::new(expr), + evaluator: None, ty: LogicalType::Boolean, }); @@ -654,6 +655,7 @@ impl<'a, T: Transaction> Binder<'a, T> { op: BinaryOperator::And, left_expr: Box::new(acc), right_expr: Box::new(expr), + evaluator: None, ty: LogicalType::Boolean, }); // TODO: handle cross join if on_keys is empty @@ -723,6 +725,7 @@ impl<'a, T: Transaction> Binder<'a, T> { right_expr, op, ty, + .. } => { match op { BinaryOperator::Eq => { @@ -746,6 +749,7 @@ impl<'a, T: Transaction> Binder<'a, T> { right_expr, op, ty, + evaluator: None, }); } } @@ -757,6 +761,7 @@ impl<'a, T: Transaction> Binder<'a, T> { right_expr, op, ty, + evaluator: None, }); } } @@ -772,6 +777,7 @@ impl<'a, T: Transaction> Binder<'a, T> { right_expr, op, ty, + evaluator: None, }); } } @@ -800,6 +806,7 @@ impl<'a, T: Transaction> Binder<'a, T> { right_expr, op, ty, + evaluator: None, }); } _ => { @@ -813,6 +820,7 @@ impl<'a, T: Transaction> Binder<'a, T> { right_expr, op, ty, + evaluator: None, }); } } diff --git a/src/db.rs b/src/db.rs index 982c57c9..6edcbc2f 100644 --- a/src/db.rs +++ b/src/db.rs @@ -186,7 +186,11 @@ impl Database { .batch( "Expression Remapper".to_string(), HepBatchStrategy::once_topdown(), - vec![NormalizationRuleImpl::ExpressionRemapper], + vec![ + NormalizationRuleImpl::ExpressionRemapper, + // TIPS: This rule is necessary + NormalizationRuleImpl::EvaluatorBind, + ], ) .implementations(vec![ // DQL @@ -263,6 +267,7 @@ mod test { use crate::expression::{BinaryOperator, UnaryOperator}; use crate::function; use crate::storage::{Storage, Transaction}; + use crate::types::evaluator::EvaluatorFactory; use crate::types::tuple::{create_table, Tuple}; use crate::types::value::{DataValue, ValueRef}; use crate::types::LogicalType; @@ -304,7 +309,9 @@ mod test { } function!(TestFunction::test(LogicalType::Integer, LogicalType::Integer) -> LogicalType::Integer => (|v1: ValueRef, v2: ValueRef| { - let value = DataValue::binary_op(&v1, &v2, &BinaryOperator::Plus)?; + let plus_evaluator = EvaluatorFactory::binary_create(LogicalType::Integer, BinaryOperator::Plus)?; + let value = plus_evaluator.0.binary_eval(&v1, &v2); + DataValue::unary_op(&value, &UnaryOperator::Minus) })); diff --git a/src/errors.rs b/src/errors.rs index fc06f38f..85e260ba 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -43,6 +43,8 @@ pub enum DatabaseError { EmptyPlan, #[error("sql statement is empty")] EmptyStatement, + #[error("evaluator not found")] + EvaluatorNotFound, #[error("from utf8: {0}")] FromUtf8Error( #[source] diff --git a/src/execution/volcano/dql/aggregate/avg.rs b/src/execution/volcano/dql/aggregate/avg.rs index ac31a15d..a33343f4 100644 --- a/src/execution/volcano/dql/aggregate/avg.rs +++ b/src/execution/volcano/dql/aggregate/avg.rs @@ -2,6 +2,7 @@ use crate::errors::DatabaseError; use crate::execution::volcano::dql::aggregate::sum::SumAccumulator; use crate::execution::volcano::dql::aggregate::Accumulator; use crate::expression::BinaryOperator; +use crate::types::evaluator::EvaluatorFactory; use crate::types::value::{DataValue, ValueRef}; use crate::types::LogicalType; use std::sync::Arc; @@ -12,11 +13,11 @@ pub struct AvgAccumulator { } impl AvgAccumulator { - pub fn new(ty: &LogicalType) -> Self { - Self { - inner: SumAccumulator::new(ty), + pub fn new(ty: &LogicalType) -> Result { + Ok(Self { + inner: SumAccumulator::new(ty)?, count: 0, - } + }) } } @@ -31,21 +32,23 @@ impl Accumulator for AvgAccumulator { } fn evaluate(&self) -> Result { - let value = self.inner.evaluate()?; + let mut value = self.inner.evaluate()?; + let value_ty = value.logical_type(); - let quantity = if value.logical_type().is_signed_numeric() { + if self.count == 0 { + return Ok(Arc::new(DataValue::init(&value_ty))); + } + let quantity = if value_ty.is_signed_numeric() { DataValue::Int64(Some(self.count as i64)) } else { DataValue::UInt32(Some(self.count as u32)) }; - if self.count == 0 { - return Ok(Arc::new(DataValue::init(&value.logical_type()))); - } + let quantity_ty = quantity.logical_type(); - Ok(Arc::new(DataValue::binary_op( - &value, - &quantity, - &BinaryOperator::Divide, - )?)) + if value_ty != quantity_ty { + value = Arc::new(DataValue::clone(&value).cast(&quantity_ty)?) + } + let evaluator = EvaluatorFactory::binary_create(quantity_ty, BinaryOperator::Divide)?; + Ok(Arc::new(evaluator.0.binary_eval(&value, &quantity))) } } diff --git a/src/execution/volcano/dql/aggregate/hash_agg.rs b/src/execution/volcano/dql/aggregate/hash_agg.rs index 8b4a5550..40c84471 100644 --- a/src/execution/volcano/dql/aggregate/hash_agg.rs +++ b/src/execution/volcano/dql/aggregate/hash_agg.rs @@ -102,7 +102,7 @@ impl HashAggStatus { for (acc, value) in self .group_hash_accs .entry(group_keys) - .or_insert_with(|| create_accumulators(&self.agg_calls)) + .or_insert_with(|| create_accumulators(&self.agg_calls).unwrap()) .iter_mut() .zip_eq(values.iter()) { diff --git a/src/execution/volcano/dql/aggregate/min_max.rs b/src/execution/volcano/dql/aggregate/min_max.rs index 442d9804..14307449 100644 --- a/src/execution/volcano/dql/aggregate/min_max.rs +++ b/src/execution/volcano/dql/aggregate/min_max.rs @@ -1,6 +1,7 @@ use crate::errors::DatabaseError; use crate::execution::volcano::dql::aggregate::Accumulator; use crate::expression::BinaryOperator; +use crate::types::evaluator::EvaluatorFactory; use crate::types::value::{DataValue, ValueRef}; use crate::types::LogicalType; use std::sync::Arc; @@ -31,8 +32,9 @@ impl Accumulator for MinMaxAccumulator { fn update_value(&mut self, value: &ValueRef) -> Result<(), DatabaseError> { if !value.is_null() { if let Some(inner_value) = &self.inner { + let evaluator = EvaluatorFactory::binary_create(value.logical_type(), self.op)?; if let DataValue::Boolean(Some(result)) = - DataValue::binary_op(inner_value, value, &self.op)? + evaluator.0.binary_eval(inner_value, value) { result } else { diff --git a/src/execution/volcano/dql/aggregate/mod.rs b/src/execution/volcano/dql/aggregate/mod.rs index 7ab8f00d..46b27ce6 100644 --- a/src/execution/volcano/dql/aggregate/mod.rs +++ b/src/execution/volcano/dql/aggregate/mod.rs @@ -15,6 +15,7 @@ use crate::execution::volcano::dql::aggregate::sum::{DistinctSumAccumulator, Sum use crate::expression::agg::AggKind; use crate::expression::ScalarExpression; use crate::types::value::ValueRef; +use itertools::Itertools; /// Tips: Idea for sqlrs /// An accumulator represents a stateful object that lives throughout the evaluation of multiple @@ -27,20 +28,20 @@ pub trait Accumulator: Send + Sync { fn evaluate(&self) -> Result; } -fn create_accumulator(expr: &ScalarExpression) -> Box { +fn create_accumulator(expr: &ScalarExpression) -> Result, DatabaseError> { if let ScalarExpression::AggCall { kind, ty, distinct, .. } = expr { - match (kind, distinct) { + Ok(match (kind, distinct) { (AggKind::Count, false) => Box::new(CountAccumulator::new()), (AggKind::Count, true) => Box::new(DistinctCountAccumulator::new()), - (AggKind::Sum, false) => Box::new(SumAccumulator::new(ty)), - (AggKind::Sum, true) => Box::new(DistinctSumAccumulator::new(ty)), + (AggKind::Sum, false) => Box::new(SumAccumulator::new(ty)?), + (AggKind::Sum, true) => Box::new(DistinctSumAccumulator::new(ty)?), (AggKind::Min, _) => Box::new(MinMaxAccumulator::new(ty, false)), (AggKind::Max, _) => Box::new(MinMaxAccumulator::new(ty, true)), - (AggKind::Avg, _) => Box::new(AvgAccumulator::new(ty)), - } + (AggKind::Avg, _) => Box::new(AvgAccumulator::new(ty)?), + }) } else { unreachable!( "create_accumulator called with non-aggregate expression {}", @@ -49,6 +50,8 @@ fn create_accumulator(expr: &ScalarExpression) -> Box { } } -pub(crate) fn create_accumulators(exprs: &[ScalarExpression]) -> Vec> { - exprs.iter().map(create_accumulator).collect() +pub(crate) fn create_accumulators( + exprs: &[ScalarExpression], +) -> Result>, DatabaseError> { + exprs.iter().map(create_accumulator).try_collect() } diff --git a/src/execution/volcano/dql/aggregate/simple_agg.rs b/src/execution/volcano/dql/aggregate/simple_agg.rs index 08092963..fb2db157 100644 --- a/src/execution/volcano/dql/aggregate/simple_agg.rs +++ b/src/execution/volcano/dql/aggregate/simple_agg.rs @@ -36,7 +36,7 @@ impl SimpleAggExecutor { agg_calls, mut input, } = self; - let mut accs = create_accumulators(&agg_calls); + let mut accs = create_accumulators(&agg_calls)?; let schema = input.output_schema().clone(); #[for_await] diff --git a/src/execution/volcano/dql/aggregate/sum.rs b/src/execution/volcano/dql/aggregate/sum.rs index ab1bfd2b..d5d2ab53 100644 --- a/src/execution/volcano/dql/aggregate/sum.rs +++ b/src/execution/volcano/dql/aggregate/sum.rs @@ -1,6 +1,7 @@ use crate::errors::DatabaseError; use crate::execution::volcano::dql::aggregate::Accumulator; use crate::expression::BinaryOperator; +use crate::types::evaluator::{BinaryEvaluatorBox, EvaluatorFactory}; use crate::types::value::{DataValue, ValueRef}; use crate::types::LogicalType; use ahash::RandomState; @@ -9,22 +10,24 @@ use std::sync::Arc; pub struct SumAccumulator { result: DataValue, + evaluator: BinaryEvaluatorBox, } impl SumAccumulator { - pub fn new(ty: &LogicalType) -> Self { + pub fn new(ty: &LogicalType) -> Result { assert!(ty.is_numeric()); - Self { + Ok(Self { result: DataValue::init(ty), - } + evaluator: EvaluatorFactory::binary_create(*ty, BinaryOperator::Plus)?, + }) } } impl Accumulator for SumAccumulator { fn update_value(&mut self, value: &ValueRef) -> Result<(), DatabaseError> { if !value.is_null() { - self.result = DataValue::binary_op(&self.result, value, &BinaryOperator::Plus)?; + self.result = self.evaluator.0.binary_eval(&self.result, value); } Ok(()) @@ -41,11 +44,11 @@ pub struct DistinctSumAccumulator { } impl DistinctSumAccumulator { - pub fn new(ty: &LogicalType) -> Self { - Self { + pub fn new(ty: &LogicalType) -> Result { + Ok(Self { distinct_values: HashSet::default(), - inner: SumAccumulator::new(ty), - } + inner: SumAccumulator::new(ty)?, + }) } } diff --git a/src/execution/volcano/dql/join/nested_loop_join.rs b/src/execution/volcano/dql/join/nested_loop_join.rs index 21dce92a..fd8e7627 100644 --- a/src/execution/volcano/dql/join/nested_loop_join.rs +++ b/src/execution/volcano/dql/join/nested_loop_join.rs @@ -330,6 +330,8 @@ mod test { use crate::planner::operator::Operator; use crate::storage::kipdb::KipStorage; use crate::storage::Storage; + use crate::types::evaluator::int32::Int32GtBinaryEvaluator; + use crate::types::evaluator::BinaryEvaluatorBox; use crate::types::value::DataValue; use crate::types::LogicalType; use std::collections::HashSet; @@ -441,7 +443,8 @@ mod test { true, desc.clone(), )))), - ty: LogicalType::Integer, + evaluator: Some(BinaryEvaluatorBox(Arc::new(Int32GtBinaryEvaluator))), + ty: LogicalType::Boolean, }; (on_keys, values_t1, values_t2, filter) diff --git a/src/expression/evaluator.rs b/src/expression/evaluator.rs index 6eea7aee..2499a214 100644 --- a/src/expression/evaluator.rs +++ b/src/expression/evaluator.rs @@ -2,6 +2,7 @@ use crate::catalog::ColumnRef; use crate::errors::DatabaseError; use crate::expression::function::ScalarFunction; use crate::expression::{AliasType, BinaryOperator, ScalarExpression}; +use crate::types::evaluator::EvaluatorFactory; use crate::types::tuple::Tuple; use crate::types::value::{DataValue, Utf8Type, ValueRef}; use crate::types::LogicalType; @@ -80,13 +81,19 @@ impl ScalarExpression { ScalarExpression::Binary { left_expr, right_expr, - op, + evaluator, .. } => { let left = left_expr.eval(tuple, schema)?; let right = right_expr.eval(tuple, schema)?; - Ok(Arc::new(DataValue::binary_op(&left, &right, op)?)) + Ok(Arc::new( + evaluator + .as_ref() + .ok_or(DatabaseError::EvaluatorNotFound)? + .0 + .binary_eval(&left, &right), + )) } ScalarExpression::IsNull { expr, negated } => { let mut is_null = expr.eval(tuple, schema)?.is_null(); @@ -292,10 +299,17 @@ impl ScalarExpression { operand_value = Some(expr.eval(tuple, schema)?); } for (when_expr, result_expr) in expr_pairs { - let when_value = when_expr.eval(tuple, schema)?; + let mut when_value = when_expr.eval(tuple, schema)?; let is_true = if let Some(operand_value) = &operand_value { - operand_value - .binary_op(&when_value, &BinaryOperator::Eq)? + let ty = operand_value.logical_type(); + let evaluator = EvaluatorFactory::binary_create(ty, BinaryOperator::Eq)?; + + if when_value.logical_type() != ty { + when_value = Arc::new(DataValue::clone(&when_value).cast(&ty)?); + } + evaluator + .0 + .binary_eval(operand_value, &when_value) .is_true()? } else { when_value.is_true()? diff --git a/src/expression/mod.rs b/src/expression/mod.rs index 3a7dc0c9..1f3d3ec2 100644 --- a/src/expression/mod.rs +++ b/src/expression/mod.rs @@ -11,7 +11,9 @@ use sqlparser::ast::{ use self::agg::AggKind; use crate::catalog::{ColumnCatalog, ColumnDesc, ColumnRef}; +use crate::errors::DatabaseError; use crate::expression::function::ScalarFunction; +use crate::types::evaluator::{BinaryEvaluatorBox, EvaluatorFactory}; use crate::types::value::ValueRef; use crate::types::LogicalType; @@ -57,6 +59,7 @@ pub enum ScalarExpression { op: BinaryOperator, left_expr: Box, right_expr: Box, + evaluator: Option, ty: LogicalType, }, AggCall { @@ -275,6 +278,141 @@ impl ScalarExpression { } } + pub fn bind_evaluator(&mut self) -> Result<(), DatabaseError> { + match self { + ScalarExpression::Binary { + left_expr, + right_expr, + op, + evaluator, + .. + } => { + left_expr.bind_evaluator()?; + right_expr.bind_evaluator()?; + + let ty = LogicalType::max_logical_type( + &left_expr.return_type(), + &right_expr.return_type(), + )?; + let fn_cast = |expr: &mut ScalarExpression, ty: LogicalType| { + if expr.return_type() != ty { + *expr = ScalarExpression::TypeCast { + expr: Box::new(mem::replace(expr, ScalarExpression::Empty)), + ty, + } + } + }; + fn_cast(left_expr, ty); + fn_cast(right_expr, ty); + + *evaluator = Some(EvaluatorFactory::binary_create(ty, *op)?); + } + ScalarExpression::Alias { expr, .. } => { + expr.bind_evaluator()?; + } + ScalarExpression::TypeCast { expr, .. } => { + expr.bind_evaluator()?; + } + ScalarExpression::IsNull { expr, .. } => { + expr.bind_evaluator()?; + } + ScalarExpression::Unary { expr, .. } => { + expr.bind_evaluator()?; + } + ScalarExpression::AggCall { args, .. } + | ScalarExpression::Coalesce { exprs: args, .. } + | ScalarExpression::Tuple(args) => { + for arg in args { + arg.bind_evaluator()?; + } + } + ScalarExpression::In { expr, args, .. } => { + expr.bind_evaluator()?; + for arg in args { + arg.bind_evaluator()?; + } + } + ScalarExpression::Between { + expr, + left_expr, + right_expr, + .. + } => { + expr.bind_evaluator()?; + left_expr.bind_evaluator()?; + right_expr.bind_evaluator()?; + } + ScalarExpression::SubString { + expr, + for_expr, + from_expr, + } => { + expr.bind_evaluator()?; + if let Some(expr) = for_expr { + expr.bind_evaluator()?; + } + if let Some(expr) = from_expr { + expr.bind_evaluator()?; + } + } + ScalarExpression::Position { expr, in_expr } => { + expr.bind_evaluator()?; + in_expr.bind_evaluator()?; + } + ScalarExpression::Empty => unreachable!(), + ScalarExpression::Constant(_) + | ScalarExpression::ColumnRef(_) + | ScalarExpression::Reference { .. } => (), + ScalarExpression::Function(function) => { + for expr in function.args.iter_mut() { + expr.bind_evaluator()?; + } + } + ScalarExpression::If { + condition, + left_expr, + right_expr, + .. + } => { + condition.bind_evaluator()?; + left_expr.bind_evaluator()?; + right_expr.bind_evaluator()?; + } + ScalarExpression::IfNull { + left_expr, + right_expr, + .. + } + | ScalarExpression::NullIf { + left_expr, + right_expr, + .. + } => { + left_expr.bind_evaluator()?; + right_expr.bind_evaluator()?; + } + ScalarExpression::CaseWhen { + operand_expr, + expr_pairs, + else_expr, + .. + } => { + if let Some(expr) = operand_expr { + expr.bind_evaluator()?; + } + for (expr_1, expr_2) in expr_pairs { + expr_1.bind_evaluator()?; + expr_2.bind_evaluator()?; + } + if let Some(expr) = else_expr { + expr.bind_evaluator()?; + } + } + } + + Ok(()) + } + pub fn has_count_star(&self) -> bool { match self { ScalarExpression::Alias { expr, .. } => expr.has_count_star(), diff --git a/src/expression/simplify.rs b/src/expression/simplify.rs index 54e73237..755b7d28 100644 --- a/src/expression/simplify.rs +++ b/src/expression/simplify.rs @@ -2,6 +2,7 @@ use crate::catalog::ColumnRef; use crate::errors::DatabaseError; use crate::expression::function::ScalarFunction; use crate::expression::{BinaryOperator, ScalarExpression, UnaryOperator}; +use crate::types::evaluator::EvaluatorFactory; use crate::types::value::{DataValue, ValueRef}; use crate::types::{ColumnId, LogicalType}; use std::mem; @@ -157,12 +158,20 @@ impl ScalarExpression { left_expr, right_expr, op, + ty, .. } => { - let left = left_expr.unpack_val()?; - let right = right_expr.unpack_val()?; + let mut left = left_expr.unpack_val()?; + let mut right = right_expr.unpack_val()?; + let evaluator = EvaluatorFactory::binary_create(*ty, *op).ok()?; - DataValue::binary_op(&left, &right, op).ok().map(Arc::new) + if left.logical_type() != *ty { + left = Arc::new(DataValue::clone(&left).cast(ty).ok()?); + } + if right.logical_type() != *ty { + right = Arc::new(DataValue::clone(&right).cast(ty).ok()?); + } + Some(Arc::new(evaluator.0.binary_eval(&left, &right))) } _ => None, } @@ -210,15 +219,27 @@ impl ScalarExpression { op, .. } => { + let ty = LogicalType::max_logical_type( + &left_expr.return_type(), + &right_expr.return_type(), + )?; left_expr.constant_calculation()?; right_expr.constant_calculation()?; if let ( ScalarExpression::Constant(left_val), ScalarExpression::Constant(right_val), - ) = (left_expr.as_ref(), right_expr.as_ref()) + ) = (left_expr.as_mut(), right_expr.as_mut()) { - let value = DataValue::binary_op(left_val, right_val, op)?; + let evaluator = EvaluatorFactory::binary_create(ty, *op)?; + + if left_val.logical_type() != ty { + *left_val = Arc::new(DataValue::clone(left_val).cast(&ty)?); + } + if right_val.logical_type() != ty { + *right_val = Arc::new(DataValue::clone(right_val).cast(&ty)?); + } + let value = evaluator.0.binary_eval(left_val, right_val); let _ = mem::replace(self, ScalarExpression::Constant(Arc::new(value))); } } @@ -326,6 +347,7 @@ impl ScalarExpression { right_expr, op, ty, + .. } => { Self::fix_expr(replaces, left_expr, right_expr, op)?; @@ -399,7 +421,7 @@ impl ScalarExpression { ); } } - ScalarExpression::Unary { expr, op, ty } => { + ScalarExpression::Unary { expr, op, ty, .. } => { if let Some(val) = expr.unpack_val() { let new_expr = ScalarExpression::Constant(Arc::new(DataValue::unary_op(&val, op)?)); @@ -430,6 +452,7 @@ impl ScalarExpression { op: op_1, left_expr: expr.clone(), right_expr: Box::new(args.remove(0)), + evaluator: None, ty: LogicalType::Boolean, }; @@ -440,9 +463,11 @@ impl ScalarExpression { op: op_1, left_expr: expr.clone(), right_expr: Box::new(arg), + evaluator: None, ty: LogicalType::Boolean, }), right_expr: Box::new(new_expr), + evaluator: None, ty: LogicalType::Boolean, } } @@ -470,14 +495,17 @@ impl ScalarExpression { op: left_op, left_expr: expr.clone(), right_expr: mem::replace(left_expr, Box::new(ScalarExpression::Empty)), + evaluator: None, ty: LogicalType::Boolean, }), right_expr: Box::new(ScalarExpression::Binary { op: right_op, left_expr: mem::replace(expr, Box::new(ScalarExpression::Empty)), right_expr: mem::replace(right_expr, Box::new(ScalarExpression::Empty)), + evaluator: None, ty: LogicalType::Boolean, }), + evaluator: None, ty: LogicalType::Boolean, }; @@ -616,6 +644,7 @@ impl ScalarExpression { op: fixed_op, left_expr: fixed_left_expr, right_expr: fixed_right_expr, + evaluator: None, ty: fix_ty, }), ); diff --git a/src/expression/value_compute.rs b/src/expression/value_compute.rs index 297f2905..cc1acb4c 100644 --- a/src/expression/value_compute.rs +++ b/src/expression/value_compute.rs @@ -1,151 +1,10 @@ use crate::errors::DatabaseError; -use crate::expression::{BinaryOperator, UnaryOperator}; -use crate::types::value::{DataValue, Utf8Type, ValueRef}; +use crate::expression::UnaryOperator; +use crate::types::value::DataValue; use crate::types::LogicalType; -use regex::Regex; -use sqlparser::ast::CharLengthUnits; -use std::cmp::Ordering; - -fn unpack_bool(value: DataValue) -> Option { - match value { - DataValue::Boolean(inner) => inner, - _ => None, - } -} - -fn unpack_utf8(value: DataValue) -> Option { - match value { - DataValue::Utf8 { value: inner, .. } => inner, - _ => None, - } -} - -fn unpack_tuple(value: DataValue) -> Option> { - match value { - DataValue::Tuple(inner) => inner, - _ => None, - } -} - -macro_rules! numeric_binary_compute { - ($compute_type:path, $left:expr, $right:expr, $op:expr, $unified_type:expr) => { - match $op { - BinaryOperator::Plus => { - let value = if let ($compute_type(Some(v1)), $compute_type(Some(v2))) = - ($left.cast($unified_type)?, $right.cast($unified_type)?) - { - Some(v1 + v2) - } else { - None - }; - - $compute_type(value) - } - BinaryOperator::Minus => { - let value = if let ($compute_type(Some(v1)), $compute_type(Some(v2))) = - ($left.cast($unified_type)?, $right.cast($unified_type)?) - { - Some(v1 - v2) - } else { - None - }; - - $compute_type(value) - } - BinaryOperator::Multiply => { - let value = if let ($compute_type(Some(v1)), $compute_type(Some(v2))) = - ($left.cast($unified_type)?, $right.cast($unified_type)?) - { - Some(v1 * v2) - } else { - None - }; - - $compute_type(value) - } - BinaryOperator::Divide => { - let value = if let ($compute_type(Some(v1)), $compute_type(Some(v2))) = - ($left.cast($unified_type)?, $right.cast($unified_type)?) - { - Some(v1 as f64 / v2 as f64) - } else { - None - }; - - DataValue::Float64(value) - } - - BinaryOperator::Gt => { - let value = if let ($compute_type(Some(v1)), $compute_type(Some(v2))) = - ($left.cast($unified_type)?, $right.cast($unified_type)?) - { - Some(v1 > v2) - } else { - None - }; - - DataValue::Boolean(value) - } - BinaryOperator::Lt => { - let value = if let ($compute_type(Some(v1)), $compute_type(Some(v2))) = - ($left.cast($unified_type)?, $right.cast($unified_type)?) - { - Some(v1 < v2) - } else { - None - }; - - DataValue::Boolean(value) - } - BinaryOperator::GtEq => { - let value = if let ($compute_type(Some(v1)), $compute_type(Some(v2))) = - ($left.cast($unified_type)?, $right.cast($unified_type)?) - { - Some(v1 >= v2) - } else { - None - }; - - DataValue::Boolean(value) - } - BinaryOperator::LtEq => { - let value = if let ($compute_type(Some(v1)), $compute_type(Some(v2))) = - ($left.cast($unified_type)?, $right.cast($unified_type)?) - { - Some(v1 <= v2) - } else { - None - }; - - DataValue::Boolean(value) - } - BinaryOperator::Eq => { - let value = match ($left.cast($unified_type)?, $right.cast($unified_type)?) { - ($compute_type(Some(v1)), $compute_type(Some(v2))) => Some(v1 == v2), - (_, _) => None, - }; - - DataValue::Boolean(value) - } - BinaryOperator::NotEq => { - let value = match ($left.cast($unified_type)?, $right.cast($unified_type)?) { - ($compute_type(Some(v1)), $compute_type(Some(v2))) => Some(v1 != v2), - (_, _) => None, - }; - - DataValue::Boolean(value) - } - _ => { - return Err(DatabaseError::UnsupportedBinaryOperator( - *$unified_type, - *$op, - )) - } - } - }; -} impl DataValue { + // FIXME: like BinaryEvaluator pub fn unary_op(&self, op: &UnaryOperator) -> Result { let mut value_type = self.logical_type(); let mut value = self.clone(); @@ -186,1548 +45,4 @@ impl DataValue { Err(DatabaseError::InvalidType) } } - /// Tips: - /// - Null values operate as null values - pub fn binary_op( - &self, - right: &DataValue, - op: &BinaryOperator, - ) -> Result { - if let BinaryOperator::Like(escape_char) | BinaryOperator::NotLike(escape_char) = op { - let value_option = unpack_utf8( - self.clone() - .cast(&LogicalType::Varchar(None, CharLengthUnits::Characters))?, - ); - let pattern_option = unpack_utf8( - right - .clone() - .cast(&LogicalType::Varchar(None, CharLengthUnits::Characters))?, - ); - - let mut is_match = if let (Some(value), Some(pattern)) = (value_option, pattern_option) - { - let mut regex_pattern = String::new(); - let mut chars = pattern.chars().peekable(); - while let Some(c) = chars.next() { - if matches!(escape_char.map(|escape_c| escape_c == c), Some(true)) { - if let Some(next_char) = chars.next() { - regex_pattern.push(next_char); - } - } else if c == '%' { - regex_pattern.push_str(".*"); - } else if c == '_' { - regex_pattern.push('.'); - } else { - regex_pattern.push(c); - } - } - - Regex::new(®ex_pattern).unwrap().is_match(&value) - } else { - return Ok(DataValue::Boolean(None)); - }; - if matches!(op, BinaryOperator::NotLike(_)) { - is_match = !is_match; - } - return Ok(DataValue::Boolean(Some(is_match))); - } - let unified_type = - LogicalType::max_logical_type(&self.logical_type(), &right.logical_type())?; - - let value = match &unified_type { - LogicalType::Tinyint => { - numeric_binary_compute!( - DataValue::Int8, - self.clone(), - right.clone(), - op, - &unified_type - ) - } - LogicalType::Smallint => { - numeric_binary_compute!( - DataValue::Int16, - self.clone(), - right.clone(), - op, - &unified_type - ) - } - LogicalType::Integer => { - numeric_binary_compute!( - DataValue::Int32, - self.clone(), - right.clone(), - op, - &unified_type - ) - } - LogicalType::Bigint => { - numeric_binary_compute!( - DataValue::Int64, - self.clone(), - right.clone(), - op, - &unified_type - ) - } - LogicalType::UTinyint => { - numeric_binary_compute!( - DataValue::UInt8, - self.clone(), - right.clone(), - op, - &unified_type - ) - } - LogicalType::USmallint => { - numeric_binary_compute!( - DataValue::UInt16, - self.clone(), - right.clone(), - op, - &unified_type - ) - } - LogicalType::UInteger => { - numeric_binary_compute!( - DataValue::UInt32, - self.clone(), - right.clone(), - op, - &unified_type - ) - } - LogicalType::UBigint => { - numeric_binary_compute!( - DataValue::UInt64, - self.clone(), - right.clone(), - op, - &unified_type - ) - } - LogicalType::Float => { - numeric_binary_compute!( - DataValue::Float32, - self.clone(), - right.clone(), - op, - &unified_type - ) - } - LogicalType::Double => { - numeric_binary_compute!( - DataValue::Float64, - self.clone(), - right.clone(), - op, - &unified_type - ) - } - LogicalType::Date => { - numeric_binary_compute!( - DataValue::Date32, - self.clone(), - right.clone(), - op, - &unified_type - ) - } - LogicalType::DateTime => { - numeric_binary_compute!( - DataValue::Date64, - self.clone(), - right.clone(), - op, - &unified_type - ) - } - LogicalType::Time => { - numeric_binary_compute!( - DataValue::Time, - self.clone(), - right.clone(), - op, - &unified_type - ) - } - LogicalType::Decimal(_, _) => { - let left_value = self.clone().cast(&unified_type)?; - let right_value = right.clone().cast(&unified_type)?; - - match op { - BinaryOperator::Plus => { - let value = - if let (DataValue::Decimal(Some(v1)), DataValue::Decimal(Some(v2))) = - (left_value, right_value) - { - Some(v1 + v2) - } else { - None - }; - - DataValue::Decimal(value) - } - BinaryOperator::Minus => { - let value = - if let (DataValue::Decimal(Some(v1)), DataValue::Decimal(Some(v2))) = - (left_value, right_value) - { - Some(v1 - v2) - } else { - None - }; - - DataValue::Decimal(value) - } - BinaryOperator::Multiply => { - let value = - if let (DataValue::Decimal(Some(v1)), DataValue::Decimal(Some(v2))) = - (left_value, right_value) - { - Some(v1 * v2) - } else { - None - }; - - DataValue::Decimal(value) - } - BinaryOperator::Divide => { - let value = - if let (DataValue::Decimal(Some(v1)), DataValue::Decimal(Some(v2))) = - (left_value, right_value) - { - Some(v1 / v2) - } else { - None - }; - - DataValue::Decimal(value) - } - - BinaryOperator::Gt => { - let value = - if let (DataValue::Decimal(Some(v1)), DataValue::Decimal(Some(v2))) = - (left_value, right_value) - { - Some(v1 > v2) - } else { - None - }; - - DataValue::Boolean(value) - } - BinaryOperator::Lt => { - let value = - if let (DataValue::Decimal(Some(v1)), DataValue::Decimal(Some(v2))) = - (left_value, right_value) - { - Some(v1 < v2) - } else { - None - }; - - DataValue::Boolean(value) - } - BinaryOperator::GtEq => { - let value = - if let (DataValue::Decimal(Some(v1)), DataValue::Decimal(Some(v2))) = - (left_value, right_value) - { - Some(v1 >= v2) - } else { - None - }; - - DataValue::Boolean(value) - } - BinaryOperator::LtEq => { - let value = - if let (DataValue::Decimal(Some(v1)), DataValue::Decimal(Some(v2))) = - (left_value, right_value) - { - Some(v1 <= v2) - } else { - None - }; - - DataValue::Boolean(value) - } - BinaryOperator::Eq => { - let value = match (left_value, right_value) { - (DataValue::Decimal(Some(v1)), DataValue::Decimal(Some(v2))) => { - Some(v1 == v2) - } - (_, _) => None, - }; - - DataValue::Boolean(value) - } - BinaryOperator::NotEq => { - let value = match (left_value, right_value) { - (DataValue::Decimal(Some(v1)), DataValue::Decimal(Some(v2))) => { - Some(v1 != v2) - } - (_, _) => None, - }; - - DataValue::Boolean(value) - } - _ => return Err(DatabaseError::UnsupportedBinaryOperator(unified_type, *op)), - } - } - LogicalType::Boolean => { - let left_value = unpack_bool(self.clone().cast(&unified_type)?); - let right_value = unpack_bool(right.clone().cast(&unified_type)?); - - match op { - BinaryOperator::And => { - let value = match (left_value, right_value) { - (Some(v1), Some(v2)) => Some(v1 && v2), - (Some(false), _) | (_, Some(false)) => Some(false), - _ => None, - }; - - DataValue::Boolean(value) - } - BinaryOperator::Or => { - let value = match (left_value, right_value) { - (Some(v1), Some(v2)) => Some(v1 || v2), - (Some(true), _) | (_, Some(true)) => Some(true), - _ => None, - }; - - DataValue::Boolean(value) - } - BinaryOperator::Eq => { - let value = match (left_value, right_value) { - (Some(v1), Some(v2)) => Some(v1 == v2), - (_, _) => None, - }; - - DataValue::Boolean(value) - } - BinaryOperator::NotEq => { - let value = match (left_value, right_value) { - (Some(v1), Some(v2)) => Some(v1 != v2), - (_, _) => None, - }; - - DataValue::Boolean(value) - } - _ => return Err(DatabaseError::UnsupportedBinaryOperator(unified_type, *op)), - } - } - LogicalType::Varchar(_, _) | LogicalType::Char(_, _) => { - let left_value = unpack_utf8(self.clone().cast(&unified_type)?); - let right_value = unpack_utf8(right.clone().cast(&unified_type)?); - - match op { - BinaryOperator::Gt => { - let value = if let (Some(v1), Some(v2)) = (left_value, right_value) { - Some(v1 > v2) - } else { - None - }; - - DataValue::Boolean(value) - } - BinaryOperator::Lt => { - let value = if let (Some(v1), Some(v2)) = (left_value, right_value) { - Some(v1 < v2) - } else { - None - }; - - DataValue::Boolean(value) - } - BinaryOperator::GtEq => { - let value = if let (Some(v1), Some(v2)) = (left_value, right_value) { - Some(v1 >= v2) - } else { - None - }; - - DataValue::Boolean(value) - } - BinaryOperator::LtEq => { - let value = if let (Some(v1), Some(v2)) = (left_value, right_value) { - Some(v1 <= v2) - } else { - None - }; - - DataValue::Boolean(value) - } - BinaryOperator::Eq => { - let value = match (left_value, right_value) { - (Some(v1), Some(v2)) => Some(v1 == v2), - (_, _) => None, - }; - - DataValue::Boolean(value) - } - BinaryOperator::NotEq => { - let value = match (left_value, right_value) { - (Some(v1), Some(v2)) => Some(v1 != v2), - (_, _) => None, - }; - - DataValue::Boolean(value) - } - BinaryOperator::StringConcat => { - let value = match (left_value, right_value) { - (Some(v1), Some(v2)) => Some(v1 + &v2), - _ => None, - }; - - DataValue::Utf8 { - value, - ty: Utf8Type::Variable(None), - unit: CharLengthUnits::Characters, - } - } - _ => return Err(DatabaseError::UnsupportedBinaryOperator(unified_type, *op)), - } - } - LogicalType::SqlNull => return Ok(DataValue::Null), - LogicalType::Invalid => return Err(DatabaseError::InvalidType), - LogicalType::Tuple => { - let left_value = unpack_tuple(self.clone().cast(&unified_type)?); - let right_value = unpack_tuple(right.clone().cast(&unified_type)?); - - match op { - BinaryOperator::Eq => { - let value = match (left_value, right_value) { - (Some(v1), Some(v2)) => Some(v1 == v2), - (_, _) => None, - }; - - DataValue::Boolean(value) - } - BinaryOperator::NotEq => { - let value = match (left_value, right_value) { - (Some(v1), Some(v2)) => Some(v1 != v2), - (_, _) => None, - }; - - DataValue::Boolean(value) - } - BinaryOperator::Gt - | BinaryOperator::GtEq - | BinaryOperator::Lt - | BinaryOperator::LtEq => { - let value = match (left_value, right_value) { - (Some(v1), Some(v2)) => Self::tuple_cmp(v1, v2).map(|order| match op { - BinaryOperator::Gt => order.is_gt(), - BinaryOperator::Lt => order.is_lt(), - BinaryOperator::GtEq => order.is_ge(), - BinaryOperator::LtEq => order.is_le(), - _ => unreachable!(), - }), - (_, _) => None, - }; - - DataValue::Boolean(value) - } - _ => return Err(DatabaseError::UnsupportedBinaryOperator(unified_type, *op)), - } - } - }; - - Ok(value) - } - - fn tuple_cmp(v1: Vec, v2: Vec) -> Option { - let mut order = Ordering::Equal; - let mut v1_iter = v1.iter(); - let mut v2_iter = v2.iter(); - - while order == Ordering::Equal { - order = match (v1_iter.next(), v2_iter.next()) { - (Some(v1), Some(v2)) => v1.partial_cmp(v2)?, - (Some(_), None) => Ordering::Greater, - (None, Some(_)) => Ordering::Less, - (None, None) => break, - } - } - Some(order) - } -} - -#[cfg(test)] -mod test { - use crate::errors::DatabaseError; - use crate::expression::BinaryOperator; - use crate::types::value::{DataValue, Utf8Type}; - use sqlparser::ast::CharLengthUnits; - - #[test] - fn test_binary_op_arithmetic_plus() -> Result<(), DatabaseError> { - let plus_i32_1 = DataValue::binary_op( - &DataValue::Int32(None), - &DataValue::Int32(None), - &BinaryOperator::Plus, - )?; - let plus_i32_2 = DataValue::binary_op( - &DataValue::Int32(Some(1)), - &DataValue::Int32(None), - &BinaryOperator::Plus, - )?; - let plus_i32_3 = DataValue::binary_op( - &DataValue::Int32(None), - &DataValue::Int32(Some(1)), - &BinaryOperator::Plus, - )?; - let plus_i32_4 = DataValue::binary_op( - &DataValue::Int32(Some(1)), - &DataValue::Int32(Some(1)), - &BinaryOperator::Plus, - )?; - - assert_eq!(plus_i32_1, plus_i32_2); - assert_eq!(plus_i32_2, plus_i32_3); - assert_eq!(plus_i32_4, DataValue::Int32(Some(2))); - - let plus_i64_1 = DataValue::binary_op( - &DataValue::Int64(None), - &DataValue::Int64(None), - &BinaryOperator::Plus, - )?; - let plus_i64_2 = DataValue::binary_op( - &DataValue::Int64(Some(1)), - &DataValue::Int64(None), - &BinaryOperator::Plus, - )?; - let plus_i64_3 = DataValue::binary_op( - &DataValue::Int64(None), - &DataValue::Int64(Some(1)), - &BinaryOperator::Plus, - )?; - let plus_i64_4 = DataValue::binary_op( - &DataValue::Int64(Some(1)), - &DataValue::Int64(Some(1)), - &BinaryOperator::Plus, - )?; - - assert_eq!(plus_i64_1, plus_i64_2); - assert_eq!(plus_i64_2, plus_i64_3); - assert_eq!(plus_i64_4, DataValue::Int64(Some(2))); - - let plus_f64_1 = DataValue::binary_op( - &DataValue::Float64(None), - &DataValue::Float64(None), - &BinaryOperator::Plus, - )?; - let plus_f64_2 = DataValue::binary_op( - &DataValue::Float64(Some(1.0)), - &DataValue::Float64(None), - &BinaryOperator::Plus, - )?; - let plus_f64_3 = DataValue::binary_op( - &DataValue::Float64(None), - &DataValue::Float64(Some(1.0)), - &BinaryOperator::Plus, - )?; - let plus_f64_4 = DataValue::binary_op( - &DataValue::Float64(Some(1.0)), - &DataValue::Float64(Some(1.0)), - &BinaryOperator::Plus, - )?; - - assert_eq!(plus_f64_1, plus_f64_2); - assert_eq!(plus_f64_2, plus_f64_3); - assert_eq!(plus_f64_4, DataValue::Float64(Some(2.0))); - - Ok(()) - } - - #[test] - fn test_binary_op_arithmetic_minus() -> Result<(), DatabaseError> { - let minus_i32_1 = DataValue::binary_op( - &DataValue::Int32(None), - &DataValue::Int32(None), - &BinaryOperator::Minus, - )?; - let minus_i32_2 = DataValue::binary_op( - &DataValue::Int32(Some(1)), - &DataValue::Int32(None), - &BinaryOperator::Minus, - )?; - let minus_i32_3 = DataValue::binary_op( - &DataValue::Int32(None), - &DataValue::Int32(Some(1)), - &BinaryOperator::Minus, - )?; - let minus_i32_4 = DataValue::binary_op( - &DataValue::Int32(Some(1)), - &DataValue::Int32(Some(1)), - &BinaryOperator::Minus, - )?; - - assert_eq!(minus_i32_1, minus_i32_2); - assert_eq!(minus_i32_2, minus_i32_3); - assert_eq!(minus_i32_4, DataValue::Int32(Some(0))); - - let minus_i64_1 = DataValue::binary_op( - &DataValue::Int64(None), - &DataValue::Int64(None), - &BinaryOperator::Minus, - )?; - let minus_i64_2 = DataValue::binary_op( - &DataValue::Int64(Some(1)), - &DataValue::Int64(None), - &BinaryOperator::Minus, - )?; - let minus_i64_3 = DataValue::binary_op( - &DataValue::Int64(None), - &DataValue::Int64(Some(1)), - &BinaryOperator::Minus, - )?; - let minus_i64_4 = DataValue::binary_op( - &DataValue::Int64(Some(1)), - &DataValue::Int64(Some(1)), - &BinaryOperator::Minus, - )?; - - assert_eq!(minus_i64_1, minus_i64_2); - assert_eq!(minus_i64_2, minus_i64_3); - assert_eq!(minus_i64_4, DataValue::Int64(Some(0))); - - let minus_f64_1 = DataValue::binary_op( - &DataValue::Float64(None), - &DataValue::Float64(None), - &BinaryOperator::Minus, - )?; - let minus_f64_2 = DataValue::binary_op( - &DataValue::Float64(Some(1.0)), - &DataValue::Float64(None), - &BinaryOperator::Minus, - )?; - let minus_f64_3 = DataValue::binary_op( - &DataValue::Float64(None), - &DataValue::Float64(Some(1.0)), - &BinaryOperator::Minus, - )?; - let minus_f64_4 = DataValue::binary_op( - &DataValue::Float64(Some(1.0)), - &DataValue::Float64(Some(1.0)), - &BinaryOperator::Minus, - )?; - - assert_eq!(minus_f64_1, minus_f64_2); - assert_eq!(minus_f64_2, minus_f64_3); - assert_eq!(minus_f64_4, DataValue::Float64(Some(0.0))); - - Ok(()) - } - - #[test] - fn test_binary_op_arithmetic_multiply() -> Result<(), DatabaseError> { - let multiply_i32_1 = DataValue::binary_op( - &DataValue::Int32(None), - &DataValue::Int32(None), - &BinaryOperator::Multiply, - )?; - let multiply_i32_2 = DataValue::binary_op( - &DataValue::Int32(Some(1)), - &DataValue::Int32(None), - &BinaryOperator::Multiply, - )?; - let multiply_i32_3 = DataValue::binary_op( - &DataValue::Int32(None), - &DataValue::Int32(Some(1)), - &BinaryOperator::Multiply, - )?; - let multiply_i32_4 = DataValue::binary_op( - &DataValue::Int32(Some(1)), - &DataValue::Int32(Some(1)), - &BinaryOperator::Multiply, - )?; - - assert_eq!(multiply_i32_1, multiply_i32_2); - assert_eq!(multiply_i32_2, multiply_i32_3); - assert_eq!(multiply_i32_4, DataValue::Int32(Some(1))); - - let multiply_i64_1 = DataValue::binary_op( - &DataValue::Int64(None), - &DataValue::Int64(None), - &BinaryOperator::Multiply, - )?; - let multiply_i64_2 = DataValue::binary_op( - &DataValue::Int64(Some(1)), - &DataValue::Int64(None), - &BinaryOperator::Multiply, - )?; - let multiply_i64_3 = DataValue::binary_op( - &DataValue::Int64(None), - &DataValue::Int64(Some(1)), - &BinaryOperator::Multiply, - )?; - let multiply_i64_4 = DataValue::binary_op( - &DataValue::Int64(Some(1)), - &DataValue::Int64(Some(1)), - &BinaryOperator::Multiply, - )?; - - assert_eq!(multiply_i64_1, multiply_i64_2); - assert_eq!(multiply_i64_2, multiply_i64_3); - assert_eq!(multiply_i64_4, DataValue::Int64(Some(1))); - - let multiply_f64_1 = DataValue::binary_op( - &DataValue::Float64(None), - &DataValue::Float64(None), - &BinaryOperator::Multiply, - )?; - let multiply_f64_2 = DataValue::binary_op( - &DataValue::Float64(Some(1.0)), - &DataValue::Float64(None), - &BinaryOperator::Multiply, - )?; - let multiply_f64_3 = DataValue::binary_op( - &DataValue::Float64(None), - &DataValue::Float64(Some(1.0)), - &BinaryOperator::Multiply, - )?; - let multiply_f64_4 = DataValue::binary_op( - &DataValue::Float64(Some(1.0)), - &DataValue::Float64(Some(1.0)), - &BinaryOperator::Multiply, - )?; - - assert_eq!(multiply_f64_1, multiply_f64_2); - assert_eq!(multiply_f64_2, multiply_f64_3); - assert_eq!(multiply_f64_4, DataValue::Float64(Some(1.0))); - - Ok(()) - } - - #[test] - fn test_binary_op_arithmetic_divide() -> Result<(), DatabaseError> { - let divide_i32_1 = DataValue::binary_op( - &DataValue::Int32(None), - &DataValue::Int32(None), - &BinaryOperator::Divide, - )?; - let divide_i32_2 = DataValue::binary_op( - &DataValue::Int32(Some(1)), - &DataValue::Int32(None), - &BinaryOperator::Divide, - )?; - let divide_i32_3 = DataValue::binary_op( - &DataValue::Int32(None), - &DataValue::Int32(Some(1)), - &BinaryOperator::Divide, - )?; - let divide_i32_4 = DataValue::binary_op( - &DataValue::Int32(Some(1)), - &DataValue::Int32(Some(1)), - &BinaryOperator::Divide, - )?; - - assert_eq!(divide_i32_1, divide_i32_2); - assert_eq!(divide_i32_2, divide_i32_3); - assert_eq!(divide_i32_4, DataValue::Float64(Some(1.0))); - - let divide_i64_1 = DataValue::binary_op( - &DataValue::Int64(None), - &DataValue::Int64(None), - &BinaryOperator::Divide, - )?; - let divide_i64_2 = DataValue::binary_op( - &DataValue::Int64(Some(1)), - &DataValue::Int64(None), - &BinaryOperator::Divide, - )?; - let divide_i64_3 = DataValue::binary_op( - &DataValue::Int64(None), - &DataValue::Int64(Some(1)), - &BinaryOperator::Divide, - )?; - let divide_i64_4 = DataValue::binary_op( - &DataValue::Int64(Some(1)), - &DataValue::Int64(Some(1)), - &BinaryOperator::Divide, - )?; - - assert_eq!(divide_i64_1, divide_i64_2); - assert_eq!(divide_i64_2, divide_i64_3); - assert_eq!(divide_i64_4, DataValue::Float64(Some(1.0))); - - let divide_f64_1 = DataValue::binary_op( - &DataValue::Float64(None), - &DataValue::Float64(None), - &BinaryOperator::Divide, - )?; - let divide_f64_2 = DataValue::binary_op( - &DataValue::Float64(Some(1.0)), - &DataValue::Float64(None), - &BinaryOperator::Divide, - )?; - let divide_f64_3 = DataValue::binary_op( - &DataValue::Float64(None), - &DataValue::Float64(Some(1.0)), - &BinaryOperator::Divide, - )?; - let divide_f64_4 = DataValue::binary_op( - &DataValue::Float64(Some(1.0)), - &DataValue::Float64(Some(1.0)), - &BinaryOperator::Divide, - )?; - - assert_eq!(divide_f64_1, divide_f64_2); - assert_eq!(divide_f64_2, divide_f64_3); - assert_eq!(divide_f64_4, DataValue::Float64(Some(1.0))); - - Ok(()) - } - - #[test] - fn test_binary_op_cast() -> Result<(), DatabaseError> { - let i32_cast_1 = DataValue::binary_op( - &DataValue::Int32(Some(1)), - &DataValue::Int8(Some(1)), - &BinaryOperator::Plus, - )?; - let i32_cast_2 = DataValue::binary_op( - &DataValue::Int32(Some(1)), - &DataValue::Int16(Some(1)), - &BinaryOperator::Plus, - )?; - - assert_eq!(i32_cast_1, i32_cast_2); - - let i64_cast_1 = DataValue::binary_op( - &DataValue::Int64(Some(1)), - &DataValue::Int8(Some(1)), - &BinaryOperator::Plus, - )?; - let i64_cast_2 = DataValue::binary_op( - &DataValue::Int64(Some(1)), - &DataValue::Int16(Some(1)), - &BinaryOperator::Plus, - )?; - let i64_cast_3 = DataValue::binary_op( - &DataValue::Int64(Some(1)), - &DataValue::Int32(Some(1)), - &BinaryOperator::Plus, - )?; - - assert_eq!(i64_cast_1, i64_cast_2); - assert_eq!(i64_cast_2, i64_cast_3); - - let f64_cast_1 = DataValue::binary_op( - &DataValue::Float64(Some(1.0)), - &DataValue::Float32(Some(1.0)), - &BinaryOperator::Plus, - )?; - assert_eq!(f64_cast_1, DataValue::Float64(Some(2.0))); - - Ok(()) - } - - #[test] - fn test_binary_op_i32_compare() -> Result<(), DatabaseError> { - assert_eq!( - DataValue::binary_op( - &DataValue::Int32(Some(1)), - &DataValue::Int32(Some(0)), - &BinaryOperator::Gt - )?, - DataValue::Boolean(Some(true)) - ); - assert_eq!( - DataValue::binary_op( - &DataValue::Int32(Some(1)), - &DataValue::Int32(Some(0)), - &BinaryOperator::Lt - )?, - DataValue::Boolean(Some(false)) - ); - assert_eq!( - DataValue::binary_op( - &DataValue::Int32(Some(1)), - &DataValue::Int32(Some(1)), - &BinaryOperator::GtEq - )?, - DataValue::Boolean(Some(true)) - ); - assert_eq!( - DataValue::binary_op( - &DataValue::Int32(Some(1)), - &DataValue::Int32(Some(1)), - &BinaryOperator::LtEq - )?, - DataValue::Boolean(Some(true)) - ); - assert_eq!( - DataValue::binary_op( - &DataValue::Int32(Some(1)), - &DataValue::Int32(Some(1)), - &BinaryOperator::NotEq - )?, - DataValue::Boolean(Some(false)) - ); - assert_eq!( - DataValue::binary_op( - &DataValue::Int32(Some(1)), - &DataValue::Int32(Some(1)), - &BinaryOperator::Eq - )?, - DataValue::Boolean(Some(true)) - ); - - assert_eq!( - DataValue::binary_op( - &DataValue::Int32(None), - &DataValue::Int32(Some(0)), - &BinaryOperator::Gt - )?, - DataValue::Boolean(None) - ); - assert_eq!( - DataValue::binary_op( - &DataValue::Int32(None), - &DataValue::Int32(Some(0)), - &BinaryOperator::Lt - )?, - DataValue::Boolean(None) - ); - assert_eq!( - DataValue::binary_op( - &DataValue::Int32(None), - &DataValue::Int32(Some(1)), - &BinaryOperator::GtEq - )?, - DataValue::Boolean(None) - ); - assert_eq!( - DataValue::binary_op( - &DataValue::Int32(None), - &DataValue::Int32(Some(1)), - &BinaryOperator::LtEq - )?, - DataValue::Boolean(None) - ); - assert_eq!( - DataValue::binary_op( - &DataValue::Int32(None), - &DataValue::Int32(Some(1)), - &BinaryOperator::NotEq - )?, - DataValue::Boolean(None) - ); - - assert_eq!( - DataValue::binary_op( - &DataValue::Int32(None), - &DataValue::Int32(Some(1)), - &BinaryOperator::Eq - )?, - DataValue::Boolean(None) - ); - assert_eq!( - DataValue::binary_op( - &DataValue::Int32(None), - &DataValue::Int32(None), - &BinaryOperator::Eq - )?, - DataValue::Boolean(None) - ); - - Ok(()) - } - - #[test] - fn test_binary_op_i64_compare() -> Result<(), DatabaseError> { - assert_eq!( - DataValue::binary_op( - &DataValue::Int64(Some(1)), - &DataValue::Int64(Some(0)), - &BinaryOperator::Gt - )?, - DataValue::Boolean(Some(true)) - ); - assert_eq!( - DataValue::binary_op( - &DataValue::Int64(Some(1)), - &DataValue::Int64(Some(0)), - &BinaryOperator::Lt - )?, - DataValue::Boolean(Some(false)) - ); - assert_eq!( - DataValue::binary_op( - &DataValue::Int64(Some(1)), - &DataValue::Int64(Some(1)), - &BinaryOperator::GtEq - )?, - DataValue::Boolean(Some(true)) - ); - assert_eq!( - DataValue::binary_op( - &DataValue::Int64(Some(1)), - &DataValue::Int64(Some(1)), - &BinaryOperator::LtEq - )?, - DataValue::Boolean(Some(true)) - ); - assert_eq!( - DataValue::binary_op( - &DataValue::Int64(Some(1)), - &DataValue::Int64(Some(1)), - &BinaryOperator::NotEq - )?, - DataValue::Boolean(Some(false)) - ); - assert_eq!( - DataValue::binary_op( - &DataValue::Int64(Some(1)), - &DataValue::Int64(Some(1)), - &BinaryOperator::Eq - )?, - DataValue::Boolean(Some(true)) - ); - - assert_eq!( - DataValue::binary_op( - &DataValue::Int64(None), - &DataValue::Int64(Some(0)), - &BinaryOperator::Gt - )?, - DataValue::Boolean(None) - ); - assert_eq!( - DataValue::binary_op( - &DataValue::Int64(None), - &DataValue::Int64(Some(0)), - &BinaryOperator::Lt - )?, - DataValue::Boolean(None) - ); - assert_eq!( - DataValue::binary_op( - &DataValue::Int64(None), - &DataValue::Int64(Some(1)), - &BinaryOperator::GtEq - )?, - DataValue::Boolean(None) - ); - assert_eq!( - DataValue::binary_op( - &DataValue::Int64(None), - &DataValue::Int64(Some(1)), - &BinaryOperator::LtEq - )?, - DataValue::Boolean(None) - ); - assert_eq!( - DataValue::binary_op( - &DataValue::Int64(None), - &DataValue::Int64(Some(1)), - &BinaryOperator::NotEq - )?, - DataValue::Boolean(None) - ); - - assert_eq!( - DataValue::binary_op( - &DataValue::Int64(None), - &DataValue::Int64(Some(1)), - &BinaryOperator::Eq - )?, - DataValue::Boolean(None) - ); - assert_eq!( - DataValue::binary_op( - &DataValue::Int64(None), - &DataValue::Int64(None), - &BinaryOperator::Eq - )?, - DataValue::Boolean(None) - ); - - Ok(()) - } - - #[test] - fn test_binary_op_f64_compare() -> Result<(), DatabaseError> { - assert_eq!( - DataValue::binary_op( - &DataValue::Float64(Some(1.0)), - &DataValue::Float64(Some(0.0)), - &BinaryOperator::Gt - )?, - DataValue::Boolean(Some(true)) - ); - assert_eq!( - DataValue::binary_op( - &DataValue::Float64(Some(1.0)), - &DataValue::Float64(Some(0.0)), - &BinaryOperator::Lt - )?, - DataValue::Boolean(Some(false)) - ); - assert_eq!( - DataValue::binary_op( - &DataValue::Float64(Some(1.0)), - &DataValue::Float64(Some(1.0)), - &BinaryOperator::GtEq - )?, - DataValue::Boolean(Some(true)) - ); - assert_eq!( - DataValue::binary_op( - &DataValue::Float64(Some(1.0)), - &DataValue::Float64(Some(1.0)), - &BinaryOperator::LtEq - )?, - DataValue::Boolean(Some(true)) - ); - assert_eq!( - DataValue::binary_op( - &DataValue::Float64(Some(1.0)), - &DataValue::Float64(Some(1.0)), - &BinaryOperator::NotEq - )?, - DataValue::Boolean(Some(false)) - ); - assert_eq!( - DataValue::binary_op( - &DataValue::Float64(Some(1.0)), - &DataValue::Float64(Some(1.0)), - &BinaryOperator::Eq - )?, - DataValue::Boolean(Some(true)) - ); - - assert_eq!( - DataValue::binary_op( - &DataValue::Float64(None), - &DataValue::Float64(Some(0.0)), - &BinaryOperator::Gt - )?, - DataValue::Boolean(None) - ); - assert_eq!( - DataValue::binary_op( - &DataValue::Float64(None), - &DataValue::Float64(Some(0.0)), - &BinaryOperator::Lt - )?, - DataValue::Boolean(None) - ); - assert_eq!( - DataValue::binary_op( - &DataValue::Float64(None), - &DataValue::Float64(Some(1.0)), - &BinaryOperator::GtEq - )?, - DataValue::Boolean(None) - ); - assert_eq!( - DataValue::binary_op( - &DataValue::Float64(None), - &DataValue::Float64(Some(1.0)), - &BinaryOperator::LtEq - )?, - DataValue::Boolean(None) - ); - assert_eq!( - DataValue::binary_op( - &DataValue::Float64(None), - &DataValue::Float64(Some(1.0)), - &BinaryOperator::NotEq - )?, - DataValue::Boolean(None) - ); - - assert_eq!( - DataValue::binary_op( - &DataValue::Float64(None), - &DataValue::Float64(Some(1.0)), - &BinaryOperator::Eq - )?, - DataValue::Boolean(None) - ); - assert_eq!( - DataValue::binary_op( - &DataValue::Float64(None), - &DataValue::Float64(None), - &BinaryOperator::Eq - )?, - DataValue::Boolean(None) - ); - - Ok(()) - } - - #[test] - fn test_binary_op_f32_compare() -> Result<(), DatabaseError> { - assert_eq!( - DataValue::binary_op( - &DataValue::Float32(Some(1.0)), - &DataValue::Float32(Some(0.0)), - &BinaryOperator::Gt - )?, - DataValue::Boolean(Some(true)) - ); - assert_eq!( - DataValue::binary_op( - &DataValue::Float32(Some(1.0)), - &DataValue::Float32(Some(0.0)), - &BinaryOperator::Lt - )?, - DataValue::Boolean(Some(false)) - ); - assert_eq!( - DataValue::binary_op( - &DataValue::Float32(Some(1.0)), - &DataValue::Float32(Some(1.0)), - &BinaryOperator::GtEq - )?, - DataValue::Boolean(Some(true)) - ); - assert_eq!( - DataValue::binary_op( - &DataValue::Float32(Some(1.0)), - &DataValue::Float32(Some(1.0)), - &BinaryOperator::LtEq - )?, - DataValue::Boolean(Some(true)) - ); - assert_eq!( - DataValue::binary_op( - &DataValue::Float32(Some(1.0)), - &DataValue::Float32(Some(1.0)), - &BinaryOperator::NotEq - )?, - DataValue::Boolean(Some(false)) - ); - assert_eq!( - DataValue::binary_op( - &DataValue::Float32(Some(1.0)), - &DataValue::Float32(Some(1.0)), - &BinaryOperator::Eq - )?, - DataValue::Boolean(Some(true)) - ); - - assert_eq!( - DataValue::binary_op( - &DataValue::Float32(None), - &DataValue::Float32(Some(0.0)), - &BinaryOperator::Gt - )?, - DataValue::Boolean(None) - ); - assert_eq!( - DataValue::binary_op( - &DataValue::Float32(None), - &DataValue::Float32(Some(0.0)), - &BinaryOperator::Lt - )?, - DataValue::Boolean(None) - ); - assert_eq!( - DataValue::binary_op( - &DataValue::Float32(None), - &DataValue::Float32(Some(1.0)), - &BinaryOperator::GtEq - )?, - DataValue::Boolean(None) - ); - assert_eq!( - DataValue::binary_op( - &DataValue::Float32(None), - &DataValue::Float32(Some(1.0)), - &BinaryOperator::LtEq - )?, - DataValue::Boolean(None) - ); - assert_eq!( - DataValue::binary_op( - &DataValue::Float32(None), - &DataValue::Float32(Some(1.0)), - &BinaryOperator::NotEq - )?, - DataValue::Boolean(None) - ); - - assert_eq!( - DataValue::binary_op( - &DataValue::Float32(None), - &DataValue::Float32(Some(1.0)), - &BinaryOperator::Eq - )?, - DataValue::Boolean(None) - ); - assert_eq!( - DataValue::binary_op( - &DataValue::Float32(None), - &DataValue::Float32(None), - &BinaryOperator::Eq - )?, - DataValue::Boolean(None) - ); - - Ok(()) - } - - #[test] - fn test_binary_op_bool_compare() -> Result<(), DatabaseError> { - assert_eq!( - DataValue::binary_op( - &DataValue::Boolean(Some(true)), - &DataValue::Boolean(Some(true)), - &BinaryOperator::And - )?, - DataValue::Boolean(Some(true)) - ); - assert_eq!( - DataValue::binary_op( - &DataValue::Boolean(Some(false)), - &DataValue::Boolean(Some(true)), - &BinaryOperator::And - )?, - DataValue::Boolean(Some(false)) - ); - assert_eq!( - DataValue::binary_op( - &DataValue::Boolean(Some(false)), - &DataValue::Boolean(Some(false)), - &BinaryOperator::And - )?, - DataValue::Boolean(Some(false)) - ); - - assert_eq!( - DataValue::binary_op( - &DataValue::Boolean(None), - &DataValue::Boolean(Some(true)), - &BinaryOperator::And - )?, - DataValue::Boolean(None) - ); - - assert_eq!( - DataValue::binary_op( - &DataValue::Boolean(Some(true)), - &DataValue::Boolean(Some(true)), - &BinaryOperator::Or - )?, - DataValue::Boolean(Some(true)) - ); - assert_eq!( - DataValue::binary_op( - &DataValue::Boolean(Some(false)), - &DataValue::Boolean(Some(true)), - &BinaryOperator::Or - )?, - DataValue::Boolean(Some(true)) - ); - assert_eq!( - DataValue::binary_op( - &DataValue::Boolean(Some(false)), - &DataValue::Boolean(Some(false)), - &BinaryOperator::Or - )?, - DataValue::Boolean(Some(false)) - ); - - assert_eq!( - DataValue::binary_op( - &DataValue::Boolean(None), - &DataValue::Boolean(Some(true)), - &BinaryOperator::Or - )?, - DataValue::Boolean(Some(true)) - ); - - Ok(()) - } - - #[test] - fn test_binary_op_utf8_compare() -> Result<(), DatabaseError> { - assert_eq!( - DataValue::binary_op( - &DataValue::Utf8 { - value: Some("a".to_string()), - ty: Utf8Type::Variable(None), - unit: CharLengthUnits::Characters, - }, - &DataValue::Utf8 { - value: Some("b".to_string()), - ty: Utf8Type::Variable(None), - unit: CharLengthUnits::Characters, - }, - &BinaryOperator::Gt - )?, - DataValue::Boolean(Some(false)) - ); - assert_eq!( - DataValue::binary_op( - &DataValue::Utf8 { - value: Some("a".to_string()), - ty: Utf8Type::Variable(None), - unit: CharLengthUnits::Characters, - }, - &DataValue::Utf8 { - value: Some("b".to_string()), - ty: Utf8Type::Variable(None), - unit: CharLengthUnits::Characters, - }, - &BinaryOperator::Lt - )?, - DataValue::Boolean(Some(true)) - ); - assert_eq!( - DataValue::binary_op( - &DataValue::Utf8 { - value: Some("a".to_string()), - ty: Utf8Type::Variable(None), - unit: CharLengthUnits::Characters, - }, - &DataValue::Utf8 { - value: Some("a".to_string()), - ty: Utf8Type::Variable(None), - unit: CharLengthUnits::Characters, - }, - &BinaryOperator::GtEq - )?, - DataValue::Boolean(Some(true)) - ); - assert_eq!( - DataValue::binary_op( - &DataValue::Utf8 { - value: Some("a".to_string()), - ty: Utf8Type::Variable(None), - unit: CharLengthUnits::Characters, - }, - &DataValue::Utf8 { - value: Some("a".to_string()), - ty: Utf8Type::Variable(None), - unit: CharLengthUnits::Characters, - }, - &BinaryOperator::LtEq - )?, - DataValue::Boolean(Some(true)) - ); - assert_eq!( - DataValue::binary_op( - &DataValue::Utf8 { - value: Some("a".to_string()), - ty: Utf8Type::Variable(None), - unit: CharLengthUnits::Characters, - }, - &DataValue::Utf8 { - value: Some("a".to_string()), - ty: Utf8Type::Variable(None), - unit: CharLengthUnits::Characters, - }, - &BinaryOperator::NotEq - )?, - DataValue::Boolean(Some(false)) - ); - assert_eq!( - DataValue::binary_op( - &DataValue::Utf8 { - value: Some("a".to_string()), - ty: Utf8Type::Variable(None), - unit: CharLengthUnits::Characters, - }, - &DataValue::Utf8 { - value: Some("a".to_string()), - ty: Utf8Type::Variable(None), - unit: CharLengthUnits::Characters, - }, - &BinaryOperator::Eq - )?, - DataValue::Boolean(Some(true)) - ); - - assert_eq!( - DataValue::binary_op( - &DataValue::Utf8 { - value: None, - ty: Utf8Type::Variable(None), - unit: CharLengthUnits::Characters, - }, - &DataValue::Utf8 { - value: Some("a".to_string()), - ty: Utf8Type::Variable(None), - unit: CharLengthUnits::Characters, - }, - &BinaryOperator::Gt - )?, - DataValue::Boolean(None) - ); - assert_eq!( - DataValue::binary_op( - &DataValue::Utf8 { - value: None, - ty: Utf8Type::Variable(None), - unit: CharLengthUnits::Characters, - }, - &DataValue::Utf8 { - value: Some("a".to_string()), - ty: Utf8Type::Variable(None), - unit: CharLengthUnits::Characters, - }, - &BinaryOperator::Lt - )?, - DataValue::Boolean(None) - ); - assert_eq!( - DataValue::binary_op( - &DataValue::Utf8 { - value: None, - ty: Utf8Type::Variable(None), - unit: CharLengthUnits::Characters, - }, - &DataValue::Utf8 { - value: Some("a".to_string()), - ty: Utf8Type::Variable(None), - unit: CharLengthUnits::Characters, - }, - &BinaryOperator::GtEq - )?, - DataValue::Boolean(None) - ); - assert_eq!( - DataValue::binary_op( - &DataValue::Utf8 { - value: None, - ty: Utf8Type::Variable(None), - unit: CharLengthUnits::Characters, - }, - &DataValue::Utf8 { - value: Some("a".to_string()), - ty: Utf8Type::Variable(None), - unit: CharLengthUnits::Characters, - }, - &BinaryOperator::LtEq - )?, - DataValue::Boolean(None) - ); - assert_eq!( - DataValue::binary_op( - &DataValue::Utf8 { - value: None, - ty: Utf8Type::Variable(None), - unit: CharLengthUnits::Characters, - }, - &DataValue::Utf8 { - value: Some("a".to_string()), - ty: Utf8Type::Variable(None), - unit: CharLengthUnits::Characters, - }, - &BinaryOperator::NotEq - )?, - DataValue::Boolean(None) - ); - - Ok(()) - } } diff --git a/src/marcos/mod.rs b/src/marcos/mod.rs index 8042d41d..8d1b8729 100644 --- a/src/marcos/mod.rs +++ b/src/marcos/mod.rs @@ -132,6 +132,7 @@ mod test { use crate::expression::function::{FuncMonotonicity, FunctionSummary, ScalarFunctionImpl}; use crate::expression::BinaryOperator; use crate::expression::ScalarExpression; + use crate::types::evaluator::EvaluatorFactory; use crate::types::tuple::{SchemaRef, Tuple}; use crate::types::value::{DataValue, Utf8Type, ValueRef}; use crate::types::LogicalType; @@ -203,7 +204,9 @@ mod test { } function!(MyFunction::sum(LogicalType::Integer, LogicalType::Integer) -> LogicalType::Integer => (|v1: ValueRef, v2: ValueRef| { - DataValue::binary_op(&v1, &v2, &BinaryOperator::Plus) + let plus_evaluator = EvaluatorFactory::binary_create(LogicalType::Integer, BinaryOperator::Plus)?; + + Ok(plus_evaluator.0.binary_eval(&v1, &v2)) })); #[test] diff --git a/src/optimizer/core/histogram.rs b/src/optimizer/core/histogram.rs index 6f1b9954..f3adb672 100644 --- a/src/optimizer/core/histogram.rs +++ b/src/optimizer/core/histogram.rs @@ -3,6 +3,7 @@ use crate::execution::volcano::dql::sort::radix_sort; use crate::expression::range_detacher::Range; use crate::expression::BinaryOperator; use crate::optimizer::core::cm_sketch::CountMinSketch; +use crate::types::evaluator::EvaluatorFactory; use crate::types::index::{IndexId, IndexMeta}; use crate::types::value::{DataValue, ValueRef}; use crate::types::LogicalType; @@ -159,15 +160,16 @@ fn is_under( is_min: bool, ) -> Result { let _is_under = |value: &ValueRef, target: &ValueRef, is_min: bool| { - let res = value.binary_op( - target, - &if is_min { + let evaluator = EvaluatorFactory::binary_create( + value.logical_type(), + if is_min { BinaryOperator::Lt } else { BinaryOperator::LtEq }, )?; - Ok::(matches!(res, DataValue::Boolean(Some(true)))) + let value = evaluator.0.binary_eval(value, target); + Ok::(matches!(value, DataValue::Boolean(Some(true)))) }; Ok(match target { @@ -183,15 +185,16 @@ fn is_above( is_min: bool, ) -> Result { let _is_above = |value: &ValueRef, target: &ValueRef, is_min: bool| { - let res = value.binary_op( - target, - &if is_min { + let evaluator = EvaluatorFactory::binary_create( + value.logical_type(), + if is_min { BinaryOperator::GtEq } else { BinaryOperator::Gt }, )?; - Ok::(matches!(res, DataValue::Boolean(Some(true)))) + let value = evaluator.0.binary_eval(value, target); + Ok::(matches!(value, DataValue::Boolean(Some(true)))) }; Ok(match target { Bound::Included(target) => _is_above(value, target, is_min)?, diff --git a/src/optimizer/core/memo.rs b/src/optimizer/core/memo.rs index cf926398..2df1ee44 100644 --- a/src/optimizer/core/memo.rs +++ b/src/optimizer/core/memo.rs @@ -126,8 +126,8 @@ mod tests { BinderContext::new(&transaction, &functions, Arc::new(AtomicUsize::new(0))), None, ); + // where: c1 => 2, (40, +inf) let stmt = crate::parser::parse_sql( - // FIXME: Only by bracketing (c1 > 40 or c1 = 2) can the filter be pushed down below the join "select c1, c3 from t1 inner join t2 on c1 = c3 where (c1 > 40 or c1 = 2) and c3 > 22", )?; let plan = binder.bind(&stmt[0])?; diff --git a/src/optimizer/rule/normalization/combine_operators.rs b/src/optimizer/rule/normalization/combine_operators.rs index 5e63c07a..0aff8a08 100644 --- a/src/optimizer/rule/normalization/combine_operators.rs +++ b/src/optimizer/rule/normalization/combine_operators.rs @@ -88,6 +88,7 @@ impl NormalizationRule for CombineFilter { op: BinaryOperator::And, left_expr: Box::new(op.predicate), right_expr: Box::new(child_op.predicate.clone()), + evaluator: None, ty: LogicalType::Boolean, }; child_op.having = op.having || child_op.having; @@ -212,6 +213,7 @@ mod tests { op: BinaryOperator::Eq, left_expr: Box::new(Constant(Arc::new(DataValue::Int8(Some(1))))), right_expr: Box::new(Constant(Arc::new(DataValue::Int8(Some(1))))), + evaluator: None, ty: LogicalType::Boolean, } } else { diff --git a/src/optimizer/rule/normalization/expression_remapper.rs b/src/optimizer/rule/normalization/compilation_in_advance.rs similarity index 56% rename from src/optimizer/rule/normalization/expression_remapper.rs rename to src/optimizer/rule/normalization/compilation_in_advance.rs index 7817030b..0f3ba62e 100644 --- a/src/optimizer/rule/normalization/expression_remapper.rs +++ b/src/optimizer/rule/normalization/compilation_in_advance.rs @@ -16,6 +16,15 @@ lazy_static! { }; } +lazy_static! { + static ref EVALUATOR_BIND_RULE: Pattern = { + Pattern { + predicate: |_| true, + children: PatternChildrenPredicate::None, + } + }; +} + #[derive(Clone)] pub struct ExpressionRemapper; @@ -119,3 +128,96 @@ impl NormalizationRule for ExpressionRemapper { Ok(()) } } + +#[derive(Clone)] +pub struct EvaluatorBind; + +impl EvaluatorBind { + fn _apply(node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), DatabaseError> { + if let Some(child_id) = graph.eldest_child_at(node_id) { + Self::_apply(child_id, graph)?; + } + // for join + if let Operator::Join(_) = graph.operator(node_id) { + if let Some(child_id) = graph.youngest_child_at(node_id) { + Self::_apply(child_id, graph)?; + } + } + let operator = graph.operator_mut(node_id); + + match operator { + Operator::Join(op) => { + match &mut op.on { + JoinCondition::On { on, filter } => { + for (left_expr, right_expr) in on { + left_expr.bind_evaluator()?; + right_expr.bind_evaluator()?; + } + if let Some(expr) = filter { + expr.bind_evaluator()?; + } + } + JoinCondition::None => {} + } + + return Ok(()); + } + Operator::Aggregate(op) => { + for expr in op.agg_calls.iter_mut().chain(op.groupby_exprs.iter_mut()) { + expr.bind_evaluator()?; + } + } + Operator::Filter(op) => { + op.predicate.bind_evaluator()?; + } + Operator::Project(op) => { + for expr in op.exprs.iter_mut() { + expr.bind_evaluator()?; + } + } + Operator::Sort(op) => { + for sort_field in op.sort_fields.iter_mut() { + sort_field.expr.bind_evaluator()?; + } + } + Operator::Dummy + | Operator::Scan(_) + | Operator::Limit(_) + | Operator::Values(_) + | Operator::Show + | Operator::Explain + | Operator::Describe(_) + | Operator::Insert(_) + | Operator::Update(_) + | Operator::Delete(_) + | Operator::Analyze(_) + | Operator::AddColumn(_) + | Operator::DropColumn(_) + | Operator::CreateTable(_) + | Operator::CreateIndex(_) + | Operator::DropTable(_) + | Operator::Truncate(_) + | Operator::CopyFromFile(_) + | Operator::CopyToFile(_) + | Operator::Union(_) => (), + } + + Ok(()) + } +} + +impl MatchPattern for EvaluatorBind { + fn pattern(&self) -> &Pattern { + &EVALUATOR_BIND_RULE + } +} + +impl NormalizationRule for EvaluatorBind { + fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), DatabaseError> { + Self::_apply(node_id, graph)?; + // mark changed to skip this rule batch + graph.version += 1; + + Ok(()) + } +} diff --git a/src/optimizer/rule/normalization/mod.rs b/src/optimizer/rule/normalization/mod.rs index f83155af..d0ee07f6 100644 --- a/src/optimizer/rule/normalization/mod.rs +++ b/src/optimizer/rule/normalization/mod.rs @@ -7,7 +7,9 @@ use crate::optimizer::rule::normalization::column_pruning::ColumnPruning; use crate::optimizer::rule::normalization::combine_operators::{ CollapseGroupByAgg, CollapseProject, CombineFilter, }; -use crate::optimizer::rule::normalization::expression_remapper::ExpressionRemapper; +use crate::optimizer::rule::normalization::compilation_in_advance::{ + EvaluatorBind, ExpressionRemapper, +}; use crate::optimizer::rule::normalization::pushdown_limit::{ EliminateLimits, LimitProjectTranspose, PushLimitIntoScan, PushLimitThroughJoin, }; @@ -18,7 +20,7 @@ use crate::optimizer::rule::normalization::simplification::SimplifyFilter; mod column_pruning; mod combine_operators; -mod expression_remapper; +mod compilation_in_advance; mod pushdown_limit; mod pushdown_predicates; mod simplification; @@ -42,8 +44,9 @@ pub enum NormalizationRuleImpl { // Simplification SimplifyFilter, ConstantCalculation, - // ColumnRemapper + // CompilationInAdvance ExpressionRemapper, + EvaluatorBind, } impl MatchPattern for NormalizationRuleImpl { @@ -62,6 +65,7 @@ impl MatchPattern for NormalizationRuleImpl { NormalizationRuleImpl::SimplifyFilter => SimplifyFilter.pattern(), NormalizationRuleImpl::ConstantCalculation => ConstantCalculation.pattern(), NormalizationRuleImpl::ExpressionRemapper => ExpressionRemapper.pattern(), + NormalizationRuleImpl::EvaluatorBind => EvaluatorBind.pattern(), } } } @@ -92,6 +96,7 @@ impl NormalizationRule for NormalizationRuleImpl { } NormalizationRuleImpl::ConstantCalculation => ConstantCalculation.apply(node_id, graph), NormalizationRuleImpl::ExpressionRemapper => ExpressionRemapper.apply(node_id, graph), + NormalizationRuleImpl::EvaluatorBind => EvaluatorBind.apply(node_id, graph), } } } diff --git a/src/optimizer/rule/normalization/pushdown_predicates.rs b/src/optimizer/rule/normalization/pushdown_predicates.rs index b1c6e34a..f55b07f0 100644 --- a/src/optimizer/rule/normalization/pushdown_predicates.rs +++ b/src/optimizer/rule/normalization/pushdown_predicates.rs @@ -71,6 +71,7 @@ fn reduce_filters(filters: Vec, having: bool) -> Option 0 let plan_8 = select_sql_run("select * from t1 where 1 < c1 + 1").await?; - // c1 < 24 let plan_9 = select_sql_run("select * from t1 where (-1 - c1) + 1 > 24").await?; - // c1 < 24 let plan_10 = select_sql_run("select * from t1 where 24 < (-1 - c1) + 1").await?; @@ -289,11 +287,13 @@ mod test { right_expr: Box::new(ScalarExpression::Constant(Arc::new( DataValue::Int32(Some(1)) ))), + evaluator: None, ty: LogicalType::Integer, }), ty: LogicalType::Integer, }), right_expr: Box::new(ScalarExpression::ColumnRef(Arc::new(c2_col))), + evaluator: None, ty: LogicalType::Boolean, } ) diff --git a/src/types/evaluator/boolean.rs b/src/types/evaluator/boolean.rs new file mode 100644 index 00000000..7bb9e8a3 --- /dev/null +++ b/src/types/evaluator/boolean.rs @@ -0,0 +1,91 @@ +use crate::types::evaluator::BinaryEvaluator; +use crate::types::evaluator::DataValue; +use serde::{Deserialize, Serialize}; +use std::hint; + +#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] +pub struct BooleanAndBinaryEvaluator; +#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] +pub struct BooleanOrBinaryEvaluator; +#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] +pub struct BooleanEqBinaryEvaluator; +#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] +pub struct BooleanNotEqBinaryEvaluator; + +#[typetag::serde] +impl BinaryEvaluator for BooleanAndBinaryEvaluator { + fn binary_eval(&self, left: &DataValue, right: &DataValue) -> DataValue { + let left = match left { + DataValue::Boolean(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let right = match right { + DataValue::Boolean(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let value = match (left, right) { + (Some(v1), Some(v2)) => Some(*v1 && *v2), + (Some(false), _) | (_, Some(false)) => Some(false), + _ => None, + }; + DataValue::Boolean(value) + } +} + +#[typetag::serde] +impl BinaryEvaluator for BooleanOrBinaryEvaluator { + fn binary_eval(&self, left: &DataValue, right: &DataValue) -> DataValue { + let left = match left { + DataValue::Boolean(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let right = match right { + DataValue::Boolean(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let value = match (left, right) { + (Some(v1), Some(v2)) => Some(*v1 || *v2), + (Some(true), _) | (_, Some(true)) => Some(true), + _ => None, + }; + DataValue::Boolean(value) + } +} + +#[typetag::serde] +impl BinaryEvaluator for BooleanEqBinaryEvaluator { + fn binary_eval(&self, left: &DataValue, right: &DataValue) -> DataValue { + let left = match left { + DataValue::Boolean(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let right = match right { + DataValue::Boolean(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let value = match (left, right) { + (Some(v1), Some(v2)) => Some(v1 == v2), + (_, _) => None, + }; + DataValue::Boolean(value) + } +} + +#[typetag::serde] +impl BinaryEvaluator for BooleanNotEqBinaryEvaluator { + fn binary_eval(&self, left: &DataValue, right: &DataValue) -> DataValue { + let left = match left { + DataValue::Boolean(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let right = match right { + DataValue::Boolean(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let value = match (left, right) { + (Some(v1), Some(v2)) => Some(v1 != v2), + (_, _) => None, + }; + DataValue::Boolean(value) + } +} diff --git a/src/types/evaluator/date.rs b/src/types/evaluator/date.rs new file mode 100644 index 00000000..db345d70 --- /dev/null +++ b/src/types/evaluator/date.rs @@ -0,0 +1,8 @@ +use crate::numeric_binary_evaluator_definition; +use crate::types::evaluator::BinaryEvaluator; +use crate::types::evaluator::DataValue; +use paste::paste; +use serde::{Deserialize, Serialize}; +use std::hint; + +numeric_binary_evaluator_definition!(Date, DataValue::Date32); diff --git a/src/types/evaluator/datetime.rs b/src/types/evaluator/datetime.rs new file mode 100644 index 00000000..c5dda53b --- /dev/null +++ b/src/types/evaluator/datetime.rs @@ -0,0 +1,8 @@ +use crate::numeric_binary_evaluator_definition; +use crate::types::evaluator::BinaryEvaluator; +use crate::types::evaluator::DataValue; +use paste::paste; +use serde::{Deserialize, Serialize}; +use std::hint; + +numeric_binary_evaluator_definition!(DateTime, DataValue::Date64); diff --git a/src/types/evaluator/decimal.rs b/src/types/evaluator/decimal.rs new file mode 100644 index 00000000..9139f877 --- /dev/null +++ b/src/types/evaluator/decimal.rs @@ -0,0 +1,216 @@ +use crate::types::evaluator::BinaryEvaluator; +use crate::types::evaluator::DataValue; +use serde::{Deserialize, Serialize}; +use std::hint; + +#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] +pub struct DecimalPlusBinaryEvaluator; +#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] +pub struct DecimalMinusBinaryEvaluator; +#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] +pub struct DecimalMultiplyBinaryEvaluator; +#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] +pub struct DecimalDivideBinaryEvaluator; +#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] +pub struct DecimalGtBinaryEvaluator; +#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] +pub struct DecimalGtEqBinaryEvaluator; +#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] +pub struct DecimalLtBinaryEvaluator; +#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] +pub struct DecimalLtEqBinaryEvaluator; +#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] +pub struct DecimalEqBinaryEvaluator; +#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] +pub struct DecimalNotEqBinaryEvaluator; + +#[typetag::serde] +impl BinaryEvaluator for DecimalPlusBinaryEvaluator { + fn binary_eval(&self, left: &DataValue, right: &DataValue) -> DataValue { + let left = match left { + DataValue::Decimal(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let right = match right { + DataValue::Decimal(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let value = if let (Some(v1), Some(v2)) = (left, right) { + Some(v1 + v2) + } else { + None + }; + DataValue::Decimal(value) + } +} +#[typetag::serde] +impl BinaryEvaluator for DecimalMinusBinaryEvaluator { + fn binary_eval(&self, left: &DataValue, right: &DataValue) -> DataValue { + let left = match left { + DataValue::Decimal(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let right = match right { + DataValue::Decimal(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let value = if let (Some(v1), Some(v2)) = (left, right) { + Some(v1 - v2) + } else { + None + }; + DataValue::Decimal(value) + } +} +#[typetag::serde] +impl BinaryEvaluator for DecimalMultiplyBinaryEvaluator { + fn binary_eval(&self, left: &DataValue, right: &DataValue) -> DataValue { + let left = match left { + DataValue::Decimal(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let right = match right { + DataValue::Decimal(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let value = if let (Some(v1), Some(v2)) = (left, right) { + Some(v1 * v2) + } else { + None + }; + DataValue::Decimal(value) + } +} +#[typetag::serde] +impl BinaryEvaluator for DecimalDivideBinaryEvaluator { + fn binary_eval(&self, left: &DataValue, right: &DataValue) -> DataValue { + let left = match left { + DataValue::Decimal(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let right = match right { + DataValue::Decimal(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let value = if let (Some(v1), Some(v2)) = (left, right) { + Some(v1 / v2) + } else { + None + }; + DataValue::Decimal(value) + } +} +#[typetag::serde] +impl BinaryEvaluator for DecimalGtBinaryEvaluator { + fn binary_eval(&self, left: &DataValue, right: &DataValue) -> DataValue { + let left = match left { + DataValue::Decimal(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let right = match right { + DataValue::Decimal(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let value = if let (Some(v1), Some(v2)) = (left, right) { + Some(v1 > v2) + } else { + None + }; + DataValue::Boolean(value) + } +} +#[typetag::serde] +impl BinaryEvaluator for DecimalGtEqBinaryEvaluator { + fn binary_eval(&self, left: &DataValue, right: &DataValue) -> DataValue { + let left = match left { + DataValue::Decimal(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let right = match right { + DataValue::Decimal(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let value = if let (Some(v1), Some(v2)) = (left, right) { + Some(v1 >= v2) + } else { + None + }; + DataValue::Boolean(value) + } +} +#[typetag::serde] +impl BinaryEvaluator for DecimalLtBinaryEvaluator { + fn binary_eval(&self, left: &DataValue, right: &DataValue) -> DataValue { + let left = match left { + DataValue::Decimal(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let right = match right { + DataValue::Decimal(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let value = if let (Some(v1), Some(v2)) = (left, right) { + Some(v1 < v2) + } else { + None + }; + DataValue::Boolean(value) + } +} +#[typetag::serde] +impl BinaryEvaluator for DecimalLtEqBinaryEvaluator { + fn binary_eval(&self, left: &DataValue, right: &DataValue) -> DataValue { + let left = match left { + DataValue::Decimal(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let right = match right { + DataValue::Decimal(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let value = if let (Some(v1), Some(v2)) = (left, right) { + Some(v1 <= v2) + } else { + None + }; + DataValue::Boolean(value) + } +} +#[typetag::serde] +impl BinaryEvaluator for DecimalEqBinaryEvaluator { + fn binary_eval(&self, left: &DataValue, right: &DataValue) -> DataValue { + let left = match left { + DataValue::Decimal(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let right = match right { + DataValue::Decimal(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let value = if let (Some(v1), Some(v2)) = (left, right) { + Some(v1 == v2) + } else { + None + }; + DataValue::Boolean(value) + } +} +#[typetag::serde] +impl BinaryEvaluator for DecimalNotEqBinaryEvaluator { + fn binary_eval(&self, left: &DataValue, right: &DataValue) -> DataValue { + let left = match left { + DataValue::Decimal(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let right = match right { + DataValue::Decimal(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let value = if let (Some(v1), Some(v2)) = (left, right) { + Some(v1 != v2) + } else { + None + }; + DataValue::Boolean(value) + } +} diff --git a/src/types/evaluator/float32.rs b/src/types/evaluator/float32.rs new file mode 100644 index 00000000..c36fe516 --- /dev/null +++ b/src/types/evaluator/float32.rs @@ -0,0 +1,8 @@ +use crate::numeric_binary_evaluator_definition; +use crate::types::evaluator::BinaryEvaluator; +use crate::types::evaluator::DataValue; +use paste::paste; +use serde::{Deserialize, Serialize}; +use std::hint; + +numeric_binary_evaluator_definition!(Float32, DataValue::Float32); diff --git a/src/types/evaluator/float64.rs b/src/types/evaluator/float64.rs new file mode 100644 index 00000000..cf352316 --- /dev/null +++ b/src/types/evaluator/float64.rs @@ -0,0 +1,8 @@ +use crate::numeric_binary_evaluator_definition; +use crate::types::evaluator::BinaryEvaluator; +use crate::types::evaluator::DataValue; +use paste::paste; +use serde::{Deserialize, Serialize}; +use std::hint; + +numeric_binary_evaluator_definition!(Float64, DataValue::Float64); diff --git a/src/types/evaluator/int16.rs b/src/types/evaluator/int16.rs new file mode 100644 index 00000000..1e9fe272 --- /dev/null +++ b/src/types/evaluator/int16.rs @@ -0,0 +1,8 @@ +use crate::numeric_binary_evaluator_definition; +use crate::types::evaluator::BinaryEvaluator; +use crate::types::evaluator::DataValue; +use paste::paste; +use serde::{Deserialize, Serialize}; +use std::hint; + +numeric_binary_evaluator_definition!(Int16, DataValue::Int16); diff --git a/src/types/evaluator/int32.rs b/src/types/evaluator/int32.rs new file mode 100644 index 00000000..9181b991 --- /dev/null +++ b/src/types/evaluator/int32.rs @@ -0,0 +1,8 @@ +use crate::numeric_binary_evaluator_definition; +use crate::types::evaluator::BinaryEvaluator; +use crate::types::evaluator::DataValue; +use paste::paste; +use serde::{Deserialize, Serialize}; +use std::hint; + +numeric_binary_evaluator_definition!(Int32, DataValue::Int32); diff --git a/src/types/evaluator/int64.rs b/src/types/evaluator/int64.rs new file mode 100644 index 00000000..f5df3bb5 --- /dev/null +++ b/src/types/evaluator/int64.rs @@ -0,0 +1,8 @@ +use crate::numeric_binary_evaluator_definition; +use crate::types::evaluator::BinaryEvaluator; +use crate::types::evaluator::DataValue; +use paste::paste; +use serde::{Deserialize, Serialize}; +use std::hint; + +numeric_binary_evaluator_definition!(Int64, DataValue::Int64); diff --git a/src/types/evaluator/int8.rs b/src/types/evaluator/int8.rs new file mode 100644 index 00000000..4e7daff0 --- /dev/null +++ b/src/types/evaluator/int8.rs @@ -0,0 +1,8 @@ +use crate::numeric_binary_evaluator_definition; +use crate::types::evaluator::BinaryEvaluator; +use crate::types::evaluator::DataValue; +use paste::paste; +use serde::{Deserialize, Serialize}; +use std::hint; + +numeric_binary_evaluator_definition!(Int8, DataValue::Int8); diff --git a/src/types/evaluator/mod.rs b/src/types/evaluator/mod.rs new file mode 100644 index 00000000..34a7a4ac --- /dev/null +++ b/src/types/evaluator/mod.rs @@ -0,0 +1,1035 @@ +pub mod boolean; +pub mod date; +pub mod datetime; +pub mod decimal; +pub mod float32; +pub mod float64; +pub mod int16; +pub mod int32; +pub mod int64; +pub mod int8; +pub mod null; +pub mod time; +pub mod tuple; +pub mod uint16; +pub mod uint32; +pub mod uint64; +pub mod uint8; +pub mod utf8; + +use crate::errors::DatabaseError; +use crate::expression::BinaryOperator; +use crate::types::evaluator::boolean::*; +use crate::types::evaluator::date::*; +use crate::types::evaluator::datetime::*; +use crate::types::evaluator::decimal::*; +use crate::types::evaluator::float32::*; +use crate::types::evaluator::float64::*; +use crate::types::evaluator::int16::*; +use crate::types::evaluator::int32::*; +use crate::types::evaluator::int64::*; +use crate::types::evaluator::int8::*; +use crate::types::evaluator::null::NullBinaryEvaluator; +use crate::types::evaluator::time::*; +use crate::types::evaluator::tuple::{ + TupleEqBinaryEvaluator, TupleGtBinaryEvaluator, TupleGtEqBinaryEvaluator, + TupleLtBinaryEvaluator, TupleLtEqBinaryEvaluator, TupleNotEqBinaryEvaluator, +}; +use crate::types::evaluator::uint16::*; +use crate::types::evaluator::uint32::*; +use crate::types::evaluator::uint64::*; +use crate::types::evaluator::uint8::*; +use crate::types::evaluator::utf8::*; +use crate::types::evaluator::utf8::{ + Utf8EqBinaryEvaluator, Utf8GtBinaryEvaluator, Utf8GtEqBinaryEvaluator, Utf8LtBinaryEvaluator, + Utf8LtEqBinaryEvaluator, Utf8NotEqBinaryEvaluator, Utf8StringConcatBinaryEvaluator, +}; +use crate::types::value::DataValue; +use crate::types::LogicalType; +use paste::paste; +use serde::{Deserialize, Serialize}; +use std::fmt::Debug; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; + +#[typetag::serde(tag = "type")] +pub trait BinaryEvaluator: Send + Sync + Debug { + fn binary_eval(&self, left: &DataValue, right: &DataValue) -> DataValue; +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct BinaryEvaluatorBox(pub Arc); + +impl PartialEq for BinaryEvaluatorBox { + fn eq(&self, _: &Self) -> bool { + // FIXME + true + } +} + +impl Eq for BinaryEvaluatorBox {} + +impl Hash for BinaryEvaluatorBox { + fn hash(&self, state: &mut H) { + state.write_i8(42) + } +} + +macro_rules! numeric_binary_evaluator { + ($value_type:ident, $op:expr, $ty:expr) => { + paste! { + match $op { + BinaryOperator::Plus => Ok(BinaryEvaluatorBox(Arc::new([<$value_type PlusBinaryEvaluator>]))), + BinaryOperator::Minus => Ok(BinaryEvaluatorBox(Arc::new([<$value_type MinusBinaryEvaluator>]))), + BinaryOperator::Multiply => Ok(BinaryEvaluatorBox(Arc::new([<$value_type MultiplyBinaryEvaluator>]))), + BinaryOperator::Divide => Ok(BinaryEvaluatorBox(Arc::new([<$value_type DivideBinaryEvaluator>]))), + BinaryOperator::Gt => Ok(BinaryEvaluatorBox(Arc::new([<$value_type GtBinaryEvaluator>]))), + BinaryOperator::GtEq => Ok(BinaryEvaluatorBox(Arc::new([<$value_type GtEqBinaryEvaluator>]))), + BinaryOperator::Lt => Ok(BinaryEvaluatorBox(Arc::new([<$value_type LtBinaryEvaluator>]))), + BinaryOperator::LtEq => Ok(BinaryEvaluatorBox(Arc::new([<$value_type LtEqBinaryEvaluator>]))), + BinaryOperator::Eq => Ok(BinaryEvaluatorBox(Arc::new([<$value_type EqBinaryEvaluator>]))), + BinaryOperator::NotEq => Ok(BinaryEvaluatorBox(Arc::new([<$value_type NotEqBinaryEvaluator>]))), + _ => { + return Err(DatabaseError::UnsupportedBinaryOperator( + $ty, + $op, + )) + } + } + } + }; +} + +pub struct EvaluatorFactory; + +impl EvaluatorFactory { + pub fn binary_create( + ty: LogicalType, + op: BinaryOperator, + ) -> Result { + match ty { + LogicalType::Tinyint => numeric_binary_evaluator!(Int8, op, LogicalType::Tinyint), + LogicalType::Smallint => numeric_binary_evaluator!(Int16, op, LogicalType::Smallint), + LogicalType::Integer => numeric_binary_evaluator!(Int32, op, LogicalType::Integer), + LogicalType::Bigint => numeric_binary_evaluator!(Int64, op, LogicalType::Bigint), + LogicalType::UTinyint => numeric_binary_evaluator!(UInt8, op, LogicalType::UTinyint), + LogicalType::USmallint => numeric_binary_evaluator!(UInt16, op, LogicalType::USmallint), + LogicalType::UInteger => numeric_binary_evaluator!(UInt32, op, LogicalType::UInteger), + LogicalType::UBigint => numeric_binary_evaluator!(UInt64, op, LogicalType::UBigint), + LogicalType::Float => numeric_binary_evaluator!(Float32, op, LogicalType::Float), + LogicalType::Double => numeric_binary_evaluator!(Float64, op, LogicalType::Double), + LogicalType::Date => numeric_binary_evaluator!(Date, op, LogicalType::Date), + LogicalType::DateTime => numeric_binary_evaluator!(DateTime, op, LogicalType::DateTime), + LogicalType::Time => numeric_binary_evaluator!(Time, op, LogicalType::Time), + LogicalType::Decimal(_, _) => numeric_binary_evaluator!(Decimal, op, ty), + LogicalType::Boolean => match op { + BinaryOperator::And => Ok(BinaryEvaluatorBox(Arc::new(BooleanAndBinaryEvaluator))), + BinaryOperator::Or => Ok(BinaryEvaluatorBox(Arc::new(BooleanOrBinaryEvaluator))), + BinaryOperator::Eq => Ok(BinaryEvaluatorBox(Arc::new(BooleanEqBinaryEvaluator))), + BinaryOperator::NotEq => { + Ok(BinaryEvaluatorBox(Arc::new(BooleanNotEqBinaryEvaluator))) + } + _ => Err(DatabaseError::UnsupportedBinaryOperator( + LogicalType::Boolean, + op, + )), + }, + LogicalType::Varchar(_, _) | LogicalType::Char(_, _) => match op { + BinaryOperator::Gt => Ok(BinaryEvaluatorBox(Arc::new(Utf8GtBinaryEvaluator))), + BinaryOperator::Lt => Ok(BinaryEvaluatorBox(Arc::new(Utf8LtBinaryEvaluator))), + BinaryOperator::GtEq => Ok(BinaryEvaluatorBox(Arc::new(Utf8GtEqBinaryEvaluator))), + BinaryOperator::LtEq => Ok(BinaryEvaluatorBox(Arc::new(Utf8LtEqBinaryEvaluator))), + BinaryOperator::Eq => Ok(BinaryEvaluatorBox(Arc::new(Utf8EqBinaryEvaluator))), + BinaryOperator::NotEq => Ok(BinaryEvaluatorBox(Arc::new(Utf8NotEqBinaryEvaluator))), + BinaryOperator::StringConcat => Ok(BinaryEvaluatorBox(Arc::new( + Utf8StringConcatBinaryEvaluator, + ))), + BinaryOperator::Like(escape_char) => { + Ok(BinaryEvaluatorBox(Arc::new(Utf8LikeBinaryEvaluator { + escape_char, + }))) + } + BinaryOperator::NotLike(escape_char) => { + Ok(BinaryEvaluatorBox(Arc::new(Utf8NotLikeBinaryEvaluator { + escape_char, + }))) + } + _ => Err(DatabaseError::UnsupportedBinaryOperator(ty, op)), + }, + LogicalType::SqlNull => Ok(BinaryEvaluatorBox(Arc::new(NullBinaryEvaluator))), + LogicalType::Invalid => Err(DatabaseError::InvalidType), + LogicalType::Tuple => match op { + BinaryOperator::Eq => Ok(BinaryEvaluatorBox(Arc::new(TupleEqBinaryEvaluator))), + BinaryOperator::NotEq => { + Ok(BinaryEvaluatorBox(Arc::new(TupleNotEqBinaryEvaluator))) + } + BinaryOperator::Gt => Ok(BinaryEvaluatorBox(Arc::new(TupleGtBinaryEvaluator))), + BinaryOperator::GtEq => Ok(BinaryEvaluatorBox(Arc::new(TupleGtEqBinaryEvaluator))), + BinaryOperator::Lt => Ok(BinaryEvaluatorBox(Arc::new(TupleLtBinaryEvaluator))), + BinaryOperator::LtEq => Ok(BinaryEvaluatorBox(Arc::new(TupleLtEqBinaryEvaluator))), + _ => Err(DatabaseError::UnsupportedBinaryOperator(ty, op)), + }, + } + } +} + +#[macro_export] +macro_rules! numeric_binary_evaluator_definition { + ($value_type:ident, $compute_type:path) => { + paste! { + #[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] + pub struct [<$value_type PlusBinaryEvaluator>]; + #[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] + pub struct [<$value_type MinusBinaryEvaluator>]; + #[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] + pub struct [<$value_type MultiplyBinaryEvaluator>]; + #[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] + pub struct [<$value_type DivideBinaryEvaluator>]; + #[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] + pub struct [<$value_type GtBinaryEvaluator>]; + #[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] + pub struct [<$value_type GtEqBinaryEvaluator>]; + #[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] + pub struct [<$value_type LtBinaryEvaluator>]; + #[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] + pub struct [<$value_type LtEqBinaryEvaluator>]; + #[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] + pub struct [<$value_type EqBinaryEvaluator>]; + #[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] + pub struct [<$value_type NotEqBinaryEvaluator>]; + + #[typetag::serde] + impl BinaryEvaluator for [<$value_type PlusBinaryEvaluator>] { + fn binary_eval(&self, left: &DataValue, right: &DataValue) -> DataValue { + let left = match left { + $compute_type(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let right = match right { + $compute_type(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let value = if let (Some(v1), Some(v2)) = (left, right) { + Some(v1 + v2) + } else { + None + }; + $compute_type(value) + } + } + #[typetag::serde] + impl BinaryEvaluator for [<$value_type MinusBinaryEvaluator>] { + fn binary_eval(&self, left: &DataValue, right: &DataValue) -> DataValue { + let left = match left { + $compute_type(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let right = match right { + $compute_type(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let value = if let (Some(v1), Some(v2)) = (left, right) { + Some(v1 - v2) + } else { + None + }; + $compute_type(value) + } + } + #[typetag::serde] + impl BinaryEvaluator for [<$value_type MultiplyBinaryEvaluator>] { + fn binary_eval(&self, left: &DataValue, right: &DataValue) -> DataValue { + let left = match left { + $compute_type(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let right = match right { + $compute_type(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let value = if let (Some(v1), Some(v2)) = (left, right) { + Some(v1 * v2) + } else { + None + }; + $compute_type(value) + } + } + #[typetag::serde] + impl BinaryEvaluator for [<$value_type DivideBinaryEvaluator>] { + fn binary_eval(&self, left: &DataValue, right: &DataValue) -> DataValue { + let left = match left { + $compute_type(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let right = match right { + $compute_type(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let value = if let (Some(v1), Some(v2)) = (left, right) { + Some(*v1 as f64 / *v2 as f64) + } else { + None + }; + DataValue::Float64(value) + } + } + #[typetag::serde] + impl BinaryEvaluator for [<$value_type GtBinaryEvaluator>] { + fn binary_eval(&self, left: &DataValue, right: &DataValue) -> DataValue { + let left = match left { + $compute_type(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let right = match right { + $compute_type(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let value = if let (Some(v1), Some(v2)) = (left, right) { + Some(v1 > v2) + } else { + None + }; + DataValue::Boolean(value) + } + } + #[typetag::serde] + impl BinaryEvaluator for [<$value_type GtEqBinaryEvaluator>] { + fn binary_eval(&self, left: &DataValue, right: &DataValue) -> DataValue { + let left = match left { + $compute_type(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let right = match right { + $compute_type(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let value = if let (Some(v1), Some(v2)) = (left, right) { + Some(v1 >= v2) + } else { + None + }; + DataValue::Boolean(value) + } + } + #[typetag::serde] + impl BinaryEvaluator for [<$value_type LtBinaryEvaluator>] { + fn binary_eval(&self, left: &DataValue, right: &DataValue) -> DataValue { + let left = match left { + $compute_type(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let right = match right { + $compute_type(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let value = if let (Some(v1), Some(v2)) = (left, right) { + Some(v1 < v2) + } else { + None + }; + DataValue::Boolean(value) + } + } + #[typetag::serde] + impl BinaryEvaluator for [<$value_type LtEqBinaryEvaluator>] { + fn binary_eval(&self, left: &DataValue, right: &DataValue) -> DataValue { + let left = match left { + $compute_type(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let right = match right { + $compute_type(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let value = if let (Some(v1), Some(v2)) = (left, right) { + Some(v1 <= v2) + } else { + None + }; + DataValue::Boolean(value) + } + } + #[typetag::serde] + impl BinaryEvaluator for [<$value_type EqBinaryEvaluator>] { + fn binary_eval(&self, left: &DataValue, right: &DataValue) -> DataValue { + let left = match left { + $compute_type(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let right = match right { + $compute_type(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let value = if let (Some(v1), Some(v2)) = (left, right) { + Some(v1 == v2) + } else { + None + }; + DataValue::Boolean(value) + } + } + #[typetag::serde] + impl BinaryEvaluator for [<$value_type NotEqBinaryEvaluator>] { + fn binary_eval(&self, left: &DataValue, right: &DataValue) -> DataValue { + let left = match left { + $compute_type(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let right = match right { + $compute_type(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let value = if let (Some(v1), Some(v2)) = (left, right) { + Some(v1 != v2) + } else { + None + }; + DataValue::Boolean(value) + } + } + } + }; +} + +#[cfg(test)] +mod test { + use crate::errors::DatabaseError; + use crate::expression::BinaryOperator; + use crate::types::evaluator::EvaluatorFactory; + use crate::types::value::{DataValue, Utf8Type}; + use crate::types::LogicalType; + use sqlparser::ast::CharLengthUnits; + + #[test] + fn test_binary_op_arithmetic_plus() -> Result<(), DatabaseError> { + let plus_evaluator = + EvaluatorFactory::binary_create(LogicalType::Integer, BinaryOperator::Plus)?; + let plus_i32_1 = plus_evaluator + .0 + .binary_eval(&DataValue::Int32(None), &DataValue::Int32(None)); + let plus_i32_2 = plus_evaluator + .0 + .binary_eval(&DataValue::Int32(Some(1)), &DataValue::Int32(None)); + let plus_i32_3 = plus_evaluator + .0 + .binary_eval(&DataValue::Int32(None), &DataValue::Int32(Some(1))); + let plus_i32_4 = plus_evaluator + .0 + .binary_eval(&DataValue::Int32(Some(1)), &DataValue::Int32(Some(1))); + + assert_eq!(plus_i32_1, plus_i32_2); + assert_eq!(plus_i32_2, plus_i32_3); + assert_eq!(plus_i32_4, DataValue::Int32(Some(2))); + + let plus_evaluator = + EvaluatorFactory::binary_create(LogicalType::Bigint, BinaryOperator::Plus)?; + let plus_i64_1 = plus_evaluator + .0 + .binary_eval(&DataValue::Int64(None), &DataValue::Int64(None)); + let plus_i64_2 = plus_evaluator + .0 + .binary_eval(&DataValue::Int64(Some(1)), &DataValue::Int64(None)); + let plus_i64_3 = plus_evaluator + .0 + .binary_eval(&DataValue::Int64(None), &DataValue::Int64(Some(1))); + let plus_i64_4 = plus_evaluator + .0 + .binary_eval(&DataValue::Int64(Some(1)), &DataValue::Int64(Some(1))); + + assert_eq!(plus_i64_1, plus_i64_2); + assert_eq!(plus_i64_2, plus_i64_3); + assert_eq!(plus_i64_4, DataValue::Int64(Some(2))); + + let plus_evaluator = + EvaluatorFactory::binary_create(LogicalType::Double, BinaryOperator::Plus)?; + let plus_f64_1 = plus_evaluator + .0 + .binary_eval(&DataValue::Float64(None), &DataValue::Float64(None)); + let plus_f64_2 = plus_evaluator + .0 + .binary_eval(&DataValue::Float64(Some(1.0)), &DataValue::Float64(None)); + let plus_f64_3 = plus_evaluator + .0 + .binary_eval(&DataValue::Float64(None), &DataValue::Float64(Some(1.0))); + let plus_f64_4 = plus_evaluator.0.binary_eval( + &DataValue::Float64(Some(1.0)), + &DataValue::Float64(Some(1.0)), + ); + + assert_eq!(plus_f64_1, plus_f64_2); + assert_eq!(plus_f64_2, plus_f64_3); + assert_eq!(plus_f64_4, DataValue::Float64(Some(2.0))); + + Ok(()) + } + + #[test] + fn test_binary_op_arithmetic_minus() -> Result<(), DatabaseError> { + let minus_evaluator = + EvaluatorFactory::binary_create(LogicalType::Integer, BinaryOperator::Minus)?; + let minus_i32_1 = minus_evaluator + .0 + .binary_eval(&DataValue::Int32(None), &DataValue::Int32(None)); + let minus_i32_2 = minus_evaluator + .0 + .binary_eval(&DataValue::Int32(Some(1)), &DataValue::Int32(None)); + let minus_i32_3 = minus_evaluator + .0 + .binary_eval(&DataValue::Int32(None), &DataValue::Int32(Some(1))); + let minus_i32_4 = minus_evaluator + .0 + .binary_eval(&DataValue::Int32(Some(1)), &DataValue::Int32(Some(1))); + + assert_eq!(minus_i32_1, minus_i32_2); + assert_eq!(minus_i32_2, minus_i32_3); + assert_eq!(minus_i32_4, DataValue::Int32(Some(0))); + + let minus_evaluator = + EvaluatorFactory::binary_create(LogicalType::Bigint, BinaryOperator::Minus)?; + let minus_i64_1 = minus_evaluator + .0 + .binary_eval(&DataValue::Int64(None), &DataValue::Int64(None)); + let minus_i64_2 = minus_evaluator + .0 + .binary_eval(&DataValue::Int64(Some(1)), &DataValue::Int64(None)); + let minus_i64_3 = minus_evaluator + .0 + .binary_eval(&DataValue::Int64(None), &DataValue::Int64(Some(1))); + let minus_i64_4 = minus_evaluator + .0 + .binary_eval(&DataValue::Int64(Some(1)), &DataValue::Int64(Some(1))); + + assert_eq!(minus_i64_1, minus_i64_2); + assert_eq!(minus_i64_2, minus_i64_3); + assert_eq!(minus_i64_4, DataValue::Int64(Some(0))); + + let minus_evaluator = + EvaluatorFactory::binary_create(LogicalType::Double, BinaryOperator::Minus)?; + let minus_f64_1 = minus_evaluator + .0 + .binary_eval(&DataValue::Float64(None), &DataValue::Float64(None)); + let minus_f64_2 = minus_evaluator + .0 + .binary_eval(&DataValue::Float64(Some(1.0)), &DataValue::Float64(None)); + let minus_f64_3 = minus_evaluator + .0 + .binary_eval(&DataValue::Float64(None), &DataValue::Float64(Some(1.0))); + let minus_f64_4 = minus_evaluator.0.binary_eval( + &DataValue::Float64(Some(1.0)), + &DataValue::Float64(Some(1.0)), + ); + + assert_eq!(minus_f64_1, minus_f64_2); + assert_eq!(minus_f64_2, minus_f64_3); + assert_eq!(minus_f64_4, DataValue::Float64(Some(0.0))); + + Ok(()) + } + + #[test] + fn test_binary_op_arithmetic_multiply() -> Result<(), DatabaseError> { + let multiply_evaluator = + EvaluatorFactory::binary_create(LogicalType::Integer, BinaryOperator::Multiply)?; + let multiply_i32_1 = multiply_evaluator + .0 + .binary_eval(&DataValue::Int32(None), &DataValue::Int32(None)); + let multiply_i32_2 = multiply_evaluator + .0 + .binary_eval(&DataValue::Int32(Some(1)), &DataValue::Int32(None)); + let multiply_i32_3 = multiply_evaluator + .0 + .binary_eval(&DataValue::Int32(None), &DataValue::Int32(Some(1))); + let multiply_i32_4 = multiply_evaluator + .0 + .binary_eval(&DataValue::Int32(Some(1)), &DataValue::Int32(Some(1))); + + assert_eq!(multiply_i32_1, multiply_i32_2); + assert_eq!(multiply_i32_2, multiply_i32_3); + assert_eq!(multiply_i32_4, DataValue::Int32(Some(1))); + + let multiply_evaluator = + EvaluatorFactory::binary_create(LogicalType::Bigint, BinaryOperator::Multiply)?; + let multiply_i64_1 = multiply_evaluator + .0 + .binary_eval(&DataValue::Int64(None), &DataValue::Int64(None)); + let multiply_i64_2 = multiply_evaluator + .0 + .binary_eval(&DataValue::Int64(Some(1)), &DataValue::Int64(None)); + let multiply_i64_3 = multiply_evaluator + .0 + .binary_eval(&DataValue::Int64(None), &DataValue::Int64(Some(1))); + let multiply_i64_4 = multiply_evaluator + .0 + .binary_eval(&DataValue::Int64(Some(1)), &DataValue::Int64(Some(1))); + + assert_eq!(multiply_i64_1, multiply_i64_2); + assert_eq!(multiply_i64_2, multiply_i64_3); + assert_eq!(multiply_i64_4, DataValue::Int64(Some(1))); + + let multiply_evaluator = + EvaluatorFactory::binary_create(LogicalType::Double, BinaryOperator::Multiply)?; + let multiply_f64_1 = multiply_evaluator + .0 + .binary_eval(&DataValue::Float64(None), &DataValue::Float64(None)); + let multiply_f64_2 = multiply_evaluator + .0 + .binary_eval(&DataValue::Float64(Some(1.0)), &DataValue::Float64(None)); + let multiply_f64_3 = multiply_evaluator + .0 + .binary_eval(&DataValue::Float64(None), &DataValue::Float64(Some(1.0))); + let multiply_f64_4 = multiply_evaluator.0.binary_eval( + &DataValue::Float64(Some(1.0)), + &DataValue::Float64(Some(1.0)), + ); + + assert_eq!(multiply_f64_1, multiply_f64_2); + assert_eq!(multiply_f64_2, multiply_f64_3); + assert_eq!(multiply_f64_4, DataValue::Float64(Some(1.0))); + + Ok(()) + } + + #[test] + fn test_binary_op_arithmetic_divide() -> Result<(), DatabaseError> { + let divide_evaluator = + EvaluatorFactory::binary_create(LogicalType::Integer, BinaryOperator::Divide)?; + let divide_i32_1 = divide_evaluator + .0 + .binary_eval(&DataValue::Int32(None), &DataValue::Int32(None)); + let divide_i32_2 = divide_evaluator + .0 + .binary_eval(&DataValue::Int32(Some(1)), &DataValue::Int32(None)); + let divide_i32_3 = divide_evaluator + .0 + .binary_eval(&DataValue::Int32(None), &DataValue::Int32(Some(1))); + let divide_i32_4 = divide_evaluator + .0 + .binary_eval(&DataValue::Int32(Some(1)), &DataValue::Int32(Some(1))); + + assert_eq!(divide_i32_1, divide_i32_2); + assert_eq!(divide_i32_2, divide_i32_3); + assert_eq!(divide_i32_4, DataValue::Float64(Some(1.0))); + + let divide_evaluator = + EvaluatorFactory::binary_create(LogicalType::Bigint, BinaryOperator::Divide)?; + let divide_i64_1 = divide_evaluator + .0 + .binary_eval(&DataValue::Int64(None), &DataValue::Int64(None)); + let divide_i64_2 = divide_evaluator + .0 + .binary_eval(&DataValue::Int64(Some(1)), &DataValue::Int64(None)); + let divide_i64_3 = divide_evaluator + .0 + .binary_eval(&DataValue::Int64(None), &DataValue::Int64(Some(1))); + let divide_i64_4 = divide_evaluator + .0 + .binary_eval(&DataValue::Int64(Some(1)), &DataValue::Int64(Some(1))); + + assert_eq!(divide_i64_1, divide_i64_2); + assert_eq!(divide_i64_2, divide_i64_3); + assert_eq!(divide_i64_4, DataValue::Float64(Some(1.0))); + + let divide_evaluator = + EvaluatorFactory::binary_create(LogicalType::Double, BinaryOperator::Divide)?; + let divide_f64_1 = divide_evaluator + .0 + .binary_eval(&DataValue::Float64(None), &DataValue::Float64(None)); + let divide_f64_2 = divide_evaluator + .0 + .binary_eval(&DataValue::Float64(Some(1.0)), &DataValue::Float64(None)); + let divide_f64_3 = divide_evaluator + .0 + .binary_eval(&DataValue::Float64(None), &DataValue::Float64(Some(1.0))); + let divide_f64_4 = divide_evaluator.0.binary_eval( + &DataValue::Float64(Some(1.0)), + &DataValue::Float64(Some(1.0)), + ); + + assert_eq!(divide_f64_1, divide_f64_2); + assert_eq!(divide_f64_2, divide_f64_3); + assert_eq!(divide_f64_4, DataValue::Float64(Some(1.0))); + + Ok(()) + } + + #[test] + fn test_binary_op_i32_compare() -> Result<(), DatabaseError> { + let evaluator = EvaluatorFactory::binary_create(LogicalType::Integer, BinaryOperator::Gt)?; + assert_eq!( + evaluator + .0 + .binary_eval(&DataValue::Int32(Some(1)), &DataValue::Int32(Some(0)),), + DataValue::Boolean(Some(true)) + ); + let evaluator = EvaluatorFactory::binary_create(LogicalType::Integer, BinaryOperator::Lt)?; + assert_eq!( + evaluator + .0 + .binary_eval(&DataValue::Int32(Some(1)), &DataValue::Int32(Some(0)),), + DataValue::Boolean(Some(false)) + ); + let evaluator = + EvaluatorFactory::binary_create(LogicalType::Integer, BinaryOperator::GtEq)?; + assert_eq!( + evaluator + .0 + .binary_eval(&DataValue::Int32(Some(1)), &DataValue::Int32(Some(1)),), + DataValue::Boolean(Some(true)) + ); + let evaluator = + EvaluatorFactory::binary_create(LogicalType::Integer, BinaryOperator::LtEq)?; + assert_eq!( + evaluator + .0 + .binary_eval(&DataValue::Int32(Some(1)), &DataValue::Int32(Some(1)),), + DataValue::Boolean(Some(true)) + ); + let evaluator = + EvaluatorFactory::binary_create(LogicalType::Integer, BinaryOperator::NotEq)?; + assert_eq!( + evaluator + .0 + .binary_eval(&DataValue::Int32(Some(1)), &DataValue::Int32(Some(1)),), + DataValue::Boolean(Some(false)) + ); + let evaluator = EvaluatorFactory::binary_create(LogicalType::Integer, BinaryOperator::Eq)?; + assert_eq!( + evaluator + .0 + .binary_eval(&DataValue::Int32(Some(1)), &DataValue::Int32(Some(1)),), + DataValue::Boolean(Some(true)) + ); + let evaluator = EvaluatorFactory::binary_create(LogicalType::Integer, BinaryOperator::Gt)?; + assert_eq!( + evaluator + .0 + .binary_eval(&DataValue::Int32(None), &DataValue::Int32(Some(0)),), + DataValue::Boolean(None) + ); + let evaluator = EvaluatorFactory::binary_create(LogicalType::Integer, BinaryOperator::Lt)?; + assert_eq!( + evaluator + .0 + .binary_eval(&DataValue::Int32(None), &DataValue::Int32(Some(0)),), + DataValue::Boolean(None) + ); + let evaluator = + EvaluatorFactory::binary_create(LogicalType::Integer, BinaryOperator::GtEq)?; + assert_eq!( + evaluator + .0 + .binary_eval(&DataValue::Int32(None), &DataValue::Int32(Some(1)),), + DataValue::Boolean(None) + ); + let evaluator = + EvaluatorFactory::binary_create(LogicalType::Integer, BinaryOperator::LtEq)?; + assert_eq!( + evaluator + .0 + .binary_eval(&DataValue::Int32(None), &DataValue::Int32(Some(1)),), + DataValue::Boolean(None) + ); + let evaluator = + EvaluatorFactory::binary_create(LogicalType::Integer, BinaryOperator::NotEq)?; + assert_eq!( + evaluator + .0 + .binary_eval(&DataValue::Int32(None), &DataValue::Int32(Some(1)),), + DataValue::Boolean(None) + ); + let evaluator = EvaluatorFactory::binary_create(LogicalType::Integer, BinaryOperator::Eq)?; + assert_eq!( + evaluator + .0 + .binary_eval(&DataValue::Int32(None), &DataValue::Int32(Some(1)),), + DataValue::Boolean(None) + ); + let evaluator = EvaluatorFactory::binary_create(LogicalType::Integer, BinaryOperator::Eq)?; + assert_eq!( + evaluator + .0 + .binary_eval(&DataValue::Int32(None), &DataValue::Int32(None),), + DataValue::Boolean(None) + ); + + Ok(()) + } + + #[test] + fn test_binary_op_bool_compare() -> Result<(), DatabaseError> { + let evaluator = EvaluatorFactory::binary_create(LogicalType::Boolean, BinaryOperator::And)?; + assert_eq!( + evaluator.0.binary_eval( + &DataValue::Boolean(Some(true)), + &DataValue::Boolean(Some(true)), + ), + DataValue::Boolean(Some(true)) + ); + assert_eq!( + evaluator.0.binary_eval( + &DataValue::Boolean(Some(false)), + &DataValue::Boolean(Some(true)), + ), + DataValue::Boolean(Some(false)) + ); + assert_eq!( + evaluator.0.binary_eval( + &DataValue::Boolean(Some(false)), + &DataValue::Boolean(Some(false)), + ), + DataValue::Boolean(Some(false)) + ); + assert_eq!( + evaluator + .0 + .binary_eval(&DataValue::Boolean(None), &DataValue::Boolean(Some(true)),), + DataValue::Boolean(None) + ); + let evaluator = EvaluatorFactory::binary_create(LogicalType::Boolean, BinaryOperator::Or)?; + assert_eq!( + evaluator.0.binary_eval( + &DataValue::Boolean(Some(true)), + &DataValue::Boolean(Some(true)), + ), + DataValue::Boolean(Some(true)) + ); + assert_eq!( + evaluator.0.binary_eval( + &DataValue::Boolean(Some(false)), + &DataValue::Boolean(Some(true)), + ), + DataValue::Boolean(Some(true)) + ); + assert_eq!( + evaluator.0.binary_eval( + &DataValue::Boolean(Some(false)), + &DataValue::Boolean(Some(false)), + ), + DataValue::Boolean(Some(false)) + ); + assert_eq!( + evaluator + .0 + .binary_eval(&DataValue::Boolean(None), &DataValue::Boolean(Some(true)),), + DataValue::Boolean(Some(true)) + ); + + Ok(()) + } + + #[test] + fn test_binary_op_utf8_compare() -> Result<(), DatabaseError> { + let evaluator = EvaluatorFactory::binary_create( + LogicalType::Varchar(None, CharLengthUnits::Characters), + BinaryOperator::Gt, + )?; + assert_eq!( + evaluator.0.binary_eval( + &DataValue::Utf8 { + value: Some("a".to_string()), + ty: Utf8Type::Variable(None), + unit: CharLengthUnits::Characters, + }, + &DataValue::Utf8 { + value: Some("b".to_string()), + ty: Utf8Type::Variable(None), + unit: CharLengthUnits::Characters, + }, + ), + DataValue::Boolean(Some(false)) + ); + let evaluator = EvaluatorFactory::binary_create( + LogicalType::Varchar(None, CharLengthUnits::Characters), + BinaryOperator::Lt, + )?; + assert_eq!( + evaluator.0.binary_eval( + &DataValue::Utf8 { + value: Some("a".to_string()), + ty: Utf8Type::Variable(None), + unit: CharLengthUnits::Characters, + }, + &DataValue::Utf8 { + value: Some("b".to_string()), + ty: Utf8Type::Variable(None), + unit: CharLengthUnits::Characters, + }, + ), + DataValue::Boolean(Some(true)) + ); + let evaluator = EvaluatorFactory::binary_create( + LogicalType::Varchar(None, CharLengthUnits::Characters), + BinaryOperator::GtEq, + )?; + assert_eq!( + evaluator.0.binary_eval( + &DataValue::Utf8 { + value: Some("a".to_string()), + ty: Utf8Type::Variable(None), + unit: CharLengthUnits::Characters, + }, + &DataValue::Utf8 { + value: Some("a".to_string()), + ty: Utf8Type::Variable(None), + unit: CharLengthUnits::Characters, + }, + ), + DataValue::Boolean(Some(true)) + ); + let evaluator = EvaluatorFactory::binary_create( + LogicalType::Varchar(None, CharLengthUnits::Characters), + BinaryOperator::LtEq, + )?; + assert_eq!( + evaluator.0.binary_eval( + &DataValue::Utf8 { + value: Some("a".to_string()), + ty: Utf8Type::Variable(None), + unit: CharLengthUnits::Characters, + }, + &DataValue::Utf8 { + value: Some("a".to_string()), + ty: Utf8Type::Variable(None), + unit: CharLengthUnits::Characters, + }, + ), + DataValue::Boolean(Some(true)) + ); + let evaluator = EvaluatorFactory::binary_create( + LogicalType::Varchar(None, CharLengthUnits::Characters), + BinaryOperator::NotEq, + )?; + assert_eq!( + evaluator.0.binary_eval( + &DataValue::Utf8 { + value: Some("a".to_string()), + ty: Utf8Type::Variable(None), + unit: CharLengthUnits::Characters, + }, + &DataValue::Utf8 { + value: Some("a".to_string()), + ty: Utf8Type::Variable(None), + unit: CharLengthUnits::Characters, + }, + ), + DataValue::Boolean(Some(false)) + ); + let evaluator = EvaluatorFactory::binary_create( + LogicalType::Varchar(None, CharLengthUnits::Characters), + BinaryOperator::Eq, + )?; + assert_eq!( + evaluator.0.binary_eval( + &DataValue::Utf8 { + value: Some("a".to_string()), + ty: Utf8Type::Variable(None), + unit: CharLengthUnits::Characters, + }, + &DataValue::Utf8 { + value: Some("a".to_string()), + ty: Utf8Type::Variable(None), + unit: CharLengthUnits::Characters, + }, + ), + DataValue::Boolean(Some(true)) + ); + let evaluator = EvaluatorFactory::binary_create( + LogicalType::Varchar(None, CharLengthUnits::Characters), + BinaryOperator::Gt, + )?; + assert_eq!( + evaluator.0.binary_eval( + &DataValue::Utf8 { + value: None, + ty: Utf8Type::Variable(None), + unit: CharLengthUnits::Characters, + }, + &DataValue::Utf8 { + value: Some("a".to_string()), + ty: Utf8Type::Variable(None), + unit: CharLengthUnits::Characters, + }, + ), + DataValue::Boolean(None) + ); + let evaluator = EvaluatorFactory::binary_create( + LogicalType::Varchar(None, CharLengthUnits::Characters), + BinaryOperator::Lt, + )?; + assert_eq!( + evaluator.0.binary_eval( + &DataValue::Utf8 { + value: None, + ty: Utf8Type::Variable(None), + unit: CharLengthUnits::Characters, + }, + &DataValue::Utf8 { + value: Some("a".to_string()), + ty: Utf8Type::Variable(None), + unit: CharLengthUnits::Characters, + }, + ), + DataValue::Boolean(None) + ); + let evaluator = EvaluatorFactory::binary_create( + LogicalType::Varchar(None, CharLengthUnits::Characters), + BinaryOperator::GtEq, + )?; + assert_eq!( + evaluator.0.binary_eval( + &DataValue::Utf8 { + value: None, + ty: Utf8Type::Variable(None), + unit: CharLengthUnits::Characters, + }, + &DataValue::Utf8 { + value: Some("a".to_string()), + ty: Utf8Type::Variable(None), + unit: CharLengthUnits::Characters, + }, + ), + DataValue::Boolean(None) + ); + let evaluator = EvaluatorFactory::binary_create( + LogicalType::Varchar(None, CharLengthUnits::Characters), + BinaryOperator::LtEq, + )?; + assert_eq!( + evaluator.0.binary_eval( + &DataValue::Utf8 { + value: None, + ty: Utf8Type::Variable(None), + unit: CharLengthUnits::Characters, + }, + &DataValue::Utf8 { + value: Some("a".to_string()), + ty: Utf8Type::Variable(None), + unit: CharLengthUnits::Characters, + }, + ), + DataValue::Boolean(None) + ); + let evaluator = EvaluatorFactory::binary_create( + LogicalType::Varchar(None, CharLengthUnits::Characters), + BinaryOperator::NotEq, + )?; + assert_eq!( + evaluator.0.binary_eval( + &DataValue::Utf8 { + value: None, + ty: Utf8Type::Variable(None), + unit: CharLengthUnits::Characters, + }, + &DataValue::Utf8 { + value: Some("a".to_string()), + ty: Utf8Type::Variable(None), + unit: CharLengthUnits::Characters, + }, + ), + DataValue::Boolean(None) + ); + + Ok(()) + } +} diff --git a/src/types/evaluator/null.rs b/src/types/evaluator/null.rs new file mode 100644 index 00000000..1632a644 --- /dev/null +++ b/src/types/evaluator/null.rs @@ -0,0 +1,15 @@ +use crate::types::evaluator::BinaryEvaluator; +use crate::types::evaluator::DataValue; +use serde::{Deserialize, Serialize}; + +/// Tips: +/// - Null values operate as null values +#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] +pub struct NullBinaryEvaluator; + +#[typetag::serde] +impl BinaryEvaluator for NullBinaryEvaluator { + fn binary_eval(&self, _: &DataValue, _: &DataValue) -> DataValue { + DataValue::Null + } +} diff --git a/src/types/evaluator/time.rs b/src/types/evaluator/time.rs new file mode 100644 index 00000000..6264eed1 --- /dev/null +++ b/src/types/evaluator/time.rs @@ -0,0 +1,8 @@ +use crate::numeric_binary_evaluator_definition; +use crate::types::evaluator::BinaryEvaluator; +use crate::types::evaluator::DataValue; +use paste::paste; +use serde::{Deserialize, Serialize}; +use std::hint; + +numeric_binary_evaluator_definition!(Time, DataValue::Time); diff --git a/src/types/evaluator/tuple.rs b/src/types/evaluator/tuple.rs new file mode 100644 index 00000000..c7818c5d --- /dev/null +++ b/src/types/evaluator/tuple.rs @@ -0,0 +1,144 @@ +use crate::types::evaluator::BinaryEvaluator; +use crate::types::evaluator::DataValue; +use crate::types::value::ValueRef; +use serde::{Deserialize, Serialize}; +use std::cmp::Ordering; +use std::hint; + +#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] +pub struct TupleEqBinaryEvaluator; +#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] +pub struct TupleNotEqBinaryEvaluator; +#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] +pub struct TupleGtBinaryEvaluator; +#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] +pub struct TupleGtEqBinaryEvaluator; +#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] +pub struct TupleLtBinaryEvaluator; +#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] +pub struct TupleLtEqBinaryEvaluator; + +fn tuple_cmp(v1: &[ValueRef], v2: &[ValueRef]) -> Option { + let mut order = Ordering::Equal; + let mut v1_iter = v1.iter(); + let mut v2_iter = v2.iter(); + + while order == Ordering::Equal { + order = match (v1_iter.next(), v2_iter.next()) { + (Some(v1), Some(v2)) => v1.partial_cmp(v2)?, + (Some(_), None) => Ordering::Greater, + (None, Some(_)) => Ordering::Less, + (None, None) => break, + } + } + Some(order) +} + +#[typetag::serde] +impl BinaryEvaluator for TupleEqBinaryEvaluator { + fn binary_eval(&self, left: &DataValue, right: &DataValue) -> DataValue { + let left = match left { + DataValue::Tuple(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let right = match right { + DataValue::Tuple(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let value = match (left, right) { + (Some(v1), Some(v2)) => Some(v1 == v2), + (_, _) => None, + }; + DataValue::Boolean(value) + } +} +#[typetag::serde] +impl BinaryEvaluator for TupleNotEqBinaryEvaluator { + fn binary_eval(&self, left: &DataValue, right: &DataValue) -> DataValue { + let left = match left { + DataValue::Tuple(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let right = match right { + DataValue::Tuple(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let value = match (left, right) { + (Some(v1), Some(v2)) => Some(v1 != v2), + (_, _) => None, + }; + DataValue::Boolean(value) + } +} +#[typetag::serde] +impl BinaryEvaluator for TupleGtBinaryEvaluator { + fn binary_eval(&self, left: &DataValue, right: &DataValue) -> DataValue { + let left = match left { + DataValue::Tuple(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let right = match right { + DataValue::Tuple(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let value = match (left, right) { + (Some(v1), Some(v2)) => tuple_cmp(v1, v2).map(|order| order.is_gt()), + (_, _) => None, + }; + DataValue::Boolean(value) + } +} +#[typetag::serde] +impl BinaryEvaluator for TupleGtEqBinaryEvaluator { + fn binary_eval(&self, left: &DataValue, right: &DataValue) -> DataValue { + let left = match left { + DataValue::Tuple(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let right = match right { + DataValue::Tuple(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let value = match (left, right) { + (Some(v1), Some(v2)) => tuple_cmp(v1, v2).map(|order| order.is_ge()), + (_, _) => None, + }; + DataValue::Boolean(value) + } +} +#[typetag::serde] +impl BinaryEvaluator for TupleLtBinaryEvaluator { + fn binary_eval(&self, left: &DataValue, right: &DataValue) -> DataValue { + let left = match left { + DataValue::Tuple(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let right = match right { + DataValue::Tuple(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let value = match (left, right) { + (Some(v1), Some(v2)) => tuple_cmp(v1, v2).map(|order| order.is_lt()), + (_, _) => None, + }; + DataValue::Boolean(value) + } +} +#[typetag::serde] +impl BinaryEvaluator for TupleLtEqBinaryEvaluator { + fn binary_eval(&self, left: &DataValue, right: &DataValue) -> DataValue { + let left = match left { + DataValue::Tuple(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let right = match right { + DataValue::Tuple(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let value = match (left, right) { + (Some(v1), Some(v2)) => tuple_cmp(v1, v2).map(|order| order.is_le()), + (_, _) => None, + }; + DataValue::Boolean(value) + } +} diff --git a/src/types/evaluator/uint16.rs b/src/types/evaluator/uint16.rs new file mode 100644 index 00000000..e8a4bbab --- /dev/null +++ b/src/types/evaluator/uint16.rs @@ -0,0 +1,8 @@ +use crate::numeric_binary_evaluator_definition; +use crate::types::evaluator::BinaryEvaluator; +use crate::types::evaluator::DataValue; +use paste::paste; +use serde::{Deserialize, Serialize}; +use std::hint; + +numeric_binary_evaluator_definition!(UInt16, DataValue::UInt16); diff --git a/src/types/evaluator/uint32.rs b/src/types/evaluator/uint32.rs new file mode 100644 index 00000000..b21e011d --- /dev/null +++ b/src/types/evaluator/uint32.rs @@ -0,0 +1,8 @@ +use crate::numeric_binary_evaluator_definition; +use crate::types::evaluator::BinaryEvaluator; +use crate::types::evaluator::DataValue; +use paste::paste; +use serde::{Deserialize, Serialize}; +use std::hint; + +numeric_binary_evaluator_definition!(UInt32, DataValue::UInt32); diff --git a/src/types/evaluator/uint64.rs b/src/types/evaluator/uint64.rs new file mode 100644 index 00000000..0a3a1274 --- /dev/null +++ b/src/types/evaluator/uint64.rs @@ -0,0 +1,8 @@ +use crate::numeric_binary_evaluator_definition; +use crate::types::evaluator::BinaryEvaluator; +use crate::types::evaluator::DataValue; +use paste::paste; +use serde::{Deserialize, Serialize}; +use std::hint; + +numeric_binary_evaluator_definition!(UInt64, DataValue::UInt64); diff --git a/src/types/evaluator/uint8.rs b/src/types/evaluator/uint8.rs new file mode 100644 index 00000000..15530d46 --- /dev/null +++ b/src/types/evaluator/uint8.rs @@ -0,0 +1,8 @@ +use crate::numeric_binary_evaluator_definition; +use crate::types::evaluator::BinaryEvaluator; +use crate::types::evaluator::DataValue; +use paste::paste; +use serde::{Deserialize, Serialize}; +use std::hint; + +numeric_binary_evaluator_definition!(UInt8, DataValue::UInt8); diff --git a/src/types/evaluator/utf8.rs b/src/types/evaluator/utf8.rs new file mode 100644 index 00000000..61c8ba22 --- /dev/null +++ b/src/types/evaluator/utf8.rs @@ -0,0 +1,226 @@ +use crate::types::evaluator::BinaryEvaluator; +use crate::types::evaluator::DataValue; +use crate::types::value::Utf8Type; +use regex::Regex; +use serde::{Deserialize, Serialize}; +use sqlparser::ast::CharLengthUnits; +use std::hint; + +#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] +pub struct Utf8GtBinaryEvaluator; +#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] +pub struct Utf8GtEqBinaryEvaluator; +#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] +pub struct Utf8LtBinaryEvaluator; +#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] +pub struct Utf8LtEqBinaryEvaluator; +#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] +pub struct Utf8EqBinaryEvaluator; +#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] +pub struct Utf8NotEqBinaryEvaluator; +#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] +pub struct Utf8StringConcatBinaryEvaluator; +#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] +pub struct Utf8LikeBinaryEvaluator { + pub(crate) escape_char: Option, +} +#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] +pub struct Utf8NotLikeBinaryEvaluator { + pub(crate) escape_char: Option, +} + +#[typetag::serde] +impl BinaryEvaluator for Utf8GtBinaryEvaluator { + fn binary_eval(&self, left: &DataValue, right: &DataValue) -> DataValue { + let left = match left { + DataValue::Utf8 { value, .. } => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let right = match right { + DataValue::Utf8 { value, .. } => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let value = if let (Some(v1), Some(v2)) = (left, right) { + Some(v1 > v2) + } else { + None + }; + DataValue::Boolean(value) + } +} +#[typetag::serde] +impl BinaryEvaluator for Utf8GtEqBinaryEvaluator { + fn binary_eval(&self, left: &DataValue, right: &DataValue) -> DataValue { + let left = match left { + DataValue::Utf8 { value, .. } => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let right = match right { + DataValue::Utf8 { value, .. } => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let value = if let (Some(v1), Some(v2)) = (left, right) { + Some(v1 >= v2) + } else { + None + }; + DataValue::Boolean(value) + } +} +#[typetag::serde] +impl BinaryEvaluator for Utf8LtBinaryEvaluator { + fn binary_eval(&self, left: &DataValue, right: &DataValue) -> DataValue { + let left = match left { + DataValue::Utf8 { value, .. } => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let right = match right { + DataValue::Utf8 { value, .. } => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let value = if let (Some(v1), Some(v2)) = (left, right) { + Some(v1 < v2) + } else { + None + }; + DataValue::Boolean(value) + } +} +#[typetag::serde] +impl BinaryEvaluator for Utf8LtEqBinaryEvaluator { + fn binary_eval(&self, left: &DataValue, right: &DataValue) -> DataValue { + let left = match left { + DataValue::Utf8 { value, .. } => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let right = match right { + DataValue::Utf8 { value, .. } => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let value = if let (Some(v1), Some(v2)) = (left, right) { + Some(v1 <= v2) + } else { + None + }; + DataValue::Boolean(value) + } +} +#[typetag::serde] +impl BinaryEvaluator for Utf8EqBinaryEvaluator { + fn binary_eval(&self, left: &DataValue, right: &DataValue) -> DataValue { + let left = match left { + DataValue::Utf8 { value, .. } => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let right = match right { + DataValue::Utf8 { value, .. } => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let value = if let (Some(v1), Some(v2)) = (left, right) { + Some(v1 == v2) + } else { + None + }; + DataValue::Boolean(value) + } +} +#[typetag::serde] +impl BinaryEvaluator for Utf8NotEqBinaryEvaluator { + fn binary_eval(&self, left: &DataValue, right: &DataValue) -> DataValue { + let left = match left { + DataValue::Utf8 { value, .. } => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let right = match right { + DataValue::Utf8 { value, .. } => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let value = if let (Some(v1), Some(v2)) = (left, right) { + Some(v1 != v2) + } else { + None + }; + DataValue::Boolean(value) + } +} +#[typetag::serde] +impl BinaryEvaluator for Utf8StringConcatBinaryEvaluator { + fn binary_eval(&self, left: &DataValue, right: &DataValue) -> DataValue { + let left = match left { + DataValue::Utf8 { value, .. } => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let right = match right { + DataValue::Utf8 { value, .. } => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let value = match (left, right) { + (Some(v1), Some(v2)) => Some(v1.clone() + v2), + _ => None, + }; + DataValue::Utf8 { + value, + ty: Utf8Type::Variable(None), + unit: CharLengthUnits::Characters, + } + } +} +#[typetag::serde] +impl BinaryEvaluator for Utf8LikeBinaryEvaluator { + fn binary_eval(&self, left: &DataValue, right: &DataValue) -> DataValue { + let value = match left { + DataValue::Utf8 { value, .. } => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let pattern = match right { + DataValue::Utf8 { value, .. } => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let is_match = if let (Some(value), Some(pattern)) = (value, pattern) { + string_like(value, pattern, self.escape_char) + } else { + return DataValue::Boolean(None); + }; + + DataValue::Boolean(Some(is_match)) + } +} +#[typetag::serde] +impl BinaryEvaluator for Utf8NotLikeBinaryEvaluator { + fn binary_eval(&self, left: &DataValue, right: &DataValue) -> DataValue { + let value = match left { + DataValue::Utf8 { value, .. } => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let pattern = match right { + DataValue::Utf8 { value, .. } => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + let is_match = if let (Some(value), Some(pattern)) = (value, pattern) { + string_like(value, pattern, self.escape_char) + } else { + return DataValue::Boolean(None); + }; + + DataValue::Boolean(Some(!is_match)) + } +} + +fn string_like(value: &str, pattern: &str, escape_char: Option) -> bool { + let mut regex_pattern = String::new(); + let mut chars = pattern.chars().peekable(); + while let Some(c) = chars.next() { + if matches!(escape_char.map(|escape_c| escape_c == c), Some(true)) { + if let Some(next_char) = chars.next() { + regex_pattern.push(next_char); + } + } else if c == '%' { + regex_pattern.push_str(".*"); + } else if c == '_' { + regex_pattern.push('.'); + } else { + regex_pattern.push(c); + } + } + Regex::new(®ex_pattern).unwrap().is_match(value) +} diff --git a/src/types/mod.rs b/src/types/mod.rs index 1ec599b7..b8b5734e 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -1,3 +1,4 @@ +pub mod evaluator; pub mod index; pub mod tuple; pub mod tuple_builder; @@ -202,10 +203,12 @@ impl LogicalType { ) { return Ok(LogicalType::DateTime); } - if let (LogicalType::Char(..), LogicalType::Varchar(len, ..)) - | (LogicalType::Varchar(len, ..), LogicalType::Char(..)) = (left, right) + if let (LogicalType::Char(..), LogicalType::Varchar(..)) + | (LogicalType::Varchar(..), LogicalType::Char(..)) + | (LogicalType::Char(..), LogicalType::Char(..)) + | (LogicalType::Varchar(..), LogicalType::Varchar(..)) = (left, right) { - return Ok(LogicalType::Varchar(*len, CharLengthUnits::Characters)); + return Ok(LogicalType::Varchar(None, CharLengthUnits::Characters)); } Err(DatabaseError::Incomparable(*left, *right)) }