Skip to content

Commit

Permalink
Support Default, constraints of CreateTable and Fix (#103)
Browse files Browse the repository at this point in the history
* refactor(column): remove field: `table_name`

* feat(create_table): add `default` on `CreateTable`

* feat(cast): support `CAST` Function

* feat: Constraints for CreateTable and Fix `Or` BinaryOperator in scope_aggregation cause index scan error

* version up

* feat(create_table): support `if not exists`

* version up

* code fmt
  • Loading branch information
KKould authored Nov 27, 2023
1 parent f333011 commit e7dbac1
Show file tree
Hide file tree
Showing 54 changed files with 1,093 additions and 632 deletions.
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

[package]
name = "kip-sql"
version = "0.0.1-alpha.3"
version = "0.0.1-alpha.5"
edition = "2021"
authors = ["Kould <[email protected]>", "Xwg <[email protected]>"]
description = "build the SQL layer of KipDB database"
Expand Down Expand Up @@ -36,7 +36,7 @@ ahash = "0.8.3"
lazy_static = "1.4.0"
comfy-table = "7.0.1"
bytes = "1.5.0"
kip_db = "0.1.2-alpha.17"
kip_db = "0.1.2-alpha.18"
rust_decimal = "1"
csv = "1"
regex = "1.10.2"
Expand Down
11 changes: 5 additions & 6 deletions src/binder/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
agg_calls: Vec<ScalarExpression>,
groupby_exprs: Vec<ScalarExpression>,
) -> LogicalPlan {
AggregateOperator::new(children, agg_calls, groupby_exprs)
AggregateOperator::build(children, agg_calls, groupby_exprs)
}

