Skip to content

Commit

Permalink
Optimize expression mapping (#137)
Browse files Browse the repository at this point in the history
* fix: duplicate ColumnIds between multiple tables may cause constant relationship extraction errors

* feat: added `NormalizationRuleImpl::ColumnRemapper` to optimize expression mapping and avoid polling search

* style: remove useless fields
  • Loading branch information
KKould authored Feb 12, 2024
1 parent 86aaf60 commit b6a3726
Show file tree
Hide file tree
Showing 20 changed files with 454 additions and 112 deletions.
4 changes: 3 additions & 1 deletion src/binder/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
) -> LogicalPlan {
self.context.step(QueryBindStep::Agg);

AggregateOperator::build(children, agg_calls, groupby_exprs)
AggregateOperator::build(children, agg_calls, groupby_exprs, false)
}

pub fn extract_select_aggregate(
Expand Down Expand Up @@ -137,6 +137,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
}
ScalarExpression::Constant(_) | ScalarExpression::ColumnRef { .. } => (),
ScalarExpression::Empty => unreachable!(),
ScalarExpression::Reference { .. } => unreachable!(),
}

Ok(())
Expand Down Expand Up @@ -310,6 +311,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
}
ScalarExpression::Constant(_) => Ok(()),
ScalarExpression::Empty => unreachable!(),
ScalarExpression::Reference { .. } => unreachable!(),
}
}
}
2 changes: 1 addition & 1 deletion src/binder/distinct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@ impl<'a, T: Transaction> Binder<'a, T> {
) -> LogicalPlan {
self.context.step(QueryBindStep::Distinct);

AggregateOperator::build(children, vec![], select_list)
AggregateOperator::build(children, vec![], select_list, true)
}
}
1 change: 1 addition & 0 deletions src/catalog/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ impl TableCatalog {
let index = IndexMeta {
id: index_id,
column_ids,
table_name: self.name.clone(),
name,
is_unique,
is_primary,
Expand Down
5 changes: 5 additions & 0 deletions src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,11 @@ impl<S: Storage> Database<S> {
NormalizationRuleImpl::EliminateLimits,
],
)
.batch(
"Expression Remapper".to_string(),
HepBatchStrategy::once_topdown(),
vec![NormalizationRuleImpl::ExpressionRemapper],
)
.implementations(vec![
// DQL
ImplementationRuleImpl::SimpleAggregate,
Expand Down
2 changes: 2 additions & 0 deletions src/execution/volcano/dql/aggregate/hash_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ impl From<(AggregateOperator, LogicalPlan)> for HashAggExecutor {
AggregateOperator {
agg_calls,
groupby_exprs,
..
},
input,
): (AggregateOperator, LogicalPlan),
Expand Down Expand Up @@ -197,6 +198,7 @@ mod test {
args: vec![ScalarExpression::ColumnRef(t1_columns[1].clone())],
ty: LogicalType::Integer,
}],
is_distinct: false,
};

let input = LogicalPlan {
Expand Down
4 changes: 2 additions & 2 deletions src/execution/volcano/dql/join/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub struct HashJoin {

impl From<(JoinOperator, LogicalPlan, LogicalPlan)> for HashJoin {
fn from(
(JoinOperator { on, join_type }, left_input, right_input): (
(JoinOperator { on, join_type, .. }, left_input, right_input): (
JoinOperator,
LogicalPlan,
LogicalPlan,
Expand Down Expand Up @@ -180,7 +180,7 @@ impl HashJoinStatus {
&filter,
join_tuples.is_empty() || matches!(ty, JoinType::Full | JoinType::Cross),
) {
let mut filter_tuples = Vec::with_capacity(join_tuples.len());
let mut filter_tuples = Vec::new();

for mut tuple in join_tuples {
if let DataValue::Boolean(option) = expr.eval(&tuple)?.as_ref() {
Expand Down
34 changes: 14 additions & 20 deletions src/expression/evaluator.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use crate::catalog::ColumnSummary;
use crate::errors::DatabaseError;
use crate::expression::value_compute::{binary_op, unary_op};
use crate::expression::{AliasType, ScalarExpression};
Expand Down Expand Up @@ -29,14 +28,14 @@ macro_rules! eval_to_num {

impl ScalarExpression {
pub fn eval(&self, tuple: &Tuple) -> Result<ValueRef, DatabaseError> {
if let Some(value) = Self::eval_with_summary(tuple, self.output_column().summary()) {
return Ok(value.clone());
}

match &self {
match self {
ScalarExpression::Constant(val) => Ok(val.clone()),
ScalarExpression::ColumnRef(col) => {
let value = Self::eval_with_summary(tuple, col.summary())
let value = tuple
.schema_ref
.iter()
.find_position(|tul_col| tul_col.summary() == col.summary())
.map(|(i, _)| &tuple.values[i])
.unwrap_or(&NULL_VALUE)
.clone();

Expand Down Expand Up @@ -116,11 +115,7 @@ impl ScalarExpression {
Ok(Arc::new(unary_op(&value, op)?))
}
ScalarExpression::AggCall { .. } => {
let value = Self::eval_with_summary(tuple, self.output_column().summary())
.unwrap_or(&NULL_VALUE)
.clone();

Ok(value)
unreachable!("must use `NormalizationRuleImpl::ExpressionRemapper`")
}
ScalarExpression::Between {
expr,
Expand Down Expand Up @@ -166,15 +161,14 @@ impl ScalarExpression {
Ok(Arc::new(DataValue::Utf8(None)))
}
}
ScalarExpression::Reference { pos, .. } => {
return Ok(tuple
.values
.get(*pos)
.unwrap_or_else(|| &NULL_VALUE)
.clone());
}
ScalarExpression::Empty => unreachable!(),
}
}

fn eval_with_summary<'a>(tuple: &'a Tuple, summary: &ColumnSummary) -> Option<&'a ValueRef> {
tuple
.schema_ref
.iter()
.find_position(|tul_col| tul_col.summary() == summary)
.map(|(i, _)| &tuple.values[i])
}
}
99 changes: 93 additions & 6 deletions src/expression/mod.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use itertools::Itertools;
use serde::{Deserialize, Serialize};
use std::fmt;
use std::fmt::{Debug, Formatter};
use std::sync::Arc;
use std::{fmt, mem};

use sqlparser::ast::{BinaryOperator as SqlBinaryOperator, UnaryOperator as SqlUnaryOperator};

Expand Down Expand Up @@ -77,6 +77,10 @@ pub enum ScalarExpression {
},
// Temporary expression used for expression substitution
Empty,
Reference {
expr: Box<ScalarExpression>,
pos: usize,
},
}

impl ScalarExpression {
Expand All @@ -87,6 +91,85 @@ impl ScalarExpression {
self
}
}
pub fn unpack_reference(&self) -> &ScalarExpression {
if let ScalarExpression::Reference { expr, .. } = self {
expr.unpack_reference()
} else {
self
}
}

pub fn try_reference(&mut self, output_exprs: &[ScalarExpression]) {
if let Some((pos, _)) = output_exprs
.iter()
.find_position(|expr| self.output_name() == expr.output_name())
{
let expr = Box::new(mem::replace(self, ScalarExpression::Empty));
*self = ScalarExpression::Reference { expr, pos };
return;
}

match self {
ScalarExpression::Alias { expr, .. } => {
expr.try_reference(output_exprs);
}
ScalarExpression::TypeCast { expr, .. } => {
expr.try_reference(output_exprs);
}
ScalarExpression::IsNull { expr, .. } => {
expr.try_reference(output_exprs);
}
ScalarExpression::Unary { expr, .. } => {
expr.try_reference(output_exprs);
}
ScalarExpression::Binary {
left_expr,
right_expr,
..
} => {
left_expr.try_reference(output_exprs);
right_expr.try_reference(output_exprs);
}
ScalarExpression::AggCall { args, .. } => {
for arg in args {
arg.try_reference(output_exprs);
}
}
ScalarExpression::In { expr, args, .. } => {
expr.try_reference(output_exprs);
for arg in args {
arg.try_reference(output_exprs);
}
}
ScalarExpression::Between {
expr,
left_expr,
right_expr,
..
} => {
expr.try_reference(output_exprs);
left_expr.try_reference(output_exprs);
right_expr.try_reference(output_exprs);
}
ScalarExpression::SubString {
expr,
for_expr,
from_expr,
} => {
expr.try_reference(output_exprs);
if let Some(expr) = for_expr {
expr.try_reference(output_exprs);
}
if let Some(expr) = from_expr {
expr.try_reference(output_exprs);
}
}
ScalarExpression::Empty => unreachable!(),
ScalarExpression::Constant(_)
| ScalarExpression::ColumnRef(_)
| ScalarExpression::Reference { .. } => (),
}
}

pub fn has_count_star(&self) -> bool {
match self {
Expand Down Expand Up @@ -124,7 +207,9 @@ impl ScalarExpression {
LogicalType::Boolean
}
Self::SubString { .. } => LogicalType::Varchar(None),
Self::Alias { expr, .. } => expr.return_type(),
Self::Alias { expr, .. } | ScalarExpression::Reference { expr, .. } => {
expr.return_type()
}
ScalarExpression::Empty => unreachable!(),
}
}
Expand Down Expand Up @@ -193,7 +278,8 @@ impl ScalarExpression {
columns_collect(from_expr, vec, only_column_ref);
}
}
ScalarExpression::Constant(_) | ScalarExpression::Empty => (),
ScalarExpression::Constant(_) => (),
ScalarExpression::Reference { .. } | ScalarExpression::Empty => unreachable!(),
}
}
let mut exprs = Vec::new();
Expand Down Expand Up @@ -241,7 +327,7 @@ impl ScalarExpression {
Some(true)
)
}
ScalarExpression::Empty => unreachable!(),
ScalarExpression::Reference { .. } | ScalarExpression::Empty => unreachable!(),
}
}

Expand All @@ -261,7 +347,7 @@ impl ScalarExpression {
}
},
ScalarExpression::TypeCast { expr, ty } => {
format!("CAST({} as {})", expr.output_name(), ty)
format!("cast ({} as {})", expr.output_name(), ty)
}
ScalarExpression::IsNull { expr, negated } => {
let suffix = if *negated { "is not null" } else { "is null" };
Expand Down Expand Up @@ -289,7 +375,7 @@ impl ScalarExpression {
let args_str = args.iter().map(|expr| expr.output_name()).join(", ");
let op = |allow_distinct, distinct| {
if allow_distinct && distinct {
"DISTINCT "
"distinct "
} else {
""
}
Expand Down Expand Up @@ -344,6 +430,7 @@ impl ScalarExpression {
op("for", for_expr),
)
}
ScalarExpression::Reference { expr, .. } => expr.output_name(),
ScalarExpression::Empty => unreachable!(),
}
}
Expand Down
Loading

0 comments on commit b6a3726

Please sign in to comment.