diff --git a/src/binder/expr.rs b/src/binder/expr.rs index 0466c862..ae9e7890 100644 --- a/src/binder/expr.rs +++ b/src/binder/expr.rs @@ -394,6 +394,7 @@ impl<'a, T: Transaction> Binder<'a, T> { Ok(ScalarExpression::Unary { op: (*op).into(), expr, + evaluator: None, ty, }) } diff --git a/src/db.rs b/src/db.rs index 6edcbc2f..a3a77309 100644 --- a/src/db.rs +++ b/src/db.rs @@ -309,10 +309,11 @@ mod test { } function!(TestFunction::test(LogicalType::Integer, LogicalType::Integer) -> LogicalType::Integer => (|v1: ValueRef, v2: ValueRef| { - let plus_evaluator = EvaluatorFactory::binary_create(LogicalType::Integer, BinaryOperator::Plus)?; - let value = plus_evaluator.0.binary_eval(&v1, &v2); + let plus_binary_evaluator = EvaluatorFactory::binary_create(LogicalType::Integer, BinaryOperator::Plus)?; + let value = plus_binary_evaluator.0.binary_eval(&v1, &v2); - DataValue::unary_op(&value, &UnaryOperator::Minus) + let plus_unary_evaluator = EvaluatorFactory::unary_create(LogicalType::Integer, UnaryOperator::Minus)?; + Ok(plus_unary_evaluator.0.unary_eval(&value)) })); #[tokio::test] diff --git a/src/errors.rs b/src/errors.rs index 85e260ba..79acfcb0 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -1,4 +1,4 @@ -use crate::expression::BinaryOperator; +use crate::expression::{BinaryOperator, UnaryOperator}; use crate::types::LogicalType; use chrono::ParseError; use kip_db::KernelError; @@ -143,6 +143,8 @@ pub enum DatabaseError { TooLong, #[error("there are more buckets: {0} than elements: {1}")] TooManyBuckets(usize, usize), + #[error("unsupported unary operator: {0} cannot support {1} for calculations")] + UnsupportedUnaryOperator(LogicalType, UnaryOperator), #[error("unsupported binary operator: {0} cannot support {1} for calculations")] UnsupportedBinaryOperator(LogicalType, BinaryOperator), #[error("unsupported statement: {0}")] diff --git a/src/expression/evaluator.rs b/src/expression/evaluator.rs index 2499a214..02a5681b 100644 --- a/src/expression/evaluator.rs +++ b/src/expression/evaluator.rs @@ -128,10 +128,18 @@ impl ScalarExpression { } Ok(Arc::new(DataValue::Boolean(Some(is_in)))) } - ScalarExpression::Unary { expr, op, .. } => { + ScalarExpression::Unary { + expr, evaluator, .. + } => { let value = expr.eval(tuple, schema)?; - Ok(Arc::new(DataValue::unary_op(&value, op)?)) + Ok(Arc::new( + evaluator + .as_ref() + .ok_or(DatabaseError::EvaluatorNotFound)? + .0 + .unary_eval(&value), + )) } ScalarExpression::AggCall { .. } => { unreachable!("must use `NormalizationRuleImpl::ExpressionRemapper`") diff --git a/src/expression/mod.rs b/src/expression/mod.rs index 1f3d3ec2..9988cd45 100644 --- a/src/expression/mod.rs +++ b/src/expression/mod.rs @@ -13,7 +13,7 @@ use self::agg::AggKind; use crate::catalog::{ColumnCatalog, ColumnDesc, ColumnRef}; use crate::errors::DatabaseError; use crate::expression::function::ScalarFunction; -use crate::types::evaluator::{BinaryEvaluatorBox, EvaluatorFactory}; +use crate::types::evaluator::{BinaryEvaluatorBox, EvaluatorFactory, UnaryEvaluatorBox}; use crate::types::value::ValueRef; use crate::types::LogicalType; @@ -22,7 +22,6 @@ mod evaluator; pub mod function; pub mod range_detacher; pub mod simplify; -pub mod value_compute; #[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] pub enum AliasType { @@ -53,6 +52,7 @@ pub enum ScalarExpression { Unary { op: UnaryOperator, expr: Box, + evaluator: Option, ty: LogicalType, }, Binary { @@ -307,6 +307,29 @@ impl ScalarExpression { *evaluator = Some(EvaluatorFactory::binary_create(ty, *op)?); } + ScalarExpression::Unary { + expr, + op, + evaluator, + .. + } => { + expr.bind_evaluator()?; + + let ty = expr.return_type(); + if ty.is_unsigned_numeric() { + *expr.as_mut() = ScalarExpression::TypeCast { + expr: Box::new(mem::replace(expr, ScalarExpression::Empty)), + ty: match ty { + LogicalType::UTinyint => LogicalType::Tinyint, + LogicalType::USmallint => LogicalType::Smallint, + LogicalType::UInteger => LogicalType::Integer, + LogicalType::UBigint => LogicalType::Bigint, + _ => unreachable!(), + }, + } + } + *evaluator = Some(EvaluatorFactory::unary_create(ty, *op)?); + } ScalarExpression::Alias { expr, .. } => { expr.bind_evaluator()?; } @@ -316,9 +339,6 @@ impl ScalarExpression { ScalarExpression::IsNull { expr, .. } => { expr.bind_evaluator()?; } - ScalarExpression::Unary { expr, .. } => { - expr.bind_evaluator()?; - } ScalarExpression::AggCall { args, .. } | ScalarExpression::Coalesce { exprs: args, .. } | ScalarExpression::Tuple(args) => { diff --git a/src/expression/simplify.rs b/src/expression/simplify.rs index 755b7d28..4a1efb05 100644 --- a/src/expression/simplify.rs +++ b/src/expression/simplify.rs @@ -149,29 +149,49 @@ impl ScalarExpression { Some(Arc::new(DataValue::Boolean(is_null))) } - ScalarExpression::Unary { expr, op, .. } => { - let val = expr.unpack_val()?; - - DataValue::unary_op(&val, op).ok().map(Arc::new) + ScalarExpression::Unary { + expr, + op, + evaluator, + ty, + .. + } => { + let value = expr.unpack_val()?; + let unary_value = if let Some(evaluator) = evaluator { + evaluator.0.unary_eval(&value) + } else { + EvaluatorFactory::unary_create(*ty, *op) + .ok()? + .0 + .unary_eval(&value) + }; + Some(Arc::new(unary_value)) } ScalarExpression::Binary { left_expr, right_expr, op, ty, + evaluator, .. } => { let mut left = left_expr.unpack_val()?; let mut right = right_expr.unpack_val()?; - let evaluator = EvaluatorFactory::binary_create(*ty, *op).ok()?; - if left.logical_type() != *ty { left = Arc::new(DataValue::clone(&left).cast(ty).ok()?); } if right.logical_type() != *ty { right = Arc::new(DataValue::clone(&right).cast(ty).ok()?); } - Some(Arc::new(evaluator.0.binary_eval(&left, &right))) + let binary_value = if let Some(evaluator) = evaluator { + evaluator.0.binary_eval(&left, &right) + } else { + EvaluatorFactory::binary_create(*ty, *op) + .ok()? + .0 + .binary_eval(&left, &right) + }; + Some(Arc::new(binary_value)) } _ => None, } @@ -205,11 +225,23 @@ impl ScalarExpression { pub fn constant_calculation(&mut self) -> Result<(), DatabaseError> { match self { - ScalarExpression::Unary { expr, op, .. } => { + ScalarExpression::Unary { + expr, + op, + evaluator, + ty, + .. + } => { expr.constant_calculation()?; if let ScalarExpression::Constant(unary_val) = expr.as_ref() { - let value = DataValue::unary_op(unary_val, op)?; + let value = if let Some(evaluator) = evaluator { + evaluator.0.unary_eval(unary_val) + } else { + EvaluatorFactory::unary_create(*ty, *op)? + .0 + .unary_eval(unary_val) + }; let _ = mem::replace(self, ScalarExpression::Constant(Arc::new(value))); } } @@ -421,10 +453,22 @@ impl ScalarExpression { ); } } - ScalarExpression::Unary { expr, op, ty, .. } => { - if let Some(val) = expr.unpack_val() { - let new_expr = - ScalarExpression::Constant(Arc::new(DataValue::unary_op(&val, op)?)); + ScalarExpression::Unary { + expr, + op, + ty, + evaluator, + .. + } => { + if let Some(value) = expr.unpack_val() { + let value = if let Some(evaluator) = evaluator { + evaluator.0.unary_eval(&value) + } else { + EvaluatorFactory::unary_create(*ty, *op)? + .0 + .unary_eval(&value) + }; + let new_expr = ScalarExpression::Constant(Arc::new(value)); let _ = mem::replace(self, new_expr); } else { replaces.push(Replace::Unary(ReplaceUnary { @@ -571,6 +615,7 @@ impl ScalarExpression { Box::new(ScalarExpression::Unary { op: fix_op, expr, + evaluator: None, ty: fix_ty, }), ); diff --git a/src/expression/value_compute.rs b/src/expression/value_compute.rs deleted file mode 100644 index cc1acb4c..00000000 --- a/src/expression/value_compute.rs +++ /dev/null @@ -1,48 +0,0 @@ -use crate::errors::DatabaseError; -use crate::expression::UnaryOperator; -use crate::types::value::DataValue; -use crate::types::LogicalType; - -impl DataValue { - // FIXME: like BinaryEvaluator - pub fn unary_op(&self, op: &UnaryOperator) -> Result { - let mut value_type = self.logical_type(); - let mut value = self.clone(); - - if value_type.is_numeric() && matches!(op, UnaryOperator::Plus | UnaryOperator::Minus) { - if value_type.is_unsigned_numeric() { - match value_type { - LogicalType::UTinyint => value_type = LogicalType::Tinyint, - LogicalType::USmallint => value_type = LogicalType::Smallint, - LogicalType::UInteger => value_type = LogicalType::Integer, - LogicalType::UBigint => value_type = LogicalType::Bigint, - _ => unreachable!(), - }; - value = value.cast(&value_type)?; - } - - let result = match op { - UnaryOperator::Plus => value, - UnaryOperator::Minus => match value { - DataValue::Float32(option) => DataValue::Float32(option.map(|v| -v)), - DataValue::Float64(option) => DataValue::Float64(option.map(|v| -v)), - DataValue::Int8(option) => DataValue::Int8(option.map(|v| -v)), - DataValue::Int16(option) => DataValue::Int16(option.map(|v| -v)), - DataValue::Int32(option) => DataValue::Int32(option.map(|v| -v)), - DataValue::Int64(option) => DataValue::Int64(option.map(|v| -v)), - _ => unreachable!(), - }, - _ => unreachable!(), - }; - - Ok(result) - } else if matches!((value_type, op), (LogicalType::Boolean, UnaryOperator::Not)) { - match value { - DataValue::Boolean(option) => Ok(DataValue::Boolean(option.map(|v| !v))), - _ => unreachable!(), - } - } else { - Err(DatabaseError::InvalidType) - } - } -} diff --git a/src/optimizer/rule/normalization/simplification.rs b/src/optimizer/rule/normalization/simplification.rs index 3ab88448..67ed9f7a 100644 --- a/src/optimizer/rule/normalization/simplification.rs +++ b/src/optimizer/rule/normalization/simplification.rs @@ -290,6 +290,7 @@ mod test { evaluator: None, ty: LogicalType::Integer, }), + evaluator: None, ty: LogicalType::Integer, }), right_expr: Box::new(ScalarExpression::ColumnRef(Arc::new(c2_col))), diff --git a/src/types/evaluator/boolean.rs b/src/types/evaluator/boolean.rs index 7bb9e8a3..eb671a75 100644 --- a/src/types/evaluator/boolean.rs +++ b/src/types/evaluator/boolean.rs @@ -1,8 +1,10 @@ -use crate::types::evaluator::BinaryEvaluator; use crate::types::evaluator::DataValue; +use crate::types::evaluator::{BinaryEvaluator, UnaryEvaluator}; use serde::{Deserialize, Serialize}; use std::hint; +#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] +pub struct BooleanNotUnaryEvaluator; #[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] pub struct BooleanAndBinaryEvaluator; #[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] @@ -12,6 +14,16 @@ pub struct BooleanEqBinaryEvaluator; #[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] pub struct BooleanNotEqBinaryEvaluator; +#[typetag::serde] +impl UnaryEvaluator for BooleanNotUnaryEvaluator { + fn unary_eval(&self, value: &DataValue) -> DataValue { + let value = match value { + DataValue::Boolean(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + DataValue::Boolean(value.map(|v| !v)) + } +} #[typetag::serde] impl BinaryEvaluator for BooleanAndBinaryEvaluator { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> DataValue { diff --git a/src/types/evaluator/float32.rs b/src/types/evaluator/float32.rs index c36fe516..2328cae7 100644 --- a/src/types/evaluator/float32.rs +++ b/src/types/evaluator/float32.rs @@ -1,8 +1,9 @@ -use crate::numeric_binary_evaluator_definition; -use crate::types::evaluator::BinaryEvaluator; use crate::types::evaluator::DataValue; +use crate::types::evaluator::{BinaryEvaluator, UnaryEvaluator}; +use crate::{numeric_binary_evaluator_definition, numeric_unary_evaluator_definition}; use paste::paste; use serde::{Deserialize, Serialize}; use std::hint; +numeric_unary_evaluator_definition!(Float32, DataValue::Float32); numeric_binary_evaluator_definition!(Float32, DataValue::Float32); diff --git a/src/types/evaluator/float64.rs b/src/types/evaluator/float64.rs index cf352316..0dcb953d 100644 --- a/src/types/evaluator/float64.rs +++ b/src/types/evaluator/float64.rs @@ -1,8 +1,9 @@ -use crate::numeric_binary_evaluator_definition; -use crate::types::evaluator::BinaryEvaluator; use crate::types::evaluator::DataValue; +use crate::types::evaluator::{BinaryEvaluator, UnaryEvaluator}; +use crate::{numeric_binary_evaluator_definition, numeric_unary_evaluator_definition}; use paste::paste; use serde::{Deserialize, Serialize}; use std::hint; +numeric_unary_evaluator_definition!(Float64, DataValue::Float64); numeric_binary_evaluator_definition!(Float64, DataValue::Float64); diff --git a/src/types/evaluator/int16.rs b/src/types/evaluator/int16.rs index 1e9fe272..b2c4a748 100644 --- a/src/types/evaluator/int16.rs +++ b/src/types/evaluator/int16.rs @@ -1,8 +1,9 @@ -use crate::numeric_binary_evaluator_definition; -use crate::types::evaluator::BinaryEvaluator; use crate::types::evaluator::DataValue; +use crate::types::evaluator::{BinaryEvaluator, UnaryEvaluator}; +use crate::{numeric_binary_evaluator_definition, numeric_unary_evaluator_definition}; use paste::paste; use serde::{Deserialize, Serialize}; use std::hint; +numeric_unary_evaluator_definition!(Int16, DataValue::Int16); numeric_binary_evaluator_definition!(Int16, DataValue::Int16); diff --git a/src/types/evaluator/int32.rs b/src/types/evaluator/int32.rs index 9181b991..cc7b0e86 100644 --- a/src/types/evaluator/int32.rs +++ b/src/types/evaluator/int32.rs @@ -1,8 +1,9 @@ -use crate::numeric_binary_evaluator_definition; -use crate::types::evaluator::BinaryEvaluator; use crate::types::evaluator::DataValue; +use crate::types::evaluator::{BinaryEvaluator, UnaryEvaluator}; +use crate::{numeric_binary_evaluator_definition, numeric_unary_evaluator_definition}; use paste::paste; use serde::{Deserialize, Serialize}; use std::hint; +numeric_unary_evaluator_definition!(Int32, DataValue::Int32); numeric_binary_evaluator_definition!(Int32, DataValue::Int32); diff --git a/src/types/evaluator/int64.rs b/src/types/evaluator/int64.rs index f5df3bb5..eafac83e 100644 --- a/src/types/evaluator/int64.rs +++ b/src/types/evaluator/int64.rs @@ -1,8 +1,9 @@ -use crate::numeric_binary_evaluator_definition; -use crate::types::evaluator::BinaryEvaluator; use crate::types::evaluator::DataValue; +use crate::types::evaluator::{BinaryEvaluator, UnaryEvaluator}; +use crate::{numeric_binary_evaluator_definition, numeric_unary_evaluator_definition}; use paste::paste; use serde::{Deserialize, Serialize}; use std::hint; +numeric_unary_evaluator_definition!(Int64, DataValue::Int64); numeric_binary_evaluator_definition!(Int64, DataValue::Int64); diff --git a/src/types/evaluator/int8.rs b/src/types/evaluator/int8.rs index 4e7daff0..b3abe5f7 100644 --- a/src/types/evaluator/int8.rs +++ b/src/types/evaluator/int8.rs @@ -1,8 +1,9 @@ -use crate::numeric_binary_evaluator_definition; -use crate::types::evaluator::BinaryEvaluator; use crate::types::evaluator::DataValue; +use crate::types::evaluator::{BinaryEvaluator, UnaryEvaluator}; +use crate::{numeric_binary_evaluator_definition, numeric_unary_evaluator_definition}; use paste::paste; use serde::{Deserialize, Serialize}; use std::hint; +numeric_unary_evaluator_definition!(Int8, DataValue::Int8); numeric_binary_evaluator_definition!(Int8, DataValue::Int8); diff --git a/src/types/evaluator/mod.rs b/src/types/evaluator/mod.rs index 34a7a4ac..1a80c77a 100644 --- a/src/types/evaluator/mod.rs +++ b/src/types/evaluator/mod.rs @@ -18,7 +18,7 @@ pub mod uint8; pub mod utf8; use crate::errors::DatabaseError; -use crate::expression::BinaryOperator; +use crate::expression::{BinaryOperator, UnaryOperator}; use crate::types::evaluator::boolean::*; use crate::types::evaluator::date::*; use crate::types::evaluator::datetime::*; @@ -57,9 +57,17 @@ pub trait BinaryEvaluator: Send + Sync + Debug { fn binary_eval(&self, left: &DataValue, right: &DataValue) -> DataValue; } +#[typetag::serde(tag = "type")] +pub trait UnaryEvaluator: Send + Sync + Debug { + fn unary_eval(&self, value: &DataValue) -> DataValue; +} + #[derive(Clone, Debug, Serialize, Deserialize)] pub struct BinaryEvaluatorBox(pub Arc); +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct UnaryEvaluatorBox(pub Arc); + impl PartialEq for BinaryEvaluatorBox { fn eq(&self, _: &Self) -> bool { // FIXME @@ -75,6 +83,21 @@ impl Hash for BinaryEvaluatorBox { } } +impl PartialEq for UnaryEvaluatorBox { + fn eq(&self, _: &Self) -> bool { + // FIXME + true + } +} + +impl Eq for UnaryEvaluatorBox {} + +impl Hash for UnaryEvaluatorBox { + fn hash(&self, state: &mut H) { + state.write_i8(42) + } +} + macro_rules! numeric_binary_evaluator { ($value_type:ident, $op:expr, $ty:expr) => { paste! { @@ -100,9 +123,44 @@ macro_rules! numeric_binary_evaluator { }; } +macro_rules! numeric_unary_evaluator { + ($value_type:ident, $op:expr, $ty:expr) => { + paste! { + match $op { + UnaryOperator::Plus => Ok(UnaryEvaluatorBox(Arc::new([<$value_type PlusUnaryEvaluator>]))), + UnaryOperator::Minus => Ok(UnaryEvaluatorBox(Arc::new([<$value_type MinusUnaryEvaluator>]))), + _ => { + return Err(DatabaseError::UnsupportedUnaryOperator( + $ty, + $op, + )) + } + } + } + }; +} + pub struct EvaluatorFactory; impl EvaluatorFactory { + pub fn unary_create( + ty: LogicalType, + op: UnaryOperator, + ) -> Result { + match ty { + LogicalType::Tinyint => numeric_unary_evaluator!(Int8, op, LogicalType::Tinyint), + LogicalType::Smallint => numeric_unary_evaluator!(Int16, op, LogicalType::Smallint), + LogicalType::Integer => numeric_unary_evaluator!(Int32, op, LogicalType::Integer), + LogicalType::Bigint => numeric_unary_evaluator!(Int64, op, LogicalType::Bigint), + LogicalType::Boolean => match op { + UnaryOperator::Not => Ok(UnaryEvaluatorBox(Arc::new(BooleanNotUnaryEvaluator))), + _ => Err(DatabaseError::UnsupportedUnaryOperator(ty, op)), + }, + LogicalType::Float => numeric_unary_evaluator!(Float32, op, LogicalType::Float), + LogicalType::Double => numeric_unary_evaluator!(Float64, op, LogicalType::Double), + _ => Err(DatabaseError::UnsupportedUnaryOperator(ty, op)), + } + } pub fn binary_create( ty: LogicalType, op: BinaryOperator, @@ -173,6 +231,35 @@ impl EvaluatorFactory { } } +#[macro_export] +macro_rules! numeric_unary_evaluator_definition { + ($value_type:ident, $compute_type:path) => { + paste! { + #[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] + pub struct [<$value_type PlusUnaryEvaluator>]; + #[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] + pub struct [<$value_type MinusUnaryEvaluator>]; + + #[typetag::serde] + impl UnaryEvaluator for [<$value_type PlusUnaryEvaluator>] { + fn unary_eval(&self, value: &DataValue) -> DataValue { + value.clone() + } + } + #[typetag::serde] + impl UnaryEvaluator for [<$value_type MinusUnaryEvaluator>] { + fn unary_eval(&self, value: &DataValue) -> DataValue { + let value = match value { + $compute_type(value) => value, + _ => unsafe { hint::unreachable_unchecked() }, + }; + $compute_type(value.map(|v| -v)) + } + } + } + }; +} + #[macro_export] macro_rules! numeric_binary_evaluator_definition { ($value_type:ident, $compute_type:path) => {