Skip to content

Commit

Permalink
fix: add alias for subquery
Browse files Browse the repository at this point in the history
  • Loading branch information
KKould committed Feb 11, 2024
1 parent d96b2c3 commit de8ca71
Show file tree
Hide file tree
Showing 15 changed files with 274 additions and 144 deletions.
19 changes: 14 additions & 5 deletions src/binder/expr.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::catalog::ColumnCatalog;
use crate::errors::DatabaseError;
use crate::expression;
use crate::expression::agg::AggKind;
Expand All @@ -9,7 +10,7 @@ use std::slice;
use std::sync::Arc;

use super::{lower_ident, Binder};
use crate::expression::ScalarExpression;
use crate::expression::{AliasType, ScalarExpression};
use crate::storage::Transaction;
use crate::types::value::DataValue;
use crate::types::LogicalType;
Expand All @@ -19,7 +20,7 @@ macro_rules! try_alias {
if let Some(expr) = $context.expr_aliases.get(&$column_name) {
return Ok(ScalarExpression::Alias {
expr: Box::new(expr.clone()),
alias: $column_name,
alias: AliasType::Name($column_name),
});
}
};
Expand Down Expand Up @@ -91,18 +92,26 @@ impl<'a, T: Transaction> Binder<'a, T> {
}
Expr::Subquery(query) => {
let mut sub_query = self.bind_query(query)?;
let sub_query_schema = sub_query.out_schmea();
let sub_query_schema = sub_query.output_schema();

if sub_query_schema.len() > 1 {
return Err(DatabaseError::MisMatch(
"expects only one expression to be returned".to_string(),
"the expression returned by the subquery".to_string(),
));
}
let expr = ScalarExpression::ColumnRef(sub_query_schema[0].clone());
let column = sub_query_schema[0].clone();
let mut alias_column = ColumnCatalog::clone(&column);
alias_column.set_table_name(self.context.temp_table());

self.context.sub_query(sub_query);

Ok(expr)
Ok(ScalarExpression::Alias {
expr: Box::new(ScalarExpression::ColumnRef(column)),
alias: AliasType::Expr(Box::new(ScalarExpression::ColumnRef(Arc::new(
alias_column,
)))),
})
}
_ => {
todo!()
Expand Down
11 changes: 10 additions & 1 deletion src/binder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ mod update;

use sqlparser::ast::{Ident, ObjectName, ObjectType, SetExpr, Statement};
use std::collections::HashMap;
use std::sync::Arc;

use crate::catalog::{TableCatalog, TableName};
use crate::errors::DatabaseError;
Expand Down Expand Up @@ -57,6 +58,8 @@ pub struct BinderContext<'a, T: Transaction> {

bind_step: QueryBindStep,
sub_queries: HashMap<QueryBindStep, Vec<LogicalPlan>>,

temp_table_id: usize,
}

impl<'a, T: Transaction> BinderContext<'a, T> {
Expand All @@ -70,9 +73,15 @@ impl<'a, T: Transaction> BinderContext<'a, T> {
agg_calls: Default::default(),
bind_step: QueryBindStep::From,
sub_queries: Default::default(),
temp_table_id: 0,
}
}

pub fn temp_table(&mut self) -> TableName {
self.temp_table_id += 1;
Arc::new(format!("_temp_table_{}_", self.temp_table_id))
}

pub fn step(&mut self, bind_step: QueryBindStep) {
self.bind_step = bind_step;
}
Expand All @@ -84,7 +93,7 @@ impl<'a, T: Transaction> BinderContext<'a, T> {
.push(sub_query)
}

