diff --git a/src/binder/aggregate.rs b/src/binder/aggregate.rs index 6c51fd44..c8c1764d 100644 --- a/src/binder/aggregate.rs +++ b/src/binder/aggregate.rs @@ -301,7 +301,7 @@ impl<'a, T: Transaction> Binder<'a, T> { return Ok(()); } if matches!(expr, ScalarExpression::Alias { .. }) { - return self.validate_having_orderby(expr.unpack_alias()); + return self.validate_having_orderby(expr.unpack_alias_ref()); } Err(DatabaseError::AggMiss( diff --git a/src/binder/expr.rs b/src/binder/expr.rs index d50354f7..94c7ed23 100644 --- a/src/binder/expr.rs +++ b/src/binder/expr.rs @@ -1,17 +1,19 @@ -use crate::catalog::ColumnCatalog; +use crate::catalog::{ColumnCatalog, ColumnRef}; use crate::errors::DatabaseError; use crate::expression; use crate::expression::agg::AggKind; use itertools::Itertools; use sqlparser::ast::{ - BinaryOperator, DataType, Expr, Function, FunctionArg, FunctionArgExpr, Ident, UnaryOperator, + BinaryOperator, DataType, Expr, Function, FunctionArg, FunctionArgExpr, Ident, Query, + UnaryOperator, }; use std::slice; use std::sync::Arc; -use super::{lower_ident, Binder, QueryBindStep}; +use super::{lower_ident, Binder, QueryBindStep, SubQueryType}; use crate::expression::function::{FunctionSummary, ScalarFunction}; use crate::expression::{AliasType, ScalarExpression}; +use crate::planner::LogicalPlan; use crate::storage::Transaction; use crate::types::value::DataValue; use crate::types::LogicalType; @@ -99,33 +101,40 @@ impl<'a, T: Transaction> Binder<'a, T> { from_expr, }) } - Expr::Subquery(query) => { - let mut sub_query = self.bind_query(query)?; - let sub_query_schema = sub_query.output_schema(); - - if sub_query_schema.len() != 1 { - return Err(DatabaseError::MisMatch( - "expects only one expression to be returned", - "the expression returned by the subquery", - )); - } - let column = sub_query_schema[0].clone(); - self.context.sub_query(sub_query); + Expr::Subquery(subquery) => { + let (sub_query, column) = self.bind_subquery(subquery)?; + self.context.sub_query(SubQueryType::SubQuery(sub_query)); if self.context.is_step(&QueryBindStep::Where) { - let mut alias_column = ColumnCatalog::clone(&column); - alias_column.set_table_name(self.context.temp_table()); - - Ok(ScalarExpression::Alias { - expr: Box::new(ScalarExpression::ColumnRef(column)), - alias: AliasType::Expr(Box::new(ScalarExpression::ColumnRef(Arc::new( - alias_column, - )))), - }) + Ok(self.bind_temp_column(column)) } else { Ok(ScalarExpression::ColumnRef(column)) } } + Expr::InSubquery { + expr, + subquery, + negated, + } => { + let (sub_query, column) = self.bind_subquery(subquery)?; + self.context + .sub_query(SubQueryType::InSubQuery(*negated, sub_query)); + + if !self.context.is_step(&QueryBindStep::Where) { + return Err(DatabaseError::UnsupportedStmt( + "`in subquery` can only appear in `Where`".to_string(), + )); + } + + let alias_expr = self.bind_temp_column(column); + + Ok(ScalarExpression::Binary { + op: expression::BinaryOperator::Eq, + left_expr: Box::new(self.bind_expr(expr)?), + right_expr: Box::new(alias_expr), + ty: LogicalType::Boolean, + }) + } Expr::Tuple(exprs) => { let mut bond_exprs = Vec::with_capacity(exprs.len()); @@ -187,6 +196,35 @@ impl<'a, T: Transaction> Binder<'a, T> { } } + fn bind_temp_column(&mut self, column: ColumnRef) -> ScalarExpression { + let mut alias_column = ColumnCatalog::clone(&column); + alias_column.set_table_name(self.context.temp_table()); + + ScalarExpression::Alias { + expr: Box::new(ScalarExpression::ColumnRef(column)), + alias: AliasType::Expr(Box::new(ScalarExpression::ColumnRef(Arc::new( + alias_column, + )))), + } + } + + fn bind_subquery( + &mut self, + subquery: &Query, + ) -> Result<(LogicalPlan, Arc), DatabaseError> { + let mut sub_query = self.bind_query(subquery)?; + let sub_query_schema = sub_query.output_schema(); + + if sub_query_schema.len() != 1 { + return Err(DatabaseError::MisMatch( + "expects only one expression to be returned", + "the expression returned by the subquery", + )); + } + let column = sub_query_schema[0].clone(); + Ok((sub_query, column)) + } + pub fn bind_like( &mut self, negated: bool, diff --git a/src/binder/mod.rs b/src/binder/mod.rs index ecf30b51..1b777859 100644 --- a/src/binder/mod.rs +++ b/src/binder/mod.rs @@ -46,6 +46,12 @@ pub enum QueryBindStep { Limit, } +#[derive(Debug, Clone, Hash, Eq, PartialEq)] +pub enum SubQueryType { + SubQuery(LogicalPlan), + InSubQuery(bool, LogicalPlan), +} + #[derive(Clone)] pub struct BinderContext<'a, T: Transaction> { functions: &'a Functions, @@ -60,7 +66,7 @@ pub struct BinderContext<'a, T: Transaction> { pub(crate) agg_calls: Vec, bind_step: QueryBindStep, - sub_queries: HashMap>, + sub_queries: HashMap>, temp_table_id: usize, pub(crate) allow_default: bool, @@ -96,14 +102,18 @@ impl<'a, T: Transaction> BinderContext<'a, T> { &self.bind_step == bind_step } - pub fn sub_query(&mut self, sub_query: LogicalPlan) { + pub fn step_now(&self) -> QueryBindStep { + self.bind_step + } + + pub fn sub_query(&mut self, sub_query: SubQueryType) { self.sub_queries .entry(self.bind_step) .or_default() .push(sub_query) } - pub fn sub_queries_at_now(&mut self) -> Option> { + pub fn sub_queries_at_now(&mut self) -> Option> { self.sub_queries.remove(&self.bind_step) } diff --git a/src/binder/select.rs b/src/binder/select.rs index cdbf08ac..0fa37f48 100644 --- a/src/binder/select.rs +++ b/src/binder/select.rs @@ -14,7 +14,7 @@ use crate::{ types::value::DataValue, }; -use super::{lower_case_name, lower_ident, Binder, QueryBindStep}; +use super::{lower_case_name, lower_ident, Binder, QueryBindStep, SubQueryType}; use crate::catalog::{ColumnCatalog, ColumnSummary, TableName}; use crate::errors::DatabaseError; @@ -37,6 +37,8 @@ use sqlparser::ast::{ impl<'a, T: Transaction> Binder<'a, T> { pub(crate) fn bind_query(&mut self, query: &Query) -> Result { + let origin_step = self.context.step_now(); + if let Some(_with) = &query.with { // TODO support with clause. } @@ -60,6 +62,7 @@ impl<'a, T: Transaction> Binder<'a, T> { plan = self.bind_limit(plan, limit, offset)?; } + self.context.step(origin_step); Ok(plan) } @@ -463,16 +466,28 @@ impl<'a, T: Transaction> Binder<'a, T> { let predicate = self.bind_expr(predicate)?; if let Some(sub_queries) = self.context.sub_queries_at_now() { - for mut sub_query in sub_queries { + for sub_query in sub_queries { let mut on_keys: Vec<(ScalarExpression, ScalarExpression)> = vec![]; let mut filter = vec![]; + let (mut plan, join_ty) = match sub_query { + SubQueryType::SubQuery(plan) => (plan, JoinType::Inner), + SubQueryType::InSubQuery(is_not, plan) => { + let join_ty = if is_not { + JoinType::LeftAnti + } else { + JoinType::LeftSemi + }; + (plan, join_ty) + } + }; + Self::extract_join_keys( predicate.clone(), &mut on_keys, &mut filter, children.output_schema(), - sub_query.output_schema(), + plan.output_schema(), )?; // combine multiple filter exprs into one BinaryExpr @@ -487,12 +502,12 @@ impl<'a, T: Transaction> Binder<'a, T> { children = LJoinOperator::build( children, - sub_query, + plan, JoinCondition::On { on: on_keys, filter: join_filter, }, - JoinType::Inner, + join_ty, ); } return Ok(children); @@ -731,7 +746,7 @@ impl<'a, T: Transaction> Binder<'a, T> { fn_contains(left_schema, summary) || fn_contains(right_schema, summary) }; - match expr { + match expr.unpack_alias() { ScalarExpression::Binary { left_expr, right_expr, @@ -740,7 +755,7 @@ impl<'a, T: Transaction> Binder<'a, T> { } => { match op { BinaryOperator::Eq => { - match (left_expr.as_ref(), right_expr.as_ref()) { + match (left_expr.unpack_alias_ref(), right_expr.unpack_alias_ref()) { // example: foo = bar (ScalarExpression::ColumnRef(l), ScalarExpression::ColumnRef(r)) => { // reorder left and right joins keys to pattern: (left, right) @@ -824,7 +839,7 @@ impl<'a, T: Transaction> Binder<'a, T> { } } } - _ => { + expr => { if expr .referenced_columns(true) .iter() diff --git a/src/execution/volcano/dql/join/hash_join.rs b/src/execution/volcano/dql/join/hash_join.rs index 358d428b..96e28272 100644 --- a/src/execution/volcano/dql/join/hash_join.rs +++ b/src/execution/volcano/dql/join/hash_join.rs @@ -123,6 +123,7 @@ impl HashJoinStatus { } #[try_stream(boxed, ok = Tuple, error = DatabaseError)] + #[allow(unused_assignments)] pub(crate) async fn right_probe(&mut self, tuple: Tuple) { let HashJoinStatus { on_right_keys, @@ -138,18 +139,19 @@ impl HashJoinStatus { let values = Self::eval_keys(on_right_keys, &tuple, &full_schema_ref[*left_schema_len..])?; if let Some((tuples, is_used, is_filtered)) = build_map.get_mut(&values) { - if *ty == JoinType::LeftAnti { - *is_used = true; - return Ok(()); - } let mut bits_option = None; + *is_used = true; - if *ty != JoinType::LeftSemi { - *is_used = true; - } else if *is_filtered { - return Ok(()); - } else { - bits_option = Some(BitVector::new(tuples.len())); + match ty { + JoinType::LeftSemi => { + if *is_filtered { + return Ok(()); + } else { + bits_option = Some(BitVector::new(tuples.len())); + } + } + JoinType::LeftAnti => return Ok(()), + _ => (), } for (i, Tuple { values, .. }) in tuples.iter().enumerate() { let full_values = values @@ -279,7 +281,12 @@ impl HashJoinStatus { join_ty: &'a JoinType, left_schema_len: usize, ) { - for (_, (left_tuples, is_used, is_filtered)) in build_map.drain() { + let is_left_semi = matches!(join_ty, JoinType::LeftSemi); + + for (_, (left_tuples, mut is_used, is_filtered)) in build_map.drain() { + if is_left_semi { + is_used = !is_used; + } if is_used { continue; } @@ -541,7 +548,7 @@ mod test { executor.ty = JoinType::LeftSemi; let mut tuples = try_collect(&mut executor.execute(&transaction)).await?; - assert_eq!(tuples.len(), 3); + assert_eq!(tuples.len(), 2); tuples.sort_by_key(|tuple| { let mut bytes = Vec::new(); tuple.values[0].memcomparable_encode(&mut bytes).unwrap(); @@ -556,10 +563,6 @@ mod test { tuples[1].values, build_integers(vec![Some(1), Some(3), Some(5)]) ); - assert_eq!( - tuples[2].values, - build_integers(vec![Some(3), Some(5), Some(7)]) - ); } // Anti { diff --git a/src/expression/mod.rs b/src/expression/mod.rs index a1823dc6..bcad26aa 100644 --- a/src/expression/mod.rs +++ b/src/expression/mod.rs @@ -115,7 +115,7 @@ pub enum ScalarExpression { } impl ScalarExpression { - pub fn unpack_alias(&self) -> &ScalarExpression { + pub fn unpack_alias(self) -> ScalarExpression { if let ScalarExpression::Alias { expr, .. } = self { expr.unpack_alias() } else { @@ -123,8 +123,16 @@ impl ScalarExpression { } } + pub fn unpack_alias_ref(&self) -> &ScalarExpression { + if let ScalarExpression::Alias { expr, .. } = self { + expr.unpack_alias_ref() + } else { + self + } + } + pub fn try_reference(&mut self, output_exprs: &[ScalarExpression]) { - let fn_output_column = |expr: &ScalarExpression| expr.unpack_alias().output_column(); + let fn_output_column = |expr: &ScalarExpression| expr.unpack_alias_ref().output_column(); let self_column = fn_output_column(self); if let Some((pos, _)) = output_exprs .iter() diff --git a/src/marcos/mod.rs b/src/marcos/mod.rs index fde9c77c..9b48cee3 100644 --- a/src/marcos/mod.rs +++ b/src/marcos/mod.rs @@ -84,7 +84,7 @@ macro_rules! function { Arc::new(Self { summary: FunctionSummary { - name: function_name.to_string(), + name: function_name, arg_types } }) diff --git a/tests/slt/sql_2016/E061_11.slt b/tests/slt/sql_2016/E061_11.slt index 9d204c44..19e9b691 100644 --- a/tests/slt/sql_2016/E061_11.slt +++ b/tests/slt/sql_2016/E061_11.slt @@ -1,8 +1,7 @@ # E061-11: Subqueries in IN predicate -# TODO: Support Subquery on `WHERE` +statement ok +CREATE TABLE TABLE_E061_11_01_01 ( ID INT PRIMARY KEY, A INT ); -# statement ok -# CREATE TABLE TABLE_E061_11_01_01 ( ID INT PRIMARY KEY, A INT ); - -# SELECT A FROM TABLE_E061_11_01_01 WHERE A IN ( SELECT 1 ) +query I +SELECT A FROM TABLE_E061_11_01_01 WHERE A IN ( SELECT 1 ); diff --git a/tests/slt/sql_2016/E061_13.slt b/tests/slt/sql_2016/E061_13.slt index 4b9534f7..5ffc02df 100644 --- a/tests/slt/sql_2016/E061_13.slt +++ b/tests/slt/sql_2016/E061_13.slt @@ -1,19 +1,19 @@ # E061-13: Correlated subqueries -# TODO: Support Subquery on `WHERE` with `IN/Not IN` +statement ok +CREATE TABLE TABLE_E061_13_01_011 ( ID INT PRIMARY KEY, A INT ); -# statement ok -# CREATE TABLE TABLE_E061_13_01_011 ( ID INT PRIMARY KEY, A INT ); +statement ok +CREATE TABLE TABLE_E061_13_01_012 ( ID INT PRIMARY KEY, B INT ); -# statement ok -# CREATE TABLE TABLE_E061_13_01_012 ( ID INT PRIMARY KEY, B INT ); +query I +SELECT A FROM TABLE_E061_13_01_011 WHERE A IN ( SELECT B FROM TABLE_E061_13_01_012 WHERE B = A ); -# SELECT A FROM TABLE_E061_13_01_011 WHERE A IN ( SELECT B FROM TABLE_E061_13_01_012 WHERE B = A ) +statement ok +CREATE TABLE TABLE_E061_13_01_021 ( ID INT PRIMARY KEY, A INT ); -# statement ok -# CREATE TABLE TABLE_E061_13_01_021 ( ID INT PRIMARY KEY, A INT ); +statement ok +CREATE TABLE TABLE_E061_13_01_022 ( ID INT PRIMARY KEY, B INT ); -# statement ok -# CREATE TABLE TABLE_E061_13_01_022 ( ID INT PRIMARY KEY, B INT ); - -# SELECT A FROM TABLE_E061_13_01_021 WHERE A NOT IN ( SELECT B FROM TABLE_E061_13_01_022 WHERE B = A ) +query I +SELECT A FROM TABLE_E061_13_01_021 WHERE A NOT IN ( SELECT B FROM TABLE_E061_13_01_022 WHERE B = A ); diff --git a/tests/slt/subquery.slt b/tests/slt/subquery.slt index 5df37dbc..a370bd54 100644 --- a/tests/slt/subquery.slt +++ b/tests/slt/subquery.slt @@ -44,5 +44,15 @@ select * from t1 where a <= (select 4) and (-a + 1) < (select 1) - 1 ---- 1 3 4 +query III +select * from t1 where a in (select 1) +---- +0 1 2 + +query III +select * from t1 where a not in (select 1) +---- +1 3 4 + statement ok drop table t1; \ No newline at end of file