Skip to content

Commit

Permalink
feat: support Subquery on WHERE with IN/Not IN
Browse files Browse the repository at this point in the history
  • Loading branch information
KKould committed Feb 28, 2024
1 parent a3c7301 commit 0f77a71
Show file tree
Hide file tree
Showing 10 changed files with 155 additions and 72 deletions.
2 changes: 1 addition & 1 deletion src/binder/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
return Ok(());
}
if matches!(expr, ScalarExpression::Alias { .. }) {
return self.validate_having_orderby(expr.unpack_alias());
return self.validate_having_orderby(expr.unpack_alias_ref());
}

Err(DatabaseError::AggMiss(
Expand Down
86 changes: 62 additions & 24 deletions src/binder/expr.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
use crate::catalog::ColumnCatalog;
use crate::catalog::{ColumnCatalog, ColumnRef};
use crate::errors::DatabaseError;
use crate::expression;
use crate::expression::agg::AggKind;
use itertools::Itertools;
use sqlparser::ast::{
BinaryOperator, DataType, Expr, Function, FunctionArg, FunctionArgExpr, Ident, UnaryOperator,
BinaryOperator, DataType, Expr, Function, FunctionArg, FunctionArgExpr, Ident, Query,
UnaryOperator,
};
use std::slice;
use std::sync::Arc;

use super::{lower_ident, Binder, QueryBindStep};
use super::{lower_ident, Binder, QueryBindStep, SubQueryType};
use crate::expression::function::{FunctionSummary, ScalarFunction};
use crate::expression::{AliasType, ScalarExpression};
use crate::planner::LogicalPlan;
use crate::storage::Transaction;
use crate::types::value::DataValue;
use crate::types::LogicalType;
Expand Down Expand Up @@ -99,33 +101,40 @@ impl<'a, T: Transaction> Binder<'a, T> {
from_expr,
})
}
Expr::Subquery(query) => {
let mut sub_query = self.bind_query(query)?;
let sub_query_schema = sub_query.output_schema();

if sub_query_schema.len() != 1 {
return Err(DatabaseError::MisMatch(
"expects only one expression to be returned",
"the expression returned by the subquery",
));
}
let column = sub_query_schema[0].clone();
self.context.sub_query(sub_query);
Expr::Subquery(subquery) => {
let (sub_query, column) = self.bind_subquery(subquery)?;
self.context.sub_query(SubQueryType::SubQuery(sub_query));

if self.context.is_step(&QueryBindStep::Where) {
let mut alias_column = ColumnCatalog::clone(&column);
alias_column.set_table_name(self.context.temp_table());

Ok(ScalarExpression::Alias {
expr: Box::new(ScalarExpression::ColumnRef(column)),
alias: AliasType::Expr(Box::new(ScalarExpression::ColumnRef(Arc::new(
alias_column,
)))),
})
Ok(self.bind_temp_column(column))
} else {
Ok(ScalarExpression::ColumnRef(column))
}
}
Expr::InSubquery {
expr,
subquery,
negated,
} => {
let (sub_query, column) = self.bind_subquery(subquery)?;
self.context
.sub_query(SubQueryType::InSubQuery(*negated, sub_query));

if !self.context.is_step(&QueryBindStep::Where) {
return Err(DatabaseError::UnsupportedStmt(
"`in subquery` can only appear in `Where`".to_string(),
));
}

let alias_expr = self.bind_temp_column(column);

Ok(ScalarExpression::Binary {
op: expression::BinaryOperator::Eq,
left_expr: Box::new(self.bind_expr(expr)?),
right_expr: Box::new(alias_expr),
ty: LogicalType::Boolean,
})
}
Expr::Tuple(exprs) => {
let mut bond_exprs = Vec::with_capacity(exprs.len());

Expand Down Expand Up @@ -187,6 +196,35 @@ impl<'a, T: Transaction> Binder<'a, T> {
}
}

fn bind_temp_column(&mut self, column: ColumnRef) -> ScalarExpression {
let mut alias_column = ColumnCatalog::clone(&column);
alias_column.set_table_name(self.context.temp_table());

ScalarExpression::Alias {
expr: Box::new(ScalarExpression::ColumnRef(column)),
alias: AliasType::Expr(Box::new(ScalarExpression::ColumnRef(Arc::new(
alias_column,
)))),
}
}

fn bind_subquery(
&mut self,
subquery: &Query,
) -> Result<(LogicalPlan, Arc<ColumnCatalog>), DatabaseError> {
let mut sub_query = self.bind_query(subquery)?;
let sub_query_schema = sub_query.output_schema();

if sub_query_schema.len() != 1 {
return Err(DatabaseError::MisMatch(
"expects only one expression to be returned",
"the expression returned by the subquery",
));
}
let column = sub_query_schema[0].clone();
Ok((sub_query, column))
}

pub fn bind_like(
&mut self,
negated: bool,
Expand Down
16 changes: 13 additions & 3 deletions src/binder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ pub enum QueryBindStep {
Limit,
}

#[derive(Debug, Clone, Hash, Eq, PartialEq)]
pub enum SubQueryType {
SubQuery(LogicalPlan),
InSubQuery(bool, LogicalPlan),
}

#[derive(Clone)]
pub struct BinderContext<'a, T: Transaction> {
functions: &'a Functions,
Expand All @@ -60,7 +66,7 @@ pub struct BinderContext<'a, T: Transaction> {
pub(crate) agg_calls: Vec<ScalarExpression>,

bind_step: QueryBindStep,
sub_queries: HashMap<QueryBindStep, Vec<LogicalPlan>>,
sub_queries: HashMap<QueryBindStep, Vec<SubQueryType>>,

temp_table_id: usize,
pub(crate) allow_default: bool,
Expand Down Expand Up @@ -96,14 +102,18 @@ impl<'a, T: Transaction> BinderContext<'a, T> {
&self.bind_step == bind_step
}

pub fn sub_query(&mut self, sub_query: LogicalPlan) {
pub fn step_now(&self) -> QueryBindStep {
self.bind_step
}

pub fn sub_query(&mut self, sub_query: SubQueryType) {
self.sub_queries
.entry(self.bind_step)
.or_default()
.push(sub_query)
}

pub fn sub_queries_at_now(&mut self) -> Option<Vec<LogicalPlan>> {
pub fn sub_queries_at_now(&mut self) -> Option<Vec<SubQueryType>> {
self.sub_queries.remove(&self.bind_step)
}

Expand Down
31 changes: 23 additions & 8 deletions src/binder/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use crate::{
types::value::DataValue,
};

use super::{lower_case_name, lower_ident, Binder, QueryBindStep};
use super::{lower_case_name, lower_ident, Binder, QueryBindStep, SubQueryType};

use crate::catalog::{ColumnCatalog, ColumnSummary, TableName};
use crate::errors::DatabaseError;
Expand All @@ -37,6 +37,8 @@ use sqlparser::ast::{

impl<'a, T: Transaction> Binder<'a, T> {
pub(crate) fn bind_query(&mut self, query: &Query) -> Result<LogicalPlan, DatabaseError> {
let origin_step = self.context.step_now();

if let Some(_with) = &query.with {
// TODO support with clause.
}
Expand All @@ -60,6 +62,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
plan = self.bind_limit(plan, limit, offset)?;
}

self.context.step(origin_step);
Ok(plan)
}

Expand Down Expand Up @@ -463,16 +466,28 @@ impl<'a, T: Transaction> Binder<'a, T> {
let predicate = self.bind_expr(predicate)?;

if let Some(sub_queries) = self.context.sub_queries_at_now() {
for mut sub_query in sub_queries {
for sub_query in sub_queries {
let mut on_keys: Vec<(ScalarExpression, ScalarExpression)> = vec![];
let mut filter = vec![];

let (mut plan, join_ty) = match sub_query {
SubQueryType::SubQuery(plan) => (plan, JoinType::Inner),
SubQueryType::InSubQuery(is_not, plan) => {
let join_ty = if is_not {
JoinType::LeftAnti
} else {
JoinType::LeftSemi
};
(plan, join_ty)
}
};

Self::extract_join_keys(
predicate.clone(),
&mut on_keys,
&mut filter,
children.output_schema(),
sub_query.output_schema(),
plan.output_schema(),
)?;

// combine multiple filter exprs into one BinaryExpr
Expand All @@ -487,12 +502,12 @@ impl<'a, T: Transaction> Binder<'a, T> {

children = LJoinOperator::build(
children,
sub_query,
plan,
JoinCondition::On {
on: on_keys,
filter: join_filter,
},
JoinType::Inner,
join_ty,
);
}
return Ok(children);
Expand Down Expand Up @@ -731,7 +746,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
fn_contains(left_schema, summary) || fn_contains(right_schema, summary)
};

match expr {
match expr.unpack_alias() {
ScalarExpression::Binary {
left_expr,
right_expr,
Expand All @@ -740,7 +755,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
} => {
match op {
BinaryOperator::Eq => {
match (left_expr.as_ref(), right_expr.as_ref()) {
match (left_expr.unpack_alias_ref(), right_expr.unpack_alias_ref()) {
// example: foo = bar
(ScalarExpression::ColumnRef(l), ScalarExpression::ColumnRef(r)) => {
// reorder left and right joins keys to pattern: (left, right)
Expand Down Expand Up @@ -824,7 +839,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
}
}
}
_ => {
expr => {
if expr
.referenced_columns(true)
.iter()
Expand Down
35 changes: 19 additions & 16 deletions src/execution/volcano/dql/join/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ impl HashJoinStatus {
}

#[try_stream(boxed, ok = Tuple, error = DatabaseError)]
#[allow(unused_assignments)]
pub(crate) async fn right_probe(&mut self, tuple: Tuple) {
let HashJoinStatus {
on_right_keys,
Expand All @@ -138,18 +139,19 @@ impl HashJoinStatus {
let values = Self::eval_keys(on_right_keys, &tuple, &full_schema_ref[*left_schema_len..])?;

if let Some((tuples, is_used, is_filtered)) = build_map.get_mut(&values) {
if *ty == JoinType::LeftAnti {
*is_used = true;
return Ok(());
}
let mut bits_option = None;
*is_used = true;

if *ty != JoinType::LeftSemi {
*is_used = true;
} else if *is_filtered {
return Ok(());
} else {
bits_option = Some(BitVector::new(tuples.len()));
match ty {
JoinType::LeftSemi => {
if *is_filtered {
return Ok(());
} else {
bits_option = Some(BitVector::new(tuples.len()));
}
}
JoinType::LeftAnti => return Ok(()),
_ => (),
}
for (i, Tuple { values, .. }) in tuples.iter().enumerate() {
let full_values = values
Expand Down Expand Up @@ -279,7 +281,12 @@ impl HashJoinStatus {
join_ty: &'a JoinType,
left_schema_len: usize,
) {
for (_, (left_tuples, is_used, is_filtered)) in build_map.drain() {
let is_left_semi = matches!(join_ty, JoinType::LeftSemi);

for (_, (left_tuples, mut is_used, is_filtered)) in build_map.drain() {
if is_left_semi {
is_used = !is_used;
}
if is_used {
continue;
}
Expand Down Expand Up @@ -541,7 +548,7 @@ mod test {
executor.ty = JoinType::LeftSemi;
let mut tuples = try_collect(&mut executor.execute(&transaction)).await?;

assert_eq!(tuples.len(), 3);
assert_eq!(tuples.len(), 2);
tuples.sort_by_key(|tuple| {
let mut bytes = Vec::new();
tuple.values[0].memcomparable_encode(&mut bytes).unwrap();
Expand All @@ -556,10 +563,6 @@ mod test {
tuples[1].values,
build_integers(vec![Some(1), Some(3), Some(5)])
);
assert_eq!(
tuples[2].values,
build_integers(vec![Some(3), Some(5), Some(7)])
);
}
// Anti
{
Expand Down
12 changes: 10 additions & 2 deletions src/expression/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,16 +115,24 @@ pub enum ScalarExpression {
}

impl ScalarExpression {
pub fn unpack_alias(&self) -> &ScalarExpression {
pub fn unpack_alias(self) -> ScalarExpression {
if let ScalarExpression::Alias { expr, .. } = self {
expr.unpack_alias()
} else {
self
}
}

pub fn unpack_alias_ref(&self) -> &ScalarExpression {
if let ScalarExpression::Alias { expr, .. } = self {
expr.unpack_alias_ref()
} else {
self
}
}

pub fn try_reference(&mut self, output_exprs: &[ScalarExpression]) {
let fn_output_column = |expr: &ScalarExpression| expr.unpack_alias().output_column();
let fn_output_column = |expr: &ScalarExpression| expr.unpack_alias_ref().output_column();
let self_column = fn_output_column(self);
if let Some((pos, _)) = output_exprs
.iter()
Expand Down
2 changes: 1 addition & 1 deletion src/marcos/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ macro_rules! function {

Arc::new(Self {
summary: FunctionSummary {
name: function_name.to_string(),
name: function_name,
arg_types
}
})
Expand Down
9 changes: 4 additions & 5 deletions tests/slt/sql_2016/E061_11.slt
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# E061-11: Subqueries in IN predicate

# TODO: Support Subquery on `WHERE`
statement ok
CREATE TABLE TABLE_E061_11_01_01 ( ID INT PRIMARY KEY, A INT );

# statement ok
# CREATE TABLE TABLE_E061_11_01_01 ( ID INT PRIMARY KEY, A INT );

# SELECT A FROM TABLE_E061_11_01_01 WHERE A IN ( SELECT 1 )
query I
SELECT A FROM TABLE_E061_11_01_01 WHERE A IN ( SELECT 1 );
Loading

0 comments on commit 0f77a71

Please sign in to comment.