pub fn sub_query_for_now(&mut self) -> Option<Vec<LogicalPlan>> {
pub fn sub_queries_at_now(&mut self) -> Option<Vec<LogicalPlan>> {
self.sub_queries.remove(&self.bind_step)
}

Expand Down
178 changes: 106 additions & 72 deletions src/binder/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ use crate::{

use super::{lower_case_name, lower_ident, Binder, QueryBindStep};

use crate::catalog::{ColumnCatalog, TableName};
use crate::catalog::{ColumnCatalog, ColumnSummary, TableName};
use crate::errors::DatabaseError;
use crate::execution::volcano::dql::join::joins_nullable;
use crate::expression::BinaryOperator;
use crate::expression::{AliasType, BinaryOperator};
use crate::planner::operator::join::JoinCondition;
use crate::planner::operator::sort::{SortField, SortOperator};
use crate::planner::LogicalPlan;
Expand Down Expand Up @@ -254,7 +254,7 @@ impl<'a, T: Transaction> Binder<'a, T> {

select_items.push(ScalarExpression::Alias {
expr: Box::new(expr),
alias: alias_name,
alias: AliasType::Name(alias_name),
});
}
SelectItem::Wildcard(_) => {
Expand Down Expand Up @@ -301,7 +301,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
if let Some(alias_expr) = alias_map.get(&expr) {
expr = ScalarExpression::Alias {
expr: Box::new(expr),
alias: alias_expr.to_string(),
alias: AliasType::Name(alias_expr.to_string()),
}
}
exprs.push(expr);
Expand Down Expand Up @@ -360,9 +360,8 @@ impl<'a, T: Transaction> Binder<'a, T> {
self.context.step(QueryBindStep::Where);

let predicate = self.bind_expr(predicate)?;
println!("{}", predicate);

if let Some(sub_queries) = self.context.sub_query_for_now() {
if let Some(sub_queries) = self.context.sub_queries_at_now() {
for mut sub_query in sub_queries {
let mut on_keys: Vec<(ScalarExpression, ScalarExpression)> = vec![];
let mut filter = vec![];
Expand All @@ -371,20 +370,19 @@ impl<'a, T: Transaction> Binder<'a, T> {
predicate.clone(),
&mut on_keys,
&mut filter,
children.out_schmea(),
sub_query.out_schmea(),
children.output_schema(),
sub_query.output_schema(),
)?;

// combine multiple filter exprs into one BinaryExpr
let join_filter =
filter
.into_iter()
.reduce(|acc, expr| ScalarExpression::Binary {
op: BinaryOperator::And,
left_expr: Box::new(acc),
right_expr: Box::new(expr),
ty: LogicalType::Boolean,
});
let join_filter = filter
.into_iter()
.reduce(|acc, expr| ScalarExpression::Binary {
op: BinaryOperator::And,
left_expr: Box::new(acc),
right_expr: Box::new(expr),
ty: LogicalType::Boolean,
});

children = LJoinOperator::build(
children,
Expand Down Expand Up @@ -582,41 +580,97 @@ impl<'a, T: Transaction> Binder<'a, T> {
left_schema: &Schema,
right_schema: &Schema,
) -> Result<(), DatabaseError> {
let fn_contains = |schema: &Schema, summary: &ColumnSummary| {
schema.iter().any(|column| summary == &column.summary)
};
let fn_or_contains =
|left_schema: &Schema, right_schema: &Schema, summary: &ColumnSummary| {
fn_contains(left_schema, summary) || fn_contains(right_schema, summary)
};

match expr {
ScalarExpression::Binary {
left_expr,
right_expr,
op,
ty,
} => match op {
BinaryOperator::Eq => {
let fn_contains = |schema: &Schema, name: &str| {
schema.iter().any(|column| column.name() == name)
};

match (left_expr.as_ref(), right_expr.as_ref()) {
// example: foo = bar
(ScalarExpression::ColumnRef(l), ScalarExpression::ColumnRef(r)) => {
// reorder left and right joins keys to pattern: (left, right)
if fn_contains(left_schema, l.name())
&& fn_contains(right_schema, r.name())
{
accum.push((*left_expr, *right_expr));
} else if fn_contains(left_schema, r.name())
&& fn_contains(right_schema, l.name())
{
accum.push((*right_expr, *left_expr));
} else {
accum_filter.push(ScalarExpression::Binary {
left_expr,
right_expr,
op,
ty,
});
} => {
match op {
BinaryOperator::Eq => {
match (left_expr.as_ref(), right_expr.as_ref()) {
// example: foo = bar
(ScalarExpression::ColumnRef(l), ScalarExpression::ColumnRef(r)) => {
// reorder left and right joins keys to pattern: (left, right)
if fn_contains(left_schema, l.summary())
&& fn_contains(right_schema, r.summary())
{
accum.push((*left_expr, *right_expr));
} else if fn_contains(left_schema, r.summary())
&& fn_contains(right_schema, l.summary())
{
accum.push((*right_expr, *left_expr));
} else if fn_or_contains(left_schema, right_schema, l.summary())
|| fn_or_contains(left_schema, right_schema, r.summary())
{
accum_filter.push(ScalarExpression::Binary {
left_expr,
right_expr,
op,
ty,
});
}
}
(ScalarExpression::ColumnRef(column), _)
| (_, ScalarExpression::ColumnRef(column)) => {
if fn_or_contains(left_schema, right_schema, column.summary()) {
accum_filter.push(ScalarExpression::Binary {
left_expr,
right_expr,
op,
ty,
});
}
}
_other => {
// example: baz > 1
if left_expr.referenced_columns(true).iter().all(|column| {
fn_or_contains(left_schema, right_schema, column.summary())
}) && right_expr.referenced_columns(true).iter().all(|column| {
fn_or_contains(left_schema, right_schema, column.summary())
}) {
accum_filter.push(ScalarExpression::Binary {
left_expr,
right_expr,
op,
ty,
});
}
}
}
// example: baz = 1
_other => {
}
BinaryOperator::And => {
// example: foo = bar AND baz > 1
Self::extract_join_keys(
*left_expr,
accum,
accum_filter,
left_schema,
right_schema,
)?;
Self::extract_join_keys(
*right_expr,
accum,
accum_filter,
left_schema,
right_schema,
)?;
}
_ => {
if left_expr.referenced_columns(true).iter().all(|column| {
fn_or_contains(left_schema, right_schema, column.summary())
}) && right_expr.referenced_columns(true).iter().all(|column| {
fn_or_contains(left_schema, right_schema, column.summary())
}) {
accum_filter.push(ScalarExpression::Binary {
left_expr,
right_expr,
Expand All @@ -626,36 +680,16 @@ impl<'a, T: Transaction> Binder<'a, T> {
}
}
}
BinaryOperator::And => {
// example: foo = bar AND baz > 1
Self::extract_join_keys(
*left_expr,
accum,
accum_filter,
left_schema,
right_schema,
)?;
Self::extract_join_keys(
*right_expr,
accum,
accum_filter,
left_schema,
right_schema,
)?;
}
_ => {
}
_ => {
if expr
.referenced_columns(true)
.iter()
.all(|column| fn_or_contains(left_schema, right_schema, column.summary()))
{
// example: baz > 1
accum_filter.push(ScalarExpression::Binary {
left_expr,
right_expr,
op,
ty,
});
accum_filter.push(expr);
}
},
_ => {
// example: baz in (xxx), something else will convert to filter logic
accum_filter.push(expr);
}
}

Expand Down
4 changes: 4 additions & 0 deletions src/catalog/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ impl ColumnCatalog {
self.summary.table_name.as_ref()
}

pub fn set_table_name(&mut self, table_name: TableName) {
self.summary.table_name = Some(table_name);
}

pub fn datatype(&self) -> &LogicalType {
&self.desc.column_datatype
}
Expand Down
2 changes: 1 addition & 1 deletion src/execution/volcano/dql/aggregate/hash_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ mod test {
}),
childrens: vec![],
physical_option: None,
_out_columns: None,
_output_schema_ref: None,
};

let tuples =
Expand Down
4 changes: 2 additions & 2 deletions src/execution/volcano/dql/join/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ mod test {
}),
childrens: vec![],
physical_option: None,
_out_columns: None,
_output_schema_ref: None,
};

let values_t2 = LogicalPlan {
Expand Down Expand Up @@ -438,7 +438,7 @@ mod test {
}),
childrens: vec![],
physical_option: None,
_out_columns: None,
_output_schema_ref: None,
};

(on_keys, values_t1, values_t2)
Expand Down
4 changes: 2 additions & 2 deletions src/execution/volcano/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ pub fn build_write<T: Transaction>(plan: LogicalPlan, transaction: &mut T) -> Bo
operator,
mut childrens,
physical_option,
_out_columns,
_output_schema_ref: _out_schema_ref,
} = plan;

match operator {
Expand Down Expand Up @@ -164,7 +164,7 @@ pub fn build_write<T: Transaction>(plan: LogicalPlan, transaction: &mut T) -> Bo
operator,
childrens,
physical_option,
_out_columns,
_output_schema_ref: _out_schema_ref,
},
transaction,
),
Expand Down
Loading

0 comments on commit de8ca71

Please sign in to comment.