Skip to content

Commit

Permalink
perf: select the corresponding type of BinaryEvaluator during optim… (
Browse files Browse the repository at this point in the history
#201)

* perf: select the corresponding type of `BinaryEvaluator` during optimization to avoid runtime type judgment overhead.

* fix: return `test_udf`

* style: code fmt

* style: add `EvaluatorNotFound`

* style: remove `UnaryEvaluator`

* style: remove `UnaryEvaluator`
  • Loading branch information
KKould authored Apr 8, 2024
1 parent b1c7928 commit 50d3e6d
Show file tree
Hide file tree
Showing 44 changed files with 2,245 additions and 1,755 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ kip_db = { version = "0.1.2-alpha.25.fix2" }
lazy_static = { version = "1.4.0" }
log = { version = "0.4.21", optional = true }
ordered-float = { version = "4.2.0" }
paste = { version = "1.0.14" }
petgraph = { version = "0.6.4" }
pgwire = { version = "0.19.2", optional = true }
rand = { version = "0.9.0-alpha.0" }
Expand Down
16 changes: 14 additions & 2 deletions src/binder/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
op: expression::BinaryOperator::Eq,
left_expr,
right_expr: Box::new(alias_expr),
evaluator: None,
ty: LogicalType::Boolean,
})
}
Expand Down Expand Up @@ -272,6 +273,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
op,
left_expr,
right_expr,
evaluator: None,
ty: LogicalType::Boolean,
})
}
Expand Down Expand Up @@ -342,10 +344,19 @@ impl<'a, T: Transaction> Binder<'a, T> {
BinaryOperator::Plus
| BinaryOperator::Minus
| BinaryOperator::Multiply
| BinaryOperator::Divide
| BinaryOperator::Modulo => {
LogicalType::max_logical_type(&left_expr.return_type(), &right_expr.return_type())?
}
BinaryOperator::Divide => {
if let LogicalType::Decimal(precision, scale) = LogicalType::max_logical_type(
&left_expr.return_type(),
&right_expr.return_type(),
)? {
LogicalType::Decimal(precision, scale)
} else {
LogicalType::Double
}
}
BinaryOperator::Gt
| BinaryOperator::Lt
| BinaryOperator::GtEq
Expand All @@ -356,13 +367,14 @@ impl<'a, T: Transaction> Binder<'a, T> {
| BinaryOperator::Or
| BinaryOperator::Xor => LogicalType::Boolean,
BinaryOperator::StringConcat => LogicalType::Varchar(None, CharLengthUnits::Characters),
_ => todo!(),
op => return Err(DatabaseError::UnsupportedStmt(format!("{}", op))),
};

Ok(ScalarExpression::Binary {
op: (op.clone()).into(),
left_expr,
right_expr,
evaluator: None,
ty,
})
}
Expand Down
8 changes: 8 additions & 0 deletions src/binder/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
op: BinaryOperator::And,
left_expr: Box::new(acc),
right_expr: Box::new(expr),
evaluator: None,
ty: LogicalType::Boolean,
});

