From d96b2c3b3e3e12ccba7665e3b05ed5450db361ab Mon Sep 17 00:00:00 2001 From: Kould <2435992353@qq.com> Date: Sat, 10 Feb 2024 23:29:43 +0800 Subject: [PATCH 1/3] feat: support `Subquery` on Where --- src/bin/server.rs | 2 +- src/binder/aggregate.rs | 8 +- src/binder/alter_table.rs | 18 +- src/binder/analyze.rs | 14 +- src/binder/copy.rs | 22 +- src/binder/create_table.rs | 9 +- src/binder/delete.rs | 15 +- src/binder/describe.rs | 9 +- src/binder/distinct.rs | 4 +- src/binder/drop_table.rs | 9 +- src/binder/explain.rs | 6 +- src/binder/expr.rs | 15 ++ src/binder/insert.rs | 46 ++-- src/binder/mod.rs | 34 +++ src/binder/select.rs | 225 ++++++++++++------ src/binder/show.rs | 7 +- src/binder/truncate.rs | 10 +- src/binder/update.rs | 11 +- src/catalog/table.rs | 43 ++-- .../codegen/dql/aggregate/simple_agg.rs | 2 +- src/execution/codegen/mod.rs | 2 +- src/execution/volcano/ddl/add_column.rs | 4 +- src/execution/volcano/ddl/drop_column.rs | 6 +- src/execution/volcano/dml/analyze.rs | 6 +- src/execution/volcano/dml/copy_from_file.rs | 7 +- src/execution/volcano/dml/delete.rs | 4 +- src/execution/volcano/dml/insert.rs | 23 +- src/execution/volcano/dml/update.rs | 6 +- .../volcano/dql/aggregate/hash_agg.rs | 5 +- .../volcano/dql/aggregate/simple_agg.rs | 2 +- src/execution/volcano/dql/describe.rs | 4 +- src/execution/volcano/dql/dummy.rs | 2 +- src/execution/volcano/dql/explain.rs | 4 +- src/execution/volcano/dql/join/hash_join.rs | 22 +- src/execution/volcano/dql/projection.rs | 6 +- src/execution/volcano/dql/show_table.rs | 4 +- src/execution/volcano/dql/values.rs | 6 +- src/execution/volcano/mod.rs | 2 + src/expression/evaluator.rs | 5 +- src/expression/mod.rs | 7 +- src/expression/simplify.rs | 25 +- src/marcos/mod.rs | 4 +- src/optimizer/core/opt_expr.rs | 6 +- src/optimizer/heuristic/graph.rs | 1 + src/optimizer/heuristic/matcher.rs | 4 + src/planner/mod.rs | 72 ++++++ src/planner/operator/aggregate.rs | 9 +- src/planner/operator/copy_from_file.rs | 6 +- src/planner/operator/filter.rs | 9 +- src/planner/operator/join.rs | 27 ++- src/planner/operator/limit.rs | 9 +- src/planner/operator/mod.rs | 2 +- src/planner/operator/scan.rs | 13 +- src/planner/operator/values.rs | 6 +- src/storage/kip.rs | 23 +- src/storage/mod.rs | 8 +- src/storage/table_codec.rs | 15 +- src/types/tuple.rs | 26 +- src/types/tuple_builder.rs | 36 ++- tests/slt/join.slt | 14 ++ tests/sqllogictest/src/lib.rs | 2 +- 61 files changed, 583 insertions(+), 370 deletions(-) diff --git a/src/bin/server.rs b/src/bin/server.rs index cafef1fd..8d315ee8 100644 --- a/src/bin/server.rs +++ b/src/bin/server.rs @@ -167,7 +167,7 @@ fn encode_tuples<'a>(tuples: Vec) -> PgWireResult> { let mut results = Vec::with_capacity(tuples.len()); let schema = Arc::new( tuples[0] - .columns + .schema_ref .iter() .map(|column| { let pg_type = into_pg_type(column.datatype())?; diff --git a/src/binder/aggregate.rs b/src/binder/aggregate.rs index 22bdec6a..1483fe40 100644 --- a/src/binder/aggregate.rs +++ b/src/binder/aggregate.rs @@ -11,7 +11,7 @@ use crate::{ planner::operator::{aggregate::AggregateOperator, sort::SortField}, }; -use super::Binder; +use super::{Binder, QueryBindStep}; impl<'a, T: Transaction> Binder<'a, T> { pub fn bind_aggregate( @@ -20,6 +20,8 @@ impl<'a, T: Transaction> Binder<'a, T> { agg_calls: Vec, groupby_exprs: Vec, ) -> LogicalPlan { + self.context.step(QueryBindStep::Agg); + AggregateOperator::build(children, agg_calls, groupby_exprs) } @@ -133,7 +135,8 @@ impl<'a, T: Transaction> Binder<'a, T> { self.visit_column_agg_expr(expr)?; } } - ScalarExpression::Constant(_) | ScalarExpression::ColumnRef { .. } => {} + ScalarExpression::Constant(_) | ScalarExpression::ColumnRef { .. } => (), + ScalarExpression::Empty => unreachable!(), } Ok(()) @@ -306,6 +309,7 @@ impl<'a, T: Transaction> Binder<'a, T> { Ok(()) } ScalarExpression::Constant(_) => Ok(()), + ScalarExpression::Empty => unreachable!(), } } } diff --git a/src/binder/alter_table.rs b/src/binder/alter_table.rs index 40e6cf5e..27edd6f7 100644 --- a/src/binder/alter_table.rs +++ b/src/binder/alter_table.rs @@ -35,15 +35,14 @@ impl<'a, T: Transaction> Binder<'a, T> { "illegal column naming".to_string(), )); } - LogicalPlan { - operator: Operator::AddColumn(AddColumnOperator { + LogicalPlan::new( + Operator::AddColumn(AddColumnOperator { table_name, if_not_exists: *if_not_exists, column, }), - childrens: vec![plan], - physical_option: None, - } + vec![plan], + ) } AlterTableOperation::DropColumn { column_name, @@ -53,15 +52,14 @@ impl<'a, T: Transaction> Binder<'a, T> { let plan = ScanOperator::build(table_name.clone(), table); let column_name = column_name.value.clone(); - LogicalPlan { - operator: Operator::DropColumn(DropColumnOperator { + LogicalPlan::new( + Operator::DropColumn(DropColumnOperator { table_name, if_exists: *if_exists, column_name, }), - childrens: vec![plan], - physical_option: None, - } + vec![plan], + ) } AlterTableOperation::DropPrimaryKey => todo!(), AlterTableOperation::RenameColumn { diff --git a/src/binder/analyze.rs b/src/binder/analyze.rs index e7022a23..5bf5e4ed 100644 --- a/src/binder/analyze.rs +++ b/src/binder/analyze.rs @@ -15,19 +15,17 @@ impl<'a, T: Transaction> Binder<'a, T> { let table_catalog = self.context.table_and_bind(table_name.clone(), None)?; let columns = table_catalog - .columns_with_id() - .filter_map(|(_, column)| column.desc.is_index().then_some(column.clone())) + .columns() + .filter_map(|column| column.desc.is_index().then_some(column.clone())) .collect_vec(); let scan_op = ScanOperator::build(table_name.clone(), table_catalog); - let plan = LogicalPlan { - operator: Operator::Analyze(AnalyzeOperator { + Ok(LogicalPlan::new( + Operator::Analyze(AnalyzeOperator { table_name, columns, }), - childrens: vec![scan_op], - physical_option: None, - }; - Ok(plan) + vec![scan_op], + )) } } diff --git a/src/binder/copy.rs b/src/binder/copy.rs index fcab372f..8a85b08f 100644 --- a/src/binder/copy.rs +++ b/src/binder/copy.rs @@ -72,7 +72,7 @@ impl<'a, T: Transaction> Binder<'a, T> { }; if let Some(table) = self.context.table(Arc::new(table_name.to_string())) { - let columns = table.clone_columns(); + let schema_ref = table.schema_ref().clone(); let ext_source = ExtSource { path: match target { CopyTarget::File { filename } => filename.into(), @@ -83,22 +83,20 @@ impl<'a, T: Transaction> Binder<'a, T> { if to { // COPY TO - Ok(LogicalPlan { - operator: Operator::CopyToFile(CopyToFileOperator { source: ext_source }), - childrens: vec![], - physical_option: None, - }) + Ok(LogicalPlan::new( + Operator::CopyToFile(CopyToFileOperator { source: ext_source }), + vec![], + )) } else { // COPY FROM - Ok(LogicalPlan { - operator: Operator::CopyFromFile(CopyFromFileOperator { + Ok(LogicalPlan::new( + Operator::CopyFromFile(CopyFromFileOperator { source: ext_source, - columns, + schema_ref, table: table_name.to_string(), }), - childrens: vec![], - physical_option: None, - }) + vec![], + )) } } else { Err(DatabaseError::InvalidTable(format!( diff --git a/src/binder/create_table.rs b/src/binder/create_table.rs index 09c34531..83e9fae6 100644 --- a/src/binder/create_table.rs +++ b/src/binder/create_table.rs @@ -80,15 +80,14 @@ impl<'a, T: Transaction> Binder<'a, T> { )); } - let plan = LogicalPlan { - operator: Operator::CreateTable(CreateTableOperator { + let plan = LogicalPlan::new( + Operator::CreateTable(CreateTableOperator { table_name, columns, if_not_exists, }), - childrens: vec![], - physical_option: None, - }; + vec![], + ); Ok(plan) } diff --git a/src/binder/delete.rs b/src/binder/delete.rs index 0b1d7d56..c5f81b85 100644 --- a/src/binder/delete.rs +++ b/src/binder/delete.rs @@ -19,9 +19,9 @@ impl<'a, T: Transaction> Binder<'a, T> { let table_catalog = self.context.table_and_bind(table_name.clone(), None)?; let primary_key_column = table_catalog - .columns_with_id() - .find(|(_, column)| column.desc.is_primary) - .map(|(_, column)| Arc::clone(column)) + .columns() + .find(|column| column.desc.is_primary) + .cloned() .unwrap(); let mut plan = ScanOperator::build(table_name.clone(), table_catalog); @@ -34,14 +34,13 @@ impl<'a, T: Transaction> Binder<'a, T> { plan = self.bind_where(plan, predicate)?; } - Ok(LogicalPlan { - operator: Operator::Delete(DeleteOperator { + Ok(LogicalPlan::new( + Operator::Delete(DeleteOperator { table_name, primary_key_column, }), - childrens: vec![plan], - physical_option: None, - }) + vec![plan], + )) } else { unreachable!("only table") } diff --git a/src/binder/describe.rs b/src/binder/describe.rs index 27525702..e7655c71 100644 --- a/src/binder/describe.rs +++ b/src/binder/describe.rs @@ -14,10 +14,9 @@ impl<'a, T: Transaction> Binder<'a, T> { ) -> Result { let table_name = Arc::new(lower_case_name(name)?); - Ok(LogicalPlan { - operator: Operator::Describe(DescribeOperator { table_name }), - childrens: vec![], - physical_option: None, - }) + Ok(LogicalPlan::new( + Operator::Describe(DescribeOperator { table_name }), + vec![], + )) } } diff --git a/src/binder/distinct.rs b/src/binder/distinct.rs index fa184821..fbdff1bf 100644 --- a/src/binder/distinct.rs +++ b/src/binder/distinct.rs @@ -1,4 +1,4 @@ -use crate::binder::Binder; +use crate::binder::{Binder, QueryBindStep}; use crate::expression::ScalarExpression; use crate::planner::operator::aggregate::AggregateOperator; use crate::planner::LogicalPlan; @@ -10,6 +10,8 @@ impl<'a, T: Transaction> Binder<'a, T> { children: LogicalPlan, select_list: Vec, ) -> LogicalPlan { + self.context.step(QueryBindStep::Distinct); + AggregateOperator::build(children, vec![], select_list) } } diff --git a/src/binder/drop_table.rs b/src/binder/drop_table.rs index 95321c2f..79fd9800 100644 --- a/src/binder/drop_table.rs +++ b/src/binder/drop_table.rs @@ -15,14 +15,13 @@ impl<'a, T: Transaction> Binder<'a, T> { ) -> Result { let table_name = Arc::new(lower_case_name(name)?); - let plan = LogicalPlan { - operator: Operator::DropTable(DropTableOperator { + let plan = LogicalPlan::new( + Operator::DropTable(DropTableOperator { table_name, if_exists: *if_exists, }), - childrens: vec![], - physical_option: None, - }; + vec![], + ); Ok(plan) } } diff --git a/src/binder/explain.rs b/src/binder/explain.rs index d08eaa7d..d5df278f 100644 --- a/src/binder/explain.rs +++ b/src/binder/explain.rs @@ -6,10 +6,6 @@ use crate::storage::Transaction; impl<'a, T: Transaction> Binder<'a, T> { pub(crate) fn bind_explain(&mut self, plan: LogicalPlan) -> Result { - Ok(LogicalPlan { - operator: Operator::Explain, - childrens: vec![plan], - physical_option: None, - }) + Ok(LogicalPlan::new(Operator::Explain, vec![plan])) } } diff --git a/src/binder/expr.rs b/src/binder/expr.rs index 90755729..dfb4c722 100644 --- a/src/binder/expr.rs +++ b/src/binder/expr.rs @@ -89,6 +89,21 @@ impl<'a, T: Transaction> Binder<'a, T> { from_expr, }) } + Expr::Subquery(query) => { + let mut sub_query = self.bind_query(query)?; + let sub_query_schema = sub_query.out_schmea(); + + if sub_query_schema.len() > 1 { + return Err(DatabaseError::MisMatch( + "expects only one expression to be returned".to_string(), + "the expression returned by the subquery".to_string(), + )); + } + let expr = ScalarExpression::ColumnRef(sub_query_schema[0].clone()); + self.context.sub_query(sub_query); + + Ok(expr) + } _ => { todo!() } diff --git a/src/binder/insert.rs b/src/binder/insert.rs index 363f05b9..2944875f 100644 --- a/src/binder/insert.rs +++ b/src/binder/insert.rs @@ -1,5 +1,4 @@ use crate::binder::{lower_case_name, Binder}; -use crate::catalog::ColumnRef; use crate::errors::DatabaseError; use crate::expression::value_compute::unary_op; use crate::expression::ScalarExpression; @@ -8,6 +7,7 @@ use crate::planner::operator::values::ValuesOperator; use crate::planner::operator::Operator; use crate::planner::LogicalPlan; use crate::storage::Transaction; +use crate::types::tuple::SchemaRef; use crate::types::value::{DataValue, ValueRef}; use sqlparser::ast::{Expr, Ident, ObjectName}; use std::slice; @@ -24,15 +24,20 @@ impl<'a, T: Transaction> Binder<'a, T> { let table_name = Arc::new(lower_case_name(name)?); if let Some(table) = self.context.table(table_name.clone()) { - let mut columns = Vec::new(); + let mut _schema_ref = None; let values_len = expr_rows[0].len(); if idents.is_empty() { - columns = table.clone_columns(); - if values_len > columns.len() { - return Err(DatabaseError::ValuesLenMismatch(columns.len(), values_len)); + let temp_schema_ref = table.schema_ref().clone(); + if values_len > temp_schema_ref.len() { + return Err(DatabaseError::ValuesLenMismatch( + temp_schema_ref.len(), + values_len, + )); } + _schema_ref = Some(temp_schema_ref); } else { + let mut columns = Vec::with_capacity(idents.len()); for ident in idents { match self.bind_column_ref_from_identifiers( slice::from_ref(ident), @@ -45,7 +50,9 @@ impl<'a, T: Transaction> Binder<'a, T> { if values_len != columns.len() { return Err(DatabaseError::ValuesLenMismatch(columns.len(), values_len)); } + _schema_ref = Some(Arc::new(columns)); } + let schema_ref = _schema_ref.ok_or(DatabaseError::ColumnsEmpty)?; let mut rows = Vec::with_capacity(expr_rows.len()); for expr_row in expr_rows { if expr_row.len() != values_len { @@ -57,14 +64,15 @@ impl<'a, T: Transaction> Binder<'a, T> { match &self.bind_expr(expr)? { ScalarExpression::Constant(value) => { // Check if the value length is too long - value.check_len(columns[i].datatype())?; - let cast_value = DataValue::clone(value).cast(columns[i].datatype())?; + value.check_len(schema_ref[i].datatype())?; + let cast_value = + DataValue::clone(value).cast(schema_ref[i].datatype())?; row.push(Arc::new(cast_value)) } ScalarExpression::Unary { expr, op, .. } => { if let ScalarExpression::Constant(value) = expr.as_ref() { row.push(Arc::new( - unary_op(value, op)?.cast(columns[i].datatype())?, + unary_op(value, op)?.cast(schema_ref[i].datatype())?, )) } else { unreachable!() @@ -76,16 +84,15 @@ impl<'a, T: Transaction> Binder<'a, T> { rows.push(row); } - let values_plan = self.bind_values(rows, columns); + let values_plan = self.bind_values(rows, schema_ref); - Ok(LogicalPlan { - operator: Operator::Insert(InsertOperator { + Ok(LogicalPlan::new( + Operator::Insert(InsertOperator { table_name, is_overwrite, }), - childrens: vec![values_plan], - physical_option: None, - }) + vec![values_plan], + )) } else { Err(DatabaseError::InvalidTable(format!( "not found table {}", @@ -97,12 +104,11 @@ impl<'a, T: Transaction> Binder<'a, T> { pub(crate) fn bind_values( &mut self, rows: Vec>, - columns: Vec, + schema_ref: SchemaRef, ) -> LogicalPlan { - LogicalPlan { - operator: Operator::Values(ValuesOperator { rows, columns }), - childrens: vec![], - physical_option: None, - } + LogicalPlan::new( + Operator::Values(ValuesOperator { rows, schema_ref }), + vec![], + ) } } diff --git a/src/binder/mod.rs b/src/binder/mod.rs index ebee9b89..88c0487f 100644 --- a/src/binder/mod.rs +++ b/src/binder/mod.rs @@ -30,6 +30,20 @@ pub enum InputRefType { GroupBy, } +// Tips: only query now! +#[derive(Debug, Copy, Clone, Hash, Eq, PartialEq)] +pub enum QueryBindStep { + From, + Join, + Where, + Agg, + Having, + Distinct, + Sort, + Project, + Limit, +} + #[derive(Clone)] pub struct BinderContext<'a, T: Transaction> { transaction: &'a T, @@ -40,6 +54,9 @@ pub struct BinderContext<'a, T: Transaction> { // agg group_by_exprs: Vec, pub(crate) agg_calls: Vec, + + bind_step: QueryBindStep, + sub_queries: HashMap>, } impl<'a, T: Transaction> BinderContext<'a, T> { @@ -51,9 +68,26 @@ impl<'a, T: Transaction> BinderContext<'a, T> { table_aliases: Default::default(), group_by_exprs: vec![], agg_calls: Default::default(), + bind_step: QueryBindStep::From, + sub_queries: Default::default(), } } + pub fn step(&mut self, bind_step: QueryBindStep) { + self.bind_step = bind_step; + } + + pub fn sub_query(&mut self, sub_query: LogicalPlan) { + self.sub_queries + .entry(self.bind_step) + .or_default() + .push(sub_query) + } + + pub fn sub_query_for_now(&mut self) -> Option> { + self.sub_queries.remove(&self.bind_step) + } + pub fn table(&self, table_name: TableName) -> Option<&TableCatalog> { if let Some(real_name) = self.table_aliases.get(table_name.as_ref()) { self.transaction.table(real_name.clone()) diff --git a/src/binder/select.rs b/src/binder/select.rs index 56ee0738..8578c6a8 100644 --- a/src/binder/select.rs +++ b/src/binder/select.rs @@ -14,9 +14,9 @@ use crate::{ types::value::DataValue, }; -use super::{lower_case_name, lower_ident, Binder}; +use super::{lower_case_name, lower_ident, Binder, QueryBindStep}; -use crate::catalog::{ColumnCatalog, TableCatalog, TableName}; +use crate::catalog::{ColumnCatalog, TableName}; use crate::errors::DatabaseError; use crate::execution::volcano::dql::join::joins_nullable; use crate::expression::BinaryOperator; @@ -24,9 +24,9 @@ use crate::planner::operator::join::JoinCondition; use crate::planner::operator::sort::{SortField, SortOperator}; use crate::planner::LogicalPlan; use crate::storage::Transaction; +use crate::types::tuple::Schema; use crate::types::LogicalType; use itertools::Itertools; -use sqlparser::ast; use sqlparser::ast::{ Distinct, Expr, Ident, Join, JoinConstraint, JoinOperator, Offset, OrderByExpr, Query, Select, SelectItem, SetExpr, TableAlias, TableFactor, TableWithJoins, @@ -66,12 +66,10 @@ impl<'a, T: Transaction> Binder<'a, T> { let mut select_list = self.normalize_select_item(&select.projection)?; - self.extract_select_join(&mut select_list); - if let Some(predicate) = &select.selection { plan = self.bind_where(plan, predicate)?; } - + self.extract_select_join(&mut select_list); self.extract_select_aggregate(&mut select_list)?; if !select.group_by.is_empty() { @@ -113,13 +111,11 @@ impl<'a, T: Transaction> Binder<'a, T> { &mut self, from: &[TableWithJoins], ) -> Result { + self.context.step(QueryBindStep::From); + assert!(from.len() < 2, "not support yet."); if from.is_empty() { - return Ok(LogicalPlan { - operator: Operator::Dummy, - childrens: vec![], - physical_option: None, - }); + return Ok(LogicalPlan::new(Operator::Dummy, vec![])); } let TableWithJoins { relation, joins } = &from[0]; @@ -190,20 +186,24 @@ impl<'a, T: Transaction> Binder<'a, T> { table_name: TableName, ) -> Result<(), DatabaseError> { if !alias_column.is_empty() { - let aliases = alias_column.iter().map(lower_ident).collect_vec(); let table = self .context .table(table_name.clone()) .ok_or(DatabaseError::TableNotFound)?; - if aliases.len() != table.columns_len() { + if alias_column.len() != table.columns_len() { return Err(DatabaseError::MisMatch( "Alias".to_string(), "Columns".to_string(), )); } - let columns = table.clone_columns(); - for (alias, column) in aliases.into_iter().zip(columns.into_iter()) { + let aliases_with_columns = alias_column + .iter() + .map(lower_ident) + .zip(table.columns().cloned()) + .collect_vec(); + + for (alias, column) in aliases_with_columns { self.context .add_alias(alias, ScalarExpression::ColumnRef(column)); } @@ -295,8 +295,8 @@ impl<'a, T: Transaction> Binder<'a, T> { }) .map(|(alias, expr)| (expr, alias)) .collect(); - for (_, col) in table.columns_with_id() { - let mut expr = ScalarExpression::ColumnRef(col.clone()); + for column in table.columns() { + let mut expr = ScalarExpression::ColumnRef(column.clone()); if let Some(alias_expr) = alias_map.get(&expr) { expr = ScalarExpression::Alias { @@ -333,12 +333,14 @@ impl<'a, T: Transaction> Binder<'a, T> { let left_table = self .context - .table(left_table.clone()) + .table(left_table) + .map(|table| table.schema_ref()) .cloned() .ok_or(DatabaseError::TableNotFound)?; let right_table = self .context - .table(right_table.clone()) + .table(right_table) + .map(|table| table.schema_ref()) .cloned() .ok_or(DatabaseError::TableNotFound)?; @@ -352,14 +354,51 @@ impl<'a, T: Transaction> Binder<'a, T> { pub(crate) fn bind_where( &mut self, - children: LogicalPlan, + mut children: LogicalPlan, predicate: &Expr, ) -> Result { - Ok(FilterOperator::build( - self.bind_expr(predicate)?, - children, - false, - )) + self.context.step(QueryBindStep::Where); + + let predicate = self.bind_expr(predicate)?; + println!("{}", predicate); + + if let Some(sub_queries) = self.context.sub_query_for_now() { + for mut sub_query in sub_queries { + let mut on_keys: Vec<(ScalarExpression, ScalarExpression)> = vec![]; + let mut filter = vec![]; + + Self::extract_join_keys( + predicate.clone(), + &mut on_keys, + &mut filter, + children.out_schmea(), + sub_query.out_schmea(), + )?; + + // combine multiple filter exprs into one BinaryExpr + let join_filter = + filter + .into_iter() + .reduce(|acc, expr| ScalarExpression::Binary { + op: BinaryOperator::And, + left_expr: Box::new(acc), + right_expr: Box::new(expr), + ty: LogicalType::Boolean, + }); + + children = LJoinOperator::build( + children, + sub_query, + JoinCondition::On { + on: on_keys, + filter: join_filter, + }, + JoinType::Inner, + ); + } + return Ok(children); + } + Ok(FilterOperator::build(predicate, children, false)) } fn bind_having( @@ -367,6 +406,8 @@ impl<'a, T: Transaction> Binder<'a, T> { children: LogicalPlan, having: ScalarExpression, ) -> Result { + self.context.step(QueryBindStep::Having); + self.validate_having_orderby(&having)?; Ok(FilterOperator::build(having, children, true)) } @@ -376,22 +417,24 @@ impl<'a, T: Transaction> Binder<'a, T> { children: LogicalPlan, select_list: Vec, ) -> Result { - Ok(LogicalPlan { - operator: Operator::Project(ProjectOperator { exprs: select_list }), - childrens: vec![children], - physical_option: None, - }) + self.context.step(QueryBindStep::Project); + + Ok(LogicalPlan::new( + Operator::Project(ProjectOperator { exprs: select_list }), + vec![children], + )) } fn bind_sort(&mut self, children: LogicalPlan, sort_fields: Vec) -> LogicalPlan { - LogicalPlan { - operator: Operator::Sort(SortOperator { + self.context.step(QueryBindStep::Sort); + + LogicalPlan::new( + Operator::Sort(SortOperator { sort_fields, limit: None, }), - childrens: vec![children], - physical_option: None, - } + vec![children], + ) } fn bind_limit( @@ -400,6 +443,8 @@ impl<'a, T: Transaction> Binder<'a, T> { limit_expr: &Option, offset_expr: &Option, ) -> Result { + self.context.step(QueryBindStep::Limit); + let mut limit = None; let mut offset = None; if let Some(expr) = limit_expr { @@ -480,8 +525,8 @@ impl<'a, T: Transaction> Binder<'a, T> { fn bind_join_constraint( &mut self, - left_table: &TableCatalog, - right_table: &TableCatalog, + left_schema: &Schema, + right_schema: &Schema, constraint: &JoinConstraint, ) -> Result { match constraint { @@ -490,8 +535,15 @@ impl<'a, T: Transaction> Binder<'a, T> { let mut on_keys: Vec<(ScalarExpression, ScalarExpression)> = vec![]; // expression that didn't match equi-join pattern let mut filter = vec![]; + let expr = self.bind_expr(expr)?; - self.extract_join_keys(expr, &mut on_keys, &mut filter, left_table, right_table)?; + Self::extract_join_keys( + expr, + &mut on_keys, + &mut filter, + left_schema, + right_schema, + )?; // combine multiple filter exprs into one BinaryExpr let join_filter = filter @@ -524,70 +576,89 @@ impl<'a, T: Transaction> Binder<'a, T> { /// foo = bar AND baz > 1 => accum=[(foo, bar)] accum_filter=[baz > 1] /// ``` fn extract_join_keys( - &mut self, - expr: &Expr, + expr: ScalarExpression, accum: &mut Vec<(ScalarExpression, ScalarExpression)>, accum_filter: &mut Vec, - left_schema: &TableCatalog, - right_schema: &TableCatalog, + left_schema: &Schema, + right_schema: &Schema, ) -> Result<(), DatabaseError> { match expr { - Expr::BinaryOp { left, op, right } => match op { - ast::BinaryOperator::Eq => { - let left = self.bind_expr(left)?; - let right = self.bind_expr(right)?; - - match (&left, &right) { + ScalarExpression::Binary { + left_expr, + right_expr, + op, + ty, + } => match op { + BinaryOperator::Eq => { + let fn_contains = |schema: &Schema, name: &str| { + schema.iter().any(|column| column.name() == name) + }; + + match (left_expr.as_ref(), right_expr.as_ref()) { // example: foo = bar (ScalarExpression::ColumnRef(l), ScalarExpression::ColumnRef(r)) => { // reorder left and right joins keys to pattern: (left, right) - if left_schema.contains_column(l.name()) - && right_schema.contains_column(r.name()) + if fn_contains(left_schema, l.name()) + && fn_contains(right_schema, r.name()) { - accum.push((left, right)); - } else if left_schema.contains_column(r.name()) - && right_schema.contains_column(l.name()) + accum.push((*left_expr, *right_expr)); + } else if fn_contains(left_schema, r.name()) + && fn_contains(right_schema, l.name()) { - accum.push((right, left)); + accum.push((*right_expr, *left_expr)); } else { - accum_filter.push(self.bind_expr(expr)?); + accum_filter.push(ScalarExpression::Binary { + left_expr, + right_expr, + op, + ty, + }); } } // example: baz = 1 _other => { - accum_filter.push(self.bind_expr(expr)?); + accum_filter.push(ScalarExpression::Binary { + left_expr, + right_expr, + op, + ty, + }); } } } - ast::BinaryOperator::And => { + BinaryOperator::And => { // example: foo = bar AND baz > 1 - if let Expr::BinaryOp { left, op: _, right } = expr { - self.extract_join_keys( - left, - accum, - accum_filter, - left_schema, - right_schema, - )?; - self.extract_join_keys( - right, - accum, - accum_filter, - left_schema, - right_schema, - )?; - } + Self::extract_join_keys( + *left_expr, + accum, + accum_filter, + left_schema, + right_schema, + )?; + Self::extract_join_keys( + *right_expr, + accum, + accum_filter, + left_schema, + right_schema, + )?; } - _other => { + _ => { // example: baz > 1 - accum_filter.push(self.bind_expr(expr)?); + accum_filter.push(ScalarExpression::Binary { + left_expr, + right_expr, + op, + ty, + }); } }, - _other => { + _ => { // example: baz in (xxx), something else will convert to filter logic - accum_filter.push(self.bind_expr(expr)?); + accum_filter.push(expr); } } + Ok(()) } } diff --git a/src/binder/show.rs b/src/binder/show.rs index d13de6c9..fe35421c 100644 --- a/src/binder/show.rs +++ b/src/binder/show.rs @@ -6,11 +6,6 @@ use crate::storage::Transaction; impl<'a, T: Transaction> Binder<'a, T> { pub(crate) fn bind_show_tables(&mut self) -> Result { - let plan = LogicalPlan { - operator: Operator::Show, - childrens: vec![], - physical_option: None, - }; - Ok(plan) + Ok(LogicalPlan::new(Operator::Show, vec![])) } } diff --git a/src/binder/truncate.rs b/src/binder/truncate.rs index 787351ef..312b6199 100644 --- a/src/binder/truncate.rs +++ b/src/binder/truncate.rs @@ -14,11 +14,9 @@ impl<'a, T: Transaction> Binder<'a, T> { ) -> Result { let table_name = Arc::new(lower_case_name(name)?); - let plan = LogicalPlan { - operator: Operator::Truncate(TruncateOperator { table_name }), - childrens: vec![], - physical_option: None, - }; - Ok(plan) + Ok(LogicalPlan::new( + Operator::Truncate(TruncateOperator { table_name }), + vec![], + )) } } diff --git a/src/binder/update.rs b/src/binder/update.rs index 373fa1b1..3ff925ca 100644 --- a/src/binder/update.rs +++ b/src/binder/update.rs @@ -50,13 +50,12 @@ impl<'a, T: Transaction> Binder<'a, T> { } } - let values_plan = self.bind_values(vec![row], columns); + let values_plan = self.bind_values(vec![row], Arc::new(columns)); - Ok(LogicalPlan { - operator: Operator::Update(UpdateOperator { table_name }), - childrens: vec![plan, values_plan], - physical_option: None, - }) + Ok(LogicalPlan::new( + Operator::Update(UpdateOperator { table_name }), + vec![plan, values_plan], + )) } else { unreachable!("only table") } diff --git a/src/catalog/table.rs b/src/catalog/table.rs index d2b7fc74..19bd7d68 100644 --- a/src/catalog/table.rs +++ b/src/catalog/table.rs @@ -1,12 +1,13 @@ use itertools::Itertools; use serde::{Deserialize, Serialize}; -use std::collections::btree_map::Iter; use std::collections::BTreeMap; use std::sync::Arc; +use std::{slice, vec}; use crate::catalog::{ColumnCatalog, ColumnRef}; use crate::errors::DatabaseError; use crate::types::index::{IndexMeta, IndexMetaRef}; +use crate::types::tuple::SchemaRef; use crate::types::{ColumnId, LogicalType}; pub type TableName = Arc; @@ -15,9 +16,11 @@ pub type TableName = Arc; pub struct TableCatalog { pub(crate) name: TableName, /// Mapping from column names to column ids - column_idxs: BTreeMap, - pub(crate) columns: BTreeMap, + column_idxs: BTreeMap, + columns: BTreeMap, pub(crate) indexes: Vec, + + schema_ref: SchemaRef, } #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] @@ -35,29 +38,30 @@ impl TableCatalog { #[allow(dead_code)] pub(crate) fn get_column_by_id(&self, id: &ColumnId) -> Option<&ColumnRef> { - self.columns.get(id) + self.columns.get(id).map(|i| &self.schema_ref[*i]) } #[allow(dead_code)] pub(crate) fn get_column_id_by_name(&self, name: &str) -> Option { - self.column_idxs.get(name).cloned() + self.column_idxs.get(name).map(|(id, _)| id).cloned() } pub(crate) fn get_column_by_name(&self, name: &str) -> Option<&ColumnRef> { - let id = self.column_idxs.get(name)?; - self.columns.get(id) + self.column_idxs + .get(name) + .map(|(_, i)| &self.schema_ref[*i]) } pub(crate) fn contains_column(&self, name: &str) -> bool { self.column_idxs.contains_key(name) } - pub(crate) fn clone_columns(&self) -> Vec { - self.columns.values().cloned().collect() + pub(crate) fn columns(&self) -> slice::Iter<'_, ColumnRef> { + self.schema_ref.iter() } - pub(crate) fn columns_with_id(&self) -> Iter<'_, ColumnId, ColumnRef> { - self.columns.iter() + pub(crate) fn schema_ref(&self) -> &SchemaRef { + &self.schema_ref } pub(crate) fn columns_len(&self) -> usize { @@ -65,16 +69,15 @@ impl TableCatalog { } pub(crate) fn primary_key(&self) -> Result<(usize, &ColumnRef), DatabaseError> { - self.columns - .values() + self.schema_ref + .iter() .enumerate() .find(|(_, column)| column.desc.is_primary) .ok_or(DatabaseError::PrimaryKeyNotFound) } pub(crate) fn types(&self) -> Vec { - self.columns - .values() + self.columns() .map(|column| *column.datatype()) .collect_vec() } @@ -95,8 +98,13 @@ impl TableCatalog { col.summary.table_name = Some(self.name.clone()); col.summary.id = Some(col_id); - self.column_idxs.insert(col.name().to_string(), col_id); - self.columns.insert(col_id, Arc::new(col)); + self.column_idxs + .insert(col.name().to_string(), (col_id, self.schema_ref.len())); + self.columns.insert(col_id, self.schema_ref.len()); + + let mut schema = Vec::clone(&self.schema_ref); + schema.push(Arc::new(col)); + self.schema_ref = Arc::new(schema); Ok(col_id) } @@ -132,6 +140,7 @@ impl TableCatalog { column_idxs: BTreeMap::new(), columns: BTreeMap::new(), indexes: vec![], + schema_ref: Arc::new(vec![]), }; for col_catalog in columns.into_iter() { let _ = table_catalog.add_column(col_catalog)?; diff --git a/src/execution/codegen/dql/aggregate/simple_agg.rs b/src/execution/codegen/dql/aggregate/simple_agg.rs index 05bc074e..bf23c52e 100644 --- a/src/execution/codegen/dql/aggregate/simple_agg.rs +++ b/src/execution/codegen/dql/aggregate/simple_agg.rs @@ -85,7 +85,7 @@ impl UserData for AggAccumulators { Ok(Tuple { id: None, - columns, + schema_ref: columns, values, }) }); diff --git a/src/execution/codegen/mod.rs b/src/execution/codegen/mod.rs index 2784293c..cf94a3ac 100644 --- a/src/execution/codegen/mod.rs +++ b/src/execution/codegen/mod.rs @@ -46,7 +46,7 @@ impl UserData for Tuple { columns.push(expr.output_column()); } - tuple.columns = columns; + tuple.schema_ref = columns; tuple.values = values; Ok(()) diff --git a/src/execution/volcano/ddl/add_column.rs b/src/execution/volcano/ddl/add_column.rs index 524411ca..342fc6a9 100644 --- a/src/execution/volcano/ddl/add_column.rs +++ b/src/execution/volcano/ddl/add_column.rs @@ -44,7 +44,7 @@ impl AddColumn { let mut tuple: Tuple = tuple?; let tuples_columns = tuple_columns.get_or_insert_with(|| { - let mut columns = Vec::clone(&tuple.columns); + let mut columns = Vec::clone(&tuple.schema_ref); columns.push(Arc::new(column.clone())); Arc::new(columns) @@ -57,7 +57,7 @@ impl AddColumn { } else { tuple.values.push(Arc::new(DataValue::Null)); } - tuple.columns = tuples_columns.clone(); + tuple.schema_ref = tuples_columns.clone(); tuples.push(tuple); } for tuple in tuples { diff --git a/src/execution/volcano/ddl/drop_column.rs b/src/execution/volcano/ddl/drop_column.rs index 5b8cb21b..09778674 100644 --- a/src/execution/volcano/ddl/drop_column.rs +++ b/src/execution/volcano/ddl/drop_column.rs @@ -42,7 +42,7 @@ impl DropColumn { if tuple_columns.is_none() { if let Some((column_index, is_primary)) = tuple - .columns + .schema_ref .iter() .enumerate() .find(|(_, column)| column.name() == column_name) @@ -53,7 +53,7 @@ impl DropColumn { "drop of primary key column is not allowed.".to_owned(), ))?; } - let mut columns = Vec::clone(&tuple.columns); + let mut columns = Vec::clone(&tuple.schema_ref); let _ = columns.remove(column_index); tuple_columns = Some((column_index, Arc::new(columns))); @@ -66,7 +66,7 @@ impl DropColumn { .clone() .ok_or_else(|| DatabaseError::InvalidColumn("not found column".to_string()))?; - tuple.columns = columns; + tuple.schema_ref = columns; let _ = tuple.values.remove(column_i); tuples.push(tuple); diff --git a/src/execution/volcano/dml/analyze.rs b/src/execution/volcano/dml/analyze.rs index fabe88ec..b5bf6031 100644 --- a/src/execution/volcano/dml/analyze.rs +++ b/src/execution/volcano/dml/analyze.rs @@ -67,10 +67,10 @@ impl Analyze { #[for_await] for tuple in build_read(input, transaction) { let Tuple { - columns, values, .. + schema_ref, values, .. } = tuple?; - for (i, column) in columns.iter().enumerate() { + for (i, column) in schema_ref.iter().enumerate() { if !column.desc.is_index() { continue; } @@ -114,7 +114,7 @@ impl Analyze { yield Tuple { id: None, - columns: Arc::new(columns), + schema_ref: Arc::new(columns), values, }; } diff --git a/src/execution/volcano/dml/copy_from_file.rs b/src/execution/volcano/dml/copy_from_file.rs index 0b866038..752ff7bb 100644 --- a/src/execution/volcano/dml/copy_from_file.rs +++ b/src/execution/volcano/dml/copy_from_file.rs @@ -70,8 +70,8 @@ impl CopyFromFile { .from_reader(&mut buf_reader), }; - let column_count = self.op.columns.len(); - let tuple_builder = TupleBuilder::new(self.op.columns.clone()); + let column_count = self.op.schema_ref.len(); + let tuple_builder = TupleBuilder::new(&self.op.schema_ref); for record in reader.records() { // read records and push raw str rows into data chunk builder @@ -168,8 +168,7 @@ mod tests { header: false, }, }, - - columns, + schema_ref: Arc::new(columns), }; let executor = CopyFromFile { op: op.clone(), diff --git a/src/execution/volcano/dml/delete.rs b/src/execution/volcano/dml/delete.rs index ecb687f3..275e8b7f 100644 --- a/src/execution/volcano/dml/delete.rs +++ b/src/execution/volcano/dml/delete.rs @@ -32,9 +32,9 @@ impl Delete { let Delete { table_name, input } = self; let option_index_metas = transaction.table(table_name.clone()).map(|table_catalog| { table_catalog - .columns_with_id() + .columns() .enumerate() - .filter_map(|(i, (_, col))| { + .filter_map(|(i, col)| { col.desc .is_unique .then(|| { diff --git a/src/execution/volcano/dml/insert.rs b/src/execution/volcano/dml/insert.rs index 88c9f724..4f25a62e 100644 --- a/src/execution/volcano/dml/insert.rs +++ b/src/execution/volcano/dml/insert.rs @@ -53,24 +53,25 @@ impl Insert { let mut primary_key_index = None; let mut unique_values = HashMap::new(); let mut tuple_values = Vec::new(); - let mut tuple_columns = Vec::new(); + let mut tuple_schema_ref = None; if let Some(table_catalog) = transaction.table(table_name.clone()).cloned() { + let _ = tuple_schema_ref.get_or_insert_with(|| table_catalog.schema_ref().clone()); #[for_await] for tuple in build_read(input, transaction) { let Tuple { - columns, values, .. + schema_ref, values, .. } = tuple?; let mut tuple_map = HashMap::new(); for (i, value) in values.into_iter().enumerate() { - let col = &columns[i]; + let column = &schema_ref[i]; - if let Some(col_id) = col.id() { - tuple_map.insert(col_id, value); + if let Some(column_id) = column.id() { + tuple_map.insert(column_id, value); } } let primary_col_id = primary_key_index.get_or_insert_with(|| { - columns + schema_ref .iter() .find(|col| col.desc.is_primary) .map(|col| col.id().unwrap()) @@ -79,12 +80,9 @@ impl Insert { let tuple_id = tuple_map.get(primary_col_id).cloned().unwrap(); let mut values = Vec::with_capacity(table_catalog.columns_len()); - for (col_id, col) in table_catalog.columns_with_id() { - if table_catalog.columns_len() > tuple_columns.len() { - tuple_columns.push(col.clone()); - } + for col in table_catalog.columns() { let value = tuple_map - .remove(col_id) + .remove(&col.id().unwrap()) .or_else(|| col.default_value()) .unwrap_or_else(|| Arc::new(DataValue::none(col.datatype()))); @@ -101,7 +99,8 @@ impl Insert { } tuple_values.push((tuple_id, values)); } - let tuple_builder = TupleBuilder::new(tuple_columns); + let tuple_schema_ref = tuple_schema_ref.ok_or(DatabaseError::ColumnsEmpty)?; + let tuple_builder = TupleBuilder::new(&tuple_schema_ref); // Unique Index for (col_id, values) in unique_values { diff --git a/src/execution/volcano/dml/update.rs b/src/execution/volcano/dml/update.rs index 8888047c..72677cad 100644 --- a/src/execution/volcano/dml/update.rs +++ b/src/execution/volcano/dml/update.rs @@ -50,7 +50,9 @@ impl Update { #[for_await] for tuple in build_read(values, transaction) { let Tuple { - columns, values, .. + schema_ref: columns, + values, + .. } = tuple?; for i in 0..columns.len() { value_map.insert(columns[i].id(), values[i].clone()); @@ -66,7 +68,7 @@ impl Update { for mut tuple in tuples { let mut is_overwrite = true; - for (i, column) in tuple.columns.iter().enumerate() { + for (i, column) in tuple.schema_ref.iter().enumerate() { if let Some(value) = value_map.get(&column.id()) { if column.desc.is_primary { let old_key = tuple.id.replace(value.clone()).unwrap(); diff --git a/src/execution/volcano/dql/aggregate/hash_agg.rs b/src/execution/volcano/dql/aggregate/hash_agg.rs index 61bcedbd..800c83ea 100644 --- a/src/execution/volcano/dql/aggregate/hash_agg.rs +++ b/src/execution/volcano/dql/aggregate/hash_agg.rs @@ -124,7 +124,7 @@ impl HashAggStatus { Ok::(Tuple { id: None, - columns: group_columns.clone(), + schema_ref: group_columns.clone(), values, }) }) @@ -223,10 +223,11 @@ mod test { Arc::new(DataValue::Int32(Some(3))), ], ], - columns: t1_columns, + schema_ref: Arc::new(t1_columns), }), childrens: vec![], physical_option: None, + _out_columns: None, }; let tuples = diff --git a/src/execution/volcano/dql/aggregate/simple_agg.rs b/src/execution/volcano/dql/aggregate/simple_agg.rs index a610cac9..795e7871 100644 --- a/src/execution/volcano/dql/aggregate/simple_agg.rs +++ b/src/execution/volcano/dql/aggregate/simple_agg.rs @@ -66,7 +66,7 @@ impl SimpleAggExecutor { yield Tuple { id: None, - columns: Arc::new(columns), + schema_ref: Arc::new(columns), values, }; } diff --git a/src/execution/volcano/dql/describe.rs b/src/execution/volcano/dql/describe.rs index a6498b8a..32bdf871 100644 --- a/src/execution/volcano/dql/describe.rs +++ b/src/execution/volcano/dql/describe.rs @@ -57,7 +57,7 @@ impl Describe { } }; - for (_, column) in table.columns_with_id() { + for column in table.columns() { let values = vec![ Arc::new(DataValue::Utf8(Some(column.name().to_string()))), Arc::new(DataValue::Utf8(Some(column.datatype().to_string()))), @@ -69,7 +69,7 @@ impl Describe { ]; yield Tuple { id: None, - columns: columns.clone(), + schema_ref: columns.clone(), values, }; } diff --git a/src/execution/volcano/dql/dummy.rs b/src/execution/volcano/dql/dummy.rs index fe9c0eab..55670137 100644 --- a/src/execution/volcano/dql/dummy.rs +++ b/src/execution/volcano/dql/dummy.rs @@ -18,7 +18,7 @@ impl Dummy { pub async fn _execute(self) { yield Tuple { id: None, - columns: Arc::new(vec![]), + schema_ref: Arc::new(vec![]), values: vec![], } } diff --git a/src/execution/volcano/dql/explain.rs b/src/execution/volcano/dql/explain.rs index 46191d0d..4e736783 100644 --- a/src/execution/volcano/dql/explain.rs +++ b/src/execution/volcano/dql/explain.rs @@ -27,12 +27,12 @@ impl ReadExecutor for Explain { impl Explain { #[try_stream(boxed, ok = Tuple, error = DatabaseError)] pub async fn _execute(self) { - let columns = Arc::new(vec![Arc::new(ColumnCatalog::new_dummy("PLAN".to_string()))]); + let schema_ref = Arc::new(vec![Arc::new(ColumnCatalog::new_dummy("PLAN".to_string()))]); let values = vec![Arc::new(DataValue::Utf8(Some(self.plan.explain(0))))]; yield Tuple { id: None, - columns, + schema_ref, values, }; } diff --git a/src/execution/volcano/dql/join/hash_join.rs b/src/execution/volcano/dql/join/hash_join.rs index ab3c445e..37ec7ec2 100644 --- a/src/execution/volcano/dql/join/hash_join.rs +++ b/src/execution/volcano/dql/join/hash_join.rs @@ -130,7 +130,7 @@ impl HashJoinStatus { .. } = self; - let right_cols_len = tuple.columns.len(); + let right_cols_len = tuple.schema_ref.len(); let hash = Self::hash_row(on_right_keys, hash_random_state, &tuple)?; if !*right_init_flag { @@ -153,7 +153,7 @@ impl HashJoinStatus { Tuple { id: None, - columns: join_columns.clone(), + schema_ref: join_columns.clone(), values: full_values, } }) @@ -168,7 +168,7 @@ impl HashJoinStatus { vec![Tuple { id: None, - columns: join_columns.clone(), + schema_ref: join_columns.clone(), values, }] } else { @@ -185,13 +185,13 @@ impl HashJoinStatus { for mut tuple in join_tuples { if let DataValue::Boolean(option) = expr.eval(&tuple)?.as_ref() { if let Some(false) | None = option { - let full_cols_len = tuple.columns.len(); + let full_cols_len = tuple.schema_ref.len(); let left_cols_len = full_cols_len - right_cols_len; match ty { JoinType::Left => { for i in left_cols_len..full_cols_len { - let value_type = tuple.columns[i].datatype(); + let value_type = tuple.schema_ref[i].datatype(); tuple.values[i] = Arc::new(DataValue::none(value_type)) } @@ -199,7 +199,7 @@ impl HashJoinStatus { } JoinType::Right => { for i in 0..left_cols_len { - let value_type = tuple.columns[i].datatype(); + let value_type = tuple.schema_ref[i].datatype(); tuple.values[i] = Arc::new(DataValue::none(value_type)) } @@ -240,7 +240,7 @@ impl HashJoinStatus { .flat_map(|(_, mut tuples)| { for Tuple { values, - columns, + schema_ref: columns, id, } in tuples.iter_mut() { @@ -273,7 +273,7 @@ impl HashJoinStatus { fn columns_filling(tuple: &Tuple, join_columns: &mut Vec, force_nullable: bool) { let mut new_columns = tuple - .columns + .schema_ref .iter() .cloned() .map(|col| { @@ -403,10 +403,11 @@ mod test { Arc::new(DataValue::Int32(Some(7))), ], ], - columns: t1_columns, + schema_ref: Arc::new(t1_columns), }), childrens: vec![], physical_option: None, + _out_columns: None, }; let values_t2 = LogicalPlan { @@ -433,10 +434,11 @@ mod test { Arc::new(DataValue::Int32(Some(1))), ], ], - columns: t2_columns, + schema_ref: Arc::new(t2_columns), }), childrens: vec![], physical_option: None, + _out_columns: None, }; (on_keys, values_t1, values_t2) diff --git a/src/execution/volcano/dql/projection.rs b/src/execution/volcano/dql/projection.rs index d9791374..5fceb24a 100644 --- a/src/execution/volcano/dql/projection.rs +++ b/src/execution/volcano/dql/projection.rs @@ -29,13 +29,13 @@ impl Projection { #[try_stream(boxed, ok = Tuple, error = DatabaseError)] pub async fn _execute(self, transaction: &T) { let Projection { exprs, input } = self; - let mut columns = None; + let mut schema_ref = None; #[for_await] for tuple in build_read(input, transaction) { let mut tuple = tuple?; let mut values = Vec::with_capacity(exprs.len()); - let columns = columns.get_or_insert_with(|| { + let schema_ref = schema_ref.get_or_insert_with(|| { let mut columns = Vec::with_capacity(exprs.len()); for expr in exprs.iter() { @@ -47,7 +47,7 @@ impl Projection { for expr in exprs.iter() { values.push(expr.eval(&tuple)?); } - tuple.columns = columns.clone(); + tuple.schema_ref = schema_ref.clone(); tuple.values = values; yield tuple; diff --git a/src/execution/volcano/dql/show_table.rs b/src/execution/volcano/dql/show_table.rs index b3f1c21e..afd18ce8 100644 --- a/src/execution/volcano/dql/show_table.rs +++ b/src/execution/volcano/dql/show_table.rs @@ -25,7 +25,7 @@ impl ShowTables { colum_meta_paths: histogram_paths, } in metas { - let columns = Arc::new(vec![ + let schema_ref = Arc::new(vec![ Arc::new(ColumnCatalog::new_dummy("TABLE".to_string())), Arc::new(ColumnCatalog::new_dummy("COLUMN_METAS_LEN".to_string())), ]); @@ -36,7 +36,7 @@ impl ShowTables { yield Tuple { id: None, - columns, + schema_ref, values, }; } diff --git a/src/execution/volcano/dql/values.rs b/src/execution/volcano/dql/values.rs index 65ecb316..34c8a7fe 100644 --- a/src/execution/volcano/dql/values.rs +++ b/src/execution/volcano/dql/values.rs @@ -4,7 +4,6 @@ use crate::planner::operator::values::ValuesOperator; use crate::storage::Transaction; use crate::types::tuple::Tuple; use futures_async_stream::try_stream; -use std::sync::Arc; pub struct Values { op: ValuesOperator, @@ -25,13 +24,12 @@ impl ReadExecutor for Values { impl Values { #[try_stream(boxed, ok = Tuple, error = DatabaseError)] pub async fn _execute(self) { - let ValuesOperator { columns, rows } = self.op; - let columns = Arc::new(columns); + let ValuesOperator { schema_ref, rows } = self.op; for values in rows { yield Tuple { id: None, - columns: columns.clone(), + schema_ref: schema_ref.clone(), values, }; } diff --git a/src/execution/volcano/mod.rs b/src/execution/volcano/mod.rs index 403a8550..5d2e3855 100644 --- a/src/execution/volcano/mod.rs +++ b/src/execution/volcano/mod.rs @@ -118,6 +118,7 @@ pub fn build_write(plan: LogicalPlan, transaction: &mut T) -> Bo operator, mut childrens, physical_option, + _out_columns, } = plan; match operator { @@ -163,6 +164,7 @@ pub fn build_write(plan: LogicalPlan, transaction: &mut T) -> Bo operator, childrens, physical_option, + _out_columns, }, transaction, ), diff --git a/src/expression/evaluator.rs b/src/expression/evaluator.rs index 9980b9b2..eb15a33e 100644 --- a/src/expression/evaluator.rs +++ b/src/expression/evaluator.rs @@ -44,7 +44,7 @@ impl ScalarExpression { } ScalarExpression::Alias { expr, alias } => { if let Some(value) = tuple - .columns + .schema_ref .iter() .find_position(|tul_col| tul_col.name() == alias) .map(|(i, _)| &tuple.values[i]) @@ -159,12 +159,13 @@ impl ScalarExpression { Ok(Arc::new(DataValue::Utf8(None))) } } + ScalarExpression::Empty => unreachable!(), } } fn eval_with_summary<'a>(tuple: &'a Tuple, summary: &ColumnSummary) -> Option<&'a ValueRef> { tuple - .columns + .schema_ref .iter() .find_position(|tul_col| tul_col.summary() == summary) .map(|(i, _)| &tuple.values[i]) diff --git a/src/expression/mod.rs b/src/expression/mod.rs index 7b2db0b6..6acd28e6 100644 --- a/src/expression/mod.rs +++ b/src/expression/mod.rs @@ -1,8 +1,8 @@ use itertools::Itertools; use serde::{Deserialize, Serialize}; -use std::fmt; use std::fmt::{Debug, Formatter}; use std::sync::Arc; +use std::fmt; use sqlparser::ast::{BinaryOperator as SqlBinaryOperator, UnaryOperator as SqlUnaryOperator}; @@ -69,6 +69,8 @@ pub enum ScalarExpression { for_expr: Option>, from_expr: Option>, }, + // Temporary expression used for expression substitution + Empty, } impl ScalarExpression { @@ -117,6 +119,7 @@ impl ScalarExpression { } Self::SubString { .. } => LogicalType::Varchar(None), Self::Alias { expr, .. } => expr.return_type(), + ScalarExpression::Empty => unreachable!(), } } @@ -209,6 +212,7 @@ impl ScalarExpression { Some(true) ) } + ScalarExpression::Empty => unreachable!(), } } @@ -301,6 +305,7 @@ impl ScalarExpression { op("for", for_expr), ) } + ScalarExpression::Empty => unreachable!(), } } diff --git a/src/expression/simplify.rs b/src/expression/simplify.rs index 2a7207e2..defb71b1 100644 --- a/src/expression/simplify.rs +++ b/src/expression/simplify.rs @@ -544,7 +544,7 @@ impl ScalarExpression { (Some(col), None) => { replaces.push(Replace::Binary(ReplaceBinary { column_expr: ScalarExpression::ColumnRef(col), - val_expr: right_expr.as_ref().clone(), + val_expr: mem::replace(right_expr, ScalarExpression::Empty), op: *op, ty: *ty, is_column_left: true, @@ -553,7 +553,7 @@ impl ScalarExpression { (None, Some(col)) => { replaces.push(Replace::Binary(ReplaceBinary { column_expr: ScalarExpression::ColumnRef(col), - val_expr: left_expr.as_ref().clone(), + val_expr: mem::replace(left_expr, ScalarExpression::Empty), op: *op, ty: *ty, is_column_left: false, @@ -568,7 +568,7 @@ impl ScalarExpression { (Some(col), None) => { replaces.push(Replace::Binary(ReplaceBinary { column_expr: ScalarExpression::ColumnRef(col), - val_expr: right_expr.as_ref().clone(), + val_expr: mem::replace(right_expr, ScalarExpression::Empty), op: *op, ty: *ty, is_column_left: true, @@ -577,7 +577,7 @@ impl ScalarExpression { (None, Some(col)) => { replaces.push(Replace::Binary(ReplaceBinary { column_expr: ScalarExpression::ColumnRef(col), - val_expr: left_expr.as_ref().clone(), + val_expr: mem::replace(left_expr, ScalarExpression::Empty), op: *op, ty: *ty, is_column_left: false, @@ -675,13 +675,13 @@ impl ScalarExpression { left_expr: Box::new(ScalarExpression::Binary { op: left_op, left_expr: expr.clone(), - right_expr: left_expr.clone(), + right_expr: mem::replace(left_expr, Box::new(ScalarExpression::Empty)), ty: LogicalType::Boolean, }), right_expr: Box::new(ScalarExpression::Binary { op: right_op, - left_expr: expr.clone(), - right_expr: right_expr.clone(), + left_expr: mem::replace(expr, Box::new(ScalarExpression::Empty)), + right_expr: mem::replace(right_expr, Box::new(ScalarExpression::Empty)), ty: LogicalType::Boolean, }), ty: LogicalType::Boolean, @@ -741,11 +741,13 @@ impl ScalarExpression { ty: fix_ty, } = replace_unary; let _ = mem::replace(col_expr, Box::new(child_expr)); + + let expr = mem::replace(val_expr, Box::new(ScalarExpression::Empty)); let _ = mem::replace( val_expr, Box::new(ScalarExpression::Unary { op: fix_op, - expr: val_expr.clone(), + expr, ty: fix_ty, }), ); @@ -802,13 +804,14 @@ impl ScalarExpression { BinaryOperator::LtEq => BinaryOperator::GtEq, source_op => source_op, }; + let temp_expr = mem::replace(right_expr, Box::new(ScalarExpression::Empty)); let (fixed_op, fixed_left_expr, fixed_right_expr) = if is_column_left { - (op_flip(fix_op), right_expr.clone(), Box::new(val_expr)) + (op_flip(fix_op), temp_expr, Box::new(val_expr)) } else { if matches!(fix_op, BinaryOperator::Minus | BinaryOperator::Multiply) { let _ = mem::replace(op, comparison_flip(*op)); } - (fix_op, Box::new(val_expr), right_expr.clone()) + (fix_op, Box::new(val_expr), temp_expr) }; let _ = mem::replace(left_expr, Box::new(column_expr)); @@ -939,10 +942,12 @@ impl ScalarExpression { | ScalarExpression::In { .. } | ScalarExpression::Between { .. } | ScalarExpression::SubString { .. } => expr.convert_binary(col_id), + ScalarExpression::Empty => unreachable!(), }, ScalarExpression::Constant(_) | ScalarExpression::ColumnRef(_) | ScalarExpression::AggCall { .. } => Ok(None), + ScalarExpression::Empty => unreachable!(), } } diff --git a/src/marcos/mod.rs b/src/marcos/mod.rs index c205c59a..8570ed01 100644 --- a/src/marcos/mod.rs +++ b/src/marcos/mod.rs @@ -28,7 +28,7 @@ macro_rules! implement_from_tuple { fn from(tuple: Tuple) -> Self { fn try_get(tuple: &Tuple, field_name: &str) -> Option { let ty = LogicalType::type_trans::()?; - let (idx, _) = tuple.columns + let (idx, _) = tuple.schema_ref .iter() .enumerate() .find(|(_, col)| col.name() == field_name)?; @@ -81,7 +81,7 @@ mod test { Tuple { id: None, - columns, + schema_ref: columns, values, } } diff --git a/src/optimizer/core/opt_expr.rs b/src/optimizer/core/opt_expr.rs index 5b7be4ea..a1692758 100644 --- a/src/optimizer/core/opt_expr.rs +++ b/src/optimizer/core/opt_expr.rs @@ -40,10 +40,6 @@ impl OptExpr { .iter() .map(|c| c.to_plan_ref()) .collect::>(); - LogicalPlan { - operator: self.root.clone(), - childrens, - physical_option: None, - } + LogicalPlan::new(self.root.clone(), childrens) } } diff --git a/src/optimizer/heuristic/graph.rs b/src/optimizer/heuristic/graph.rs index 51aafe11..7a3e9ae0 100644 --- a/src/optimizer/heuristic/graph.rs +++ b/src/optimizer/heuristic/graph.rs @@ -200,6 +200,7 @@ impl HepGraph { operator, childrens, physical_option, + _out_columns: None, }) } } diff --git a/src/optimizer/heuristic/matcher.rs b/src/optimizer/heuristic/matcher.rs index 5381b174..320b4bc9 100644 --- a/src/optimizer/heuristic/matcher.rs +++ b/src/optimizer/heuristic/matcher.rs @@ -103,16 +103,20 @@ mod tests { operator: Operator::Dummy, childrens: vec![], physical_option: None, + _out_columns: None, }], physical_option: None, + _out_columns: None, }, LogicalPlan { operator: Operator::Dummy, childrens: vec![], physical_option: None, + _out_columns: None, }, ], physical_option: None, + _out_columns: None, }; let graph = HepGraph::new(all_dummy_plan.clone()); diff --git a/src/planner/mod.rs b/src/planner/mod.rs index 8cea163d..dae9c8f5 100644 --- a/src/planner/mod.rs +++ b/src/planner/mod.rs @@ -2,15 +2,29 @@ pub mod operator; use crate::catalog::TableName; use crate::planner::operator::{Operator, PhysicalOption}; +use crate::types::tuple::SchemaRef; +use itertools::Itertools; +use std::sync::Arc; #[derive(Debug, PartialEq, Eq, Clone, Hash)] pub struct LogicalPlan { pub operator: Operator, pub childrens: Vec, pub physical_option: Option, + + pub _out_columns: Option, } impl LogicalPlan { + pub fn new(operator: Operator, childrens: Vec) -> Self { + Self { + operator, + childrens, + physical_option: None, + _out_columns: None, + } + } + pub fn child(&self, index: usize) -> Option<&LogicalPlan> { self.childrens.get(index) } @@ -30,6 +44,64 @@ impl LogicalPlan { tables } + pub fn out_schmea(&mut self) -> &SchemaRef { + self._out_columns + .get_or_insert_with(|| match &self.operator { + Operator::Filter(_) | Operator::Sort(_) | Operator::Limit(_) => { + self.childrens[0].out_schmea().clone() + } + Operator::Aggregate(op) => { + let out_columns = op + .agg_calls + .iter() + .chain(op.groupby_exprs.iter()) + .map(|expr| expr.output_column()) + .collect_vec(); + Arc::new(out_columns) + } + Operator::Join(_) => { + let out_columns = self + .childrens + .iter_mut() + .flat_map(|children| Vec::clone(children.out_schmea())) + .collect_vec(); + Arc::new(out_columns) + } + Operator::Project(op) => { + let out_columns = op + .exprs + .iter() + .map(|expr| expr.output_column()) + .collect_vec(); + Arc::new(out_columns) + } + Operator::Scan(op) => { + let out_columns = op + .columns + .iter() + .map(|(_, column)| column.clone()) + .collect_vec(); + Arc::new(out_columns) + } + Operator::Values(op) => op.schema_ref.clone(), + Operator::Dummy + | Operator::Show + | Operator::Explain + | Operator::Describe(_) + | Operator::Insert(_) + | Operator::Update(_) + | Operator::Delete(_) + | Operator::Analyze(_) + | Operator::AddColumn(_) + | Operator::DropColumn(_) + | Operator::CreateTable(_) + | Operator::DropTable(_) + | Operator::Truncate(_) + | Operator::CopyFromFile(_) + | Operator::CopyToFile(_) => Arc::new(vec![]), + }) + } + pub fn explain(&self, indentation: usize) -> String { let mut result = format!("{:indent$}{}", "", self.operator, indent = indentation); diff --git a/src/planner/operator/aggregate.rs b/src/planner/operator/aggregate.rs index ed7122b5..0ca6a44a 100644 --- a/src/planner/operator/aggregate.rs +++ b/src/planner/operator/aggregate.rs @@ -16,14 +16,13 @@ impl AggregateOperator { agg_calls: Vec, groupby_exprs: Vec, ) -> LogicalPlan { - LogicalPlan { - operator: Operator::Aggregate(Self { + LogicalPlan::new( + Operator::Aggregate(Self { groupby_exprs, agg_calls, }), - childrens: vec![children], - physical_option: None, - } + vec![children], + ) } } diff --git a/src/planner/operator/copy_from_file.rs b/src/planner/operator/copy_from_file.rs index 02154991..cae0ebec 100644 --- a/src/planner/operator/copy_from_file.rs +++ b/src/planner/operator/copy_from_file.rs @@ -1,5 +1,5 @@ use crate::binder::copy::ExtSource; -use crate::catalog::ColumnRef; +use crate::types::tuple::SchemaRef; use itertools::Itertools; use std::fmt; use std::fmt::Formatter; @@ -8,13 +8,13 @@ use std::fmt::Formatter; pub struct CopyFromFileOperator { pub table: String, pub source: ExtSource, - pub columns: Vec, + pub schema_ref: SchemaRef, } impl fmt::Display for CopyFromFileOperator { fn fmt(&self, f: &mut Formatter) -> fmt::Result { let columns = self - .columns + .schema_ref .iter() .map(|column| column.name().to_string()) .join(", "); diff --git a/src/planner/operator/filter.rs b/src/planner/operator/filter.rs index 49ccf0d0..51d0e521 100644 --- a/src/planner/operator/filter.rs +++ b/src/planner/operator/filter.rs @@ -14,11 +14,10 @@ pub struct FilterOperator { impl FilterOperator { pub fn build(predicate: ScalarExpression, children: LogicalPlan, having: bool) -> LogicalPlan { - LogicalPlan { - operator: Operator::Filter(FilterOperator { predicate, having }), - childrens: vec![children], - physical_option: None, - } + LogicalPlan::new( + Operator::Filter(FilterOperator { predicate, having }), + vec![children], + ) } } diff --git a/src/planner/operator/join.rs b/src/planner/operator/join.rs index e5de3525..9b23ffe2 100644 --- a/src/planner/operator/join.rs +++ b/src/planner/operator/join.rs @@ -39,17 +39,16 @@ impl JoinOperator { on: JoinCondition, join_type: JoinType, ) -> LogicalPlan { - LogicalPlan { - operator: Operator::Join(JoinOperator { on, join_type }), - childrens: vec![left, right], - physical_option: None, - } + LogicalPlan::new( + Operator::Join(JoinOperator { on, join_type }), + vec![left, right], + ) } } impl fmt::Display for JoinOperator { fn fmt(&self, f: &mut Formatter) -> fmt::Result { - write!(f, "{} Join On {}", self.join_type, self.on)?; + write!(f, "{} Join{}", self.join_type, self.on)?; Ok(()) } @@ -59,18 +58,20 @@ impl fmt::Display for JoinCondition { fn fmt(&self, f: &mut Formatter) -> fmt::Result { match self { JoinCondition::On { on, filter } => { - let on = on - .iter() - .map(|(v1, v2)| format!("{} = {}", v1, v2)) - .join(" AND "); + if !on.is_empty() { + let on = on + .iter() + .map(|(v1, v2)| format!("{} = {}", v1, v2)) + .join(" AND "); - write!(f, "{}", on)?; + write!(f, " On {}", on)?; + } if let Some(filter) = filter { - write!(f, "Where {}", filter)?; + write!(f, " Where {}", filter)?; } } JoinCondition::None => { - write!(f, "Nothing")?; + write!(f, " Nothing")?; } } diff --git a/src/planner/operator/limit.rs b/src/planner/operator/limit.rs index ee034d73..16aa434a 100644 --- a/src/planner/operator/limit.rs +++ b/src/planner/operator/limit.rs @@ -16,11 +16,10 @@ impl LimitOperator { limit: Option, children: LogicalPlan, ) -> LogicalPlan { - LogicalPlan { - operator: Operator::Limit(LimitOperator { offset, limit }), - childrens: vec![children], - physical_option: None, - } + LogicalPlan::new( + Operator::Limit(LimitOperator { offset, limit }), + vec![children], + ) } } diff --git a/src/planner/operator/mod.rs b/src/planner/operator/mod.rs index c78948d6..99ccf76b 100644 --- a/src/planner/operator/mod.rs +++ b/src/planner/operator/mod.rs @@ -144,7 +144,7 @@ impl Operator { .map(|field| &field.expr) .flat_map(|expr| expr.referenced_columns(only_column_ref)) .collect_vec(), - Operator::Values(op) => op.columns.clone(), + Operator::Values(op) => Vec::clone(&op.schema_ref), Operator::Analyze(op) => op.columns.clone(), Operator::Delete(op) => vec![op.primary_key_column.clone()], _ => vec![], diff --git a/src/planner/operator/scan.rs b/src/planner/operator/scan.rs index 350445c0..4c2cc587 100644 --- a/src/planner/operator/scan.rs +++ b/src/planner/operator/scan.rs @@ -26,9 +26,9 @@ impl ScanOperator { let mut primary_key_option = None; // Fill all Columns in TableCatalog by default let columns = table_catalog - .columns_with_id() + .columns() .enumerate() - .map(|(i, (_, column))| { + .map(|(i, column)| { if column.desc.is_primary { primary_key_option = column.id(); } @@ -45,17 +45,16 @@ impl ScanOperator { }) .collect_vec(); - LogicalPlan { - operator: Operator::Scan(ScanOperator { + LogicalPlan::new( + Operator::Scan(ScanOperator { index_infos, table_name, primary_key: primary_key_option.unwrap(), columns, limit: (None, None), }), - childrens: vec![], - physical_option: None, - } + vec![], + ) } } diff --git a/src/planner/operator/values.rs b/src/planner/operator/values.rs index 95510652..330c9772 100644 --- a/src/planner/operator/values.rs +++ b/src/planner/operator/values.rs @@ -1,4 +1,4 @@ -use crate::catalog::ColumnRef; +use crate::types::tuple::SchemaRef; use crate::types::value::ValueRef; use itertools::Itertools; use std::fmt; @@ -7,13 +7,13 @@ use std::fmt::Formatter; #[derive(Debug, PartialEq, Eq, Clone, Hash)] pub struct ValuesOperator { pub rows: Vec>, - pub columns: Vec, + pub schema_ref: SchemaRef, } impl fmt::Display for ValuesOperator { fn fmt(&self, f: &mut Formatter) -> fmt::Result { let columns = self - .columns + .schema_ref .iter() .map(|column| column.name().to_string()) .join(", "); diff --git a/src/storage/kip.rs b/src/storage/kip.rs index 1e61fae6..dec0acf1 100644 --- a/src/storage/kip.rs +++ b/src/storage/kip.rs @@ -126,7 +126,7 @@ impl Transaction for KipTransaction { Ok(IndexIter { offset, limit: limit_option, - tuple_columns: Arc::new(tuple_columns), + tuple_schema_ref: Arc::new(tuple_columns), index_meta, table, index_values: VecDeque::new(), @@ -207,10 +207,10 @@ impl Transaction for KipTransaction { return Err(DatabaseError::NeedNullAbleOrDefault); } - for (col_id, col) in table.columns_with_id() { + for col in table.columns() { if col.name() == column.name() { return if if_not_exists { - Ok(*col_id) + Ok(col.id().unwrap()) } else { Err(DatabaseError::DuplicateColumn) }; @@ -295,7 +295,7 @@ impl Transaction for KipTransaction { Self::create_index_meta_for_table(&mut self.tx, &mut table_catalog)?; - for column in table_catalog.columns.values() { + for column in table_catalog.columns() { let (key, value) = TableCodec::encode_column(&table_name, column)?; self.tx.set(key, value); } @@ -456,9 +456,9 @@ impl KipTransaction { ) -> Result<(), DatabaseError> { let table_name = table.name.clone(); let index_column = table - .columns_with_id() - .filter(|(_, col)| col.desc.is_index()) - .map(|(col_id, column)| (*col_id, column.clone())) + .columns() + .filter(|column| column.desc.is_index()) + .map(|column| (column.id().unwrap(), column.clone())) .collect_vec(); for (col_id, col) in index_column { @@ -574,7 +574,7 @@ mod test { &"test".to_string(), Tuple { id: Some(Arc::new(DataValue::Int32(Some(1)))), - columns: columns.clone(), + schema_ref: columns.clone(), values: vec![ Arc::new(DataValue::Int32(Some(1))), Arc::new(DataValue::Boolean(Some(true))), @@ -586,7 +586,7 @@ mod test { &"test".to_string(), Tuple { id: Some(Arc::new(DataValue::Int32(Some(2)))), - columns: columns.clone(), + schema_ref: columns.clone(), values: vec![ Arc::new(DataValue::Int32(Some(2))), Arc::new(DataValue::Boolean(Some(false))), @@ -628,7 +628,6 @@ mod test { .table(Arc::new("t1".to_string())) .unwrap() .clone(); - let columns = Arc::new(table.clone_columns().into_iter().collect_vec()); let tuple_ids = vec![ Arc::new(DataValue::Int32(Some(0))), Arc::new(DataValue::Int32(Some(2))), @@ -638,7 +637,7 @@ mod test { let mut iter = IndexIter { offset: 0, limit: None, - tuple_columns: columns, + tuple_schema_ref: table.schema_ref().clone(), index_meta: Arc::new(IndexMeta { id: 0, column_ids: vec![0], @@ -686,7 +685,7 @@ mod test { .table(Arc::new("t1".to_string())) .unwrap() .clone(); - let columns = table.clone_columns().into_iter().enumerate().collect_vec(); + let columns = table.columns().cloned().enumerate().collect_vec(); let mut iter = transaction .read_by_index( Arc::new("t1".to_string()), diff --git a/src/storage/mod.rs b/src/storage/mod.rs index e32b70c3..5a82cffa 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -112,7 +112,7 @@ enum IndexValue { pub struct IndexIter<'a> { offset: usize, limit: Option, - tuple_columns: Arc>, + tuple_schema_ref: Arc>, projections: Vec, index_meta: IndexMetaRef, @@ -153,7 +153,7 @@ impl IndexIter<'_> { TableCodec::decode_tuple( &self.table.types(), &self.projections, - &self.tuple_columns, + &self.tuple_schema_ref, &bytes, ) })) @@ -205,7 +205,7 @@ impl Iter for IndexIter<'_> { let tuple = TableCodec::decode_tuple( &self.table.types(), &self.projections, - &self.tuple_columns, + &self.tuple_schema_ref, &value, ); @@ -273,7 +273,7 @@ impl Iter for IndexIter<'_> { let tuple = TableCodec::decode_tuple( &self.table.types(), &self.projections, - &self.tuple_columns, + &self.tuple_schema_ref, &bytes, ); diff --git a/src/storage/table_codec.rs b/src/storage/table_codec.rs index a59e378f..aef2d4d3 100644 --- a/src/storage/table_codec.rs +++ b/src/storage/table_codec.rs @@ -1,11 +1,10 @@ -use crate::catalog::{ColumnCatalog, ColumnRef, TableMeta}; +use crate::catalog::{ColumnCatalog, TableMeta}; use crate::errors::DatabaseError; use crate::types::index::{Index, IndexId, IndexMeta}; -use crate::types::tuple::{Tuple, TupleId}; +use crate::types::tuple::{SchemaRef, Tuple, TupleId}; use crate::types::LogicalType; use bytes::Bytes; use lazy_static::lazy_static; -use std::sync::Arc; const BOUND_MIN_TAG: u8 = 0; const BOUND_MAX_TAG: u8 = 1; @@ -174,10 +173,10 @@ impl TableCodec { pub fn decode_tuple( table_types: &[LogicalType], projections: &[usize], - tuple_columns: &Arc>, + tuple_schema_ref: &SchemaRef, bytes: &[u8], ) -> Tuple { - Tuple::deserialize_from(table_types, projections, tuple_columns, bytes) + Tuple::deserialize_from(table_types, projections, tuple_schema_ref, bytes) } /// Key: {TableName}{INDEX_META_TAG}{BOUND_MIN_TAG}{IndexID} @@ -317,14 +316,14 @@ mod tests { let tuple = Tuple { id: Some(Arc::new(DataValue::Int32(Some(0)))), - columns: Arc::new(table_catalog.clone_columns()), + schema_ref: table_catalog.schema_ref().clone(), values: vec![ Arc::new(DataValue::Int32(Some(0))), Arc::new(DataValue::Decimal(Some(Decimal::new(1, 0)))), ], }; let (_, bytes) = TableCodec::encode_tuple(&table_catalog.name, &tuple)?; - let columns = Arc::new(table_catalog.clone_columns().into_iter().collect_vec()); + let columns = table_catalog.schema_ref().clone(); assert_eq!( TableCodec::decode_tuple(&table_catalog.types(), &[0, 1], &columns, &bytes), @@ -384,7 +383,7 @@ mod tests { #[test] fn test_table_codec_column() { let table_catalog = build_table_codec(); - let col = table_catalog.clone_columns()[0].clone(); + let col = table_catalog.columns().next().cloned().unwrap(); let (_, bytes) = TableCodec::encode_column(&table_catalog.name, &col).unwrap(); let decode_col = TableCodec::decode_column(&bytes).unwrap(); diff --git a/src/types/tuple.rs b/src/types/tuple.rs index b52fa91d..692270e8 100644 --- a/src/types/tuple.rs +++ b/src/types/tuple.rs @@ -15,7 +15,7 @@ pub type SchemaRef = Arc; #[derive(Clone, Debug, PartialEq)] pub struct Tuple { pub id: Option, - pub columns: SchemaRef, + pub schema_ref: SchemaRef, pub values: Vec, } @@ -23,17 +23,17 @@ impl Tuple { pub fn deserialize_from( table_types: &[LogicalType], projections: &[usize], - tuple_columns: &Arc>, + tuple_schema_ref: &SchemaRef, bytes: &[u8], ) -> Self { - assert!(!tuple_columns.is_empty()); - assert_eq!(projections.len(), tuple_columns.len()); + assert!(!tuple_schema_ref.is_empty()); + assert_eq!(projections.len(), tuple_schema_ref.len()); fn is_none(bits: u8, i: usize) -> bool { bits & (1 << (7 - i)) > 0 } - let values_len = tuple_columns.len(); + let values_len = tuple_schema_ref.len(); let mut tuple_values = Vec::with_capacity(values_len); let bits_len = (values_len + BITS_MAX_INDEX) / BITS_MAX_INDEX; let mut id_option = None; @@ -42,14 +42,14 @@ impl Tuple { let mut pos = bits_len; for (i, logic_type) in table_types.iter().enumerate() { - if projection_i >= tuple_columns.len() { + if projection_i >= tuple_schema_ref.len() { break; } if is_none(bytes[i / BITS_MAX_INDEX], i % BITS_MAX_INDEX) { if projections[projection_i] == i { tuple_values.push(Arc::new(DataValue::none(logic_type))); Self::values_push( - tuple_columns, + tuple_schema_ref, &tuple_values, &mut id_option, &mut projection_i, @@ -63,7 +63,7 @@ impl Tuple { logic_type, ))); Self::values_push( - tuple_columns, + tuple_schema_ref, &tuple_values, &mut id_option, &mut projection_i, @@ -80,7 +80,7 @@ impl Tuple { logic_type, ))); Self::values_push( - tuple_columns, + tuple_schema_ref, &tuple_values, &mut id_option, &mut projection_i, @@ -92,7 +92,7 @@ impl Tuple { Tuple { id: id_option, - columns: tuple_columns.clone(), + schema_ref: tuple_schema_ref.clone(), values: tuple_values, } } @@ -145,7 +145,7 @@ pub fn create_table(tuples: &[Tuple]) -> Table { } let mut header = Vec::new(); - for col in tuples[0].columns.iter() { + for col in tuples[0].schema_ref.iter() { header.push(Cell::new(col.name().to_string())); } table.set_header(header); @@ -246,7 +246,7 @@ mod tests { let tuples = vec![ Tuple { id: Some(Arc::new(DataValue::Int32(Some(0)))), - columns: columns.clone(), + schema_ref: columns.clone(), values: vec![ Arc::new(DataValue::Int32(Some(0))), Arc::new(DataValue::UInt32(Some(1))), @@ -265,7 +265,7 @@ mod tests { }, Tuple { id: Some(Arc::new(DataValue::Int32(Some(1)))), - columns: columns.clone(), + schema_ref: columns.clone(), values: vec![ Arc::new(DataValue::Int32(Some(1))), Arc::new(DataValue::UInt32(None)), diff --git a/src/types/tuple_builder.rs b/src/types/tuple_builder.rs index ac08968a..d3032e07 100644 --- a/src/types/tuple_builder.rs +++ b/src/types/tuple_builder.rs @@ -1,18 +1,16 @@ -use crate::catalog::{ColumnCatalog, ColumnRef}; +use crate::catalog::ColumnCatalog; use crate::errors::DatabaseError; -use crate::types::tuple::Tuple; +use crate::types::tuple::{SchemaRef, Tuple}; use crate::types::value::{DataValue, ValueRef}; use std::sync::Arc; -pub struct TupleBuilder { - columns: Arc>, +pub struct TupleBuilder<'a> { + schema_ref: &'a SchemaRef, } -impl TupleBuilder { - pub fn new(columns: Vec) -> Self { - TupleBuilder { - columns: Arc::new(columns), - } +impl<'a> TupleBuilder<'a> { + pub fn new(schema_ref: &'a SchemaRef) -> Self { + TupleBuilder { schema_ref } } pub fn build_result(header: String, message: String) -> Result { @@ -21,7 +19,7 @@ impl TupleBuilder { Ok(Tuple { id: None, - columns, + schema_ref: columns, values, }) } @@ -31,7 +29,7 @@ impl TupleBuilder { id: Option, values: Vec, ) -> Result { - if values.len() != self.columns.len() { + if values.len() != self.schema_ref.len() { return Err(DatabaseError::MisMatch( "types".to_string(), "values".to_string(), @@ -40,29 +38,29 @@ impl TupleBuilder { Ok(Tuple { id, - columns: self.columns.clone(), + schema_ref: self.schema_ref.clone(), values, }) } - pub fn build_with_row<'a>( + pub fn build_with_row<'b>( &self, - row: impl IntoIterator, + row: impl IntoIterator, ) -> Result { - let mut values = Vec::with_capacity(self.columns.len()); + let mut values = Vec::with_capacity(self.schema_ref.len()); let mut primary_key = None; for (i, value) in row.into_iter().enumerate() { let data_value = Arc::new( - DataValue::Utf8(Some(value.to_string())).cast(self.columns[i].datatype())?, + DataValue::Utf8(Some(value.to_string())).cast(self.schema_ref[i].datatype())?, ); - if primary_key.is_none() && self.columns[i].desc.is_primary { + if primary_key.is_none() && self.schema_ref[i].desc.is_primary { primary_key = Some(data_value.clone()); } values.push(data_value); } - if values.len() != self.columns.len() { + if values.len() != self.schema_ref.len() { return Err(DatabaseError::MisMatch( "types".to_string(), "values".to_string(), @@ -71,7 +69,7 @@ impl TupleBuilder { Ok(Tuple { id: primary_key, - columns: self.columns.clone(), + schema_ref: self.schema_ref.clone(), values, }) } diff --git a/tests/slt/join.slt b/tests/slt/join.slt index 7ca5738f..9b191fbb 100644 --- a/tests/slt/join.slt +++ b/tests/slt/join.slt @@ -4,6 +4,9 @@ create table x(id int primary key, a int, b int); statement ok create table y(id int primary key, c int, d int); +statement ok +create table z(id int primary key, e int, f int); + statement ok insert into x values (0, 1, 2), (1, 1, 3); @@ -14,6 +17,9 @@ select a, b, c, d from x join y on a = c; statement ok insert into y values (0, 1, 5), (1, 1, 6), (2, 2, 7); +statement ok +insert into z values (0, 1, 5), (1, 2, 6), (2, 4, 7); + query IIII select a, b, c, d from x join y on a = c; ---- @@ -22,6 +28,14 @@ select a, b, c, d from x join y on a = c; 1 2 1 6 1 3 1 6 +query IIII +select a, b, c, d, e, f from x join y on a = c and c < 5 join z on e = a and f = 5; +---- +1 2 1 5 1 5 +1 3 1 5 1 5 +1 2 1 6 1 5 +1 3 1 6 1 5 + statement ok drop table x; diff --git a/tests/sqllogictest/src/lib.rs b/tests/sqllogictest/src/lib.rs index 52d23b6a..f70a7f07 100644 --- a/tests/sqllogictest/src/lib.rs +++ b/tests/sqllogictest/src/lib.rs @@ -24,7 +24,7 @@ impl AsyncDB for KipSQL { return Ok(DBOutput::StatementComplete(0)); } - let types = vec![DefaultColumnType::Any; tuples[0].columns.len()]; + let types = vec![DefaultColumnType::Any; tuples[0].schema_ref.len()]; let rows = tuples .into_iter() .map(|tuple| { From de8ca71903768e2e8cabf5ca5a8b9566148efd6e Mon Sep 17 00:00:00 2001 From: Kould <2435992353@qq.com> Date: Sun, 11 Feb 2024 14:31:39 +0800 Subject: [PATCH 2/3] fix: add alias for subquery --- src/binder/expr.rs | 19 +- src/binder/mod.rs | 11 +- src/binder/select.rs | 178 +++++++++++------- src/catalog/column.rs | 4 + .../volcano/dql/aggregate/hash_agg.rs | 2 +- src/execution/volcano/dql/join/hash_join.rs | 4 +- src/execution/volcano/mod.rs | 4 +- src/expression/evaluator.rs | 11 +- src/expression/mod.rs | 53 +++++- src/optimizer/heuristic/graph.rs | 2 +- src/optimizer/heuristic/matcher.rs | 8 +- .../rule/normalization/column_pruning.rs | 27 +-- src/planner/mod.rs | 18 +- tests/slt/sql_2016/E061_09.slt | 44 +++-- tests/slt/subquery.slt | 33 +++- 15 files changed, 274 insertions(+), 144 deletions(-) diff --git a/src/binder/expr.rs b/src/binder/expr.rs index dfb4c722..7e73da67 100644 --- a/src/binder/expr.rs +++ b/src/binder/expr.rs @@ -1,3 +1,4 @@ +use crate::catalog::ColumnCatalog; use crate::errors::DatabaseError; use crate::expression; use crate::expression::agg::AggKind; @@ -9,7 +10,7 @@ use std::slice; use std::sync::Arc; use super::{lower_ident, Binder}; -use crate::expression::ScalarExpression; +use crate::expression::{AliasType, ScalarExpression}; use crate::storage::Transaction; use crate::types::value::DataValue; use crate::types::LogicalType; @@ -19,7 +20,7 @@ macro_rules! try_alias { if let Some(expr) = $context.expr_aliases.get(&$column_name) { return Ok(ScalarExpression::Alias { expr: Box::new(expr.clone()), - alias: $column_name, + alias: AliasType::Name($column_name), }); } }; @@ -91,7 +92,7 @@ impl<'a, T: Transaction> Binder<'a, T> { } Expr::Subquery(query) => { let mut sub_query = self.bind_query(query)?; - let sub_query_schema = sub_query.out_schmea(); + let sub_query_schema = sub_query.output_schema(); if sub_query_schema.len() > 1 { return Err(DatabaseError::MisMatch( @@ -99,10 +100,18 @@ impl<'a, T: Transaction> Binder<'a, T> { "the expression returned by the subquery".to_string(), )); } - let expr = ScalarExpression::ColumnRef(sub_query_schema[0].clone()); + let column = sub_query_schema[0].clone(); + let mut alias_column = ColumnCatalog::clone(&column); + alias_column.set_table_name(self.context.temp_table()); + self.context.sub_query(sub_query); - Ok(expr) + Ok(ScalarExpression::Alias { + expr: Box::new(ScalarExpression::ColumnRef(column)), + alias: AliasType::Expr(Box::new(ScalarExpression::ColumnRef(Arc::new( + alias_column, + )))), + }) } _ => { todo!() diff --git a/src/binder/mod.rs b/src/binder/mod.rs index 88c0487f..33854dbc 100644 --- a/src/binder/mod.rs +++ b/src/binder/mod.rs @@ -17,6 +17,7 @@ mod update; use sqlparser::ast::{Ident, ObjectName, ObjectType, SetExpr, Statement}; use std::collections::HashMap; +use std::sync::Arc; use crate::catalog::{TableCatalog, TableName}; use crate::errors::DatabaseError; @@ -57,6 +58,8 @@ pub struct BinderContext<'a, T: Transaction> { bind_step: QueryBindStep, sub_queries: HashMap>, + + temp_table_id: usize, } impl<'a, T: Transaction> BinderContext<'a, T> { @@ -70,9 +73,15 @@ impl<'a, T: Transaction> BinderContext<'a, T> { agg_calls: Default::default(), bind_step: QueryBindStep::From, sub_queries: Default::default(), + temp_table_id: 0, } } + pub fn temp_table(&mut self) -> TableName { + self.temp_table_id += 1; + Arc::new(format!("_temp_table_{}_", self.temp_table_id)) + } + pub fn step(&mut self, bind_step: QueryBindStep) { self.bind_step = bind_step; } @@ -84,7 +93,7 @@ impl<'a, T: Transaction> BinderContext<'a, T> { .push(sub_query) } - pub fn sub_query_for_now(&mut self) -> Option> { + pub fn sub_queries_at_now(&mut self) -> Option> { self.sub_queries.remove(&self.bind_step) } diff --git a/src/binder/select.rs b/src/binder/select.rs index 8578c6a8..9a751b02 100644 --- a/src/binder/select.rs +++ b/src/binder/select.rs @@ -16,10 +16,10 @@ use crate::{ use super::{lower_case_name, lower_ident, Binder, QueryBindStep}; -use crate::catalog::{ColumnCatalog, TableName}; +use crate::catalog::{ColumnCatalog, ColumnSummary, TableName}; use crate::errors::DatabaseError; use crate::execution::volcano::dql::join::joins_nullable; -use crate::expression::BinaryOperator; +use crate::expression::{AliasType, BinaryOperator}; use crate::planner::operator::join::JoinCondition; use crate::planner::operator::sort::{SortField, SortOperator}; use crate::planner::LogicalPlan; @@ -254,7 +254,7 @@ impl<'a, T: Transaction> Binder<'a, T> { select_items.push(ScalarExpression::Alias { expr: Box::new(expr), - alias: alias_name, + alias: AliasType::Name(alias_name), }); } SelectItem::Wildcard(_) => { @@ -301,7 +301,7 @@ impl<'a, T: Transaction> Binder<'a, T> { if let Some(alias_expr) = alias_map.get(&expr) { expr = ScalarExpression::Alias { expr: Box::new(expr), - alias: alias_expr.to_string(), + alias: AliasType::Name(alias_expr.to_string()), } } exprs.push(expr); @@ -360,9 +360,8 @@ impl<'a, T: Transaction> Binder<'a, T> { self.context.step(QueryBindStep::Where); let predicate = self.bind_expr(predicate)?; - println!("{}", predicate); - if let Some(sub_queries) = self.context.sub_query_for_now() { + if let Some(sub_queries) = self.context.sub_queries_at_now() { for mut sub_query in sub_queries { let mut on_keys: Vec<(ScalarExpression, ScalarExpression)> = vec![]; let mut filter = vec![]; @@ -371,20 +370,19 @@ impl<'a, T: Transaction> Binder<'a, T> { predicate.clone(), &mut on_keys, &mut filter, - children.out_schmea(), - sub_query.out_schmea(), + children.output_schema(), + sub_query.output_schema(), )?; // combine multiple filter exprs into one BinaryExpr - let join_filter = - filter - .into_iter() - .reduce(|acc, expr| ScalarExpression::Binary { - op: BinaryOperator::And, - left_expr: Box::new(acc), - right_expr: Box::new(expr), - ty: LogicalType::Boolean, - }); + let join_filter = filter + .into_iter() + .reduce(|acc, expr| ScalarExpression::Binary { + op: BinaryOperator::And, + left_expr: Box::new(acc), + right_expr: Box::new(expr), + ty: LogicalType::Boolean, + }); children = LJoinOperator::build( children, @@ -582,41 +580,97 @@ impl<'a, T: Transaction> Binder<'a, T> { left_schema: &Schema, right_schema: &Schema, ) -> Result<(), DatabaseError> { + let fn_contains = |schema: &Schema, summary: &ColumnSummary| { + schema.iter().any(|column| summary == &column.summary) + }; + let fn_or_contains = + |left_schema: &Schema, right_schema: &Schema, summary: &ColumnSummary| { + fn_contains(left_schema, summary) || fn_contains(right_schema, summary) + }; + match expr { ScalarExpression::Binary { left_expr, right_expr, op, ty, - } => match op { - BinaryOperator::Eq => { - let fn_contains = |schema: &Schema, name: &str| { - schema.iter().any(|column| column.name() == name) - }; - - match (left_expr.as_ref(), right_expr.as_ref()) { - // example: foo = bar - (ScalarExpression::ColumnRef(l), ScalarExpression::ColumnRef(r)) => { - // reorder left and right joins keys to pattern: (left, right) - if fn_contains(left_schema, l.name()) - && fn_contains(right_schema, r.name()) - { - accum.push((*left_expr, *right_expr)); - } else if fn_contains(left_schema, r.name()) - && fn_contains(right_schema, l.name()) - { - accum.push((*right_expr, *left_expr)); - } else { - accum_filter.push(ScalarExpression::Binary { - left_expr, - right_expr, - op, - ty, - }); + } => { + match op { + BinaryOperator::Eq => { + match (left_expr.as_ref(), right_expr.as_ref()) { + // example: foo = bar + (ScalarExpression::ColumnRef(l), ScalarExpression::ColumnRef(r)) => { + // reorder left and right joins keys to pattern: (left, right) + if fn_contains(left_schema, l.summary()) + && fn_contains(right_schema, r.summary()) + { + accum.push((*left_expr, *right_expr)); + } else if fn_contains(left_schema, r.summary()) + && fn_contains(right_schema, l.summary()) + { + accum.push((*right_expr, *left_expr)); + } else if fn_or_contains(left_schema, right_schema, l.summary()) + || fn_or_contains(left_schema, right_schema, r.summary()) + { + accum_filter.push(ScalarExpression::Binary { + left_expr, + right_expr, + op, + ty, + }); + } + } + (ScalarExpression::ColumnRef(column), _) + | (_, ScalarExpression::ColumnRef(column)) => { + if fn_or_contains(left_schema, right_schema, column.summary()) { + accum_filter.push(ScalarExpression::Binary { + left_expr, + right_expr, + op, + ty, + }); + } + } + _other => { + // example: baz > 1 + if left_expr.referenced_columns(true).iter().all(|column| { + fn_or_contains(left_schema, right_schema, column.summary()) + }) && right_expr.referenced_columns(true).iter().all(|column| { + fn_or_contains(left_schema, right_schema, column.summary()) + }) { + accum_filter.push(ScalarExpression::Binary { + left_expr, + right_expr, + op, + ty, + }); + } } } - // example: baz = 1 - _other => { + } + BinaryOperator::And => { + // example: foo = bar AND baz > 1 + Self::extract_join_keys( + *left_expr, + accum, + accum_filter, + left_schema, + right_schema, + )?; + Self::extract_join_keys( + *right_expr, + accum, + accum_filter, + left_schema, + right_schema, + )?; + } + _ => { + if left_expr.referenced_columns(true).iter().all(|column| { + fn_or_contains(left_schema, right_schema, column.summary()) + }) && right_expr.referenced_columns(true).iter().all(|column| { + fn_or_contains(left_schema, right_schema, column.summary()) + }) { accum_filter.push(ScalarExpression::Binary { left_expr, right_expr, @@ -626,36 +680,16 @@ impl<'a, T: Transaction> Binder<'a, T> { } } } - BinaryOperator::And => { - // example: foo = bar AND baz > 1 - Self::extract_join_keys( - *left_expr, - accum, - accum_filter, - left_schema, - right_schema, - )?; - Self::extract_join_keys( - *right_expr, - accum, - accum_filter, - left_schema, - right_schema, - )?; - } - _ => { + } + _ => { + if expr + .referenced_columns(true) + .iter() + .all(|column| fn_or_contains(left_schema, right_schema, column.summary())) + { // example: baz > 1 - accum_filter.push(ScalarExpression::Binary { - left_expr, - right_expr, - op, - ty, - }); + accum_filter.push(expr); } - }, - _ => { - // example: baz in (xxx), something else will convert to filter logic - accum_filter.push(expr); } } diff --git a/src/catalog/column.rs b/src/catalog/column.rs index eb3e00ee..ee5a6e04 100644 --- a/src/catalog/column.rs +++ b/src/catalog/column.rs @@ -67,6 +67,10 @@ impl ColumnCatalog { self.summary.table_name.as_ref() } + pub fn set_table_name(&mut self, table_name: TableName) { + self.summary.table_name = Some(table_name); + } + pub fn datatype(&self) -> &LogicalType { &self.desc.column_datatype } diff --git a/src/execution/volcano/dql/aggregate/hash_agg.rs b/src/execution/volcano/dql/aggregate/hash_agg.rs index 800c83ea..15b351c4 100644 --- a/src/execution/volcano/dql/aggregate/hash_agg.rs +++ b/src/execution/volcano/dql/aggregate/hash_agg.rs @@ -227,7 +227,7 @@ mod test { }), childrens: vec![], physical_option: None, - _out_columns: None, + _output_schema_ref: None, }; let tuples = diff --git a/src/execution/volcano/dql/join/hash_join.rs b/src/execution/volcano/dql/join/hash_join.rs index 37ec7ec2..d552f014 100644 --- a/src/execution/volcano/dql/join/hash_join.rs +++ b/src/execution/volcano/dql/join/hash_join.rs @@ -407,7 +407,7 @@ mod test { }), childrens: vec![], physical_option: None, - _out_columns: None, + _output_schema_ref: None, }; let values_t2 = LogicalPlan { @@ -438,7 +438,7 @@ mod test { }), childrens: vec![], physical_option: None, - _out_columns: None, + _output_schema_ref: None, }; (on_keys, values_t1, values_t2) diff --git a/src/execution/volcano/mod.rs b/src/execution/volcano/mod.rs index 5d2e3855..89c6e9ed 100644 --- a/src/execution/volcano/mod.rs +++ b/src/execution/volcano/mod.rs @@ -118,7 +118,7 @@ pub fn build_write(plan: LogicalPlan, transaction: &mut T) -> Bo operator, mut childrens, physical_option, - _out_columns, + _output_schema_ref: _out_schema_ref, } = plan; match operator { @@ -164,7 +164,7 @@ pub fn build_write(plan: LogicalPlan, transaction: &mut T) -> Bo operator, childrens, physical_option, - _out_columns, + _output_schema_ref: _out_schema_ref, }, transaction, ), diff --git a/src/expression/evaluator.rs b/src/expression/evaluator.rs index eb15a33e..8179c132 100644 --- a/src/expression/evaluator.rs +++ b/src/expression/evaluator.rs @@ -1,7 +1,7 @@ use crate::catalog::ColumnSummary; use crate::errors::DatabaseError; use crate::expression::value_compute::{binary_op, unary_op}; -use crate::expression::ScalarExpression; +use crate::expression::{AliasType, ScalarExpression}; use crate::types::tuple::Tuple; use crate::types::value::{DataValue, ValueRef}; use crate::types::LogicalType; @@ -46,7 +46,14 @@ impl ScalarExpression { if let Some(value) = tuple .schema_ref .iter() - .find_position(|tul_col| tul_col.name() == alias) + .find_position(|tul_col| match alias { + AliasType::Name(alias) => { + tul_col.table_name().is_none() && tul_col.name() == alias + } + AliasType::Expr(alias_expr) => { + alias_expr.output_column().summary == tul_col.summary + } + }) .map(|(i, _)| &tuple.values[i]) { return Ok(value.clone()); diff --git a/src/expression/mod.rs b/src/expression/mod.rs index 6acd28e6..7f46b8f1 100644 --- a/src/expression/mod.rs +++ b/src/expression/mod.rs @@ -1,8 +1,8 @@ use itertools::Itertools; use serde::{Deserialize, Serialize}; +use std::fmt; use std::fmt::{Debug, Formatter}; use std::sync::Arc; -use std::fmt; use sqlparser::ast::{BinaryOperator as SqlBinaryOperator, UnaryOperator as SqlUnaryOperator}; @@ -16,6 +16,12 @@ mod evaluator; pub mod simplify; pub mod value_compute; +#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] +pub enum AliasType { + Name(String), + Expr(Box), +} + /// ScalarExpression represnet all scalar expression in SQL. /// SELECT a+1, b FROM t1. /// a+1 -> ScalarExpression::Unary(a + 1) @@ -26,7 +32,7 @@ pub enum ScalarExpression { ColumnRef(ColumnRef), Alias { expr: Box, - alias: String, + alias: AliasType, }, TypeCast { expr: Box, @@ -164,7 +170,30 @@ impl ScalarExpression { columns_collect(arg, vec, only_column_ref) } } - _ => (), + ScalarExpression::Between { + expr, + left_expr, + right_expr, + .. + } => { + columns_collect(expr, vec, only_column_ref); + columns_collect(left_expr, vec, only_column_ref); + columns_collect(right_expr, vec, only_column_ref); + } + ScalarExpression::SubString { + expr, + for_expr, + from_expr, + } => { + columns_collect(expr, vec, only_column_ref); + if let Some(for_expr) = for_expr { + columns_collect(for_expr, vec, only_column_ref); + } + if let Some(from_expr) = from_expr { + columns_collect(from_expr, vec, only_column_ref); + } + } + ScalarExpression::Constant(_) | ScalarExpression::Empty => (), } } let mut exprs = Vec::new(); @@ -219,8 +248,18 @@ impl ScalarExpression { pub fn output_name(&self) -> String { match self { ScalarExpression::Constant(value) => format!("{}", value), - ScalarExpression::ColumnRef(col) => col.name().to_string(), - ScalarExpression::Alias { alias, .. } => alias.to_string(), + ScalarExpression::ColumnRef(col) => { + if let Some(table_name) = col.table_name() { + return format!("{}.{}", table_name, col.name()); + } + col.name().to_string() + } + ScalarExpression::Alias { alias, expr } => match alias { + AliasType::Name(alias) => alias.to_string(), + AliasType::Expr(alias_expr) => { + format!("({}) as ({})", expr, alias_expr.output_name()) + } + }, ScalarExpression::TypeCast { expr, ty } => { format!("CAST({} as {})", expr.output_name(), ty) } @@ -312,6 +351,10 @@ impl ScalarExpression { pub fn output_column(&self) -> ColumnRef { match self { ScalarExpression::ColumnRef(col) => col.clone(), + ScalarExpression::Alias { + alias: AliasType::Expr(expr), + .. + } => expr.output_column(), _ => Arc::new(ColumnCatalog::new( self.output_name(), true, diff --git a/src/optimizer/heuristic/graph.rs b/src/optimizer/heuristic/graph.rs index 7a3e9ae0..627547c8 100644 --- a/src/optimizer/heuristic/graph.rs +++ b/src/optimizer/heuristic/graph.rs @@ -200,7 +200,7 @@ impl HepGraph { operator, childrens, physical_option, - _out_columns: None, + _output_schema_ref: None, }) } } diff --git a/src/optimizer/heuristic/matcher.rs b/src/optimizer/heuristic/matcher.rs index 320b4bc9..4522bb4b 100644 --- a/src/optimizer/heuristic/matcher.rs +++ b/src/optimizer/heuristic/matcher.rs @@ -103,20 +103,20 @@ mod tests { operator: Operator::Dummy, childrens: vec![], physical_option: None, - _out_columns: None, + _output_schema_ref: None, }], physical_option: None, - _out_columns: None, + _output_schema_ref: None, }, LogicalPlan { operator: Operator::Dummy, childrens: vec![], physical_option: None, - _out_columns: None, + _output_schema_ref: None, }, ], physical_option: None, - _out_columns: None, + _output_schema_ref: None, }; let graph = HepGraph::new(all_dummy_plan.clone()); diff --git a/src/optimizer/rule/normalization/column_pruning.rs b/src/optimizer/rule/normalization/column_pruning.rs index 912c5043..4e0f6b46 100644 --- a/src/optimizer/rule/normalization/column_pruning.rs +++ b/src/optimizer/rule/normalization/column_pruning.rs @@ -26,10 +26,7 @@ lazy_static! { pub struct ColumnPruning; impl ColumnPruning { - fn clear_exprs( - column_references: &mut HashSet, - exprs: &mut Vec, - ) { + fn clear_exprs(column_references: HashSet<&ColumnSummary>, exprs: &mut Vec) { exprs.retain(|expr| { if column_references.contains(expr.output_column().summary()) { return true; @@ -41,7 +38,7 @@ impl ColumnPruning { } fn _apply( - column_references: &mut HashSet, + column_references: HashSet<&ColumnSummary>, all_referenced: bool, node_id: HepNodeId, graph: &mut HepGraph, @@ -85,7 +82,7 @@ impl ColumnPruning { // Todo: Order Project // https://github.com/duckdb/duckdb/blob/main/src/optimizer/remove_unused_columns.cpp#L174 } - for child_id in graph.children_at(node_id).collect_vec() { + if let Some(child_id) = graph.eldest_child_at(node_id) { Self::_apply(column_references, true, child_id, graph); } } @@ -96,11 +93,16 @@ impl ColumnPruning { } } Operator::Limit(_) | Operator::Join(_) | Operator::Filter(_) => { - for column in operator.referenced_columns(false) { - column_references.insert(column.summary().clone()); + let temp_columns = operator.referenced_columns(false); + // why? + let mut column_references = column_references; + for column in temp_columns.iter() { + column_references.insert(column.summary()); } for child_id in graph.children_at(node_id).collect_vec() { - Self::_apply(column_references, all_referenced, child_id, graph); + let copy_references = column_references.clone(); + + Self::_apply(copy_references, all_referenced, child_id, graph); } } // Last Operator @@ -145,13 +147,12 @@ impl ColumnPruning { graph: &mut HepGraph, ) { for child_id in graph.children_at(node_id).collect_vec() { - let mut new_references: HashSet = referenced_columns + let new_references: HashSet<&ColumnSummary> = referenced_columns .iter() .map(|column| column.summary()) - .cloned() .collect(); - Self::_apply(&mut new_references, all_referenced, child_id, graph); + Self::_apply(new_references, all_referenced, child_id, graph); } } } @@ -164,7 +165,7 @@ impl MatchPattern for ColumnPruning { impl NormalizationRule for ColumnPruning { fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), DatabaseError> { - Self::_apply(&mut HashSet::new(), true, node_id, graph); + Self::_apply(HashSet::new(), true, node_id, graph); // mark changed to skip this rule batch graph.version += 1; diff --git a/src/planner/mod.rs b/src/planner/mod.rs index dae9c8f5..e0da3152 100644 --- a/src/planner/mod.rs +++ b/src/planner/mod.rs @@ -8,11 +8,11 @@ use std::sync::Arc; #[derive(Debug, PartialEq, Eq, Clone, Hash)] pub struct LogicalPlan { - pub operator: Operator, - pub childrens: Vec, - pub physical_option: Option, + pub(crate) operator: Operator, + pub(crate) childrens: Vec, + pub(crate) physical_option: Option, - pub _out_columns: Option, + pub(crate) _output_schema_ref: Option, } impl LogicalPlan { @@ -21,7 +21,7 @@ impl LogicalPlan { operator, childrens, physical_option: None, - _out_columns: None, + _output_schema_ref: None, } } @@ -44,11 +44,11 @@ impl LogicalPlan { tables } - pub fn out_schmea(&mut self) -> &SchemaRef { - self._out_columns + pub fn output_schema(&mut self) -> &SchemaRef { + self._output_schema_ref .get_or_insert_with(|| match &self.operator { Operator::Filter(_) | Operator::Sort(_) | Operator::Limit(_) => { - self.childrens[0].out_schmea().clone() + self.childrens[0].output_schema().clone() } Operator::Aggregate(op) => { let out_columns = op @@ -63,7 +63,7 @@ impl LogicalPlan { let out_columns = self .childrens .iter_mut() - .flat_map(|children| Vec::clone(children.out_schmea())) + .flat_map(|children| Vec::clone(children.output_schema())) .collect_vec(); Arc::new(out_columns) } diff --git a/tests/slt/sql_2016/E061_09.slt b/tests/slt/sql_2016/E061_09.slt index 13ca05ca..c317ea3c 100644 --- a/tests/slt/sql_2016/E061_09.slt +++ b/tests/slt/sql_2016/E061_09.slt @@ -1,33 +1,37 @@ # E061-09: Subqueries in comparison predicate -# TODO: Support Subquery on `WHERE` +statement ok +CREATE TABLE TABLE_E061_09_01_01 ( ID INT PRIMARY KEY, A INT ); -# statement ok -# CREATE TABLE TABLE_E061_09_01_01 ( ID INT PRIMARY KEY, A INT ); +query I +SELECT A FROM TABLE_E061_09_01_01 WHERE A < ( SELECT 1 ) -# SELECT A FROM TABLE_E061_09_01_01 WHERE A < ( SELECT 1 ) +statement ok +CREATE TABLE TABLE_E061_09_01_02 ( ID INT PRIMARY KEY, A INT ); -# statement ok -# CREATE TABLE TABLE_E061_09_01_02 ( ID INT PRIMARY KEY, A INT ); +query I +SELECT A FROM TABLE_E061_09_01_02 WHERE A <= ( SELECT 1 ) -# SELECT A FROM TABLE_E061_09_01_02 WHERE A <= ( SELECT 1 ) +statement ok +CREATE TABLE TABLE_E061_09_01_03 ( ID INT PRIMARY KEY, A INT ); -# statement ok -# CREATE TABLE TABLE_E061_09_01_03 ( ID INT PRIMARY KEY, A INT ); +query I +SELECT A FROM TABLE_E061_09_01_03 WHERE A <> ( SELECT 1 ) -# SELECT A FROM TABLE_E061_09_01_03 WHERE A <> ( SELECT 1 ) +statement ok +CREATE TABLE TABLE_E061_09_01_04 ( ID INT PRIMARY KEY, A INT ); -# statement ok -# CREATE TABLE TABLE_E061_09_01_04 ( ID INT PRIMARY KEY, A INT ); +query I +SELECT A FROM TABLE_E061_09_01_04 WHERE A = ( SELECT 1 ) -# SELECT A FROM TABLE_E061_09_01_04 WHERE A = ( SELECT 1 ) +statement ok +CREATE TABLE TABLE_E061_09_01_05 ( ID INT PRIMARY KEY, A INT ); -# statement ok -# CREATE TABLE TABLE_E061_09_01_05 ( ID INT PRIMARY KEY, A INT ); +query I +SELECT A FROM TABLE_E061_09_01_05 WHERE A > ( SELECT 1 ) -# SELECT A FROM TABLE_E061_09_01_05 WHERE A > ( SELECT 1 ) +statement ok +CREATE TABLE TABLE_E061_09_01_06 ( ID INT PRIMARY KEY, A INT ); -# statement ok -# CREATE TABLE TABLE_E061_09_01_06 ( ID INT PRIMARY KEY, A INT ); - -# SELECT A FROM TABLE_E061_09_01_06 WHERE A >= ( SELECT 1 ) +query I +SELECT A FROM TABLE_E061_09_01_06 WHERE A >= ( SELECT 1 ) diff --git a/tests/slt/subquery.slt b/tests/slt/subquery.slt index 332a4805..215cc6bb 100644 --- a/tests/slt/subquery.slt +++ b/tests/slt/subquery.slt @@ -1,31 +1,50 @@ # Test subquery statement ok -create table t(id int primary key, a int not null, b int not null); +create table t1(id int primary key, a int not null, b int not null); statement ok -insert into t values (0, 1, 2), (1, 3, 4); +insert into t1 values (0, 1, 2), (1, 3, 4); query II -select a, b from (select a, b from t); +select a, b from (select a, b from t1); ---- 1 2 3 4 query II -select x.a, x.b from (select a, b from t) as x; +select x.a, x.b from (select a, b from t1) as x; ---- 1 2 3 4 query II -select * from (select a, b from t); +select * from (select a, b from t1); ---- 1 2 3 4 query I -select s from (select a + b as s from t); +select s from (select a + b as s from t1); ---- 3 -7 \ No newline at end of file +7 + +query II rowsort +select x.a from (select -a as a from t1) as x; +---- +-1 +-3 + +query III +select * from t1 where a <= (select 4) and a > (select 1) +---- +1 3 4 + +query III +select * from t1 where a <= (select 4) and (-a + 1) < (select 1) - 1 +---- +1 3 4 + +statement ok +drop table t1; \ No newline at end of file From 38f36769eebab98c0dd29696eadf1308c3d43d90 Mon Sep 17 00:00:00 2001 From: Kould <2435992353@qq.com> Date: Sun, 11 Feb 2024 14:47:14 +0800 Subject: [PATCH 3/3] fix: parameter checking of subquery in Where --- src/binder/expr.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/binder/expr.rs b/src/binder/expr.rs index 7e73da67..adf62e41 100644 --- a/src/binder/expr.rs +++ b/src/binder/expr.rs @@ -94,7 +94,7 @@ impl<'a, T: Transaction> Binder<'a, T> { let mut sub_query = self.bind_query(query)?; let sub_query_schema = sub_query.output_schema(); - if sub_query_schema.len() > 1 { + if sub_query_schema.len() != 1 { return Err(DatabaseError::MisMatch( "expects only one expression to be returned".to_string(), "the expression returned by the subquery".to_string(),