Skip to content

Commit

Permalink
feat: support UnaryEvaluator (#202)
Browse files Browse the repository at this point in the history
  • Loading branch information
KKould authored Apr 9, 2024
1 parent 50d3e6d commit 74ea950
Show file tree
Hide file tree
Showing 16 changed files with 221 additions and 86 deletions.
1 change: 1 addition & 0 deletions src/binder/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
Ok(ScalarExpression::Unary {
op: (*op).into(),
expr,
evaluator: None,
ty,
})
}
Expand Down
7 changes: 4 additions & 3 deletions src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -309,10 +309,11 @@ mod test {
}

function!(TestFunction::test(LogicalType::Integer, LogicalType::Integer) -> LogicalType::Integer => (|v1: ValueRef, v2: ValueRef| {
let plus_evaluator = EvaluatorFactory::binary_create(LogicalType::Integer, BinaryOperator::Plus)?;
let value = plus_evaluator.0.binary_eval(&v1, &v2);
let plus_binary_evaluator = EvaluatorFactory::binary_create(LogicalType::Integer, BinaryOperator::Plus)?;
let value = plus_binary_evaluator.0.binary_eval(&v1, &v2);

DataValue::unary_op(&value, &UnaryOperator::Minus)
let plus_unary_evaluator = EvaluatorFactory::unary_create(LogicalType::Integer, UnaryOperator::Minus)?;
Ok(plus_unary_evaluator.0.unary_eval(&value))
}));