pub fn extract_select_aggregate(
Expand Down Expand Up @@ -153,10 +153,9 @@ impl<'a, T: Transaction> Binder<'a, T> {
HashSet::from_iter(group_raw_exprs.iter());

for expr in select_items {
if expr.has_agg_call(&self.context) {
if expr.has_agg_call() {
continue;
}

group_raw_set.remove(expr);

if !group_raw_exprs.iter().contains(expr) {
Expand All @@ -168,9 +167,9 @@ impl<'a, T: Transaction> Binder<'a, T> {
}

if !group_raw_set.is_empty() {
return Err(BindError::AggMiss(format!(
"In the GROUP BY clause the field must be in the select clause"
)));
return Err(BindError::AggMiss(
"In the GROUP BY clause the field must be in the select clause".to_string(),
));
}

Ok(())
Expand Down
3 changes: 2 additions & 1 deletion src/binder/copy.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::path::PathBuf;
use std::str::FromStr;
use std::sync::Arc;

use crate::planner::operator::copy_from_file::CopyFromFileOperator;
use crate::planner::operator::copy_to_file::CopyToFileOperator;
Expand Down Expand Up @@ -69,7 +70,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
}
};

if let Some(table) = self.context.table(&table_name.to_string()) {
if let Some(table) = self.context.table(Arc::new(table_name.to_string())) {
let cols = table.all_columns();
let ext_source = ExtSource {
path: match target {
Expand Down
105 changes: 87 additions & 18 deletions src/binder/create_table.rs
Original file line number Diff line number Diff line change
@@ -1,44 +1,71 @@
use itertools::Itertools;
use sqlparser::ast::{ColumnDef, ObjectName, TableConstraint};
use sqlparser::ast::{ColumnDef, ColumnOption, ObjectName, TableConstraint};
use std::collections::HashSet;
use std::sync::Arc;

use super::Binder;
use crate::binder::{lower_case_name, split_name, BindError};
use crate::catalog::ColumnCatalog;
use crate::catalog::{ColumnCatalog, ColumnDesc};
use crate::expression::ScalarExpression;
use crate::planner::operator::create_table::CreateTableOperator;
use crate::planner::operator::Operator;
use crate::planner::LogicalPlan;
use crate::storage::Transaction;
use crate::types::value::DataValue;
use crate::types::LogicalType;

impl<'a, T: Transaction> Binder<'a, T> {
// TODO: TableConstraint
pub(crate) fn bind_create_table(
&mut self,
name: &ObjectName,
columns: &[ColumnDef],
_constraints: &[TableConstraint],
constraints: &[TableConstraint],
if_not_exists: bool,
) -> Result<LogicalPlan, BindError> {
let name = lower_case_name(&name);
let name = lower_case_name(name);
let (_, name) = split_name(&name)?;
let table_name = Arc::new(name.to_string());

// check duplicated column names
let mut set = HashSet::new();
for col in columns.iter() {
let col_name = &col.name.value;
if !set.insert(col_name.clone()) {
return Err(BindError::AmbiguousColumn(col_name.to_string()));
{
// check duplicated column names
let mut set = HashSet::new();
for col in columns.iter() {
let col_name = &col.name.value;
if !set.insert(col_name.clone()) {
return Err(BindError::AmbiguousColumn(col_name.to_string()));
}
}
}
let columns = columns
let mut columns: Vec<ColumnCatalog> = columns
.iter()
.map(|col| ColumnCatalog::from(col.clone()))
.collect_vec();

let primary_key_count = columns.iter().filter(|col| col.desc.is_primary).count();
.map(|col| self.bind_column(col))
.try_collect()?;
for constraint in constraints {
match constraint {
TableConstraint::Unique {
columns: column_names,
is_primary,
..
} => {
for column_name in column_names {
if let Some(column) = columns
.iter_mut()
.find(|column| column.name() == column_name.to_string())
{
if *is_primary {
column.desc.is_primary = true;
} else {
column.desc.is_unique = true;
}
}
}
}
_ => todo!(),
}
}

if primary_key_count != 1 {
if columns.iter().filter(|col| col.desc.is_primary).count() != 1 {
return Err(BindError::InvalidTable(
"The primary key field must exist and have at least one".to_string(),
));
Expand All @@ -48,11 +75,53 @@ impl<'a, T: Transaction> Binder<'a, T> {
operator: Operator::CreateTable(CreateTableOperator {
table_name,
columns,
if_not_exists,
}),
childrens: vec![],
};
Ok(plan)
}

fn bind_column(&mut self, column_def: &ColumnDef) -> Result<ColumnCatalog, BindError> {
let column_name = column_def.name.to_string();
let mut column_desc = ColumnDesc::new(
LogicalType::try_from(column_def.data_type.clone())?,
false,
false,
None,
);
let mut nullable = false;

// TODO: 这里可以对更多字段可设置内容进行补充
for option_def in &column_def.options {
match &option_def.option {
ColumnOption::Null => nullable = true,
ColumnOption::NotNull => (),
ColumnOption::Unique { is_primary } => {
if *is_primary {
column_desc.is_primary = true;
nullable = false;
// Skip other options when using primary key
break;
} else {
column_desc.is_unique = true;
}
}
ColumnOption::Default(expr) => {
if let ScalarExpression::Constant(value) = self.bind_expr(expr)? {
let cast_value =
DataValue::clone(&value).cast(&column_desc.column_datatype)?;
column_desc.default = Some(Arc::new(cast_value));
} else {
unreachable!("'default' only for constant")
}
}
_ => todo!(),
}
}

Ok(ColumnCatalog::new(column_name, nullable, column_desc, None))
}
}

#[cfg(test)]
Expand Down Expand Up @@ -84,13 +153,13 @@ mod tests {
assert_eq!(op.columns[0].nullable, false);
assert_eq!(
op.columns[0].desc,
ColumnDesc::new(LogicalType::Integer, true, false)
ColumnDesc::new(LogicalType::Integer, true, false, None)
);
assert_eq!(op.columns[1].name(), "name");
assert_eq!(op.columns[1].nullable, true);
assert_eq!(
op.columns[1].desc,
ColumnDesc::new(LogicalType::Varchar(Some(10)), false, false)
ColumnDesc::new(LogicalType::Varchar(Some(10)), false, false, None)
);
}
_ => unreachable!(),
Expand Down
2 changes: 1 addition & 1 deletion src/binder/distinct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@ impl<'a, T: Transaction> Binder<'a, T> {
children: LogicalPlan,
select_list: Vec<ScalarExpression>,
) -> LogicalPlan {
AggregateOperator::new(children, vec![], select_list)
AggregateOperator::build(children, vec![], select_list)
}
}
2 changes: 1 addition & 1 deletion src/binder/drop_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use std::sync::Arc;

impl<'a, T: Transaction> Binder<'a, T> {
pub(crate) fn bind_drop_table(&mut self, name: &ObjectName) -> Result<LogicalPlan, BindError> {
let name = lower_case_name(&name);
let name = lower_case_name(name);
let (_, name) = split_name(&name)?;
let table_name = Arc::new(name.to_string());

Expand Down
21 changes: 14 additions & 7 deletions src/binder/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::expression;
use crate::expression::agg::AggKind;
use itertools::Itertools;
use sqlparser::ast::{
BinaryOperator, Expr, Function, FunctionArg, FunctionArgExpr, Ident, UnaryOperator,
BinaryOperator, DataType, Expr, Function, FunctionArg, FunctionArgExpr, Ident, UnaryOperator,
};
use std::slice;
use std::sync::Arc;
Expand Down Expand Up @@ -39,6 +39,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
list,
negated,
} => self.bind_is_in(expr, list, *negated),
Expr::Cast { expr, data_type } => self.bind_cast(expr, data_type),
_ => {
todo!()
}
Expand Down Expand Up @@ -86,15 +87,14 @@ impl<'a, T: Transaction> Binder<'a, T> {
.map(|ident| ident.value.clone())
.join(".")
.to_string(),
)
.into())
))
}
};

if let Some(table) = table_name.or(bind_table_name) {
let table_catalog = self
.context
.table(table)
.table(Arc::new(table.clone()))
.ok_or_else(|| BindError::InvalidTable(table.to_string()))?;

let column_catalog = table_catalog
Expand All @@ -104,10 +104,10 @@ impl<'a, T: Transaction> Binder<'a, T> {
} else {
// handle col syntax
let mut got_column = None;
for (_, (table_catalog, _)) in &self.context.bind_table {
for (table_catalog, _) in self.context.bind_table.values() {
if let Some(column_catalog) = table_catalog.get_column_by_name(column_name) {
if got_column.is_some() {
return Err(BindError::InvalidColumn(column_name.to_string()).into());
return Err(BindError::InvalidColumn(column_name.to_string()));
}
got_column = Some(column_catalog);
}
Expand Down Expand Up @@ -176,7 +176,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
};

Ok(ScalarExpression::Unary {
op: (op.clone()).into(),
op: (*op).into(),
expr,
ty,
})
Expand Down Expand Up @@ -255,6 +255,13 @@ impl<'a, T: Transaction> Binder<'a, T> {
})
}

fn bind_cast(&mut self, expr: &Expr, ty: &DataType) -> Result<ScalarExpression, BindError> {
Ok(ScalarExpression::TypeCast {
expr: Box::new(self.bind_expr(expr)?),
ty: LogicalType::try_from(ty.clone())?,
})
}

fn wildcard_expr() -> ScalarExpression {
ScalarExpression::Constant(Arc::new(DataValue::Utf8(Some("*".to_string()))))
}
Expand Down
4 changes: 2 additions & 2 deletions src/binder/insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
let (_, name) = split_name(&name)?;
let table_name = Arc::new(name.to_string());

if let Some(table) = self.context.table(&table_name) {
if let Some(table) = self.context.table(table_name.clone()) {
let mut columns = Vec::new();

if idents.is_empty() {
Expand All @@ -46,7 +46,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
for expr_row in expr_rows {
let mut row = Vec::with_capacity(expr_row.len());

for (i, expr) in expr_row.into_iter().enumerate() {
for (i, expr) in expr_row.iter().enumerate() {
match &self.bind_expr(expr)? {
ScalarExpression::Constant(value) => {
// Check if the value length is too long
Expand Down
Loading

0 comments on commit e7dbac1

Please sign in to comment.