diff --git a/src/binder/aggregate.rs b/src/binder/aggregate.rs index 1483fe40..b68eefac 100644 --- a/src/binder/aggregate.rs +++ b/src/binder/aggregate.rs @@ -22,7 +22,7 @@ impl<'a, T: Transaction> Binder<'a, T> { ) -> LogicalPlan { self.context.step(QueryBindStep::Agg); - AggregateOperator::build(children, agg_calls, groupby_exprs) + AggregateOperator::build(children, agg_calls, groupby_exprs, false) } pub fn extract_select_aggregate( @@ -137,6 +137,7 @@ impl<'a, T: Transaction> Binder<'a, T> { } ScalarExpression::Constant(_) | ScalarExpression::ColumnRef { .. } => (), ScalarExpression::Empty => unreachable!(), + ScalarExpression::Reference { .. } => unreachable!(), } Ok(()) @@ -310,6 +311,7 @@ impl<'a, T: Transaction> Binder<'a, T> { } ScalarExpression::Constant(_) => Ok(()), ScalarExpression::Empty => unreachable!(), + ScalarExpression::Reference { .. } => unreachable!(), } } } diff --git a/src/binder/distinct.rs b/src/binder/distinct.rs index fbdff1bf..8495e8f8 100644 --- a/src/binder/distinct.rs +++ b/src/binder/distinct.rs @@ -12,6 +12,6 @@ impl<'a, T: Transaction> Binder<'a, T> { ) -> LogicalPlan { self.context.step(QueryBindStep::Distinct); - AggregateOperator::build(children, vec![], select_list) + AggregateOperator::build(children, vec![], select_list, true) } } diff --git a/src/catalog/table.rs b/src/catalog/table.rs index 19bd7d68..293f3792 100644 --- a/src/catalog/table.rs +++ b/src/catalog/table.rs @@ -120,6 +120,7 @@ impl TableCatalog { let index = IndexMeta { id: index_id, column_ids, + table_name: self.name.clone(), name, is_unique, is_primary, diff --git a/src/db.rs b/src/db.rs index d5691d02..06908ea0 100644 --- a/src/db.rs +++ b/src/db.rs @@ -172,6 +172,11 @@ impl Database { NormalizationRuleImpl::EliminateLimits, ], ) + .batch( + "Expression Remapper".to_string(), + HepBatchStrategy::once_topdown(), + vec![NormalizationRuleImpl::ExpressionRemapper], + ) .implementations(vec![ // DQL ImplementationRuleImpl::SimpleAggregate, diff --git a/src/execution/volcano/dql/aggregate/hash_agg.rs b/src/execution/volcano/dql/aggregate/hash_agg.rs index 15b351c4..76d397cd 100644 --- a/src/execution/volcano/dql/aggregate/hash_agg.rs +++ b/src/execution/volcano/dql/aggregate/hash_agg.rs @@ -26,6 +26,7 @@ impl From<(AggregateOperator, LogicalPlan)> for HashAggExecutor { AggregateOperator { agg_calls, groupby_exprs, + .. }, input, ): (AggregateOperator, LogicalPlan), @@ -197,6 +198,7 @@ mod test { args: vec![ScalarExpression::ColumnRef(t1_columns[1].clone())], ty: LogicalType::Integer, }], + is_distinct: false, }; let input = LogicalPlan { diff --git a/src/execution/volcano/dql/join/hash_join.rs b/src/execution/volcano/dql/join/hash_join.rs index d552f014..58d51339 100644 --- a/src/execution/volcano/dql/join/hash_join.rs +++ b/src/execution/volcano/dql/join/hash_join.rs @@ -23,7 +23,7 @@ pub struct HashJoin { impl From<(JoinOperator, LogicalPlan, LogicalPlan)> for HashJoin { fn from( - (JoinOperator { on, join_type }, left_input, right_input): ( + (JoinOperator { on, join_type, .. }, left_input, right_input): ( JoinOperator, LogicalPlan, LogicalPlan, @@ -180,7 +180,7 @@ impl HashJoinStatus { &filter, join_tuples.is_empty() || matches!(ty, JoinType::Full | JoinType::Cross), ) { - let mut filter_tuples = Vec::with_capacity(join_tuples.len()); + let mut filter_tuples = Vec::new(); for mut tuple in join_tuples { if let DataValue::Boolean(option) = expr.eval(&tuple)?.as_ref() { diff --git a/src/expression/evaluator.rs b/src/expression/evaluator.rs index 8179c132..b0a49230 100644 --- a/src/expression/evaluator.rs +++ b/src/expression/evaluator.rs @@ -1,4 +1,3 @@ -use crate::catalog::ColumnSummary; use crate::errors::DatabaseError; use crate::expression::value_compute::{binary_op, unary_op}; use crate::expression::{AliasType, ScalarExpression}; @@ -29,14 +28,14 @@ macro_rules! eval_to_num { impl ScalarExpression { pub fn eval(&self, tuple: &Tuple) -> Result { - if let Some(value) = Self::eval_with_summary(tuple, self.output_column().summary()) { - return Ok(value.clone()); - } - - match &self { + match self { ScalarExpression::Constant(val) => Ok(val.clone()), ScalarExpression::ColumnRef(col) => { - let value = Self::eval_with_summary(tuple, col.summary()) + let value = tuple + .schema_ref + .iter() + .find_position(|tul_col| tul_col.summary() == col.summary()) + .map(|(i, _)| &tuple.values[i]) .unwrap_or(&NULL_VALUE) .clone(); @@ -116,11 +115,7 @@ impl ScalarExpression { Ok(Arc::new(unary_op(&value, op)?)) } ScalarExpression::AggCall { .. } => { - let value = Self::eval_with_summary(tuple, self.output_column().summary()) - .unwrap_or(&NULL_VALUE) - .clone(); - - Ok(value) + unreachable!("must use `NormalizationRuleImpl::ExpressionRemapper`") } ScalarExpression::Between { expr, @@ -166,15 +161,14 @@ impl ScalarExpression { Ok(Arc::new(DataValue::Utf8(None))) } } + ScalarExpression::Reference { pos, .. } => { + return Ok(tuple + .values + .get(*pos) + .unwrap_or_else(|| &NULL_VALUE) + .clone()); + } ScalarExpression::Empty => unreachable!(), } } - - fn eval_with_summary<'a>(tuple: &'a Tuple, summary: &ColumnSummary) -> Option<&'a ValueRef> { - tuple - .schema_ref - .iter() - .find_position(|tul_col| tul_col.summary() == summary) - .map(|(i, _)| &tuple.values[i]) - } } diff --git a/src/expression/mod.rs b/src/expression/mod.rs index 7f46b8f1..2fd11a90 100644 --- a/src/expression/mod.rs +++ b/src/expression/mod.rs @@ -1,8 +1,8 @@ use itertools::Itertools; use serde::{Deserialize, Serialize}; -use std::fmt; use std::fmt::{Debug, Formatter}; use std::sync::Arc; +use std::{fmt, mem}; use sqlparser::ast::{BinaryOperator as SqlBinaryOperator, UnaryOperator as SqlUnaryOperator}; @@ -77,6 +77,10 @@ pub enum ScalarExpression { }, // Temporary expression used for expression substitution Empty, + Reference { + expr: Box, + pos: usize, + }, } impl ScalarExpression { @@ -87,6 +91,85 @@ impl ScalarExpression { self } } + pub fn unpack_reference(&self) -> &ScalarExpression { + if let ScalarExpression::Reference { expr, .. } = self { + expr.unpack_reference() + } else { + self + } + } + + pub fn try_reference(&mut self, output_exprs: &[ScalarExpression]) { + if let Some((pos, _)) = output_exprs + .iter() + .find_position(|expr| self.output_name() == expr.output_name()) + { + let expr = Box::new(mem::replace(self, ScalarExpression::Empty)); + *self = ScalarExpression::Reference { expr, pos }; + return; + } + + match self { + ScalarExpression::Alias { expr, .. } => { + expr.try_reference(output_exprs); + } + ScalarExpression::TypeCast { expr, .. } => { + expr.try_reference(output_exprs); + } + ScalarExpression::IsNull { expr, .. } => { + expr.try_reference(output_exprs); + } + ScalarExpression::Unary { expr, .. } => { + expr.try_reference(output_exprs); + } + ScalarExpression::Binary { + left_expr, + right_expr, + .. + } => { + left_expr.try_reference(output_exprs); + right_expr.try_reference(output_exprs); + } + ScalarExpression::AggCall { args, .. } => { + for arg in args { + arg.try_reference(output_exprs); + } + } + ScalarExpression::In { expr, args, .. } => { + expr.try_reference(output_exprs); + for arg in args { + arg.try_reference(output_exprs); + } + } + ScalarExpression::Between { + expr, + left_expr, + right_expr, + .. + } => { + expr.try_reference(output_exprs); + left_expr.try_reference(output_exprs); + right_expr.try_reference(output_exprs); + } + ScalarExpression::SubString { + expr, + for_expr, + from_expr, + } => { + expr.try_reference(output_exprs); + if let Some(expr) = for_expr { + expr.try_reference(output_exprs); + } + if let Some(expr) = from_expr { + expr.try_reference(output_exprs); + } + } + ScalarExpression::Empty => unreachable!(), + ScalarExpression::Constant(_) + | ScalarExpression::ColumnRef(_) + | ScalarExpression::Reference { .. } => (), + } + } pub fn has_count_star(&self) -> bool { match self { @@ -124,7 +207,9 @@ impl ScalarExpression { LogicalType::Boolean } Self::SubString { .. } => LogicalType::Varchar(None), - Self::Alias { expr, .. } => expr.return_type(), + Self::Alias { expr, .. } | ScalarExpression::Reference { expr, .. } => { + expr.return_type() + } ScalarExpression::Empty => unreachable!(), } } @@ -193,7 +278,8 @@ impl ScalarExpression { columns_collect(from_expr, vec, only_column_ref); } } - ScalarExpression::Constant(_) | ScalarExpression::Empty => (), + ScalarExpression::Constant(_) => (), + ScalarExpression::Reference { .. } | ScalarExpression::Empty => unreachable!(), } } let mut exprs = Vec::new(); @@ -241,7 +327,7 @@ impl ScalarExpression { Some(true) ) } - ScalarExpression::Empty => unreachable!(), + ScalarExpression::Reference { .. } | ScalarExpression::Empty => unreachable!(), } } @@ -261,7 +347,7 @@ impl ScalarExpression { } }, ScalarExpression::TypeCast { expr, ty } => { - format!("CAST({} as {})", expr.output_name(), ty) + format!("cast ({} as {})", expr.output_name(), ty) } ScalarExpression::IsNull { expr, negated } => { let suffix = if *negated { "is not null" } else { "is null" }; @@ -289,7 +375,7 @@ impl ScalarExpression { let args_str = args.iter().map(|expr| expr.output_name()).join(", "); let op = |allow_distinct, distinct| { if allow_distinct && distinct { - "DISTINCT " + "distinct " } else { "" } @@ -344,6 +430,7 @@ impl ScalarExpression { op("for", for_expr), ) } + ScalarExpression::Reference { expr, .. } => expr.output_name(), ScalarExpression::Empty => unreachable!(), } } diff --git a/src/expression/simplify.rs b/src/expression/simplify.rs index defb71b1..9cdcaae0 100644 --- a/src/expression/simplify.rs +++ b/src/expression/simplify.rs @@ -409,19 +409,59 @@ struct ReplaceUnary { } impl ScalarExpression { - pub fn exist_column(&self, col_id: &ColumnId) -> bool { + pub fn exist_column(&self, table_name: &str, col_id: &ColumnId) -> bool { match self { - ScalarExpression::ColumnRef(col) => col.id() == Some(*col_id), - ScalarExpression::Alias { expr, .. } => expr.exist_column(col_id), - ScalarExpression::TypeCast { expr, .. } => expr.exist_column(col_id), - ScalarExpression::IsNull { expr, .. } => expr.exist_column(col_id), - ScalarExpression::Unary { expr, .. } => expr.exist_column(col_id), + ScalarExpression::ColumnRef(col) => { + Self::_is_belong(table_name, col) && col.id() == Some(*col_id) + } + ScalarExpression::Alias { expr, .. } => expr.exist_column(table_name, col_id), + ScalarExpression::TypeCast { expr, .. } => expr.exist_column(table_name, col_id), + ScalarExpression::IsNull { expr, .. } => expr.exist_column(table_name, col_id), + ScalarExpression::Unary { expr, .. } => expr.exist_column(table_name, col_id), ScalarExpression::Binary { left_expr, right_expr, .. - } => left_expr.exist_column(col_id) || right_expr.exist_column(col_id), - _ => false, + } => { + left_expr.exist_column(table_name, col_id) + || right_expr.exist_column(table_name, col_id) + } + ScalarExpression::AggCall { args, .. } => args + .iter() + .any(|expr| expr.exist_column(table_name, col_id)), + ScalarExpression::In { expr, args, .. } => { + expr.exist_column(table_name, col_id) + || args + .iter() + .any(|expr| expr.exist_column(table_name, col_id)) + } + ScalarExpression::Between { + expr, + left_expr, + right_expr, + .. + } => { + expr.exist_column(table_name, col_id) + || left_expr.exist_column(table_name, col_id) + || right_expr.exist_column(table_name, col_id) + } + ScalarExpression::SubString { + expr, + for_expr, + from_expr, + } => { + expr.exist_column(table_name, col_id) + || for_expr + .as_ref() + .map(|expr| expr.exist_column(table_name, col_id)) + == Some(true) + || from_expr + .as_ref() + .map(|expr| expr.exist_column(table_name, col_id)) + == Some(true) + } + ScalarExpression::Constant(_) => false, + ScalarExpression::Reference { .. } | ScalarExpression::Empty => unreachable!(), } } @@ -832,7 +872,8 @@ impl ScalarExpression { /// - `ConstantBinary::Or`: Rearrange and sort the range of each OR data pub fn convert_binary( &self, - col_id: &ColumnId, + table_name: &str, + id: &ColumnId, ) -> Result, DatabaseError> { match self { ScalarExpression::Binary { @@ -842,8 +883,8 @@ impl ScalarExpression { .. } => { match ( - left_expr.convert_binary(col_id)?, - right_expr.convert_binary(col_id)?, + left_expr.convert_binary(table_name, id)?, + right_expr.convert_binary(table_name, id)?, ) { (Some(left_binary), Some(right_binary)) => match (left_binary, right_binary) { (ConstantBinary::And(mut left), ConstantBinary::And(mut right)) => match op @@ -902,18 +943,22 @@ impl ScalarExpression { if let (Some(col), Some(val)) = (left_expr.unpack_col(false), right_expr.unpack_val()) { - return Ok(Self::new_binary(col_id, *op, col, val, false)); + return Ok(Self::new_binary(table_name, id, *op, col, val, false)); } if let (Some(val), Some(col)) = (left_expr.unpack_val(), right_expr.unpack_col(false)) { - return Ok(Self::new_binary(col_id, *op, col, val, true)); + return Ok(Self::new_binary(table_name, id, *op, col, val, true)); } Ok(None) } - (Some(binary), None) => Ok(Self::check_or(col_id, right_expr, op, binary)), - (None, Some(binary)) => Ok(Self::check_or(col_id, left_expr, op, binary)), + (Some(binary), None) => { + Ok(Self::check_or(table_name, id, right_expr, op, binary)) + } + (None, Some(binary)) => { + Ok(Self::check_or(table_name, id, left_expr, op, binary)) + } } } ScalarExpression::Alias { expr, .. } @@ -921,16 +966,20 @@ impl ScalarExpression { | ScalarExpression::Unary { expr, .. } | ScalarExpression::In { expr, .. } | ScalarExpression::Between { expr, .. } - | ScalarExpression::SubString { expr, .. } => expr.convert_binary(col_id), + | ScalarExpression::SubString { expr, .. } => expr.convert_binary(table_name, id), ScalarExpression::IsNull { expr, negated, .. } => match expr.as_ref() { ScalarExpression::ColumnRef(column) => { - Ok(column.id().is_some_and(|id| col_id == &id).then(|| { - if *negated { - ConstantBinary::NotEq(NULL_VALUE.clone()) - } else { - ConstantBinary::Eq(NULL_VALUE.clone()) + if let (Some(col_id), Some(col_table)) = (column.id(), column.table_name()) { + if id == &col_id && col_table.as_str() == table_name { + return Ok(Some(if *negated { + ConstantBinary::NotEq(NULL_VALUE.clone()) + } else { + ConstantBinary::Eq(NULL_VALUE.clone()) + })); } - })) + } + + Ok(None) } ScalarExpression::Constant(_) | ScalarExpression::Alias { .. } @@ -941,25 +990,26 @@ impl ScalarExpression { | ScalarExpression::AggCall { .. } | ScalarExpression::In { .. } | ScalarExpression::Between { .. } - | ScalarExpression::SubString { .. } => expr.convert_binary(col_id), - ScalarExpression::Empty => unreachable!(), + | ScalarExpression::SubString { .. } => expr.convert_binary(table_name, id), + ScalarExpression::Reference { .. } | ScalarExpression::Empty => unreachable!(), }, ScalarExpression::Constant(_) | ScalarExpression::ColumnRef(_) | ScalarExpression::AggCall { .. } => Ok(None), - ScalarExpression::Empty => unreachable!(), + ScalarExpression::Reference { .. } | ScalarExpression::Empty => unreachable!(), } } /// check if: c1 > c2 or c1 > 1 /// this case it makes no sense to just extract c1 > 1 fn check_or( + table_name: &str, col_id: &ColumnId, right_expr: &ScalarExpression, op: &BinaryOperator, binary: ConstantBinary, ) -> Option { - if matches!(op, BinaryOperator::Or) && right_expr.exist_column(col_id) { + if matches!(op, BinaryOperator::Or) && right_expr.exist_column(table_name, col_id) { return None; } @@ -967,13 +1017,14 @@ impl ScalarExpression { } fn new_binary( + table_name: &str, col_id: &ColumnId, mut op: BinaryOperator, col: ColumnRef, val: ValueRef, is_flip: bool, ) -> Option { - if col.id() != Some(*col_id) { + if !Self::_is_belong(table_name, &col) || col.id() != Some(*col_id) { return None; } @@ -1009,6 +1060,13 @@ impl ScalarExpression { _ => None, } } + + fn _is_belong(table_name: &str, col: &ColumnRef) -> bool { + matches!( + col.table_name().map(|name| table_name == name.as_str()), + Some(true) + ) + } } impl fmt::Display for ConstantBinary { @@ -1056,7 +1114,7 @@ mod test { summary: ColumnSummary { id: Some(0), name: "c1".to_string(), - table_name: None, + table_name: Some(Arc::new("t1".to_string())), }, nullable: false, desc: ColumnDesc { @@ -1074,7 +1132,7 @@ mod test { right_expr: Box::new(ScalarExpression::ColumnRef(col_1.clone())), ty: LogicalType::Boolean, } - .convert_binary(&0)? + .convert_binary("t1", &0)? .unwrap(); assert_eq!(binary_eq, ConstantBinary::Eq(val_1.clone())); @@ -1085,7 +1143,7 @@ mod test { right_expr: Box::new(ScalarExpression::ColumnRef(col_1.clone())), ty: LogicalType::Boolean, } - .convert_binary(&0)? + .convert_binary("t1", &0)? .unwrap(); assert_eq!(binary_not_eq, ConstantBinary::NotEq(val_1.clone())); @@ -1096,7 +1154,7 @@ mod test { right_expr: Box::new(ScalarExpression::Constant(val_1.clone())), ty: LogicalType::Boolean, } - .convert_binary(&0)? + .convert_binary("t1", &0)? .unwrap(); assert_eq!( @@ -1113,7 +1171,7 @@ mod test { right_expr: Box::new(ScalarExpression::Constant(val_1.clone())), ty: LogicalType::Boolean, } - .convert_binary(&0)? + .convert_binary("t1", &0)? .unwrap(); assert_eq!( @@ -1130,7 +1188,7 @@ mod test { right_expr: Box::new(ScalarExpression::Constant(val_1.clone())), ty: LogicalType::Boolean, } - .convert_binary(&0)? + .convert_binary("t1", &0)? .unwrap(); assert_eq!( @@ -1147,7 +1205,7 @@ mod test { right_expr: Box::new(ScalarExpression::Constant(val_1.clone())), ty: LogicalType::Boolean, } - .convert_binary(&0)? + .convert_binary("t1", &0)? .unwrap(); assert_eq!( diff --git a/src/optimizer/heuristic/graph.rs b/src/optimizer/heuristic/graph.rs index 627547c8..17945eec 100644 --- a/src/optimizer/heuristic/graph.rs +++ b/src/optimizer/heuristic/graph.rs @@ -182,6 +182,13 @@ impl HepGraph { .map(|edge| edge.target()) } + pub fn youngest_child_at(&self, id: HepNodeId) -> Option { + self.graph + .edges(id) + .max_by_key(|edge| edge.weight()) + .map(|edge| edge.target()) + } + pub fn into_plan(mut self, memo: Option<&Memo>) -> Option { self.build_childrens(self.root_index, memo) } diff --git a/src/optimizer/rule/normalization/column_pruning.rs b/src/optimizer/rule/normalization/column_pruning.rs index 4e0f6b46..c052b7c3 100644 --- a/src/optimizer/rule/normalization/column_pruning.rs +++ b/src/optimizer/rule/normalization/column_pruning.rs @@ -1,4 +1,4 @@ -use crate::catalog::{ColumnRef, ColumnSummary}; +use crate::catalog::ColumnSummary; use crate::errors::DatabaseError; use crate::expression::agg::AggKind; use crate::expression::ScalarExpression; @@ -25,8 +25,18 @@ lazy_static! { #[derive(Clone)] pub struct ColumnPruning; +macro_rules! trans_references { + ($columns:expr) => {{ + let mut column_references = HashSet::with_capacity($columns.len()); + for column in $columns { + column_references.insert(column.summary()); + } + column_references + }}; +} + impl ColumnPruning { - fn clear_exprs(column_references: HashSet<&ColumnSummary>, exprs: &mut Vec) { + fn clear_exprs(column_references: &HashSet<&ColumnSummary>, exprs: &mut Vec) { exprs.retain(|expr| { if column_references.contains(expr.output_column().summary()) { return true; @@ -48,7 +58,7 @@ impl ColumnPruning { match operator { Operator::Aggregate(op) => { if !all_referenced { - Self::clear_exprs(column_references, &mut op.agg_calls); + Self::clear_exprs(&column_references, &mut op.agg_calls); if op.agg_calls.is_empty() && op.groupby_exprs.is_empty() { let value = Arc::new(DataValue::Utf8(Some("*".to_string()))); @@ -62,28 +72,28 @@ impl ColumnPruning { }) } } - let op_ref_columns = operator.referenced_columns(false); + let is_distinct = op.is_distinct; + let referenced_columns = operator.referenced_columns(false); + let mut new_column_references = trans_references!(&referenced_columns); + // on distinct + if is_distinct { + for summary in column_references { + new_column_references.insert(summary); + } + } - Self::recollect_apply(op_ref_columns, false, node_id, graph); + Self::recollect_apply(new_column_references, false, node_id, graph); } Operator::Project(op) => { let has_count_star = op.exprs.iter().any(ScalarExpression::has_count_star); if !has_count_star { if !all_referenced { - Self::clear_exprs(column_references, &mut op.exprs); + Self::clear_exprs(&column_references, &mut op.exprs); } - let op_ref_columns = operator.referenced_columns(false); + let referenced_columns = operator.referenced_columns(false); + let new_column_references = trans_references!(&referenced_columns); - Self::recollect_apply(op_ref_columns, false, node_id, graph); - } - } - Operator::Sort(_op) => { - if !all_referenced { - // Todo: Order Project - // https://github.com/duckdb/duckdb/blob/main/src/optimizer/remove_unused_columns.cpp#L174 - } - if let Some(child_id) = graph.eldest_child_at(node_id) { - Self::_apply(column_references, true, child_id, graph); + Self::recollect_apply(new_column_references, false, node_id, graph); } } Operator::Scan(op) => { @@ -92,7 +102,7 @@ impl ColumnPruning { .retain(|(_, column)| column_references.contains(column.summary())); } } - Operator::Limit(_) | Operator::Join(_) | Operator::Filter(_) => { + Operator::Sort(_) | Operator::Limit(_) | Operator::Join(_) | Operator::Filter(_) => { let temp_columns = operator.referenced_columns(false); // why? let mut column_references = column_references; @@ -119,10 +129,11 @@ impl ColumnPruning { | Operator::Update(_) | Operator::Delete(_) | Operator::Analyze(_) => { - let op_ref_columns = operator.referenced_columns(false); + let referenced_columns = operator.referenced_columns(false); + let new_column_references = trans_references!(&referenced_columns); if let Some(child_id) = graph.eldest_child_at(node_id) { - Self::recollect_apply(op_ref_columns, true, child_id, graph); + Self::recollect_apply(new_column_references, true, child_id, graph); } else { unreachable!(); } @@ -141,18 +152,15 @@ impl ColumnPruning { } fn recollect_apply( - referenced_columns: Vec, + referenced_columns: HashSet<&ColumnSummary>, all_referenced: bool, node_id: HepNodeId, graph: &mut HepGraph, ) { for child_id in graph.children_at(node_id).collect_vec() { - let new_references: HashSet<&ColumnSummary> = referenced_columns - .iter() - .map(|column| column.summary()) - .collect(); + let copy_references: HashSet<&ColumnSummary> = referenced_columns.clone(); - Self::_apply(new_references, all_referenced, child_id, graph); + Self::_apply(copy_references, all_referenced, child_id, graph); } } } diff --git a/src/optimizer/rule/normalization/expression_remapper.rs b/src/optimizer/rule/normalization/expression_remapper.rs new file mode 100644 index 00000000..6d97730a --- /dev/null +++ b/src/optimizer/rule/normalization/expression_remapper.rs @@ -0,0 +1,119 @@ +use crate::errors::DatabaseError; +use crate::expression::ScalarExpression; +use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate}; +use crate::optimizer::core::rule::{MatchPattern, NormalizationRule}; +use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId}; +use crate::planner::operator::join::JoinCondition; +use crate::planner::operator::Operator; +use lazy_static::lazy_static; + +lazy_static! { + static ref EXPRESSION_REMAPPER_RULE: Pattern = { + Pattern { + predicate: |_| true, + children: PatternChildrenPredicate::None, + } + }; +} + +#[derive(Clone)] +pub struct ExpressionRemapper; + +impl ExpressionRemapper { + fn _apply( + output_exprs: &mut Vec, + node_id: HepNodeId, + graph: &mut HepGraph, + ) -> Result<(), DatabaseError> { + if let Some(child_id) = graph.eldest_child_at(node_id) { + Self::_apply(output_exprs, child_id, graph)?; + } + // for join + let mut left_len = 0; + if let Operator::Join(_) = graph.operator(node_id) { + let mut second_output_exprs = Vec::new(); + if let Some(child_id) = graph.youngest_child_at(node_id) { + Self::_apply(&mut second_output_exprs, child_id, graph)?; + } + left_len = output_exprs.len(); + output_exprs.append(&mut second_output_exprs); + } + let operator = graph.operator_mut(node_id); + + match operator { + Operator::Join(op) => { + match &mut op.on { + JoinCondition::On { on, filter } => { + for (left_expr, right_expr) in on { + left_expr.try_reference(&output_exprs[0..left_len]); + right_expr.try_reference(&output_exprs[left_len..]); + } + if let Some(expr) = filter { + expr.try_reference(output_exprs); + } + } + JoinCondition::None => {} + } + + return Ok(()); + } + Operator::Aggregate(op) => { + for expr in op.agg_calls.iter_mut().chain(op.groupby_exprs.iter_mut()) { + expr.try_reference(output_exprs); + } + } + Operator::Filter(op) => { + op.predicate.try_reference(output_exprs); + } + Operator::Project(op) => { + for expr in op.exprs.iter_mut() { + expr.try_reference(output_exprs); + } + } + Operator::Sort(op) => { + for sort_field in op.sort_fields.iter_mut() { + sort_field.expr.try_reference(output_exprs); + } + } + Operator::Dummy + | Operator::Scan(_) + | Operator::Limit(_) + | Operator::Values(_) + | Operator::Show + | Operator::Explain + | Operator::Describe(_) + | Operator::Insert(_) + | Operator::Update(_) + | Operator::Delete(_) + | Operator::Analyze(_) + | Operator::AddColumn(_) + | Operator::DropColumn(_) + | Operator::CreateTable(_) + | Operator::DropTable(_) + | Operator::Truncate(_) + | Operator::CopyFromFile(_) + | Operator::CopyToFile(_) => (), + } + if let Some(exprs) = operator.output_exprs() { + *output_exprs = exprs; + } + + Ok(()) + } +} + +impl MatchPattern for ExpressionRemapper { + fn pattern(&self) -> &Pattern { + &EXPRESSION_REMAPPER_RULE + } +} + +impl NormalizationRule for ExpressionRemapper { + fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), DatabaseError> { + Self::_apply(&mut Vec::new(), node_id, graph)?; + // mark changed to skip this rule batch + graph.version += 1; + + Ok(()) + } +} diff --git a/src/optimizer/rule/normalization/mod.rs b/src/optimizer/rule/normalization/mod.rs index 1d7b895d..f83155af 100644 --- a/src/optimizer/rule/normalization/mod.rs +++ b/src/optimizer/rule/normalization/mod.rs @@ -7,6 +7,7 @@ use crate::optimizer::rule::normalization::column_pruning::ColumnPruning; use crate::optimizer::rule::normalization::combine_operators::{ CollapseGroupByAgg, CollapseProject, CombineFilter, }; +use crate::optimizer::rule::normalization::expression_remapper::ExpressionRemapper; use crate::optimizer::rule::normalization::pushdown_limit::{ EliminateLimits, LimitProjectTranspose, PushLimitIntoScan, PushLimitThroughJoin, }; @@ -17,6 +18,7 @@ use crate::optimizer::rule::normalization::simplification::SimplifyFilter; mod column_pruning; mod combine_operators; +mod expression_remapper; mod pushdown_limit; mod pushdown_predicates; mod simplification; @@ -40,6 +42,8 @@ pub enum NormalizationRuleImpl { // Simplification SimplifyFilter, ConstantCalculation, + // ColumnRemapper + ExpressionRemapper, } impl MatchPattern for NormalizationRuleImpl { @@ -57,6 +61,7 @@ impl MatchPattern for NormalizationRuleImpl { NormalizationRuleImpl::PushPredicateIntoScan => PushPredicateIntoScan.pattern(), NormalizationRuleImpl::SimplifyFilter => SimplifyFilter.pattern(), NormalizationRuleImpl::ConstantCalculation => ConstantCalculation.pattern(), + NormalizationRuleImpl::ExpressionRemapper => ExpressionRemapper.pattern(), } } } @@ -86,6 +91,7 @@ impl NormalizationRule for NormalizationRuleImpl { PushPredicateIntoScan.apply(node_id, graph) } NormalizationRuleImpl::ConstantCalculation => ConstantCalculation.apply(node_id, graph), + NormalizationRuleImpl::ExpressionRemapper => ExpressionRemapper.apply(node_id, graph), } } } diff --git a/src/optimizer/rule/normalization/pushdown_predicates.rs b/src/optimizer/rule/normalization/pushdown_predicates.rs index 1ecafd26..3dfe6cd2 100644 --- a/src/optimizer/rule/normalization/pushdown_predicates.rs +++ b/src/optimizer/rule/normalization/pushdown_predicates.rs @@ -215,7 +215,9 @@ impl NormalizationRule for PushPredicateIntoScan { if let Operator::Scan(child_op) = graph.operator_mut(child_id) { //FIXME: now only support unique for IndexInfo { meta, binaries } in &mut child_op.index_infos { - let mut option = op.predicate.convert_binary(&meta.column_ids[0])?; + let mut option = op + .predicate + .convert_binary(meta.table_name.as_str(), &meta.column_ids[0])?; if let Some(mut binary) = option.take() { binary.scope_aggregation()?; diff --git a/src/optimizer/rule/normalization/simplification.rs b/src/optimizer/rule/normalization/simplification.rs index b3b4bac9..08e4dc69 100644 --- a/src/optimizer/rule/normalization/simplification.rs +++ b/src/optimizer/rule/normalization/simplification.rs @@ -157,7 +157,7 @@ mod test { unreachable!(); } if let Operator::Filter(filter_op) = best_plan.childrens[0].clone().operator { - let column_binary = filter_op.predicate.convert_binary(&0).unwrap(); + let column_binary = filter_op.predicate.convert_binary("t1", &0).unwrap(); let final_binary = ConstantBinary::Scope { min: Bound::Unbounded, max: Bound::Excluded(Arc::new(DataValue::Int32(Some(-2)))), @@ -207,10 +207,10 @@ mod test { if let Operator::Filter(filter_op) = best_plan.childrens[0].clone().operator { println!( "{expr}: {:#?}", - filter_op.predicate.convert_binary(&0).unwrap() + filter_op.predicate.convert_binary("t1", &0).unwrap() ); - Ok(filter_op.predicate.convert_binary(&0).unwrap()) + Ok(filter_op.predicate.convert_binary("t1", &0).unwrap()) } else { Ok(None) } @@ -341,7 +341,7 @@ mod test { let op_3 = op(plan_3, "-c1 > 1 and c2 + 1 > 1")?.unwrap(); let op_4 = op(plan_4, "c1 + 1 > 1 and -c2 > 1")?.unwrap(); - let cb_1_c1 = op_1.predicate.convert_binary(&0).unwrap(); + let cb_1_c1 = op_1.predicate.convert_binary("t1", &0).unwrap(); println!("op_1 => c1: {:#?}", cb_1_c1); assert_eq!( cb_1_c1, @@ -351,7 +351,7 @@ mod test { }) ); - let cb_1_c2 = op_1.predicate.convert_binary(&1).unwrap(); + let cb_1_c2 = op_1.predicate.convert_binary("t1", &1).unwrap(); println!("op_1 => c2: {:#?}", cb_1_c2); assert_eq!( cb_1_c2, @@ -361,7 +361,7 @@ mod test { }) ); - let cb_2_c1 = op_2.predicate.convert_binary(&0).unwrap(); + let cb_2_c1 = op_2.predicate.convert_binary("t1", &0).unwrap(); println!("op_2 => c1: {:#?}", cb_2_c1); assert_eq!( cb_2_c1, @@ -371,7 +371,7 @@ mod test { }) ); - let cb_2_c2 = op_2.predicate.convert_binary(&1).unwrap(); + let cb_2_c2 = op_2.predicate.convert_binary("t1", &1).unwrap(); println!("op_2 => c2: {:#?}", cb_2_c2); assert_eq!( cb_1_c1, @@ -381,7 +381,7 @@ mod test { }) ); - let cb_3_c1 = op_3.predicate.convert_binary(&0).unwrap(); + let cb_3_c1 = op_3.predicate.convert_binary("t1", &0).unwrap(); println!("op_3 => c1: {:#?}", cb_3_c1); assert_eq!( cb_3_c1, @@ -391,7 +391,7 @@ mod test { }) ); - let cb_3_c2 = op_3.predicate.convert_binary(&1).unwrap(); + let cb_3_c2 = op_3.predicate.convert_binary("t1", &1).unwrap(); println!("op_3 => c2: {:#?}", cb_3_c2); assert_eq!( cb_3_c2, @@ -401,7 +401,7 @@ mod test { }) ); - let cb_4_c1 = op_4.predicate.convert_binary(&0).unwrap(); + let cb_4_c1 = op_4.predicate.convert_binary("t1", &0).unwrap(); println!("op_4 => c1: {:#?}", cb_4_c1); assert_eq!( cb_4_c1, @@ -411,7 +411,7 @@ mod test { }) ); - let cb_4_c2 = op_4.predicate.convert_binary(&1).unwrap(); + let cb_4_c2 = op_4.predicate.convert_binary("t1", &1).unwrap(); println!("op_4 => c2: {:#?}", cb_4_c2); assert_eq!( cb_4_c2, @@ -448,7 +448,7 @@ mod test { let op_1 = op(plan_1, "c1 > c2 or c1 > 1")?.unwrap(); - let cb_1_c1 = op_1.predicate.convert_binary(&0).unwrap(); + let cb_1_c1 = op_1.predicate.convert_binary("t1", &0).unwrap(); println!("op_1 => c1: {:#?}", cb_1_c1); assert_eq!(cb_1_c1, None); @@ -479,7 +479,7 @@ mod test { let op_1 = op(plan_1, "c1 = 4 and c2 > c1 or c1 > 1")?.unwrap(); - let cb_1_c1 = op_1.predicate.convert_binary(&0).unwrap(); + let cb_1_c1 = op_1.predicate.convert_binary("t1", &0).unwrap(); println!("op_1 => c1: {:#?}", cb_1_c1); assert_eq!( cb_1_c1, @@ -518,7 +518,7 @@ mod test { let op_1 = op(plan_1, "c1 is null")?.unwrap(); - let cb_1_c1 = op_1.predicate.convert_binary(&0).unwrap(); + let cb_1_c1 = op_1.predicate.convert_binary("t1", &0).unwrap(); println!("op_1 => c1: {:#?}", cb_1_c1); assert_eq!(cb_1_c1, Some(ConstantBinary::Eq(Arc::new(DataValue::Null)))); @@ -548,7 +548,7 @@ mod test { let op_1 = op(plan_1, "c1 is not null")?.unwrap(); - let cb_1_c1 = op_1.predicate.convert_binary(&0).unwrap(); + let cb_1_c1 = op_1.predicate.convert_binary("t1", &0).unwrap(); println!("op_1 => c1: {:#?}", cb_1_c1); assert_eq!( cb_1_c1, @@ -581,7 +581,7 @@ mod test { let op_1 = op(plan_1, "c1 in (1, 2, 3)")?.unwrap(); - let cb_1_c1 = op_1.predicate.convert_binary(&0).unwrap(); + let cb_1_c1 = op_1.predicate.convert_binary("t1", &0).unwrap(); println!("op_1 => c1: {:#?}", cb_1_c1); assert_eq!( cb_1_c1, @@ -618,7 +618,7 @@ mod test { let op_1 = op(plan_1, "c1 not in (1, 2, 3)")?.unwrap(); - let cb_1_c1 = op_1.predicate.convert_binary(&0).unwrap(); + let cb_1_c1 = op_1.predicate.convert_binary("t1", &0).unwrap(); println!("op_1 => c1: {:#?}", cb_1_c1); assert_eq!( cb_1_c1, diff --git a/src/planner/operator/aggregate.rs b/src/planner/operator/aggregate.rs index 0ca6a44a..04f30744 100644 --- a/src/planner/operator/aggregate.rs +++ b/src/planner/operator/aggregate.rs @@ -8,6 +8,7 @@ use std::fmt::Formatter; pub struct AggregateOperator { pub groupby_exprs: Vec, pub agg_calls: Vec, + pub is_distinct: bool, } impl AggregateOperator { @@ -15,11 +16,13 @@ impl AggregateOperator { children: LogicalPlan, agg_calls: Vec, groupby_exprs: Vec, + is_distinct: bool, ) -> LogicalPlan { LogicalPlan::new( Operator::Aggregate(Self { groupby_exprs, agg_calls, + is_distinct, }), vec![children], ) diff --git a/src/planner/operator/mod.rs b/src/planner/operator/mod.rs index 99ccf76b..03855fde 100644 --- a/src/planner/operator/mod.rs +++ b/src/planner/operator/mod.rs @@ -19,6 +19,7 @@ pub mod update; pub mod values; use crate::catalog::ColumnRef; +use crate::expression::ScalarExpression; use crate::planner::operator::alter_table::drop_column::DropColumnOperator; use crate::planner::operator::analyze::AnalyzeOperator; use crate::planner::operator::copy_from_file::CopyFromFileOperator; @@ -103,6 +104,50 @@ pub enum PhysicalOption { } impl Operator { + pub fn output_exprs(&self) -> Option> { + match self { + Operator::Dummy => None, + Operator::Aggregate(op) => Some( + op.agg_calls + .iter() + .chain(op.groupby_exprs.iter()) + .cloned() + .collect_vec(), + ), + Operator::Filter(_) | Operator::Join(_) => None, + Operator::Project(op) => Some(op.exprs.clone()), + Operator::Scan(op) => Some( + op.columns + .iter() + .cloned() + .map(|(_, column)| ScalarExpression::ColumnRef(column)) + .collect_vec(), + ), + Operator::Sort(_) | Operator::Limit(_) => None, + Operator::Values(op) => Some( + op.schema_ref + .iter() + .cloned() + .map(ScalarExpression::ColumnRef) + .collect_vec(), + ), + Operator::Show + | Operator::Explain + | Operator::Describe(_) + | Operator::Insert(_) + | Operator::Update(_) + | Operator::Delete(_) + | Operator::Analyze(_) + | Operator::AddColumn(_) + | Operator::DropColumn(_) + | Operator::CreateTable(_) + | Operator::DropTable(_) + | Operator::Truncate(_) + | Operator::CopyFromFile(_) + | Operator::CopyToFile(_) => None, + } + } + pub fn referenced_columns(&self, only_column_ref: bool) -> Vec { match self { Operator::Aggregate(op) => op diff --git a/src/storage/kip.rs b/src/storage/kip.rs index dec0acf1..fb099b56 100644 --- a/src/storage/kip.rs +++ b/src/storage/kip.rs @@ -624,10 +624,8 @@ mod test { .await?; let transaction = fnck_sql.storage.transaction().await?; - let table = transaction - .table(Arc::new("t1".to_string())) - .unwrap() - .clone(); + let table_name = Arc::new("t1".to_string()); + let table = transaction.table(table_name.clone()).unwrap().clone(); let tuple_ids = vec![ Arc::new(DataValue::Int32(Some(0))), Arc::new(DataValue::Int32(Some(2))), @@ -641,6 +639,7 @@ mod test { index_meta: Arc::new(IndexMeta { id: 0, column_ids: vec![0], + table_name, name: "pk_a".to_string(), is_unique: false, is_primary: true, diff --git a/src/storage/table_codec.rs b/src/storage/table_codec.rs index aef2d4d3..0e377f41 100644 --- a/src/storage/table_codec.rs +++ b/src/storage/table_codec.rs @@ -353,6 +353,7 @@ mod tests { let index_meta = IndexMeta { id: 0, column_ids: vec![0], + table_name: Arc::new("T1".to_string()), name: "index_1".to_string(), is_unique: false, is_primary: false, @@ -447,6 +448,7 @@ mod tests { let index_meta = IndexMeta { id: index_id as u32, column_ids: vec![], + table_name: Arc::new(table_name.to_string()), name: "".to_string(), is_unique: false, is_primary: false, diff --git a/src/types/index.rs b/src/types/index.rs index 29dc70c8..db87b598 100644 --- a/src/types/index.rs +++ b/src/types/index.rs @@ -1,3 +1,4 @@ +use crate::catalog::TableName; use crate::expression::simplify::ConstantBinary; use crate::types::value::ValueRef; use crate::types::ColumnId; @@ -20,6 +21,7 @@ pub struct IndexInfo { pub struct IndexMeta { pub id: IndexId, pub column_ids: Vec, + pub table_name: TableName, pub name: String, pub is_unique: bool, pub is_primary: bool,