#[tokio::test]
Expand Down
4 changes: 3 additions & 1 deletion src/errors.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::expression::BinaryOperator;
use crate::expression::{BinaryOperator, UnaryOperator};
use crate::types::LogicalType;
use chrono::ParseError;
use kip_db::KernelError;
Expand Down Expand Up @@ -143,6 +143,8 @@ pub enum DatabaseError {
TooLong,
#[error("there are more buckets: {0} than elements: {1}")]
TooManyBuckets(usize, usize),
#[error("unsupported unary operator: {0} cannot support {1} for calculations")]
UnsupportedUnaryOperator(LogicalType, UnaryOperator),
#[error("unsupported binary operator: {0} cannot support {1} for calculations")]
UnsupportedBinaryOperator(LogicalType, BinaryOperator),
#[error("unsupported statement: {0}")]
Expand Down
12 changes: 10 additions & 2 deletions src/expression/evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,18 @@ impl ScalarExpression {
}
Ok(Arc::new(DataValue::Boolean(Some(is_in))))
}
ScalarExpression::Unary { expr, op, .. } => {
ScalarExpression::Unary {
expr, evaluator, ..
} => {
let value = expr.eval(tuple, schema)?;

Ok(Arc::new(DataValue::unary_op(&value, op)?))
Ok(Arc::new(
evaluator
.as_ref()
.ok_or(DatabaseError::EvaluatorNotFound)?
.0
.unary_eval(&value),
))
}
ScalarExpression::AggCall { .. } => {
unreachable!("must use `NormalizationRuleImpl::ExpressionRemapper`")
Expand Down
30 changes: 25 additions & 5 deletions src/expression/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use self::agg::AggKind;
use crate::catalog::{ColumnCatalog, ColumnDesc, ColumnRef};
use crate::errors::DatabaseError;
use crate::expression::function::ScalarFunction;
use crate::types::evaluator::{BinaryEvaluatorBox, EvaluatorFactory};
use crate::types::evaluator::{BinaryEvaluatorBox, EvaluatorFactory, UnaryEvaluatorBox};
use crate::types::value::ValueRef;
use crate::types::LogicalType;

Expand All @@ -22,7 +22,6 @@ mod evaluator;
pub mod function;
pub mod range_detacher;
pub mod simplify;
pub mod value_compute;

#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)]
pub enum AliasType {
Expand Down Expand Up @@ -53,6 +52,7 @@ pub enum ScalarExpression {
Unary {
op: UnaryOperator,
expr: Box<ScalarExpression>,
evaluator: Option<UnaryEvaluatorBox>,
ty: LogicalType,
},
Binary {
Expand Down Expand Up @@ -307,6 +307,29 @@ impl ScalarExpression {

*evaluator = Some(EvaluatorFactory::binary_create(ty, *op)?);
}
ScalarExpression::Unary {
expr,
op,
evaluator,
..
} => {
expr.bind_evaluator()?;

let ty = expr.return_type();
if ty.is_unsigned_numeric() {
*expr.as_mut() = ScalarExpression::TypeCast {
expr: Box::new(mem::replace(expr, ScalarExpression::Empty)),
ty: match ty {
LogicalType::UTinyint => LogicalType::Tinyint,
LogicalType::USmallint => LogicalType::Smallint,
LogicalType::UInteger => LogicalType::Integer,
LogicalType::UBigint => LogicalType::Bigint,
_ => unreachable!(),
},
}
}
*evaluator = Some(EvaluatorFactory::unary_create(ty, *op)?);
}
ScalarExpression::Alias { expr, .. } => {
expr.bind_evaluator()?;
}
Expand All @@ -316,9 +339,6 @@ impl ScalarExpression {
ScalarExpression::IsNull { expr, .. } => {
expr.bind_evaluator()?;
}
ScalarExpression::Unary { expr, .. } => {
expr.bind_evaluator()?;
}
ScalarExpression::AggCall { args, .. }
| ScalarExpression::Coalesce { exprs: args, .. }
| ScalarExpression::Tuple(args) => {
Expand Down
71 changes: 58 additions & 13 deletions src/expression/simplify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,29 +149,49 @@ impl ScalarExpression {

Some(Arc::new(DataValue::Boolean(is_null)))
}
ScalarExpression::Unary { expr, op, .. } => {
let val = expr.unpack_val()?;

DataValue::unary_op(&val, op).ok().map(Arc::new)
ScalarExpression::Unary {
expr,
op,
evaluator,
ty,
..
} => {
let value = expr.unpack_val()?;
let unary_value = if let Some(evaluator) = evaluator {
evaluator.0.unary_eval(&value)
} else {
EvaluatorFactory::unary_create(*ty, *op)
.ok()?
.0
.unary_eval(&value)
};
Some(Arc::new(unary_value))
}
ScalarExpression::Binary {
left_expr,
right_expr,
op,
ty,
evaluator,
..
} => {
let mut left = left_expr.unpack_val()?;
let mut right = right_expr.unpack_val()?;
let evaluator = EvaluatorFactory::binary_create(*ty, *op).ok()?;

if left.logical_type() != *ty {
left = Arc::new(DataValue::clone(&left).cast(ty).ok()?);
}
if right.logical_type() != *ty {
right = Arc::new(DataValue::clone(&right).cast(ty).ok()?);
}
Some(Arc::new(evaluator.0.binary_eval(&left, &right)))
let binary_value = if let Some(evaluator) = evaluator {
evaluator.0.binary_eval(&left, &right)
} else {
EvaluatorFactory::binary_create(*ty, *op)
.ok()?
.0
.binary_eval(&left, &right)
};
Some(Arc::new(binary_value))
}
_ => None,
}
Expand Down Expand Up @@ -205,11 +225,23 @@ impl ScalarExpression {

pub fn constant_calculation(&mut self) -> Result<(), DatabaseError> {
match self {
ScalarExpression::Unary { expr, op, .. } => {
ScalarExpression::Unary {
expr,
op,
evaluator,
ty,
..
} => {
expr.constant_calculation()?;

if let ScalarExpression::Constant(unary_val) = expr.as_ref() {
let value = DataValue::unary_op(unary_val, op)?;
let value = if let Some(evaluator) = evaluator {
evaluator.0.unary_eval(unary_val)
} else {
EvaluatorFactory::unary_create(*ty, *op)?
.0
.unary_eval(unary_val)
};
let _ = mem::replace(self, ScalarExpression::Constant(Arc::new(value)));
}
}
Expand Down Expand Up @@ -421,10 +453,22 @@ impl ScalarExpression {
);
}
}
ScalarExpression::Unary { expr, op, ty, .. } => {
if let Some(val) = expr.unpack_val() {
let new_expr =
ScalarExpression::Constant(Arc::new(DataValue::unary_op(&val, op)?));
ScalarExpression::Unary {
expr,
op,
ty,
evaluator,
..
} => {
if let Some(value) = expr.unpack_val() {
let value = if let Some(evaluator) = evaluator {
evaluator.0.unary_eval(&value)
} else {
EvaluatorFactory::unary_create(*ty, *op)?
.0
.unary_eval(&value)
};
let new_expr = ScalarExpression::Constant(Arc::new(value));
let _ = mem::replace(self, new_expr);
} else {
replaces.push(Replace::Unary(ReplaceUnary {
Expand Down Expand Up @@ -571,6 +615,7 @@ impl ScalarExpression {
Box::new(ScalarExpression::Unary {
op: fix_op,
expr,
evaluator: None,
ty: fix_ty,
}),
);
Expand Down
48 changes: 0 additions & 48 deletions src/expression/value_compute.rs

This file was deleted.

1 change: 1 addition & 0 deletions src/optimizer/rule/normalization/simplification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ mod test {
evaluator: None,
ty: LogicalType::Integer,
}),
evaluator: None,
ty: LogicalType::Integer,
}),
right_expr: Box::new(ScalarExpression::ColumnRef(Arc::new(c2_col))),
Expand Down
14 changes: 13 additions & 1 deletion src/types/evaluator/boolean.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use crate::types::evaluator::BinaryEvaluator;
use crate::types::evaluator::DataValue;
use crate::types::evaluator::{BinaryEvaluator, UnaryEvaluator};
use serde::{Deserialize, Serialize};
use std::hint;

#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)]
pub struct BooleanNotUnaryEvaluator;
#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)]
pub struct BooleanAndBinaryEvaluator;
#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)]
Expand All @@ -12,6 +14,16 @@ pub struct BooleanEqBinaryEvaluator;
#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)]
pub struct BooleanNotEqBinaryEvaluator;

#[typetag::serde]
impl UnaryEvaluator for BooleanNotUnaryEvaluator {
fn unary_eval(&self, value: &DataValue) -> DataValue {
let value = match value {
DataValue::Boolean(value) => value,
_ => unsafe { hint::unreachable_unchecked() },
};
DataValue::Boolean(value.map(|v| !v))
}
}
#[typetag::serde]
impl BinaryEvaluator for BooleanAndBinaryEvaluator {
fn binary_eval(&self, left: &DataValue, right: &DataValue) -> DataValue {
Expand Down
5 changes: 3 additions & 2 deletions src/types/evaluator/float32.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use crate::numeric_binary_evaluator_definition;
use crate::types::evaluator::BinaryEvaluator;
use crate::types::evaluator::DataValue;
use crate::types::evaluator::{BinaryEvaluator, UnaryEvaluator};
use crate::{numeric_binary_evaluator_definition, numeric_unary_evaluator_definition};
use paste::paste;
use serde::{Deserialize, Serialize};
use std::hint;

numeric_unary_evaluator_definition!(Float32, DataValue::Float32);
numeric_binary_evaluator_definition!(Float32, DataValue::Float32);
5 changes: 3 additions & 2 deletions src/types/evaluator/float64.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use crate::numeric_binary_evaluator_definition;
use crate::types::evaluator::BinaryEvaluator;
use crate::types::evaluator::DataValue;
use crate::types::evaluator::{BinaryEvaluator, UnaryEvaluator};
use crate::{numeric_binary_evaluator_definition, numeric_unary_evaluator_definition};
use paste::paste;
use serde::{Deserialize, Serialize};
use std::hint;

numeric_unary_evaluator_definition!(Float64, DataValue::Float64);
numeric_binary_evaluator_definition!(Float64, DataValue::Float64);
5 changes: 3 additions & 2 deletions src/types/evaluator/int16.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use crate::numeric_binary_evaluator_definition;
use crate::types::evaluator::BinaryEvaluator;
use crate::types::evaluator::DataValue;
use crate::types::evaluator::{BinaryEvaluator, UnaryEvaluator};
use crate::{numeric_binary_evaluator_definition, numeric_unary_evaluator_definition};
use paste::paste;
use serde::{Deserialize, Serialize};
use std::hint;

numeric_unary_evaluator_definition!(Int16, DataValue::Int16);
numeric_binary_evaluator_definition!(Int16, DataValue::Int16);
5 changes: 3 additions & 2 deletions src/types/evaluator/int32.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use crate::numeric_binary_evaluator_definition;
use crate::types::evaluator::BinaryEvaluator;
use crate::types::evaluator::DataValue;
use crate::types::evaluator::{BinaryEvaluator, UnaryEvaluator};
use crate::{numeric_binary_evaluator_definition, numeric_unary_evaluator_definition};
use paste::paste;
use serde::{Deserialize, Serialize};
use std::hint;

numeric_unary_evaluator_definition!(Int32, DataValue::Int32);
numeric_binary_evaluator_definition!(Int32, DataValue::Int32);
5 changes: 3 additions & 2 deletions src/types/evaluator/int64.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use crate::numeric_binary_evaluator_definition;
use crate::types::evaluator::BinaryEvaluator;
use crate::types::evaluator::DataValue;
use crate::types::evaluator::{BinaryEvaluator, UnaryEvaluator};
use crate::{numeric_binary_evaluator_definition, numeric_unary_evaluator_definition};
use paste::paste;
use serde::{Deserialize, Serialize};
use std::hint;

numeric_unary_evaluator_definition!(Int64, DataValue::Int64);
numeric_binary_evaluator_definition!(Int64, DataValue::Int64);
5 changes: 3 additions & 2 deletions src/types/evaluator/int8.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use crate::numeric_binary_evaluator_definition;
use crate::types::evaluator::BinaryEvaluator;
use crate::types::evaluator::DataValue;
use crate::types::evaluator::{BinaryEvaluator, UnaryEvaluator};
use crate::{numeric_binary_evaluator_definition, numeric_unary_evaluator_definition};
use paste::paste;
use serde::{Deserialize, Serialize};
use std::hint;

numeric_unary_evaluator_definition!(Int8, DataValue::Int8);
numeric_binary_evaluator_definition!(Int8, DataValue::Int8);
Loading

0 comments on commit 74ea950

Please sign in to comment.