Expand Down Expand Up @@ -654,6 +655,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
op: BinaryOperator::And,
left_expr: Box::new(acc),
right_expr: Box::new(expr),
evaluator: None,
ty: LogicalType::Boolean,
});
// TODO: handle cross join if on_keys is empty
Expand Down Expand Up @@ -723,6 +725,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
right_expr,
op,
ty,
..
} => {
match op {
BinaryOperator::Eq => {
Expand All @@ -746,6 +749,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
right_expr,
op,
ty,
evaluator: None,
});
}
}
Expand All @@ -757,6 +761,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
right_expr,
op,
ty,
evaluator: None,
});
}
}
Expand All @@ -772,6 +777,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
right_expr,
op,
ty,
evaluator: None,
});
}
}
Expand Down Expand Up @@ -800,6 +806,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
right_expr,
op,
ty,
evaluator: None,
});
}
_ => {
Expand All @@ -813,6 +820,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
right_expr,
op,
ty,
evaluator: None,
});
}
}
Expand Down
11 changes: 9 additions & 2 deletions src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,11 @@ impl<S: Storage> Database<S> {
.batch(
"Expression Remapper".to_string(),
HepBatchStrategy::once_topdown(),
vec![NormalizationRuleImpl::ExpressionRemapper],
vec![
NormalizationRuleImpl::ExpressionRemapper,
// TIPS: This rule is necessary
NormalizationRuleImpl::EvaluatorBind,
],
)
.implementations(vec![
// DQL
Expand Down Expand Up @@ -263,6 +267,7 @@ mod test {
use crate::expression::{BinaryOperator, UnaryOperator};
use crate::function;
use crate::storage::{Storage, Transaction};
use crate::types::evaluator::EvaluatorFactory;
use crate::types::tuple::{create_table, Tuple};
use crate::types::value::{DataValue, ValueRef};
use crate::types::LogicalType;
Expand Down Expand Up @@ -304,7 +309,9 @@ mod test {
}

function!(TestFunction::test(LogicalType::Integer, LogicalType::Integer) -> LogicalType::Integer => (|v1: ValueRef, v2: ValueRef| {
let value = DataValue::binary_op(&v1, &v2, &BinaryOperator::Plus)?;
let plus_evaluator = EvaluatorFactory::binary_create(LogicalType::Integer, BinaryOperator::Plus)?;
let value = plus_evaluator.0.binary_eval(&v1, &v2);

DataValue::unary_op(&value, &UnaryOperator::Minus)
}));

Expand Down
2 changes: 2 additions & 0 deletions src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ pub enum DatabaseError {
EmptyPlan,
#[error("sql statement is empty")]
EmptyStatement,
#[error("evaluator not found")]
EvaluatorNotFound,
#[error("from utf8: {0}")]
FromUtf8Error(
#[source]
Expand Down
31 changes: 17 additions & 14 deletions src/execution/volcano/dql/aggregate/avg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::errors::DatabaseError;
use crate::execution::volcano::dql::aggregate::sum::SumAccumulator;
use crate::execution::volcano::dql::aggregate::Accumulator;
use crate::expression::BinaryOperator;
use crate::types::evaluator::EvaluatorFactory;
use crate::types::value::{DataValue, ValueRef};
use crate::types::LogicalType;
use std::sync::Arc;
Expand All @@ -12,11 +13,11 @@ pub struct AvgAccumulator {
}

impl AvgAccumulator {
pub fn new(ty: &LogicalType) -> Self {
Self {
inner: SumAccumulator::new(ty),
pub fn new(ty: &LogicalType) -> Result<Self, DatabaseError> {
Ok(Self {
inner: SumAccumulator::new(ty)?,
count: 0,
}
})
}
}

Expand All @@ -31,21 +32,23 @@ impl Accumulator for AvgAccumulator {
}

fn evaluate(&self) -> Result<ValueRef, DatabaseError> {
let value = self.inner.evaluate()?;
let mut value = self.inner.evaluate()?;
let value_ty = value.logical_type();

let quantity = if value.logical_type().is_signed_numeric() {
if self.count == 0 {
return Ok(Arc::new(DataValue::init(&value_ty)));
}
let quantity = if value_ty.is_signed_numeric() {
DataValue::Int64(Some(self.count as i64))
} else {
DataValue::UInt32(Some(self.count as u32))
};
if self.count == 0 {
return Ok(Arc::new(DataValue::init(&value.logical_type())));
}
let quantity_ty = quantity.logical_type();

Ok(Arc::new(DataValue::binary_op(
&value,
&quantity,
&BinaryOperator::Divide,
)?))
if value_ty != quantity_ty {
value = Arc::new(DataValue::clone(&value).cast(&quantity_ty)?)
}
let evaluator = EvaluatorFactory::binary_create(quantity_ty, BinaryOperator::Divide)?;
Ok(Arc::new(evaluator.0.binary_eval(&value, &quantity)))
}
}
2 changes: 1 addition & 1 deletion src/execution/volcano/dql/aggregate/hash_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ impl HashAggStatus {
for (acc, value) in self
.group_hash_accs
.entry(group_keys)
.or_insert_with(|| create_accumulators(&self.agg_calls))
.or_insert_with(|| create_accumulators(&self.agg_calls).unwrap())
.iter_mut()
.zip_eq(values.iter())
{
Expand Down
4 changes: 3 additions & 1 deletion src/execution/volcano/dql/aggregate/min_max.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::errors::DatabaseError;
use crate::execution::volcano::dql::aggregate::Accumulator;
use crate::expression::BinaryOperator;
use crate::types::evaluator::EvaluatorFactory;
use crate::types::value::{DataValue, ValueRef};
use crate::types::LogicalType;
use std::sync::Arc;
Expand Down Expand Up @@ -31,8 +32,9 @@ impl Accumulator for MinMaxAccumulator {
fn update_value(&mut self, value: &ValueRef) -> Result<(), DatabaseError> {
if !value.is_null() {
if let Some(inner_value) = &self.inner {
let evaluator = EvaluatorFactory::binary_create(value.logical_type(), self.op)?;
if let DataValue::Boolean(Some(result)) =
DataValue::binary_op(inner_value, value, &self.op)?
evaluator.0.binary_eval(inner_value, value)
{
result
} else {
Expand Down
19 changes: 11 additions & 8 deletions src/execution/volcano/dql/aggregate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use crate::execution::volcano::dql::aggregate::sum::{DistinctSumAccumulator, Sum
use crate::expression::agg::AggKind;
use crate::expression::ScalarExpression;
use crate::types::value::ValueRef;
use itertools::Itertools;

/// Tips: Idea for sqlrs
/// An accumulator represents a stateful object that lives throughout the evaluation of multiple
Expand All @@ -27,20 +28,20 @@ pub trait Accumulator: Send + Sync {
fn evaluate(&self) -> Result<ValueRef, DatabaseError>;
}

fn create_accumulator(expr: &ScalarExpression) -> Box<dyn Accumulator> {
fn create_accumulator(expr: &ScalarExpression) -> Result<Box<dyn Accumulator>, DatabaseError> {
if let ScalarExpression::AggCall {
kind, ty, distinct, ..
} = expr
{
match (kind, distinct) {
Ok(match (kind, distinct) {
(AggKind::Count, false) => Box::new(CountAccumulator::new()),
(AggKind::Count, true) => Box::new(DistinctCountAccumulator::new()),
(AggKind::Sum, false) => Box::new(SumAccumulator::new(ty)),
(AggKind::Sum, true) => Box::new(DistinctSumAccumulator::new(ty)),
(AggKind::Sum, false) => Box::new(SumAccumulator::new(ty)?),
(AggKind::Sum, true) => Box::new(DistinctSumAccumulator::new(ty)?),
(AggKind::Min, _) => Box::new(MinMaxAccumulator::new(ty, false)),
(AggKind::Max, _) => Box::new(MinMaxAccumulator::new(ty, true)),
(AggKind::Avg, _) => Box::new(AvgAccumulator::new(ty)),
}
(AggKind::Avg, _) => Box::new(AvgAccumulator::new(ty)?),
})
} else {
unreachable!(
"create_accumulator called with non-aggregate expression {}",
Expand All @@ -49,6 +50,8 @@ fn create_accumulator(expr: &ScalarExpression) -> Box<dyn Accumulator> {
}
}

pub(crate) fn create_accumulators(exprs: &[ScalarExpression]) -> Vec<Box<dyn Accumulator>> {
exprs.iter().map(create_accumulator).collect()
pub(crate) fn create_accumulators(
exprs: &[ScalarExpression],
) -> Result<Vec<Box<dyn Accumulator>>, DatabaseError> {
exprs.iter().map(create_accumulator).try_collect()
}
2 changes: 1 addition & 1 deletion src/execution/volcano/dql/aggregate/simple_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ impl SimpleAggExecutor {
agg_calls,
mut input,
} = self;
let mut accs = create_accumulators(&agg_calls);
let mut accs = create_accumulators(&agg_calls)?;
let schema = input.output_schema().clone();

#[for_await]
Expand Down
19 changes: 11 additions & 8 deletions src/execution/volcano/dql/aggregate/sum.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::errors::DatabaseError;
use crate::execution::volcano::dql::aggregate::Accumulator;
use crate::expression::BinaryOperator;
use crate::types::evaluator::{BinaryEvaluatorBox, EvaluatorFactory};
use crate::types::value::{DataValue, ValueRef};
use crate::types::LogicalType;
use ahash::RandomState;
Expand All @@ -9,22 +10,24 @@ use std::sync::Arc;

pub struct SumAccumulator {
result: DataValue,
evaluator: BinaryEvaluatorBox,
}

impl SumAccumulator {
pub fn new(ty: &LogicalType) -> Self {
pub fn new(ty: &LogicalType) -> Result<Self, DatabaseError> {
assert!(ty.is_numeric());

Self {
Ok(Self {
result: DataValue::init(ty),
}
evaluator: EvaluatorFactory::binary_create(*ty, BinaryOperator::Plus)?,
})
}
}

impl Accumulator for SumAccumulator {
fn update_value(&mut self, value: &ValueRef) -> Result<(), DatabaseError> {
if !value.is_null() {
self.result = DataValue::binary_op(&self.result, value, &BinaryOperator::Plus)?;
self.result = self.evaluator.0.binary_eval(&self.result, value);
}

Ok(())
Expand All @@ -41,11 +44,11 @@ pub struct DistinctSumAccumulator {
}

impl DistinctSumAccumulator {
pub fn new(ty: &LogicalType) -> Self {
Self {
pub fn new(ty: &LogicalType) -> Result<Self, DatabaseError> {
Ok(Self {
distinct_values: HashSet::default(),
inner: SumAccumulator::new(ty),
}
inner: SumAccumulator::new(ty)?,
})
}
}

Expand Down
5 changes: 4 additions & 1 deletion src/execution/volcano/dql/join/nested_loop_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,8 @@ mod test {
use crate::planner::operator::Operator;
use crate::storage::kipdb::KipStorage;
use crate::storage::Storage;
use crate::types::evaluator::int32::Int32GtBinaryEvaluator;
use crate::types::evaluator::BinaryEvaluatorBox;
use crate::types::value::DataValue;
use crate::types::LogicalType;
use std::collections::HashSet;
Expand Down Expand Up @@ -441,7 +443,8 @@ mod test {
true,
desc.clone(),
)))),
ty: LogicalType::Integer,
evaluator: Some(BinaryEvaluatorBox(Arc::new(Int32GtBinaryEvaluator))),
ty: LogicalType::Boolean,
};

(on_keys, values_t1, values_t2, filter)
Expand Down
Loading

0 comments on commit 50d3e6d

Please sign in to comment.