Skip to content

Commit

Permalink
Support: Explicit defaults & Using on Join & the inner table in…
Browse files Browse the repository at this point in the history
… a left or right outer join can also be used in an inner join (#146)

* feat: support `Explicit defaults`

* feat: support `Using` on `Join` & The inner table in a left or right outer join can also be used in an inner join

* fix: the data inserted into t3 in test_crud_sql exceeds the limited length
  • Loading branch information
KKould authored Feb 21, 2024
1 parent 7d4e549 commit e8fc5cd
Show file tree
Hide file tree
Showing 21 changed files with 265 additions and 156 deletions.
8 changes: 5 additions & 3 deletions src/binder/create_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ impl<'a, T: Transaction> Binder<'a, T> {
let mut set = HashSet::new();
for col in columns.iter() {
let col_name = &col.name.value;
if !set.insert(col_name.clone()) {
return Err(DatabaseError::AmbiguousColumn(col_name.to_string()));
if !set.insert(col_name) {
return Err(DatabaseError::DuplicateColumn(col_name.clone()));
}
if !is_valid_identifier(col_name) {
return Err(DatabaseError::InvalidColumn(
Expand Down Expand Up @@ -122,7 +122,9 @@ impl<'a, T: Transaction> Binder<'a, T> {
DataValue::clone(&value).cast(&column_desc.column_datatype)?;
column_desc.default = Some(Arc::new(cast_value));
} else {
unreachable!("'default' only for constant")
return Err(DatabaseError::UnsupportedStmt(
"'default' only for constant".to_string(),
));
}
}
_ => todo!(),
Expand Down
40 changes: 27 additions & 13 deletions src/binder/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use sqlparser::ast::{
use std::slice;
use std::sync::Arc;

use super::{lower_ident, Binder};
use super::{lower_ident, Binder, QueryBindStep};
use crate::expression::function::{FunctionSummary, ScalarFunction};
use crate::expression::{AliasType, ScalarExpression};
use crate::storage::Transaction;
Expand All @@ -27,6 +27,14 @@ macro_rules! try_alias {
};
}

macro_rules! try_default {
($table_name:expr, $column_name:expr) => {
if let (None, "default") = ($table_name, $column_name.as_str()) {
return Ok(ScalarExpression::Empty);
}
};
}

impl<'a, T: Transaction> Binder<'a, T> {
pub(crate) fn bind_expr(&mut self, expr: &Expr) -> Result<ScalarExpression, DatabaseError> {
match expr {
Expand Down Expand Up @@ -102,17 +110,21 @@ impl<'a, T: Transaction> Binder<'a, T> {
));
}
let column = sub_query_schema[0].clone();
let mut alias_column = ColumnCatalog::clone(&column);
alias_column.set_table_name(self.context.temp_table());

self.context.sub_query(sub_query);

Ok(ScalarExpression::Alias {
expr: Box::new(ScalarExpression::ColumnRef(column)),
alias: AliasType::Expr(Box::new(ScalarExpression::ColumnRef(Arc::new(
alias_column,
)))),
})
if self.context.is_step(&QueryBindStep::Where) {
let mut alias_column = ColumnCatalog::clone(&column);
alias_column.set_table_name(self.context.temp_table());

Ok(ScalarExpression::Alias {
expr: Box::new(ScalarExpression::ColumnRef(column)),
alias: AliasType::Expr(Box::new(ScalarExpression::ColumnRef(Arc::new(
alias_column,
)))),
})
} else {
Ok(ScalarExpression::ColumnRef(column))
}
}
Expr::Tuple(exprs) => {
let mut bond_exprs = Vec::with_capacity(exprs.len());
Expand Down Expand Up @@ -215,8 +227,11 @@ impl<'a, T: Transaction> Binder<'a, T> {
))
}
};
try_alias!(self.context, column_name);
if self.context.allow_default {
try_default!(&table_name, column_name);
}
if let Some(table) = table_name.or(bind_table_name) {
try_alias!(self.context, column_name);
let table_catalog = self
.context
.table(Arc::new(table.clone()))
Expand All @@ -227,10 +242,9 @@ impl<'a, T: Transaction> Binder<'a, T> {
.ok_or_else(|| DatabaseError::NotFound("column", column_name))?;
Ok(ScalarExpression::ColumnRef(column_catalog.clone()))
} else {
try_alias!(self.context, column_name);
// handle col syntax
let mut got_column = None;
for (table_catalog, _) in self.context.bind_table.values() {
for table_catalog in self.context.bind_table.values() {
if let Some(column_catalog) = table_catalog.get_column_by_name(&column_name) {
got_column = Some(column_catalog);
}
Expand Down
47 changes: 25 additions & 22 deletions src/binder/insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ impl<'a, T: Transaction> Binder<'a, T> {
expr_rows: &Vec<Vec<Expr>>,
is_overwrite: bool,
) -> Result<LogicalPlan, DatabaseError> {
// FIXME: Make it better to detect the current BindStep
self.context.allow_default = true;
let table_name = Arc::new(lower_case_name(name)?);

if let Some(table) = self.context.table(table_name.clone()) {
Expand All @@ -43,7 +45,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
Some(table_name.to_string()),
)? {
ScalarExpression::ColumnRef(catalog) => columns.push(catalog),
_ => unreachable!(),
_ => return Err(DatabaseError::UnsupportedStmt(ident.to_string())),
}
}
if values_len != columns.len() {
Expand All @@ -53,37 +55,41 @@ impl<'a, T: Transaction> Binder<'a, T> {
}
let schema_ref = _schema_ref.ok_or(DatabaseError::ColumnsEmpty)?;
let mut rows = Vec::with_capacity(expr_rows.len());

for expr_row in expr_rows {
if expr_row.len() != values_len {
return Err(DatabaseError::ValuesLenMismatch(expr_row.len(), values_len));
}
let mut row = Vec::with_capacity(expr_row.len());

for (i, expr) in expr_row.iter().enumerate() {
match &self.bind_expr(expr)? {
ScalarExpression::Constant(value) => {
let mut expression = self.bind_expr(expr)?;

expression.constant_calculation()?;
match expression {
ScalarExpression::Constant(mut value) => {
let ty = schema_ref[i].datatype();
// Check if the value length is too long
value.check_len(schema_ref[i].datatype())?;
let cast_value =
DataValue::clone(value).cast(schema_ref[i].datatype())?;
row.push(Arc::new(cast_value))
}
ScalarExpression::Unary { expr, op, .. } => {
if let ScalarExpression::Constant(value) = expr.as_ref() {
row.push(Arc::new(
DataValue::unary_op(value, op)?
.cast(schema_ref[i].datatype())?,
))
} else {
unreachable!()
value.check_len(ty)?;

if value.logical_type() != *ty {
value = Arc::new(DataValue::clone(&value).cast(ty)?);
}
row.push(value);
}
ScalarExpression::Empty => {
row.push(schema_ref[i].default_value().ok_or_else(|| {
DatabaseError::InvalidColumn(
"column does not exist default".to_string(),
)
})?);
}
_ => unreachable!(),
_ => return Err(DatabaseError::UnsupportedStmt(expr.to_string())),
}
}

rows.push(row);
}
self.context.allow_default = false;
let values_plan = self.bind_values(rows, schema_ref);

Ok(LogicalPlan::new(
Expand All @@ -94,10 +100,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
vec![values_plan],
))
} else {
Err(DatabaseError::InvalidTable(format!(
"not found table {}",
table_name
)))
Err(DatabaseError::TableNotFound)
}
}

Expand Down
19 changes: 11 additions & 8 deletions src/binder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ mod truncate;
mod update;

use sqlparser::ast::{Ident, ObjectName, ObjectType, SetExpr, Statement};
use std::collections::HashMap;
use std::collections::{BTreeMap, HashMap};
use std::sync::Arc;

use crate::catalog::{TableCatalog, TableName};
Expand Down Expand Up @@ -50,7 +50,8 @@ pub enum QueryBindStep {
pub struct BinderContext<'a, T: Transaction> {
functions: &'a Functions,
pub(crate) transaction: &'a T,
pub(crate) bind_table: HashMap<TableName, (&'a TableCatalog, Option<JoinType>)>,
// Tips: When there are multiple tables and Wildcard, use BTreeMap to ensure that the order of the output tables is certain.
pub(crate) bind_table: BTreeMap<(TableName, Option<JoinType>), &'a TableCatalog>,
// alias
expr_aliases: HashMap<String, ScalarExpression>,
table_aliases: HashMap<String, TableName>,
Expand All @@ -62,6 +63,7 @@ pub struct BinderContext<'a, T: Transaction> {
sub_queries: HashMap<QueryBindStep, Vec<LogicalPlan>>,

temp_table_id: usize,
pub(crate) allow_default: bool,
}

impl<'a, T: Transaction> BinderContext<'a, T> {
Expand All @@ -77,6 +79,7 @@ impl<'a, T: Transaction> BinderContext<'a, T> {
bind_step: QueryBindStep::From,
sub_queries: Default::default(),
temp_table_id: 0,
allow_default: false,
}
}

Expand All @@ -89,6 +92,10 @@ impl<'a, T: Transaction> BinderContext<'a, T> {
self.bind_step = bind_step;
}

pub fn is_step(&self, bind_step: &QueryBindStep) -> bool {
&self.bind_step == bind_step
}

pub fn sub_query(&mut self, sub_query: LogicalPlan) {
self.sub_queries
.entry(self.bind_step)
Expand Down Expand Up @@ -120,12 +127,8 @@ impl<'a, T: Transaction> BinderContext<'a, T> {
}
.ok_or(DatabaseError::TableNotFound)?;

let old_table = self
.bind_table
.insert(table_name.clone(), (table, join_type));
if matches!(old_table, Some((_, Some(_)))) {
return Err(DatabaseError::Duplicated("table", table_name.to_string()));
}
self.bind_table
.insert((table_name.clone(), join_type), table);

Ok(table)
}
Expand Down
57 changes: 41 additions & 16 deletions src/binder/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use crate::{
types::value::DataValue,
};

use super::{lower_case_name, lower_ident, Binder, QueryBindStep};
use super::{lower_case_name, lower_ident, Binder, BinderContext, QueryBindStep};

use crate::catalog::{ColumnCatalog, ColumnSummary, TableName};
use crate::errors::DatabaseError;
Expand Down Expand Up @@ -338,7 +338,10 @@ impl<'a, T: Transaction> Binder<'a, T> {
let scan_op = ScanOperator::build(table_name.clone(), table_catalog);

if let Some(TableAlias { name, columns }) = alias {
self.register_alias(columns, name.value.to_lowercase(), table_name.clone())?;
let alias = lower_ident(name);
self.register_alias(columns, alias.clone(), table_name.clone())?;

return Ok((Arc::new(alias), scan_op));
}

Ok((table_name, scan_op))
Expand Down Expand Up @@ -371,7 +374,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
});
}
SelectItem::Wildcard(_) => {
for table_name in self.context.bind_table.keys() {
for (table_name, _) in self.context.bind_table.keys() {
self.bind_table_column_refs(&mut select_items, table_name.clone())?;
}
}
Expand Down Expand Up @@ -443,19 +446,16 @@ impl<'a, T: Transaction> Binder<'a, T> {
};
let (right_table, right) = self.bind_single_table_ref(relation, Some(join_type))?;
let right_table = Self::unpack_name(right_table, false);
let fn_table = |context: &BinderContext<_>, table| {
context
.table(table)
.map(|table| table.schema_ref())
.cloned()
.ok_or(DatabaseError::TableNotFound)
};

let left_table = self
.context
.table(left_table)
.map(|table| table.schema_ref())
.cloned()
.ok_or(DatabaseError::TableNotFound)?;
let right_table = self
.context
.table(right_table)
.map(|table| table.schema_ref())
.cloned()
.ok_or(DatabaseError::TableNotFound)?;
let left_table = fn_table(&self.context, left_table.clone())?;
let right_table = fn_table(&self.context, right_table.clone())?;

let on = match joint_condition {
Some(constraint) => self.bind_join_constraint(&left_table, &right_table, constraint)?,
Expand Down Expand Up @@ -605,7 +605,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
let mut left_table_force_nullable = false;
let mut left_table = None;

for (table, join_option) in bind_tables.values() {
for ((_, join_option), table) in bind_tables {
if let Some(join_type) = join_option {
let (left_force_nullable, right_force_nullable) = joins_nullable(join_type);
table_force_nullable.push((table, right_force_nullable));
Expand Down Expand Up @@ -671,6 +671,31 @@ impl<'a, T: Transaction> Binder<'a, T> {
filter: join_filter,
})
}
JoinConstraint::Using(idents) => {
let mut on_keys: Vec<(ScalarExpression, ScalarExpression)> = vec![];
let fn_column = |schema: &Schema, ident: &Ident| {
schema
.iter()
.find(|column| column.name() == lower_ident(ident))
.map(|column| ScalarExpression::ColumnRef(column.clone()))
};

for ident in idents {
if let (Some(left_column), Some(right_column)) = (
fn_column(left_schema, ident),
fn_column(right_schema, ident),
) {
on_keys.push((left_column, right_column));
} else {
return Err(DatabaseError::InvalidColumn("not found column".to_string()))?;
}
}
Ok(JoinCondition::On {
on: on_keys,
filter: None,
})
}
JoinConstraint::None => Ok(JoinCondition::None),
_ => unimplemented!("not supported join constraint {:?}", constraint),
}
}
Expand Down
Loading

0 comments on commit e8fc5cd

Please sign in to comment.