diff --git a/Cargo.toml b/Cargo.toml index c3f5c1fb..a3750e2f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,7 @@ [package] name = "kip-sql" -version = "0.0.1-alpha.3" +version = "0.0.1-alpha.5" edition = "2021" authors = ["Kould ", "Xwg "] description = "build the SQL layer of KipDB database" @@ -36,7 +36,7 @@ ahash = "0.8.3" lazy_static = "1.4.0" comfy-table = "7.0.1" bytes = "1.5.0" -kip_db = "0.1.2-alpha.17" +kip_db = "0.1.2-alpha.18" rust_decimal = "1" csv = "1" regex = "1.10.2" diff --git a/src/binder/aggregate.rs b/src/binder/aggregate.rs index 1d9a0778..2ac26f0f 100644 --- a/src/binder/aggregate.rs +++ b/src/binder/aggregate.rs @@ -20,7 +20,7 @@ impl<'a, T: Transaction> Binder<'a, T> { agg_calls: Vec, groupby_exprs: Vec, ) -> LogicalPlan { - AggregateOperator::new(children, agg_calls, groupby_exprs) + AggregateOperator::build(children, agg_calls, groupby_exprs) } pub fn extract_select_aggregate( @@ -153,10 +153,9 @@ impl<'a, T: Transaction> Binder<'a, T> { HashSet::from_iter(group_raw_exprs.iter()); for expr in select_items { - if expr.has_agg_call(&self.context) { + if expr.has_agg_call() { continue; } - group_raw_set.remove(expr); if !group_raw_exprs.iter().contains(expr) { @@ -168,9 +167,9 @@ impl<'a, T: Transaction> Binder<'a, T> { } if !group_raw_set.is_empty() { - return Err(BindError::AggMiss(format!( - "In the GROUP BY clause the field must be in the select clause" - ))); + return Err(BindError::AggMiss( + "In the GROUP BY clause the field must be in the select clause".to_string(), + )); } Ok(()) diff --git a/src/binder/alter_table.rs b/src/binder/alter_table.rs new file mode 100644 index 00000000..882695d7 --- /dev/null +++ b/src/binder/alter_table.rs @@ -0,0 +1,73 @@ +use sqlparser::ast::{AlterTableOperation, ObjectName}; + +use std::sync::Arc; + +use super::Binder; +use crate::binder::{lower_case_name, split_name, BindError}; +use crate::planner::operator::alter_table::{AddColumn, AlterTableOperator}; +use crate::planner::operator::scan::ScanOperator; +use crate::planner::operator::Operator; +use crate::planner::LogicalPlan; +use crate::storage::Transaction; + +impl<'a, T: Transaction> Binder<'a, T> { + pub(crate) fn bind_alter_table( + &mut self, + name: &ObjectName, + operation: &AlterTableOperation, + ) -> Result { + let table_name: Arc = Arc::new(split_name(&lower_case_name(name))?.1.to_string()); + + // we need convert ColumnDef to ColumnCatalog + + let plan = match operation { + AlterTableOperation::AddColumn { + column_keyword: _, + if_not_exists, + column_def, + } => { + if let Some(table) = self.context.table(table_name.clone()) { + let plan = ScanOperator::build(table_name.clone(), table); + + LogicalPlan { + operator: Operator::AlterTable(AlterTableOperator::AddColumn(AddColumn { + table_name, + if_not_exists: *if_not_exists, + column: self.bind_column(column_def)?, + })), + childrens: vec![plan], + } + } else { + return Err(BindError::InvalidTable(format!( + "not found table {}", + table_name + ))); + } + } + AlterTableOperation::DropColumn { + column_name: _, + if_exists: _, + cascade: _, + } => todo!(), + AlterTableOperation::DropPrimaryKey => todo!(), + AlterTableOperation::RenameColumn { + old_column_name: _, + new_column_name: _, + } => todo!(), + AlterTableOperation::RenameTable { table_name: _ } => todo!(), + AlterTableOperation::ChangeColumn { + old_name: _, + new_name: _, + data_type: _, + options: _, + } => todo!(), + AlterTableOperation::AlterColumn { + column_name: _, + op: _, + } => todo!(), + _ => todo!(), + }; + + Ok(plan) + } +} diff --git a/src/binder/copy.rs b/src/binder/copy.rs index 8a28a54a..9d23eed3 100644 --- a/src/binder/copy.rs +++ b/src/binder/copy.rs @@ -1,5 +1,6 @@ use std::path::PathBuf; use std::str::FromStr; +use std::sync::Arc; use crate::planner::operator::copy_from_file::CopyFromFileOperator; use crate::planner::operator::copy_to_file::CopyToFileOperator; @@ -69,7 +70,7 @@ impl<'a, T: Transaction> Binder<'a, T> { } }; - if let Some(table) = self.context.table(&table_name.to_string()) { + if let Some(table) = self.context.table(Arc::new(table_name.to_string())) { let cols = table.all_columns(); let ext_source = ExtSource { path: match target { diff --git a/src/binder/create_table.rs b/src/binder/create_table.rs index 65aeb8f3..abe20525 100644 --- a/src/binder/create_table.rs +++ b/src/binder/create_table.rs @@ -1,15 +1,18 @@ use itertools::Itertools; -use sqlparser::ast::{ColumnDef, ObjectName, TableConstraint}; +use sqlparser::ast::{ColumnDef, ColumnOption, ObjectName, TableConstraint}; use std::collections::HashSet; use std::sync::Arc; use super::Binder; use crate::binder::{lower_case_name, split_name, BindError}; -use crate::catalog::ColumnCatalog; +use crate::catalog::{ColumnCatalog, ColumnDesc}; +use crate::expression::ScalarExpression; use crate::planner::operator::create_table::CreateTableOperator; use crate::planner::operator::Operator; use crate::planner::LogicalPlan; use crate::storage::Transaction; +use crate::types::value::DataValue; +use crate::types::LogicalType; impl<'a, T: Transaction> Binder<'a, T> { // TODO: TableConstraint @@ -17,28 +20,52 @@ impl<'a, T: Transaction> Binder<'a, T> { &mut self, name: &ObjectName, columns: &[ColumnDef], - _constraints: &[TableConstraint], + constraints: &[TableConstraint], + if_not_exists: bool, ) -> Result { - let name = lower_case_name(&name); + let name = lower_case_name(name); let (_, name) = split_name(&name)?; let table_name = Arc::new(name.to_string()); - // check duplicated column names - let mut set = HashSet::new(); - for col in columns.iter() { - let col_name = &col.name.value; - if !set.insert(col_name.clone()) { - return Err(BindError::AmbiguousColumn(col_name.to_string())); + { + // check duplicated column names + let mut set = HashSet::new(); + for col in columns.iter() { + let col_name = &col.name.value; + if !set.insert(col_name.clone()) { + return Err(BindError::AmbiguousColumn(col_name.to_string())); + } } } - let columns = columns + let mut columns: Vec = columns .iter() - .map(|col| ColumnCatalog::from(col.clone())) - .collect_vec(); - - let primary_key_count = columns.iter().filter(|col| col.desc.is_primary).count(); + .map(|col| self.bind_column(col)) + .try_collect()?; + for constraint in constraints { + match constraint { + TableConstraint::Unique { + columns: column_names, + is_primary, + .. + } => { + for column_name in column_names { + if let Some(column) = columns + .iter_mut() + .find(|column| column.name() == column_name.to_string()) + { + if *is_primary { + column.desc.is_primary = true; + } else { + column.desc.is_unique = true; + } + } + } + } + _ => todo!(), + } + } - if primary_key_count != 1 { + if columns.iter().filter(|col| col.desc.is_primary).count() != 1 { return Err(BindError::InvalidTable( "The primary key field must exist and have at least one".to_string(), )); @@ -48,11 +75,53 @@ impl<'a, T: Transaction> Binder<'a, T> { operator: Operator::CreateTable(CreateTableOperator { table_name, columns, + if_not_exists, }), childrens: vec![], }; Ok(plan) } + + pub fn bind_column(&mut self, column_def: &ColumnDef) -> Result { + let column_name = column_def.name.to_string(); + let mut column_desc = ColumnDesc::new( + LogicalType::try_from(column_def.data_type.clone())?, + false, + false, + None, + ); + let mut nullable = false; + + // TODO: 这里可以对更多字段可设置内容进行补充 + for option_def in &column_def.options { + match &option_def.option { + ColumnOption::Null => nullable = true, + ColumnOption::NotNull => (), + ColumnOption::Unique { is_primary } => { + if *is_primary { + column_desc.is_primary = true; + nullable = false; + // Skip other options when using primary key + break; + } else { + column_desc.is_unique = true; + } + } + ColumnOption::Default(expr) => { + if let ScalarExpression::Constant(value) = self.bind_expr(expr)? { + let cast_value = + DataValue::clone(&value).cast(&column_desc.column_datatype)?; + column_desc.default = Some(Arc::new(cast_value)); + } else { + unreachable!("'default' only for constant") + } + } + _ => todo!(), + } + } + + Ok(ColumnCatalog::new(column_name, nullable, column_desc, None)) + } } #[cfg(test)] @@ -84,13 +153,13 @@ mod tests { assert_eq!(op.columns[0].nullable, false); assert_eq!( op.columns[0].desc, - ColumnDesc::new(LogicalType::Integer, true, false) + ColumnDesc::new(LogicalType::Integer, true, false, None) ); assert_eq!(op.columns[1].name(), "name"); assert_eq!(op.columns[1].nullable, true); assert_eq!( op.columns[1].desc, - ColumnDesc::new(LogicalType::Varchar(Some(10)), false, false) + ColumnDesc::new(LogicalType::Varchar(Some(10)), false, false, None) ); } _ => unreachable!(), diff --git a/src/binder/distinct.rs b/src/binder/distinct.rs index 1666040b..fa184821 100644 --- a/src/binder/distinct.rs +++ b/src/binder/distinct.rs @@ -10,6 +10,6 @@ impl<'a, T: Transaction> Binder<'a, T> { children: LogicalPlan, select_list: Vec, ) -> LogicalPlan { - AggregateOperator::new(children, vec![], select_list) + AggregateOperator::build(children, vec![], select_list) } } diff --git a/src/binder/drop_table.rs b/src/binder/drop_table.rs index 0f604273..e2296d2c 100644 --- a/src/binder/drop_table.rs +++ b/src/binder/drop_table.rs @@ -7,13 +7,20 @@ use sqlparser::ast::ObjectName; use std::sync::Arc; impl<'a, T: Transaction> Binder<'a, T> { - pub(crate) fn bind_drop_table(&mut self, name: &ObjectName) -> Result { - let name = lower_case_name(&name); + pub(crate) fn bind_drop_table( + &mut self, + name: &ObjectName, + if_exists: &bool, + ) -> Result { + let name = lower_case_name(name); let (_, name) = split_name(&name)?; let table_name = Arc::new(name.to_string()); let plan = LogicalPlan { - operator: Operator::DropTable(DropTableOperator { table_name }), + operator: Operator::DropTable(DropTableOperator { + table_name, + if_exists: *if_exists, + }), childrens: vec![], }; Ok(plan) diff --git a/src/binder/expr.rs b/src/binder/expr.rs index eefe09d8..a215df75 100644 --- a/src/binder/expr.rs +++ b/src/binder/expr.rs @@ -3,7 +3,7 @@ use crate::expression; use crate::expression::agg::AggKind; use itertools::Itertools; use sqlparser::ast::{ - BinaryOperator, Expr, Function, FunctionArg, FunctionArgExpr, Ident, UnaryOperator, + BinaryOperator, DataType, Expr, Function, FunctionArg, FunctionArgExpr, Ident, UnaryOperator, }; use std::slice; use std::sync::Arc; @@ -39,6 +39,7 @@ impl<'a, T: Transaction> Binder<'a, T> { list, negated, } => self.bind_is_in(expr, list, *negated), + Expr::Cast { expr, data_type } => self.bind_cast(expr, data_type), _ => { todo!() } @@ -86,15 +87,14 @@ impl<'a, T: Transaction> Binder<'a, T> { .map(|ident| ident.value.clone()) .join(".") .to_string(), - ) - .into()) + )) } }; if let Some(table) = table_name.or(bind_table_name) { let table_catalog = self .context - .table(table) + .table(Arc::new(table.clone())) .ok_or_else(|| BindError::InvalidTable(table.to_string()))?; let column_catalog = table_catalog @@ -104,10 +104,10 @@ impl<'a, T: Transaction> Binder<'a, T> { } else { // handle col syntax let mut got_column = None; - for (_, (table_catalog, _)) in &self.context.bind_table { + for (table_catalog, _) in self.context.bind_table.values() { if let Some(column_catalog) = table_catalog.get_column_by_name(column_name) { if got_column.is_some() { - return Err(BindError::InvalidColumn(column_name.to_string()).into()); + return Err(BindError::InvalidColumn(column_name.to_string())); } got_column = Some(column_catalog); } @@ -176,7 +176,7 @@ impl<'a, T: Transaction> Binder<'a, T> { }; Ok(ScalarExpression::Unary { - op: (op.clone()).into(), + op: (*op).into(), expr, ty, }) @@ -255,6 +255,13 @@ impl<'a, T: Transaction> Binder<'a, T> { }) } + fn bind_cast(&mut self, expr: &Expr, ty: &DataType) -> Result { + Ok(ScalarExpression::TypeCast { + expr: Box::new(self.bind_expr(expr)?), + ty: LogicalType::try_from(ty.clone())?, + }) + } + fn wildcard_expr() -> ScalarExpression { ScalarExpression::Constant(Arc::new(DataValue::Utf8(Some("*".to_string())))) } diff --git a/src/binder/insert.rs b/src/binder/insert.rs index 596ce692..113aa76b 100644 --- a/src/binder/insert.rs +++ b/src/binder/insert.rs @@ -24,7 +24,7 @@ impl<'a, T: Transaction> Binder<'a, T> { let (_, name) = split_name(&name)?; let table_name = Arc::new(name.to_string()); - if let Some(table) = self.context.table(&table_name) { + if let Some(table) = self.context.table(table_name.clone()) { let mut columns = Vec::new(); if idents.is_empty() { @@ -42,11 +42,11 @@ impl<'a, T: Transaction> Binder<'a, T> { } } let mut rows = Vec::with_capacity(expr_rows.len()); - + println!("{:?}", expr_rows); for expr_row in expr_rows { let mut row = Vec::with_capacity(expr_row.len()); - for (i, expr) in expr_row.into_iter().enumerate() { + for (i, expr) in expr_row.iter().enumerate() { match &self.bind_expr(expr)? { ScalarExpression::Constant(value) => { // Check if the value length is too long diff --git a/src/binder/mod.rs b/src/binder/mod.rs index fa085df0..2a649bb1 100644 --- a/src/binder/mod.rs +++ b/src/binder/mod.rs @@ -1,4 +1,5 @@ pub mod aggregate; +mod alter_table; pub mod copy; mod create_table; mod delete; @@ -48,9 +49,9 @@ impl<'a, T: Transaction> BinderContext<'a, T> { } } - pub fn table(&self, table_name: &String) -> Option<&TableCatalog> { - if let Some(real_name) = self.table_aliases.get(table_name) { - self.transaction.table(real_name) + pub fn table(&self, table_name: TableName) -> Option<&TableCatalog> { + if let Some(real_name) = self.table_aliases.get(table_name.as_ref()) { + self.transaction.table(real_name.clone()) } else { self.transaction.table(table_name) } @@ -119,16 +120,21 @@ impl<'a, T: Transaction> Binder<'a, T> { pub fn bind(mut self, stmt: &Statement) -> Result { let plan = match stmt { Statement::Query(query) => self.bind_query(query)?, + Statement::AlterTable { name, operation } => self.bind_alter_table(name, operation)?, Statement::CreateTable { name, columns, constraints, + if_not_exists, .. - } => self.bind_create_table(name, &columns, &constraints)?, + } => self.bind_create_table(name, columns, constraints, *if_not_exists)?, Statement::Drop { - object_type, names, .. + object_type, + names, + if_exists, + .. } => match object_type { - ObjectType::Table => self.bind_drop_table(&names[0])?, + ObjectType::Table => self.bind_drop_table(&names[0], if_exists)?, _ => todo!(), }, Statement::Insert { @@ -175,7 +181,7 @@ impl<'a, T: Transaction> Binder<'a, T> { target, options, .. - } => self.bind_copy(source.clone(), *to, target.clone(), &options)?, + } => self.bind_copy(source.clone(), *to, target.clone(), options)?, _ => return Err(BindError::UnsupportedStmt(stmt.to_string())), }; Ok(plan) @@ -209,6 +215,8 @@ pub enum BindError { InvalidTable(String), #[error("invalid table name: {0:?}")] InvalidTableName(Vec), + #[error("not found table: {0}")] + NotFoundTable(String), #[error("invalid column {0}")] InvalidColumn(String), #[error("ambiguous column {0}")] @@ -252,16 +260,17 @@ pub mod test { ColumnCatalog::new( "c1".to_string(), false, - ColumnDesc::new(Integer, true, false), + ColumnDesc::new(Integer, true, false, None), None, ), ColumnCatalog::new( "c2".to_string(), false, - ColumnDesc::new(Integer, false, true), + ColumnDesc::new(Integer, false, true, None), None, ), ], + false, )?; let _ = transaction.create_table( @@ -270,16 +279,17 @@ pub mod test { ColumnCatalog::new( "c3".to_string(), false, - ColumnDesc::new(Integer, true, false), + ColumnDesc::new(Integer, true, false, None), None, ), ColumnCatalog::new( "c4".to_string(), false, - ColumnDesc::new(Integer, false, false), + ColumnDesc::new(Integer, false, false, None), None, ), ], + false, )?; transaction.commit().await?; diff --git a/src/binder/select.rs b/src/binder/select.rs index cb8868d8..9668da49 100644 --- a/src/binder/select.rs +++ b/src/binder/select.rs @@ -1,5 +1,4 @@ use std::borrow::Borrow; -use std::collections::HashMap; use std::sync::Arc; use crate::{ @@ -132,7 +131,7 @@ impl<'a, T: Transaction> Binder<'a, T> { let left_name = Self::unpack_name(left_name, true); for join in joins { - plan = self.bind_join(&left_name, plan, join)?; + plan = self.bind_join(left_name.clone(), plan, join)?; } } Ok(plan) @@ -140,7 +139,7 @@ impl<'a, T: Transaction> Binder<'a, T> { fn unpack_name(table_name: Option, is_left: bool) -> TableName { let title = if is_left { "Left" } else { "Right" }; - table_name.expect(&format!("{}: Table is not named", title)) + table_name.unwrap_or_else(|| panic!("{}: Table is not named", title)) } fn bind_single_table_ref( @@ -207,7 +206,7 @@ impl<'a, T: Transaction> Binder<'a, T> { let table_catalog = self .context - .table(&table_name) + .table(table_name.clone()) .cloned() .ok_or_else(|| BindError::InvalidTable(format!("bind table {}", table)))?; @@ -221,7 +220,7 @@ impl<'a, T: Transaction> Binder<'a, T> { Ok(( table_name.clone(), - ScanOperator::new(table_name, &table_catalog), + ScanOperator::build(table_name, &table_catalog), )) } @@ -264,10 +263,10 @@ impl<'a, T: Transaction> Binder<'a, T> { fn bind_all_column_refs(&mut self) -> Result, BindError> { let mut exprs = vec![]; - for table_name in self.context.bind_table.keys().cloned() { + for table_name in self.context.bind_table.keys() { let table = self .context - .table(&table_name) + .table(table_name.clone()) .ok_or_else(|| BindError::InvalidTable(table_name.to_string()))?; for col in table.all_columns() { exprs.push(ScalarExpression::ColumnRef(col)); @@ -279,7 +278,7 @@ impl<'a, T: Transaction> Binder<'a, T> { fn bind_join( &mut self, - left_table: &String, + left_table: TableName, left: LogicalPlan, join: &Join, ) -> Result { @@ -299,21 +298,23 @@ impl<'a, T: Transaction> Binder<'a, T> { let (right_table, right) = self.bind_single_table_ref(relation, Some(join_type))?; let right_table = Self::unpack_name(right_table, false); - let left_table = - self.context.table(left_table).cloned().ok_or_else(|| { - BindError::InvalidTable(format!("Left: {} not found", left_table)) - })?; - let right_table = - self.context.table(&right_table).cloned().ok_or_else(|| { - BindError::InvalidTable(format!("Right: {} not found", right_table)) - })?; + let left_table = self + .context + .table(left_table.clone()) + .cloned() + .ok_or_else(|| BindError::InvalidTable(format!("Left: {} not found", left_table)))?; + let right_table = self + .context + .table(right_table.clone()) + .cloned() + .ok_or_else(|| BindError::InvalidTable(format!("Right: {} not found", right_table)))?; let on = match joint_condition { Some(constraint) => self.bind_join_constraint(&left_table, &right_table, constraint)?, None => JoinCondition::None, }; - Ok(LJoinOperator::new(left, right, on, join_type)) + Ok(LJoinOperator::build(left, right, on, join_type)) } pub(crate) fn bind_where( @@ -321,7 +322,7 @@ impl<'a, T: Transaction> Binder<'a, T> { children: LogicalPlan, predicate: &Expr, ) -> Result { - Ok(FilterOperator::new( + Ok(FilterOperator::build( self.bind_expr(predicate)?, children, false, @@ -334,7 +335,7 @@ impl<'a, T: Transaction> Binder<'a, T> { having: ScalarExpression, ) -> Result { self.validate_having_orderby(&having)?; - Ok(FilterOperator::new(having, children, true)) + Ok(FilterOperator::build(having, children, true)) } fn bind_project( @@ -400,7 +401,7 @@ impl<'a, T: Transaction> Binder<'a, T> { // TODO: validate limit and offset is correct use statistic. - Ok(LimitOperator::new(offset, limit, children)) + Ok(LimitOperator::build(offset, limit, children)) } pub fn extract_select_join(&mut self, select_items: &mut [ScalarExpression]) { @@ -409,33 +410,35 @@ impl<'a, T: Transaction> Binder<'a, T> { return; } - let mut table_force_nullable = HashMap::new(); + let mut table_force_nullable = Vec::with_capacity(bind_tables.len()); let mut left_table_force_nullable = false; let mut left_table = None; - for (table_name, (_, join_option)) in bind_tables { + for (table, join_option) in bind_tables.values() { if let Some(join_type) = join_option { let (left_force_nullable, right_force_nullable) = joins_nullable(join_type); - table_force_nullable.insert(table_name.clone(), right_force_nullable); + table_force_nullable.push((table, right_force_nullable)); left_table_force_nullable = left_force_nullable; } else { - left_table = Some(table_name.clone()); + left_table = Some(table); } } - if let Some(name) = left_table { - table_force_nullable.insert(name, left_table_force_nullable); + if let Some(table) = left_table { + table_force_nullable.push((table, left_table_force_nullable)); } for column in select_items { if let ScalarExpression::ColumnRef(col) = column { - if let Some(nullable) = table_force_nullable.get(col.table_name().as_ref().unwrap()) - { - let mut new_col = ColumnCatalog::clone(col); - new_col.nullable = *nullable; + let _ = table_force_nullable + .iter() + .find(|(table, _)| table.contains_column(col.name())) + .map(|(_, nullable)| { + let mut new_col = ColumnCatalog::clone(col); + new_col.nullable = *nullable; - *col = Arc::new(new_col) - } + *col = Arc::new(new_col); + }); } } } diff --git a/src/binder/truncate.rs b/src/binder/truncate.rs index 4a34478e..17ce7818 100644 --- a/src/binder/truncate.rs +++ b/src/binder/truncate.rs @@ -8,7 +8,7 @@ use std::sync::Arc; impl<'a, T: Transaction> Binder<'a, T> { pub(crate) fn bind_truncate(&mut self, name: &ObjectName) -> Result { - let name = lower_case_name(&name); + let name = lower_case_name(name); let (_, name) = split_name(&name)?; let table_name = Arc::new(name.to_string()); diff --git a/src/binder/update.rs b/src/binder/update.rs index 868d2141..784e7369 100644 --- a/src/binder/update.rs +++ b/src/binder/update.rs @@ -17,7 +17,7 @@ impl<'a, T: Transaction> Binder<'a, T> { assignments: &[Assignment], ) -> Result { if let TableFactor::Table { name, .. } = &to.relation { - let name = lower_case_name(&name); + let name = lower_case_name(name); let (_, name) = split_name(&name)?; let table_name = Arc::new(name.to_string()); @@ -40,7 +40,7 @@ impl<'a, T: Transaction> Binder<'a, T> { for ident in &assignment.id { match self.bind_column_ref_from_identifiers( - slice::from_ref(&ident), + slice::from_ref(ident), bind_table_name.as_ref(), )? { ScalarExpression::ColumnRef(catalog) => { diff --git a/src/catalog/column.rs b/src/catalog/column.rs index 4e73de6c..be603484 100644 --- a/src/catalog/column.rs +++ b/src/catalog/column.rs @@ -1,10 +1,9 @@ -use crate::catalog::TableName; use crate::expression::ScalarExpression; use serde::{Deserialize, Serialize}; -use sqlparser::ast::{ColumnDef, ColumnOption}; use std::hash::Hash; use std::sync::Arc; +use crate::types::value::ValueRef; use crate::types::{ColumnId, LogicalType}; pub type ColumnRef = Arc; @@ -21,7 +20,6 @@ pub struct ColumnCatalog { pub struct ColumnSummary { pub id: Option, pub name: String, - pub table_name: Option, } impl ColumnCatalog { @@ -35,7 +33,6 @@ impl ColumnCatalog { summary: ColumnSummary { id: None, name: column_name, - table_name: None, }, nullable, desc: column_desc, @@ -48,10 +45,9 @@ impl ColumnCatalog { summary: ColumnSummary { id: Some(0), name: column_name, - table_name: None, }, nullable: false, - desc: ColumnDesc::new(LogicalType::Varchar(None), false, false), + desc: ColumnDesc::new(LogicalType::Varchar(None), false, false, None), ref_expr: None, } } @@ -64,10 +60,6 @@ impl ColumnCatalog { self.summary.id } - pub(crate) fn table_name(&self) -> Option { - self.summary.table_name.clone() - } - pub(crate) fn name(&self) -> &str { &self.summary.name } @@ -76,51 +68,23 @@ impl ColumnCatalog { &self.desc.column_datatype } + pub(crate) fn default_value(&self) -> Option { + self.desc.default.clone() + } + #[allow(dead_code)] pub(crate) fn desc(&self) -> &ColumnDesc { &self.desc } } -impl From for ColumnCatalog { - fn from(column_def: ColumnDef) -> Self { - let column_name = column_def.name.to_string(); - let mut column_desc = ColumnDesc::new( - LogicalType::try_from(column_def.data_type).unwrap(), - false, - false, - ); - let mut nullable = false; - - // TODO: 这里可以对更多字段可设置内容进行补充 - for option_def in column_def.options { - match option_def.option { - ColumnOption::Null => nullable = true, - ColumnOption::NotNull => (), - ColumnOption::Unique { is_primary } => { - if is_primary { - column_desc.is_primary = true; - nullable = false; - // Skip other options when using primary key - break; - } else { - column_desc.is_unique = true; - } - } - _ => todo!(), - } - } - - ColumnCatalog::new(column_name, nullable, column_desc, None) - } -} - /// The descriptor of a column. #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] pub struct ColumnDesc { pub(crate) column_datatype: LogicalType, pub(crate) is_primary: bool, pub(crate) is_unique: bool, + pub(crate) default: Option, } impl ColumnDesc { @@ -128,11 +92,17 @@ impl ColumnDesc { column_datatype: LogicalType, is_primary: bool, is_unique: bool, + default: Option, ) -> ColumnDesc { ColumnDesc { column_datatype, is_primary, is_unique, + default, } } + + pub(crate) fn is_index(&self) -> bool { + self.is_unique || self.is_primary + } } diff --git a/src/catalog/mod.rs b/src/catalog/mod.rs index a5c35822..b0d077f0 100644 --- a/src/catalog/mod.rs +++ b/src/catalog/mod.rs @@ -21,4 +21,6 @@ pub enum CatalogError { NotFound(&'static str, String), #[error("duplicated {0}: {1}")] Duplicated(&'static str, String), + #[error("columns empty")] + ColumnsEmpty, } diff --git a/src/catalog/root.rs b/src/catalog/root.rs index 9112dfb3..d96f782d 100644 --- a/src/catalog/root.rs +++ b/src/catalog/root.rs @@ -62,13 +62,13 @@ mod tests { let col0 = ColumnCatalog::new( "a".to_string(), false, - ColumnDesc::new(LogicalType::Integer, false, false), + ColumnDesc::new(LogicalType::Integer, false, false, None), None, ); let col1 = ColumnCatalog::new( "b".to_string(), false, - ColumnDesc::new(LogicalType::Boolean, false, false), + ColumnDesc::new(LogicalType::Boolean, false, false, None), None, ); let col_catalogs = vec![col0, col1]; diff --git a/src/catalog/table.rs b/src/catalog/table.rs index 92268f61..bcdb2178 100644 --- a/src/catalog/table.rs +++ b/src/catalog/table.rs @@ -47,10 +47,7 @@ impl TableCatalog { } pub(crate) fn all_columns(&self) -> Vec { - self.columns - .iter() - .map(|(_, col)| Arc::clone(col)) - .collect() + self.columns.values().map(Arc::clone).collect() } /// Add a column to the table catalog. @@ -62,7 +59,6 @@ impl TableCatalog { let col_id = self.columns.len() as u32; col.summary.id = Some(col_id); - col.summary.table_name = Some(self.name.clone()); self.column_idxs.insert(col.name().to_string(), col_id); self.columns.insert(col_id, Arc::new(col)); @@ -82,13 +78,15 @@ impl TableCatalog { name: TableName, columns: Vec, ) -> Result { + if columns.is_empty() { + return Err(CatalogError::ColumnsEmpty); + } let mut table_catalog = TableCatalog { name, column_idxs: BTreeMap::new(), columns: BTreeMap::new(), indexes: vec![], }; - for col_catalog in columns.into_iter() { let _ = table_catalog.add_column(col_catalog)?; } @@ -123,13 +121,13 @@ mod tests { let col0 = ColumnCatalog::new( "a".into(), false, - ColumnDesc::new(LogicalType::Integer, false, false), + ColumnDesc::new(LogicalType::Integer, false, false, None), None, ); let col1 = ColumnCatalog::new( "b".into(), false, - ColumnDesc::new(LogicalType::Boolean, false, false), + ColumnDesc::new(LogicalType::Boolean, false, false, None), None, ); let col_catalogs = vec![col0, col1]; diff --git a/src/db.rs b/src/db.rs index 9c05b919..af03501e 100644 --- a/src/db.rs +++ b/src/db.rs @@ -37,7 +37,6 @@ impl Database { /// Run SQL queries. pub async fn run(&self, sql: &str) -> Result, DatabaseError> { let transaction = self.storage.transaction().await?; - // parse let stmts = parse_sql(sql)?; @@ -157,17 +156,17 @@ mod test { ColumnCatalog::new( "c1".to_string(), false, - ColumnDesc::new(LogicalType::Integer, true, false), + ColumnDesc::new(LogicalType::Integer, true, false, None), None, ), ColumnCatalog::new( "c2".to_string(), false, - ColumnDesc::new(LogicalType::Boolean, false, false), + ColumnDesc::new(LogicalType::Boolean, false, false, None), None, ), ]; - let _ = transaction.create_table(Arc::new("t1".to_string()), columns)?; + let _ = transaction.create_table(Arc::new("t1".to_string()), columns, false)?; transaction.commit().await?; Ok(()) diff --git a/src/execution/executor/ddl/alter_table.rs b/src/execution/executor/ddl/alter_table.rs new file mode 100644 index 00000000..015f701a --- /dev/null +++ b/src/execution/executor/ddl/alter_table.rs @@ -0,0 +1,63 @@ +use crate::execution::executor::BoxedExecutor; +use crate::planner::operator::alter_table::AddColumn; +use crate::types::tuple::Tuple; +use crate::types::value::DataValue; +use crate::{execution::ExecutorError, types::tuple_builder::TupleBuilder}; +use futures_async_stream::try_stream; +use std::cell::RefCell; +use std::ops::Deref; +use std::sync::Arc; + +use crate::{ + execution::executor::Executor, planner::operator::alter_table::AlterTableOperator, + storage::Transaction, +}; + +pub struct AlterTable { + op: AlterTableOperator, + input: BoxedExecutor, +} + +impl From<(AlterTableOperator, BoxedExecutor)> for AlterTable { + fn from((op, input): (AlterTableOperator, BoxedExecutor)) -> Self { + Self { op, input } + } +} + +impl Executor for AlterTable { + fn execute(self, transaction: &RefCell) -> crate::execution::executor::BoxedExecutor { + unsafe { self._execute(transaction.as_ptr().as_mut().unwrap()) } + } +} + +impl AlterTable { + #[try_stream(boxed, ok = Tuple, error = ExecutorError)] + async fn _execute(self, transaction: &mut T) { + let _ = transaction.alter_table(&self.op)?; + + if let AlterTableOperator::AddColumn(AddColumn { + table_name, column, .. + }) = &self.op + { + #[for_await] + for tuple in self.input { + let mut tuple: Tuple = tuple?; + let is_overwrite = true; + + tuple.columns.push(Arc::new(column.clone())); + if let Some(value) = column.default_value() { + tuple.values.push(Arc::new(value.deref().clone())); + } else { + tuple.values.push(Arc::new(DataValue::Null)); + } + + transaction.append(table_name, tuple, is_overwrite)?; + } + } + + let tuple_builder = TupleBuilder::new_result(); + let tuple = tuple_builder.push_result("ALTER TABLE SUCCESS", "1")?; + + yield tuple; + } +} diff --git a/src/execution/executor/ddl/create_table.rs b/src/execution/executor/ddl/create_table.rs index 6bfd8613..58ffec73 100644 --- a/src/execution/executor/ddl/create_table.rs +++ b/src/execution/executor/ddl/create_table.rs @@ -29,8 +29,9 @@ impl CreateTable { let CreateTableOperator { table_name, columns, + if_not_exists, } = self.op; - let _ = transaction.create_table(table_name.clone(), columns)?; + let _ = transaction.create_table(table_name.clone(), columns, if_not_exists)?; let tuple_builder = TupleBuilder::new_result(); let tuple = tuple_builder .push_result("CREATE TABLE SUCCESS", format!("{}", table_name).as_str())?; diff --git a/src/execution/executor/ddl/drop_table.rs b/src/execution/executor/ddl/drop_table.rs index 1b8c8cde..92366730 100644 --- a/src/execution/executor/ddl/drop_table.rs +++ b/src/execution/executor/ddl/drop_table.rs @@ -25,8 +25,11 @@ impl Executor for DropTable { impl DropTable { #[try_stream(boxed, ok = Tuple, error = ExecutorError)] pub async fn _execute(self, transaction: &mut T) { - let DropTableOperator { table_name } = self.op; + let DropTableOperator { + table_name, + if_exists, + } = self.op; - transaction.drop_table(&table_name)?; + transaction.drop_table(&table_name, if_exists)?; } } diff --git a/src/execution/executor/ddl/mod.rs b/src/execution/executor/ddl/mod.rs index 9c5a45a1..4ec4ceef 100644 --- a/src/execution/executor/ddl/mod.rs +++ b/src/execution/executor/ddl/mod.rs @@ -1,3 +1,4 @@ +pub(crate) mod alter_table; pub(crate) mod create_table; pub(crate) mod drop_table; pub(crate) mod truncate; diff --git a/src/execution/executor/dml/copy_from_file.rs b/src/execution/executor/dml/copy_from_file.rs index 10e68ac5..af8ae5c4 100644 --- a/src/execution/executor/dml/copy_from_file.rs +++ b/src/execution/executor/dml/copy_from_file.rs @@ -38,14 +38,14 @@ impl CopyFromFile { // `tx`, then the task will finish. let table_name = self.op.table.clone(); let handle = tokio::task::spawn_blocking(|| self.read_file_blocking(tx)); - let mut size = 0 as usize; + let mut size = 0_usize; while let Some(chunk) = rx.recv().await { transaction.append(&table_name, chunk, false)?; size += 1; } handle.await??; - let handle = tokio::task::spawn_blocking(move || return_result(size.clone(), tx1)); + let handle = tokio::task::spawn_blocking(move || return_result(size, tx1)); while let Some(chunk) = rx1.recv().await { yield chunk; } @@ -135,30 +135,27 @@ mod tests { summary: ColumnSummary { id: Some(0), name: "a".to_string(), - table_name: None, }, nullable: false, - desc: ColumnDesc::new(LogicalType::Integer, true, false), + desc: ColumnDesc::new(LogicalType::Integer, true, false, None), ref_expr: None, }), Arc::new(ColumnCatalog { summary: ColumnSummary { id: Some(1), name: "b".to_string(), - table_name: None, }, nullable: false, - desc: ColumnDesc::new(LogicalType::Float, false, false), + desc: ColumnDesc::new(LogicalType::Float, false, false, None), ref_expr: None, }), Arc::new(ColumnCatalog { summary: ColumnSummary { id: Some(1), name: "c".to_string(), - table_name: None, }, nullable: false, - desc: ColumnDesc::new(LogicalType::Varchar(Some(10)), false, false), + desc: ColumnDesc::new(LogicalType::Varchar(Some(10)), false, false, None), ref_expr: None, }), ]; diff --git a/src/execution/executor/dml/delete.rs b/src/execution/executor/dml/delete.rs index a6b60540..9eb5937d 100644 --- a/src/execution/executor/dml/delete.rs +++ b/src/execution/executor/dml/delete.rs @@ -30,7 +30,7 @@ impl Delete { #[try_stream(boxed, ok = Tuple, error = ExecutorError)] async fn _execute(self, transaction: &mut T) { let Delete { table_name, input } = self; - let option_index_metas = transaction.table(&table_name).map(|table_catalog| { + let option_index_metas = transaction.table(table_name.clone()).map(|table_catalog| { table_catalog .all_columns() .into_iter() diff --git a/src/execution/executor/dml/insert.rs b/src/execution/executor/dml/insert.rs index 34746825..cbf5ed92 100644 --- a/src/execution/executor/dml/insert.rs +++ b/src/execution/executor/dml/insert.rs @@ -52,7 +52,7 @@ impl Insert { let mut primary_key_index = None; let mut unique_values = HashMap::new(); - if let Some(table_catalog) = transaction.table(&table_name).cloned() { + if let Some(table_catalog) = transaction.table(table_name.clone()).cloned() { #[for_await] for tuple in input { let Tuple { @@ -83,12 +83,13 @@ impl Insert { for (col_id, col) in all_columns { let value = tuple_map .remove(col_id) + .or_else(|| col.default_value()) .unwrap_or_else(|| Arc::new(DataValue::none(col.datatype()))); if col.desc.is_unique && !value.is_null() { unique_values .entry(col.id()) - .or_insert_with(|| vec![]) + .or_insert_with(Vec::new) .push((tuple_id.clone(), value.clone())) } if value.is_null() && !col.nullable { diff --git a/src/execution/executor/dml/update.rs b/src/execution/executor/dml/update.rs index 411753c6..b5a93eb6 100644 --- a/src/execution/executor/dml/update.rs +++ b/src/execution/executor/dml/update.rs @@ -46,7 +46,7 @@ impl Update { values, } = self; - if let Some(table_catalog) = transaction.table(&table_name).cloned() { + if let Some(table_catalog) = transaction.table(table_name.clone()).cloned() { let mut value_map = HashMap::new(); // only once diff --git a/src/execution/executor/dql/aggregate/hash_agg.rs b/src/execution/executor/dql/aggregate/hash_agg.rs index 47ff3aea..be59f892 100644 --- a/src/execution/executor/dql/aggregate/hash_agg.rs +++ b/src/execution/executor/dql/aggregate/hash_agg.rs @@ -96,7 +96,7 @@ impl HashAggExecutor { let values: Vec = accs .iter() .map(|acc| acc.evaluate()) - .chain(group_keys.into_iter().map(|key| Ok(key))) + .chain(group_keys.into_iter().map(Ok)) .try_collect()?; yield Tuple { @@ -136,7 +136,7 @@ mod test { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); let storage = KipStorage::new(temp_dir.path()).await.unwrap(); let transaction = RefCell::new(storage.transaction().await?); - let desc = ColumnDesc::new(LogicalType::Integer, false, false); + let desc = ColumnDesc::new(LogicalType::Integer, false, false, None); let t1_columns = vec![ Arc::new(ColumnCatalog::new( diff --git a/src/execution/executor/dql/aggregate/min_max.rs b/src/execution/executor/dql/aggregate/min_max.rs index 8c45cc82..dae6f725 100644 --- a/src/execution/executor/dql/aggregate/min_max.rs +++ b/src/execution/executor/dql/aggregate/min_max.rs @@ -32,8 +32,7 @@ impl Accumulator for MinMaxAccumulator { fn update_value(&mut self, value: &ValueRef) -> Result<(), ExecutorError> { if !value.is_null() { if let Some(inner_value) = &self.inner { - if let DataValue::Boolean(Some(result)) = binary_op(&inner_value, value, &self.op)? - { + if let DataValue::Boolean(Some(result)) = binary_op(inner_value, value, &self.op)? { result } else { unreachable!() diff --git a/src/execution/executor/dql/aggregate/sum.rs b/src/execution/executor/dql/aggregate/sum.rs index ef3ca7b5..46d84e8c 100644 --- a/src/execution/executor/dql/aggregate/sum.rs +++ b/src/execution/executor/dql/aggregate/sum.rs @@ -17,7 +17,7 @@ impl SumAccumulator { assert!(ty.is_numeric()); Self { - result: DataValue::init(&ty), + result: DataValue::init(ty), } } } diff --git a/src/execution/executor/dql/index_scan.rs b/src/execution/executor/dql/index_scan.rs index 97d00963..4bea39d8 100644 --- a/src/execution/executor/dql/index_scan.rs +++ b/src/execution/executor/dql/index_scan.rs @@ -35,7 +35,7 @@ impl IndexScan { } = self.op; let (index_meta, binaries) = index_by.ok_or(TypeError::InvalidType)?; let mut iter = - transaction.read_by_index(&table_name, limit, columns, index_meta, binaries)?; + transaction.read_by_index(table_name, limit, columns, index_meta, binaries)?; while let Some(tuple) = iter.next_tuple()? { yield tuple; diff --git a/src/execution/executor/dql/join/hash_join.rs b/src/execution/executor/dql/join/hash_join.rs index 1676b558..8f785624 100644 --- a/src/execution/executor/dql/join/hash_join.rs +++ b/src/execution/executor/dql/join/hash_join.rs @@ -273,7 +273,7 @@ mod test { BoxedExecutor, BoxedExecutor, ) { - let desc = ColumnDesc::new(LogicalType::Integer, false, false); + let desc = ColumnDesc::new(LogicalType::Integer, false, false, None); let t1_columns = vec![ Arc::new(ColumnCatalog::new( diff --git a/src/execution/executor/dql/seq_scan.rs b/src/execution/executor/dql/seq_scan.rs index 00684d6b..27b798fb 100644 --- a/src/execution/executor/dql/seq_scan.rs +++ b/src/execution/executor/dql/seq_scan.rs @@ -31,7 +31,7 @@ impl SeqScan { limit, .. } = self.op; - let mut iter = transaction.read(&table_name, limit, columns)?; + let mut iter = transaction.read(table_name, limit, columns)?; while let Some(tuple) = iter.next_tuple()? { yield tuple; diff --git a/src/execution/executor/mod.rs b/src/execution/executor/mod.rs index fdc22b65..9877e91c 100644 --- a/src/execution/executor/mod.rs +++ b/src/execution/executor/mod.rs @@ -31,6 +31,8 @@ use futures::stream::BoxStream; use futures::TryStreamExt; use std::cell::RefCell; +use self::ddl::alter_table::AlterTable; + pub type BoxedExecutor = BoxStream<'static, Result>; pub trait Executor { @@ -104,6 +106,10 @@ pub fn build(plan: LogicalPlan, transaction: &RefCell) -> Box Delete::from((op, input)).execute(transaction) } Operator::Values(op) => Values::from(op).execute(transaction), + Operator::AlterTable(op) => { + let input = build(childrens.remove(0), transaction); + AlterTable::from((op, input)).execute(transaction) + } Operator::CreateTable(op) => CreateTable::from(op).execute(transaction), Operator::DropTable(op) => DropTable::from(op).execute(transaction), Operator::Truncate(op) => Truncate::from(op).execute(transaction), diff --git a/src/expression/evaluator.rs b/src/expression/evaluator.rs index 8b9201e8..5f87b63d 100644 --- a/src/expression/evaluator.rs +++ b/src/expression/evaluator.rs @@ -13,21 +13,21 @@ lazy_static! { impl ScalarExpression { pub fn eval(&self, tuple: &Tuple) -> Result { - if let Some(value) = Self::eval_with_name(&tuple, self.output_columns().name()) { + if let Some(value) = Self::eval_with_name(tuple, self.output_columns().name()) { return Ok(value.clone()); } match &self { ScalarExpression::Constant(val) => Ok(val.clone()), ScalarExpression::ColumnRef(col) => { - let value = Self::eval_with_name(&tuple, col.name()) + let value = Self::eval_with_name(tuple, col.name()) .unwrap_or(&NULL_VALUE) .clone(); Ok(value) } ScalarExpression::Alias { expr, alias } => { - if let Some(value) = Self::eval_with_name(&tuple, alias) { + if let Some(value) = Self::eval_with_name(tuple, alias) { return Ok(value.clone()); } @@ -80,7 +80,7 @@ impl ScalarExpression { Ok(Arc::new(unary_op(&value, op)?)) } ScalarExpression::AggCall { .. } => { - let value = Self::eval_with_name(&tuple, self.output_columns().name()) + let value = Self::eval_with_name(tuple, self.output_columns().name()) .unwrap_or(&NULL_VALUE) .clone(); diff --git a/src/expression/mod.rs b/src/expression/mod.rs index b3d08678..5922af94 100644 --- a/src/expression/mod.rs +++ b/src/expression/mod.rs @@ -4,12 +4,10 @@ use std::fmt; use std::fmt::{Debug, Formatter}; use std::sync::Arc; -use crate::binder::BinderContext; use sqlparser::ast::{BinaryOperator as SqlBinaryOperator, UnaryOperator as SqlUnaryOperator}; use self::agg::AggKind; use crate::catalog::{ColumnCatalog, ColumnDesc, ColumnRef}; -use crate::storage::Transaction; use crate::types::value::ValueRef; use crate::types::LogicalType; @@ -110,19 +108,19 @@ impl ScalarExpression { pub fn return_type(&self) -> LogicalType { match self { Self::Constant(v) => v.logical_type(), - Self::ColumnRef(col) => col.datatype().clone(), + Self::ColumnRef(col) => *col.datatype(), Self::Binary { ty: return_type, .. - } => return_type.clone(), + } => *return_type, Self::Unary { ty: return_type, .. - } => return_type.clone(), + } => *return_type, Self::TypeCast { ty: return_type, .. - } => return_type.clone(), + } => *return_type, Self::AggCall { ty: return_type, .. - } => return_type.clone(), + } => *return_type, Self::IsNull { .. } | Self::In { .. } => LogicalType::Boolean, Self::Alias { expr, .. } => expr.return_type(), } @@ -142,18 +140,14 @@ impl ScalarExpression { ScalarExpression::ColumnRef(col) => { vec.push(col.clone()); } - ScalarExpression::Alias { expr, .. } => { - columns_collect(&expr, vec, only_column_ref) - } + ScalarExpression::Alias { expr, .. } => columns_collect(expr, vec, only_column_ref), ScalarExpression::TypeCast { expr, .. } => { - columns_collect(&expr, vec, only_column_ref) + columns_collect(expr, vec, only_column_ref) } ScalarExpression::IsNull { expr, .. } => { - columns_collect(&expr, vec, only_column_ref) - } - ScalarExpression::Unary { expr, .. } => { - columns_collect(&expr, vec, only_column_ref) + columns_collect(expr, vec, only_column_ref) } + ScalarExpression::Unary { expr, .. } => columns_collect(expr, vec, only_column_ref), ScalarExpression::Binary { left_expr, right_expr, @@ -183,22 +177,22 @@ impl ScalarExpression { exprs } - pub fn has_agg_call(&self, context: &BinderContext<'_, T>) -> bool { + pub fn has_agg_call(&self) -> bool { match self { ScalarExpression::AggCall { .. } => true, ScalarExpression::Constant(_) => false, ScalarExpression::ColumnRef(_) => false, - ScalarExpression::Alias { expr, .. } => expr.has_agg_call(context), - ScalarExpression::TypeCast { expr, .. } => expr.has_agg_call(context), - ScalarExpression::IsNull { expr, .. } => expr.has_agg_call(context), - ScalarExpression::Unary { expr, .. } => expr.has_agg_call(context), + ScalarExpression::Alias { expr, .. } => expr.has_agg_call(), + ScalarExpression::TypeCast { expr, .. } => expr.has_agg_call(), + ScalarExpression::IsNull { expr, .. } => expr.has_agg_call(), + ScalarExpression::Unary { expr, .. } => expr.has_agg_call(), ScalarExpression::Binary { left_expr, right_expr, .. - } => left_expr.has_agg_call(context) || right_expr.has_agg_call(context), + } => left_expr.has_agg_call() || right_expr.has_agg_call(), ScalarExpression::In { expr, args, .. } => { - expr.has_agg_call(context) || args.iter().any(|arg| arg.has_agg_call(context)) + expr.has_agg_call() || args.iter().any(|arg| arg.has_agg_call()) } } } @@ -209,13 +203,13 @@ impl ScalarExpression { ScalarExpression::Constant(value) => Arc::new(ColumnCatalog::new( format!("{}", value), true, - ColumnDesc::new(value.logical_type(), false, false), + ColumnDesc::new(value.logical_type(), false, false, None), Some(self.clone()), )), ScalarExpression::Alias { expr, alias } => Arc::new(ColumnCatalog::new( alias.to_string(), true, - ColumnDesc::new(expr.return_type(), false, false), + ColumnDesc::new(expr.return_type(), false, false, None), Some(self.clone()), )), ScalarExpression::AggCall { @@ -245,7 +239,7 @@ impl ScalarExpression { Arc::new(ColumnCatalog::new( column_name, true, - ColumnDesc::new(ty.clone(), false, false), + ColumnDesc::new(*ty, false, false, None), Some(self.clone()), )) } @@ -265,7 +259,7 @@ impl ScalarExpression { Arc::new(ColumnCatalog::new( column_name, true, - ColumnDesc::new(ty.clone(), false, false), + ColumnDesc::new(*ty, false, false, None), Some(self.clone()), )) } @@ -274,7 +268,7 @@ impl ScalarExpression { Arc::new(ColumnCatalog::new( column_name, true, - ColumnDesc::new(ty.clone(), false, false), + ColumnDesc::new(*ty, false, false, None), Some(self.clone()), )) } @@ -283,7 +277,7 @@ impl ScalarExpression { Arc::new(ColumnCatalog::new( format!("{} {}", expr.output_columns().name(), suffix), true, - ColumnDesc::new(LogicalType::Boolean, false, false), + ColumnDesc::new(LogicalType::Boolean, false, false, None), Some(self.clone()), )) } @@ -305,13 +299,16 @@ impl ScalarExpression { args_string ), true, - ColumnDesc::new(LogicalType::Boolean, false, false), + ColumnDesc::new(LogicalType::Boolean, false, false, None), Some(self.clone()), )) } - _ => { - todo!() - } + ScalarExpression::TypeCast { expr, ty } => Arc::new(ColumnCatalog::new( + format!("CAST({} as {})", expr.output_columns().name(), ty), + true, + ColumnDesc::new(*ty, false, false, None), + Some(self.clone()), + )), } } } diff --git a/src/expression/simplify.rs b/src/expression/simplify.rs index 53537f01..2e045bc5 100644 --- a/src/expression/simplify.rs +++ b/src/expression/simplify.rs @@ -52,6 +52,10 @@ impl ConstantBinary { pub fn rearrange(self) -> Result, TypeError> { match self { ConstantBinary::Or(binaries) => { + if binaries.is_empty() { + return Ok(vec![]); + } + let mut condition_binaries = Vec::new(); for binary in binaries { @@ -128,12 +132,32 @@ impl ConstantBinary { pub fn scope_aggregation(&mut self) -> Result<(), TypeError> { match self { + // `Or` is allowed to contain And, `Scope`, `Eq/NotEq` + // Tips: Only single-level `And` ConstantBinary::Or(binaries) => { + let mut or_binaries = Vec::new(); for binary in binaries { - binary.scope_aggregation()? + match binary { + ConstantBinary::And(and_binaries) => { + or_binaries.append(&mut Self::and_scope_aggregation(and_binaries)?); + } + ConstantBinary::Or(_) => { + unreachable!("`Or` does not allow nested `Or`") + } + cb => { + or_binaries.push(cb.clone()); + } + } } + let or_binaries = Self::or_scope_aggregation(&or_binaries); + let _ = mem::replace(self, ConstantBinary::Or(or_binaries)); } - binary => binary._scope_aggregation()?, + // `And` is allowed to contain Scope, `Eq/NotEq` + ConstantBinary::And(binaries) => { + let and_binaries = Self::and_scope_aggregation(binaries)?; + let _ = mem::replace(self, ConstantBinary::And(and_binaries)); + } + _ => (), } Ok(()) @@ -168,76 +192,186 @@ impl ConstantBinary { } // Tips: It only makes sense if the condition is and aggregation - fn _scope_aggregation(&mut self) -> Result<(), TypeError> { - if let ConstantBinary::And(binaries) = self { - let mut scope_min = Bound::Unbounded; - let mut scope_max = Bound::Unbounded; - let mut eq_set = HashSet::with_hasher(RandomState::new()); - - let sort_op = |binary: &&ConstantBinary| match binary { - ConstantBinary::Scope { .. } => 3, - ConstantBinary::NotEq(_) => 2, - ConstantBinary::Eq(_) => 1, - ConstantBinary::And(_) | ConstantBinary::Or(_) => 0, - }; + fn and_scope_aggregation( + binaries: &Vec, + ) -> Result, TypeError> { + if binaries.is_empty() { + return Ok(vec![]); + } - // Aggregate various ranges to get the minimum range - for binary in binaries.iter().sorted_by_key(sort_op) { - match binary { - ConstantBinary::Scope { min, max } => { - // Skip if eq or noteq exists - if !eq_set.is_empty() { - continue; - } + let mut scope_min = Bound::Unbounded; + let mut scope_max = Bound::Unbounded; + let mut eq_set = HashSet::with_hasher(RandomState::new()); - if let Some(order) = Self::bound_compared(&scope_min, &min, true) { - if order.is_lt() { - scope_min = min.clone(); - } - } + let sort_op = |binary: &&ConstantBinary| match binary { + ConstantBinary::Scope { .. } => 3, + ConstantBinary::NotEq(_) => 2, + ConstantBinary::Eq(_) => 1, + ConstantBinary::And(_) | ConstantBinary::Or(_) => 0, + }; - if let Some(order) = Self::bound_compared(&scope_max, &max, false) { - if order.is_gt() { - scope_max = max.clone(); - } - } - } - ConstantBinary::Eq(val) => { - let _ = eq_set.insert(val.clone()); + // Aggregate various ranges to get the minimum range + for binary in binaries.iter().sorted_by_key(sort_op) { + match binary { + ConstantBinary::Scope { min, max } => { + // Skip if eq or noteq exists + if !eq_set.is_empty() { + continue; } - ConstantBinary::NotEq(val) => { - let _ = eq_set.remove(val); + + if let Some(order) = Self::bound_compared(&scope_min, min, true) { + if order.is_lt() { + scope_min = min.clone(); + } } - ConstantBinary::Or(_) | ConstantBinary::And(_) => { - return Err(TypeError::InvalidType) + + if let Some(order) = Self::bound_compared(&scope_max, max, false) { + if order.is_gt() { + scope_max = max.clone(); + } } } + ConstantBinary::Eq(val) => { + let _ = eq_set.insert(val.clone()); + } + ConstantBinary::NotEq(val) => { + let _ = eq_set.remove(val); + } + ConstantBinary::Or(_) | ConstantBinary::And(_) => { + return Err(TypeError::InvalidType) + } } + } - let eq_option = eq_set - .into_iter() - .sorted_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal)) - .next() - .map(|val| ConstantBinary::Eq(val)); - - if let Some(eq) = eq_option { - let _ = mem::replace(self, eq); - } else if !matches!( - (&scope_min, &scope_max), - (Bound::Unbounded, Bound::Unbounded) - ) { - let scope_binary = ConstantBinary::Scope { - min: scope_min, - max: scope_max, - }; + let eq_option = eq_set + .into_iter() + .sorted_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal)) + .next() + .map(ConstantBinary::Eq); + + return if let Some(eq) = eq_option { + Ok(vec![eq]) + } else if !matches!( + (&scope_min, &scope_max), + (Bound::Unbounded, Bound::Unbounded) + ) { + let scope_binary = ConstantBinary::Scope { + min: scope_min, + max: scope_max, + }; - let _ = mem::replace(self, scope_binary); - } else { - let _ = mem::replace(self, ConstantBinary::And(vec![])); + Ok(vec![scope_binary]) + } else { + Ok(vec![]) + }; + } + + // Tips: It only makes sense if the condition is or aggregation + fn or_scope_aggregation(binaries: &Vec) -> Vec { + if binaries.is_empty() { + return vec![]; + } + let mut scopes = Vec::new(); + let mut eqs = Vec::new(); + + let mut scope_margin = None; + + for binary in binaries { + if matches!(scope_margin, Some((Bound::Unbounded, Bound::Unbounded))) { + break; } + match binary { + ConstantBinary::Scope { + min: Bound::Unbounded, + max: Bound::Unbounded, + } => { + scope_margin = Some((Bound::Unbounded, Bound::Unbounded)); + break; + } + ConstantBinary::Scope { min, max } => { + if let Some((scope_min, scope_max)) = &mut scope_margin { + if matches!( + Self::bound_compared(scope_min, min, true).map(Ordering::is_gt), + Some(true) + ) { + let _ = mem::replace(scope_min, min.clone()); + } + if matches!( + Self::bound_compared(scope_max, max, false).map(Ordering::is_lt), + Some(true) + ) { + let _ = mem::replace(scope_max, max.clone()); + } + } else { + scope_margin = Some((min.clone(), max.clone())) + } + + scopes.push((min, max)) + } + ConstantBinary::Eq(val) => eqs.push(val), + _ => (), + } + } + if matches!( + scope_margin, + Some((Bound::Unbounded, Bound::Unbounded)) | None + ) { + return vec![]; } - Ok(()) + let mut merge_scopes: Vec<(Bound, Bound)> = Vec::new(); + + match scope_margin { + Some((Bound::Unbounded, _)) => { + if let Some((_, max)) = scopes.iter().max_by(|(_, max_a), (_, max_b)| { + Self::bound_compared(max_a, max_b, false).unwrap() + }) { + merge_scopes.push((Bound::Unbounded, (**max).clone())) + } + } + Some((_, Bound::Unbounded)) => { + if let Some((min, _)) = scopes.iter().min_by(|(min_a, _), (min_b, _)| { + Self::bound_compared(min_a, min_b, true).unwrap() + }) { + merge_scopes.push(((**min).clone(), Bound::Unbounded)) + } + } + _ => { + scopes.sort_by(|(min_a, _), (min_b, _)| { + Self::bound_compared(min_a, min_b, true).unwrap() + }); + + for i in 0..scopes.len() { + let (min, max) = scopes[i]; + if merge_scopes.is_empty() { + merge_scopes.push((min.clone(), max.clone())); + continue; + } + + let last_pos = merge_scopes.len() - 1; + let last_scope: &mut _ = &mut merge_scopes[last_pos]; + if Self::bound_compared(&last_scope.0, min, true) + .unwrap() + .is_gt() + { + merge_scopes.push((min.clone(), max.clone())); + } else if Self::bound_compared(&last_scope.1, max, false) + .unwrap() + .is_lt() + { + last_scope.1 = max.clone(); + } + } + } + } + merge_scopes + .into_iter() + .map(|(min, max)| ConstantBinary::Scope { + min: min.clone(), + max: max.clone(), + }) + .chain(eqs.into_iter().map(|val| ConstantBinary::Eq(val.clone()))) + .collect_vec() } } @@ -466,7 +600,7 @@ impl ScalarExpression { let new_expr = ScalarExpression::Constant(Arc::new(unary_op(&val, op)?)); let _ = mem::replace(self, new_expr); } else { - let _ = replaces.push(Replace::Unary(ReplaceUnary { + replaces.push(Replace::Unary(ReplaceUnary { child_expr: expr.as_ref().clone(), op: *op, ty: *ty, @@ -624,12 +758,29 @@ impl ScalarExpression { right_expr.convert_binary(col_id)?, ) { (Some(left_binary), Some(right_binary)) => match (left_binary, right_binary) { - (ConstantBinary::And(mut left), ConstantBinary::And(mut right)) - | (ConstantBinary::Or(mut left), ConstantBinary::Or(mut right)) => { - left.append(&mut right); + (ConstantBinary::And(mut left), ConstantBinary::And(mut right)) => match op + { + BinaryOperator::And => { + left.append(&mut right); - Ok(Some(ConstantBinary::And(left))) - } + Ok(Some(ConstantBinary::And(left))) + } + BinaryOperator::Or => Ok(Some(ConstantBinary::Or(vec![ + ConstantBinary::And(left), + ConstantBinary::And(right), + ]))), + BinaryOperator::Xor => todo!(), + _ => unreachable!(), + }, + (ConstantBinary::Or(mut left), ConstantBinary::Or(mut right)) => match op { + BinaryOperator::And | BinaryOperator::Or => { + left.append(&mut right); + + Ok(Some(ConstantBinary::Or(left))) + } + BinaryOperator::Xor => todo!(), + _ => unreachable!(), + }, (ConstantBinary::And(mut left), ConstantBinary::Or(mut right)) => { right.append(&mut left); @@ -671,7 +822,7 @@ impl ScalarExpression { return Ok(Self::new_binary(col_id, *op, col, val, true)); } - return Ok(None); + Ok(None) } (Some(binary), None) => Ok(Self::check_or(col_id, right_expr, op, binary)), (None, Some(binary)) => Ok(Self::check_or(col_id, left_expr, op, binary)), @@ -689,7 +840,7 @@ impl ScalarExpression { /// this case it makes no sense to just extract c1 > 1 fn check_or( col_id: &ColumnId, - right_expr: &Box, + right_expr: &ScalarExpression, op: &BinaryOperator, binary: ConstantBinary, ) -> Option { @@ -762,13 +913,13 @@ mod test { summary: ColumnSummary { id: Some(0), name: "c1".to_string(), - table_name: None, }, nullable: false, desc: ColumnDesc { column_datatype: LogicalType::Integer, is_primary: false, is_unique: false, + default: None, }, ref_expr: None, }); @@ -883,7 +1034,7 @@ mod test { binary.scope_aggregation()?; - assert_eq!(binary, ConstantBinary::Eq(val_0)); + assert_eq!(binary, ConstantBinary::And(vec![ConstantBinary::Eq(val_0)])); Ok(()) } @@ -947,10 +1098,10 @@ mod test { assert_eq!( binary, - ConstantBinary::Scope { + ConstantBinary::And(vec![ConstantBinary::Scope { min: Bound::Excluded(val_1.clone()), max: Bound::Excluded(val_2.clone()), - } + }]) ); Ok(()) @@ -991,7 +1142,156 @@ mod test { binary.scope_aggregation()?; - assert_eq!(binary, ConstantBinary::Eq(val_0.clone())); + assert_eq!( + binary, + ConstantBinary::And(vec![ConstantBinary::Eq(val_0.clone())]) + ); + + Ok(()) + } + + #[test] + fn test_scope_aggregation_or() -> Result<(), TypeError> { + let val_0 = Arc::new(DataValue::Int32(Some(0))); + let val_1 = Arc::new(DataValue::Int32(Some(1))); + let val_2 = Arc::new(DataValue::Int32(Some(2))); + let val_3 = Arc::new(DataValue::Int32(Some(3))); + + let mut binary = ConstantBinary::Or(vec![ + ConstantBinary::Scope { + min: Bound::Excluded(val_0.clone()), + max: Bound::Included(val_3.clone()), + }, + ConstantBinary::Scope { + min: Bound::Included(val_1.clone()), + max: Bound::Excluded(val_2.clone()), + }, + ConstantBinary::Scope { + min: Bound::Excluded(val_1.clone()), + max: Bound::Included(val_2.clone()), + }, + ConstantBinary::Scope { + min: Bound::Included(val_0.clone()), + max: Bound::Excluded(val_3.clone()), + }, + ]); + + binary.scope_aggregation()?; + + assert_eq!( + binary, + ConstantBinary::Or(vec![ConstantBinary::Scope { + min: Bound::Included(val_0.clone()), + max: Bound::Included(val_3.clone()), + }]) + ); + + Ok(()) + } + + #[test] + fn test_scope_aggregation_or_unbounded() -> Result<(), TypeError> { + let val_0 = Arc::new(DataValue::Int32(Some(0))); + let val_1 = Arc::new(DataValue::Int32(Some(1))); + let val_2 = Arc::new(DataValue::Int32(Some(2))); + let val_3 = Arc::new(DataValue::Int32(Some(3))); + + let mut binary = ConstantBinary::Or(vec![ + ConstantBinary::Scope { + min: Bound::Unbounded, + max: Bound::Included(val_3.clone()), + }, + ConstantBinary::Scope { + min: Bound::Unbounded, + max: Bound::Excluded(val_2.clone()), + }, + ConstantBinary::Scope { + min: Bound::Excluded(val_1.clone()), + max: Bound::Unbounded, + }, + ConstantBinary::Scope { + min: Bound::Included(val_0.clone()), + max: Bound::Unbounded, + }, + ]); + + binary.scope_aggregation()?; + + assert_eq!(binary, ConstantBinary::Or(vec![])); + + Ok(()) + } + + #[test] + fn test_scope_aggregation_or_lower_unbounded() -> Result<(), TypeError> { + let val_0 = Arc::new(DataValue::Int32(Some(2))); + let val_1 = Arc::new(DataValue::Int32(Some(3))); + + let mut binary = ConstantBinary::Or(vec![ + ConstantBinary::Scope { + min: Bound::Unbounded, + max: Bound::Excluded(val_0.clone()), + }, + ConstantBinary::Scope { + min: Bound::Unbounded, + max: Bound::Included(val_0.clone()), + }, + ConstantBinary::Scope { + min: Bound::Unbounded, + max: Bound::Excluded(val_1.clone()), + }, + ConstantBinary::Scope { + min: Bound::Unbounded, + max: Bound::Included(val_1.clone()), + }, + ]); + + binary.scope_aggregation()?; + + assert_eq!( + binary, + ConstantBinary::Or(vec![ConstantBinary::Scope { + min: Bound::Unbounded, + max: Bound::Included(val_1.clone()), + }]) + ); + + Ok(()) + } + + #[test] + fn test_scope_aggregation_or_upper_unbounded() -> Result<(), TypeError> { + let val_0 = Arc::new(DataValue::Int32(Some(2))); + let val_1 = Arc::new(DataValue::Int32(Some(3))); + + let mut binary = ConstantBinary::Or(vec![ + ConstantBinary::Scope { + min: Bound::Excluded(val_0.clone()), + max: Bound::Unbounded, + }, + ConstantBinary::Scope { + min: Bound::Included(val_0.clone()), + max: Bound::Unbounded, + }, + ConstantBinary::Scope { + min: Bound::Excluded(val_1.clone()), + max: Bound::Unbounded, + }, + ConstantBinary::Scope { + min: Bound::Included(val_1.clone()), + max: Bound::Unbounded, + }, + ]); + + binary.scope_aggregation()?; + + assert_eq!( + binary, + ConstantBinary::Or(vec![ConstantBinary::Scope { + min: Bound::Included(val_0.clone()), + max: Bound::Unbounded, + }]) + ); Ok(()) } diff --git a/src/expression/value_compute.rs b/src/expression/value_compute.rs index a47d62ec..ad124318 100644 --- a/src/expression/value_compute.rs +++ b/src/expression/value_compute.rs @@ -120,7 +120,7 @@ pub fn binary_op( let pattern_option = unpack_utf8(right.clone().cast(&LogicalType::Varchar(None))?); let mut is_match = if let (Some(value), Some(pattern)) = (value_option, pattern_option) { - let regex_pattern = pattern.replace("%", ".*").replace("_", "."); + let regex_pattern = pattern.replace('%', ".*").replace('_', "."); Regex::new(®ex_pattern).unwrap().is_match(&value) } else { diff --git a/src/marco/mod.rs b/src/marco/mod.rs index 1f6ab659..74ff6be4 100644 --- a/src/marco/mod.rs +++ b/src/marco/mod.rs @@ -66,13 +66,13 @@ mod test { Arc::new(ColumnCatalog::new( "c1".to_string(), false, - ColumnDesc::new(LogicalType::Integer, true, false), + ColumnDesc::new(LogicalType::Integer, true, false, None), None, )), Arc::new(ColumnCatalog::new( "c2".to_string(), false, - ColumnDesc::new(LogicalType::Varchar(None), false, false), + ColumnDesc::new(LogicalType::Varchar(None), false, false, None), None, )), ]; diff --git a/src/optimizer/heuristic/graph.rs b/src/optimizer/heuristic/graph.rs index 16ca4ef4..8b74e0a2 100644 --- a/src/optimizer/heuristic/graph.rs +++ b/src/optimizer/heuristic/graph.rs @@ -68,17 +68,16 @@ impl HepGraph { new_node: Operator, ) { let new_index = self.graph.add_node(new_node); - let mut order = self.graph.edges(source_id).count(); - if let Some(children_id) = children_option { + if let Some((children_id, old_edge_id)) = children_option.and_then(|children_id| { self.graph .find_edge(source_id, children_id) - .map(|old_edge_id| { - order = self.graph.remove_edge(old_edge_id).unwrap_or(0); + .map(|old_edge_id| (children_id, old_edge_id)) + }) { + order = self.graph.remove_edge(old_edge_id).unwrap_or(0); - self.graph.add_edge(new_index, children_id, 0); - }); + self.graph.add_edge(new_index, children_id, 0); } self.graph.add_edge(source_id, new_index, order); diff --git a/src/optimizer/heuristic/matcher.rs b/src/optimizer/heuristic/matcher.rs index 2911a7c7..56637ec5 100644 --- a/src/optimizer/heuristic/matcher.rs +++ b/src/optimizer/heuristic/matcher.rs @@ -23,7 +23,7 @@ impl PatternMatcher for HepMatcher<'_, '_> { fn match_opt_expr(&self) -> bool { let op = self.graph.operator(self.start_id); // check the root node predicate - if !(self.pattern.predicate)(&op) { + if !(self.pattern.predicate)(op) { return false; } @@ -34,7 +34,7 @@ impl PatternMatcher for HepMatcher<'_, '_> { .graph .nodes_iter(HepMatchOrder::TopDown, Some(self.start_id)) { - if !(self.pattern.predicate)(&self.graph.operator(node_id)) { + if !(self.pattern.predicate)(self.graph.operator(node_id)) { return false; } } diff --git a/src/optimizer/rule/column_pruning.rs b/src/optimizer/rule/column_pruning.rs index 86984530..c27516b4 100644 --- a/src/optimizer/rule/column_pruning.rs +++ b/src/optimizer/rule/column_pruning.rs @@ -115,7 +115,8 @@ impl ColumnPruning { | Operator::Truncate(_) | Operator::Show(_) | Operator::CopyFromFile(_) - | Operator::CopyToFile(_) => (), + | Operator::CopyToFile(_) + | Operator::AlterTable(_) => (), } } diff --git a/src/optimizer/rule/pushdown_limit.rs b/src/optimizer/rule/pushdown_limit.rs index 2884b84f..44a29a87 100644 --- a/src/optimizer/rule/pushdown_limit.rs +++ b/src/optimizer/rule/pushdown_limit.rs @@ -75,7 +75,7 @@ impl Rule for EliminateLimits { let child_id = graph.children_at(node_id)[0]; if let Operator::Limit(child_op) = graph.operator(child_id) { let offset = Self::binary_options(op.offset, child_op.offset, |a, b| a + b); - let limit = Self::binary_options(op.limit, child_op.limit, |a, b| cmp::min(a, b)); + let limit = Self::binary_options(op.limit, child_op.limit, cmp::min); let new_limit_op = LimitOperator { offset, limit }; diff --git a/src/optimizer/rule/pushdown_predicates.rs b/src/optimizer/rule/pushdown_predicates.rs index ac0f1e1a..b69c23ec 100644 --- a/src/optimizer/rule/pushdown_predicates.rs +++ b/src/optimizer/rule/pushdown_predicates.rs @@ -217,16 +217,17 @@ impl Rule for PushPredicateIntoScan { binary.scope_aggregation()?; let rearrange_binaries = binary.rearrange()?; - if !rearrange_binaries.is_empty() { - let mut scan_by_index = child_op.clone(); - scan_by_index.index_by = Some((meta.clone(), rearrange_binaries)); + if rearrange_binaries.is_empty() { + continue; + } + let mut scan_by_index = child_op.clone(); + scan_by_index.index_by = Some((meta.clone(), rearrange_binaries)); - // The constant expression extracted in prewhere is used to - // reduce the data scanning range and cannot replace the role of Filter. - graph.replace_node(child_id, Operator::Scan(scan_by_index)); + // The constant expression extracted in prewhere is used to + // reduce the data scanning range and cannot replace the role of Filter. + graph.replace_node(child_id, Operator::Scan(scan_by_index)); - return Ok(()); - } + return Ok(()); } } } diff --git a/src/optimizer/rule/simplification.rs b/src/optimizer/rule/simplification.rs index 3f004451..68aadadd 100644 --- a/src/optimizer/rule/simplification.rs +++ b/src/optimizer/rule/simplification.rs @@ -249,13 +249,13 @@ mod test { summary: ColumnSummary { id: Some(0), name: "c1".to_string(), - table_name: Some(Arc::new("t1".to_string())), }, nullable: false, desc: ColumnDesc { column_datatype: LogicalType::Integer, is_primary: true, is_unique: false, + default: None, }, ref_expr: None, }; @@ -263,13 +263,13 @@ mod test { summary: ColumnSummary { id: Some(1), name: "c2".to_string(), - table_name: Some(Arc::new("t1".to_string())), }, nullable: false, desc: ColumnDesc { column_datatype: LogicalType::Integer, is_primary: false, is_unique: true, + default: None, }, ref_expr: None, }; diff --git a/src/planner/operator/aggregate.rs b/src/planner/operator/aggregate.rs index acdeeb19..8d5973bf 100644 --- a/src/planner/operator/aggregate.rs +++ b/src/planner/operator/aggregate.rs @@ -8,7 +8,7 @@ pub struct AggregateOperator { } impl AggregateOperator { - pub fn new( + pub fn build( children: LogicalPlan, agg_calls: Vec, groupby_exprs: Vec, diff --git a/src/planner/operator/alter_table.rs b/src/planner/operator/alter_table.rs new file mode 100644 index 00000000..143d3072 --- /dev/null +++ b/src/planner/operator/alter_table.rs @@ -0,0 +1,19 @@ +use crate::catalog::{ColumnCatalog, TableName}; + +#[derive(Debug, PartialEq, Clone)] +pub enum AlterTableOperator { + AddColumn(AddColumn), + DropColumn, + DropPrimaryKey, + RenameColumn, + RenameTable, + ChangeColumn, + AlterColumn, +} + +#[derive(Debug, PartialEq, Clone)] +pub struct AddColumn { + pub table_name: TableName, + pub if_not_exists: bool, + pub column: ColumnCatalog, +} diff --git a/src/planner/operator/create_table.rs b/src/planner/operator/create_table.rs index f9966431..a5d07eb3 100644 --- a/src/planner/operator/create_table.rs +++ b/src/planner/operator/create_table.rs @@ -6,4 +6,5 @@ pub struct CreateTableOperator { pub table_name: TableName, /// List of columns of the table pub columns: Vec, + pub if_not_exists: bool, } diff --git a/src/planner/operator/drop_table.rs b/src/planner/operator/drop_table.rs index b343a457..5d7b022b 100644 --- a/src/planner/operator/drop_table.rs +++ b/src/planner/operator/drop_table.rs @@ -4,4 +4,5 @@ use crate::catalog::TableName; pub struct DropTableOperator { /// Table name to insert to pub table_name: TableName, + pub if_exists: bool, } diff --git a/src/planner/operator/filter.rs b/src/planner/operator/filter.rs index 6a1f5a12..fc25ad98 100644 --- a/src/planner/operator/filter.rs +++ b/src/planner/operator/filter.rs @@ -12,7 +12,7 @@ pub struct FilterOperator { } impl FilterOperator { - pub fn new(predicate: ScalarExpression, children: LogicalPlan, having: bool) -> LogicalPlan { + pub fn build(predicate: ScalarExpression, children: LogicalPlan, having: bool) -> LogicalPlan { LogicalPlan { operator: Operator::Filter(FilterOperator { predicate, having }), childrens: vec![children], diff --git a/src/planner/operator/join.rs b/src/planner/operator/join.rs index 51e9a5ac..742d978b 100644 --- a/src/planner/operator/join.rs +++ b/src/planner/operator/join.rs @@ -29,7 +29,7 @@ pub struct JoinOperator { } impl JoinOperator { - pub fn new( + pub fn build( left: LogicalPlan, right: LogicalPlan, on: JoinCondition, diff --git a/src/planner/operator/limit.rs b/src/planner/operator/limit.rs index f6e79690..12280f33 100644 --- a/src/planner/operator/limit.rs +++ b/src/planner/operator/limit.rs @@ -9,7 +9,11 @@ pub struct LimitOperator { } impl LimitOperator { - pub fn new(offset: Option, limit: Option, children: LogicalPlan) -> LogicalPlan { + pub fn build( + offset: Option, + limit: Option, + children: LogicalPlan, + ) -> LogicalPlan { LogicalPlan { operator: Operator::Limit(LimitOperator { offset, limit }), childrens: vec![children], diff --git a/src/planner/operator/mod.rs b/src/planner/operator/mod.rs index e1d11aab..0747f184 100644 --- a/src/planner/operator/mod.rs +++ b/src/planner/operator/mod.rs @@ -1,4 +1,5 @@ pub mod aggregate; +pub mod alter_table; pub mod copy_from_file; pub mod copy_to_file; pub mod create_table; @@ -31,8 +32,9 @@ use crate::planner::operator::values::ValuesOperator; use itertools::Itertools; use self::{ - aggregate::AggregateOperator, filter::FilterOperator, join::JoinOperator, limit::LimitOperator, - project::ProjectOperator, scan::ScanOperator, sort::SortOperator, + aggregate::AggregateOperator, alter_table::AlterTableOperator, filter::FilterOperator, + join::JoinOperator, limit::LimitOperator, project::ProjectOperator, scan::ScanOperator, + sort::SortOperator, }; #[derive(Debug, PartialEq, Clone)] @@ -52,6 +54,7 @@ pub enum Operator { Update(UpdateOperator), Delete(DeleteOperator), // DDL + AlterTable(AlterTableOperator), CreateTable(CreateTableOperator), DropTable(DropTableOperator), Truncate(TruncateOperator), diff --git a/src/planner/operator/scan.rs b/src/planner/operator/scan.rs index a16ddbc7..f11d9cc9 100644 --- a/src/planner/operator/scan.rs +++ b/src/planner/operator/scan.rs @@ -23,12 +23,12 @@ pub struct ScanOperator { pub index_by: Option<(IndexMetaRef, Vec)>, } impl ScanOperator { - pub fn new(table_name: TableName, table_catalog: &TableCatalog) -> LogicalPlan { + pub fn build(table_name: TableName, table_catalog: &TableCatalog) -> LogicalPlan { // Fill all Columns in TableCatalog by default let columns = table_catalog .all_columns() .into_iter() - .map(|col| ScalarExpression::ColumnRef(col)) + .map(ScalarExpression::ColumnRef) .collect_vec(); LogicalPlan { diff --git a/src/storage/kip.rs b/src/storage/kip.rs index f80064e3..538071c7 100644 --- a/src/storage/kip.rs +++ b/src/storage/kip.rs @@ -1,13 +1,12 @@ use crate::catalog::{ColumnCatalog, ColumnRef, TableCatalog, TableName}; use crate::expression::simplify::ConstantBinary; +use crate::planner::operator::alter_table::{AddColumn, AlterTableOperator}; use crate::storage::table_codec::TableCodec; use crate::storage::{ tuple_projection, Bounds, IndexIter, Iter, Projections, Storage, StorageError, Transaction, }; -use crate::types::errors::TypeError; use crate::types::index::{Index, IndexMeta, IndexMetaRef}; use crate::types::tuple::{Tuple, TupleId}; -use crate::types::value::ValueRef; use kip_db::kernel::lsm::iterator::Iter as KipDBIter; use kip_db::kernel::lsm::mvcc::TransactionIter; use kip_db::kernel::lsm::storage::Config; @@ -15,8 +14,6 @@ use kip_db::kernel::lsm::{mvcc, storage}; use kip_db::kernel::utils::lru_cache::ShardingLruCache; use std::collections::hash_map::RandomState; use std::collections::{Bound, VecDeque}; -use std::mem; -use std::ops::SubAssign; use std::path::PathBuf; use std::sync::Arc; @@ -27,8 +24,9 @@ pub struct KipStorage { impl KipStorage { pub async fn new(path: impl Into + Send) -> Result { - let config = Config::new(path); - let storage = storage::KipStorage::open_with_config(config).await?; + let storage = + storage::KipStorage::open_with_config(Config::new(path).enable_level_0_memorization()) + .await?; Ok(KipStorage { inner: Arc::new(storage), @@ -59,15 +57,15 @@ impl Transaction for KipTransaction { fn read( &self, - table_name: &String, + table_name: TableName, bounds: Bounds, projections: Projections, ) -> Result, StorageError> { let all_columns = self - .table(table_name) + .table(table_name.clone()) .ok_or(StorageError::TableNotFound)? .all_columns(); - let (min, max) = TableCodec::tuple_bound(table_name); + let (min, max) = TableCodec::tuple_bound(&table_name); let iter = self.tx.iter(Bound::Included(&min), Bound::Included(&max))?; Ok(KipIter { @@ -81,73 +79,33 @@ impl Transaction for KipTransaction { fn read_by_index( &self, - table_name: &String, - (offset_option, mut limit_option): Bounds, + table_name: TableName, + (offset_option, limit_option): Bounds, projections: Projections, index_meta: IndexMetaRef, binaries: Vec, ) -> Result, StorageError> { - let table = self.table(table_name).ok_or(StorageError::TableNotFound)?; - let mut tuple_ids = Vec::new(); - let mut offset = offset_option.unwrap_or(0); - - for binary in binaries { - if matches!(limit_option, Some(0)) { - break; - } - - match binary { - ConstantBinary::Scope { min, max } => { - let mut iter = self.scope_to_iter(table_name, &index_meta, min, max)?; - - while let Some((_, value_option)) = iter.try_next()? { - if let Some(value) = value_option { - for id in TableCodec::decode_index(&value)? { - if Self::offset_move(&mut offset) { - continue; - } - - tuple_ids.push(id); - - if Self::limit_move(&mut limit_option) { - break; - } - } - } - - if matches!(limit_option, Some(0)) { - break; - } - } - } - ConstantBinary::Eq(val) => { - if Self::offset_move(&mut offset) { - continue; - } - - let key = Self::val_to_key(table_name, &index_meta, val)?; - - if let Some(bytes) = self.tx.get(&key)? { - tuple_ids.append(&mut TableCodec::decode_index(&bytes)?) - } - - let _ = Self::limit_move(&mut limit_option); - } - _ => (), - } - } + let table = self + .table(table_name.clone()) + .ok_or(StorageError::TableNotFound)?; + let offset = offset_option.unwrap_or(0); Ok(IndexIter { + offset, + limit: limit_option, projections, + index_meta, table, - tuple_ids: VecDeque::from(tuple_ids), + index_values: VecDeque::new(), + binaries: VecDeque::from(binaries), tx: &self.tx, + scope_iter: None, }) } fn add_index( &mut self, - table_name: &String, + table_name: &str, index: Index, tuple_ids: Vec, is_unique: bool, @@ -173,8 +131,8 @@ impl Transaction for KipTransaction { Ok(()) } - fn del_index(&mut self, table_name: &String, index: &Index) -> Result<(), StorageError> { - let key = TableCodec::encode_index_key(table_name, &index)?; + fn del_index(&mut self, table_name: &str, index: &Index) -> Result<(), StorageError> { + let key = TableCodec::encode_index_key(table_name, index)?; self.tx.remove(&key)?; @@ -183,7 +141,7 @@ impl Transaction for KipTransaction { fn append( &mut self, - table_name: &String, + table_name: &str, tuple: Tuple, is_overwrite: bool, ) -> Result<(), StorageError> { @@ -204,28 +162,101 @@ impl Transaction for KipTransaction { Ok(()) } + fn alter_table(&mut self, op: &AlterTableOperator) -> Result<(), StorageError> { + match op { + AlterTableOperator::AddColumn(AddColumn { + table_name, + if_not_exists, + column, + }) => { + // we need catalog generate col_id && index_id + // generally catalog is immutable, so do not worry it changed when alter table going on + if let Some(mut catalog) = self.table(table_name.clone()).cloned() { + // not yet supported default value + if !column.nullable { + return Err(StorageError::NeedNullAble); + } + + for col in catalog.all_columns() { + if col.name() == column.name() { + if *if_not_exists { + return Ok(()); + } else { + return Err(StorageError::DuplicateColumn); + } + } + } + + let col_id = catalog.add_column(column.clone())?; + + if column.desc.is_unique { + let meta = IndexMeta { + id: 0, + column_ids: vec![col_id], + name: format!("uk_{}", column.name()), + is_unique: true, + is_primary: false, + }; + let meta_ref = catalog.add_index_meta(meta); + let (key, value) = TableCodec::encode_index_meta(table_name, meta_ref)?; + self.tx.set(key, value); + } + + let column = catalog.get_column_by_id(&col_id).unwrap(); + let (key, value) = TableCodec::encode_column(&table_name, column)?; + self.tx.set(key, value); + + Ok(()) + } else { + return Err(StorageError::TableNotFound); + } + } + AlterTableOperator::DropColumn => todo!(), + AlterTableOperator::DropPrimaryKey => todo!(), + AlterTableOperator::RenameColumn => todo!(), + AlterTableOperator::RenameTable => todo!(), + AlterTableOperator::ChangeColumn => todo!(), + AlterTableOperator::AlterColumn => todo!(), + } + } + fn create_table( &mut self, table_name: TableName, columns: Vec, + if_not_exists: bool, ) -> Result { + let (table_key, value) = TableCodec::encode_root_table(&table_name)?; + if self.tx.get(&table_key)?.is_some() { + if if_not_exists { + return Ok(table_name); + } + return Err(StorageError::TableExists); + } + self.tx.set(table_key, value); + let mut table_catalog = TableCatalog::new(table_name.clone(), columns)?; Self::create_index_meta_for_table(&mut self.tx, &mut table_catalog)?; - for (_, column) in &table_catalog.columns { - let (key, value) = TableCodec::encode_column(column)?; + for column in table_catalog.columns.values() { + let (key, value) = TableCodec::encode_column(&table_name, column)?; self.tx.set(key, value); } - let (table_key, value) = TableCodec::encode_root_table(&table_name)?; - self.tx.set(table_key, value); - self.cache.put(table_name.to_string(), table_catalog); Ok(table_name) } - fn drop_table(&mut self, table_name: &String) -> Result<(), StorageError> { + fn drop_table(&mut self, table_name: &str, if_exists: bool) -> Result<(), StorageError> { + if self.table(Arc::new(table_name.to_string())).is_none() { + if if_exists { + return Ok(()); + } else { + return Err(StorageError::TableNotFound); + } + } + self.drop_data(table_name)?; let (min, max) = TableCodec::columns_bound(table_name); @@ -245,12 +276,12 @@ impl Transaction for KipTransaction { self.tx .remove(&TableCodec::encode_root_table_key(table_name))?; - let _ = self.cache.remove(table_name); + let _ = self.cache.remove(&table_name.to_string()); Ok(()) } - fn drop_data(&mut self, table_name: &String) -> Result<(), StorageError> { + fn drop_data(&mut self, table_name: &str) -> Result<(), StorageError> { let (tuple_min, tuple_max) = TableCodec::tuple_bound(table_name); Self::_drop_data(&mut self.tx, &tuple_min, &tuple_max)?; @@ -260,17 +291,17 @@ impl Transaction for KipTransaction { Ok(()) } - fn table(&self, table_name: &String) -> Option<&TableCatalog> { - let mut option = self.cache.get(table_name); + fn table(&self, table_name: TableName) -> Option<&TableCatalog> { + let mut option = self.cache.get(&table_name); if option.is_none() { // TODO: unify the data into a `Meta` prefix and use one iteration to collect all data - let (columns, name_option) = Self::column_collect(table_name, &self.tx).ok()?; - let indexes = Self::index_meta_collect(table_name, &self.tx)?; + let columns = Self::column_collect(table_name.clone(), &self.tx).ok()?; + let indexes = Self::index_meta_collect(&table_name, &self.tx)?; - if let Some(catalog) = name_option.and_then(|table_name| { - TableCatalog::new_with_indexes(table_name, columns, indexes).ok() - }) { + if let Ok(catalog) = + TableCatalog::new_with_indexes(table_name.clone(), columns, indexes) + { option = self .cache .get_or_insert(table_name.to_string(), |_| Ok(catalog)) @@ -305,105 +336,26 @@ impl Transaction for KipTransaction { } impl KipTransaction { - fn val_to_key( - table_name: &String, - index_meta: &IndexMetaRef, - val: ValueRef, - ) -> Result, TypeError> { - let index = Index::new(index_meta.id, vec![val]); - - TableCodec::encode_index_key(table_name, &index) - } - - fn scope_to_iter( - &self, - table_name: &String, - index_meta: &IndexMetaRef, - min: Bound, - max: Bound, - ) -> Result { - let bound_encode = |bound: Bound| -> Result<_, StorageError> { - match bound { - Bound::Included(val) => Ok(Bound::Included(Self::val_to_key( - table_name, - &index_meta, - val, - )?)), - Bound::Excluded(val) => Ok(Bound::Excluded(Self::val_to_key( - table_name, - &index_meta, - val, - )?)), - Bound::Unbounded => Ok(Bound::Unbounded), - } - }; - let check_bound = |value: &mut Bound>, bound: Vec| { - if matches!(value, Bound::Unbounded) { - let _ = mem::replace(value, Bound::Included(bound)); - } - }; - let (bound_min, bound_max) = TableCodec::index_bound(table_name, &index_meta.id); - - let mut encode_min = bound_encode(min)?; - check_bound(&mut encode_min, bound_min); - - let mut encode_max = bound_encode(max)?; - check_bound(&mut encode_max, bound_max); - - Ok(self.tx.iter( - encode_min.as_ref().map(Vec::as_slice), - encode_max.as_ref().map(Vec::as_slice), - )?) - } - - fn offset_move(offset: &mut usize) -> bool { - if *offset > 0 { - offset.sub_assign(1); - - true - } else { - false - } - } - - fn limit_move(limit_option: &mut Option) -> bool { - if let Some(limit) = limit_option { - limit.sub_assign(1); - - return *limit == 0; - } - - false - } - fn column_collect( - name: &String, + table_name: TableName, tx: &mvcc::Transaction, - ) -> Result<(Vec, Option), StorageError> { - let (column_min, column_max) = TableCodec::columns_bound(name); + ) -> Result, StorageError> { + let (column_min, column_max) = TableCodec::columns_bound(&table_name); let mut column_iter = tx.iter(Bound::Included(&column_min), Bound::Included(&column_max))?; let mut columns = vec![]; - let mut name_option = None; while let Some((_, value_option)) = column_iter.try_next().ok().flatten() { if let Some(value) = value_option { - let (table_name, column) = TableCodec::decode_column(&value)?; - - if name != table_name.as_str() { - return Ok((vec![], None)); - } - let _ = name_option.insert(table_name); - - columns.push(column); + columns.push(TableCodec::decode_column(&value)?); } } - Ok((columns, name_option)) + Ok(columns) } - fn index_meta_collect(name: &String, tx: &mvcc::Transaction) -> Option> { + fn index_meta_collect(name: &str, tx: &mvcc::Transaction) -> Option> { let (index_min, index_max) = TableCodec::index_meta_bound(name); let mut index_metas = vec![]; let mut index_iter = tx @@ -412,7 +364,7 @@ impl KipTransaction { while let Some((_, value_option)) = index_iter.try_next().ok().flatten() { if let Some(value) = value_option { - if let Some(index_meta) = TableCodec::decode_index_meta(&value).ok() { + if let Ok(index_meta) = TableCodec::decode_index_meta(&value) { index_metas.push(Arc::new(index_meta)); } } @@ -422,7 +374,7 @@ impl KipTransaction { } fn _drop_data(tx: &mut mvcc::Transaction, min: &[u8], max: &[u8]) -> Result<(), StorageError> { - let mut iter = tx.iter(Bound::Included(&min), Bound::Included(&max))?; + let mut iter = tx.iter(Bound::Included(min), Bound::Included(max))?; let mut data_keys = vec![]; while let Some((key, value_option)) = iter.try_next()? { @@ -448,14 +400,19 @@ impl KipTransaction { for col in table .all_columns() .into_iter() - .filter(|col| col.desc.is_unique) + .filter(|col| col.desc.is_index()) { + let is_primary = col.desc.is_primary; + // FIXME: composite indexes may exist on future + let prefix = if is_primary { "pk" } else { "uk" }; + if let Some(col_id) = col.id() { let meta = IndexMeta { id: 0, column_ids: vec![col_id], - name: format!("uk_{}", col.name()), - is_unique: true, + name: format!("{}_{}", prefix, col.name()), + is_unique: col.desc.is_unique, + is_primary, }; let meta_ref = table.add_index_meta(meta); let (key, value) = TableCodec::encode_index_meta(&table_name, meta_ref)?; @@ -512,6 +469,7 @@ mod test { use crate::expression::ScalarExpression; use crate::storage::kip::KipStorage; use crate::storage::{IndexIter, Iter, Storage, StorageError, Transaction}; + use crate::types::index::IndexMeta; use crate::types::tuple::Tuple; use crate::types::value::DataValue; use crate::types::LogicalType; @@ -529,13 +487,13 @@ mod test { Arc::new(ColumnCatalog::new( "c1".to_string(), false, - ColumnDesc::new(LogicalType::Integer, true, false), + ColumnDesc::new(LogicalType::Integer, true, false, None), None, )), Arc::new(ColumnCatalog::new( "c2".to_string(), false, - ColumnDesc::new(LogicalType::Boolean, false, false), + ColumnDesc::new(LogicalType::Boolean, false, false, None), None, )), ]; @@ -544,9 +502,9 @@ mod test { .iter() .map(|col_ref| ColumnCatalog::clone(&col_ref)) .collect_vec(); - let _ = transaction.create_table(Arc::new("test".to_string()), source_columns)?; + let _ = transaction.create_table(Arc::new("test".to_string()), source_columns, false)?; - let table_catalog = transaction.table(&"test".to_string()); + let table_catalog = transaction.table(Arc::new("test".to_string())); assert!(table_catalog.is_some()); assert!(table_catalog .unwrap() @@ -579,7 +537,7 @@ mod test { )?; let mut iter = transaction.read( - &"test".to_string(), + Arc::new("test".to_string()), (Some(1), Some(1)), vec![ScalarExpression::ColumnRef(columns[0].clone())], )?; @@ -597,17 +555,20 @@ mod test { } #[tokio::test] - async fn test_index_iter() -> Result<(), DatabaseError> { + async fn test_index_iter_pk() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); let kipsql = Database::with_kipdb(temp_dir.path()).await?; let _ = kipsql.run("create table t1 (a int primary key)").await?; let _ = kipsql - .run("insert into t1 (a) values (0), (1), (2)") + .run("insert into t1 (a) values (0), (1), (2), (3), (4)") .await?; let transaction = kipsql.storage.transaction().await?; - let table = transaction.table(&"t1".to_string()).unwrap().clone(); + let table = transaction + .table(Arc::new("t1".to_string())) + .unwrap() + .clone(); let projections = table .all_columns() .into_iter() @@ -615,14 +576,32 @@ mod test { .collect_vec(); let tuple_ids = vec![ Arc::new(DataValue::Int32(Some(0))), - Arc::new(DataValue::Int32(Some(1))), Arc::new(DataValue::Int32(Some(2))), + Arc::new(DataValue::Int32(Some(3))), + Arc::new(DataValue::Int32(Some(4))), ]; let mut iter = IndexIter { + offset: 0, + limit: None, projections, + index_meta: Arc::new(IndexMeta { + id: 0, + column_ids: vec![0], + name: "pk_a".to_string(), + is_unique: false, + is_primary: true, + }), table: &table, - tuple_ids: VecDeque::from(tuple_ids.clone()), + binaries: VecDeque::from(vec![ + ConstantBinary::Eq(Arc::new(DataValue::Int32(Some(0)))), + ConstantBinary::Scope { + min: Bound::Included(Arc::new(DataValue::Int32(Some(2)))), + max: Bound::Included(Arc::new(DataValue::Int32(Some(4)))), + }, + ]), + index_values: VecDeque::new(), tx: &transaction.tx, + scope_iter: None, }; let mut result = Vec::new(); @@ -647,7 +626,10 @@ mod test { .await?; let transaction = kipsql.storage.transaction().await.unwrap(); - let table = transaction.table(&"t1".to_string()).unwrap().clone(); + let table = transaction + .table(Arc::new("t1".to_string())) + .unwrap() + .clone(); let projections = table .all_columns() .into_iter() @@ -655,7 +637,7 @@ mod test { .collect_vec(); let mut iter = transaction .read_by_index( - &"t1".to_string(), + Arc::new("t1".to_string()), (Some(0), Some(1)), projections, table.indexes[0].clone(), diff --git a/src/storage/mod.rs b/src/storage/mod.rs index ac561151..9f8cf2b8 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -4,13 +4,17 @@ mod table_codec; use crate::catalog::{CatalogError, ColumnCatalog, TableCatalog, TableName}; use crate::expression::simplify::ConstantBinary; use crate::expression::ScalarExpression; +use crate::planner::operator::alter_table::AlterTableOperator; use crate::storage::table_codec::TableCodec; use crate::types::errors::TypeError; use crate::types::index::{Index, IndexMetaRef}; use crate::types::tuple::{Tuple, TupleId}; +use crate::types::value::ValueRef; +use kip_db::kernel::lsm::iterator::Iter as DBIter; use kip_db::kernel::lsm::mvcc; use kip_db::KernelError; -use std::collections::VecDeque; +use std::collections::{Bound, VecDeque}; +use std::mem; use std::ops::SubAssign; pub trait Storage: Sync + Send + Clone + 'static { @@ -32,14 +36,14 @@ pub trait Transaction: Sync + Send + 'static { /// The projections is column indices. fn read( &self, - table_name: &String, + table_name: TableName, bounds: Bounds, projection: Projections, ) -> Result, StorageError>; fn read_by_index( &self, - table_name: &String, + table_name: TableName, bounds: Bounds, projection: Projections, index_meta: IndexMetaRef, @@ -48,32 +52,33 @@ pub trait Transaction: Sync + Send + 'static { fn add_index( &mut self, - table_name: &String, + table_name: &str, index: Index, tuple_ids: Vec, is_unique: bool, ) -> Result<(), StorageError>; - fn del_index(&mut self, table_name: &String, index: &Index) -> Result<(), StorageError>; + fn del_index(&mut self, table_name: &str, index: &Index) -> Result<(), StorageError>; fn append( &mut self, - table_name: &String, + table_name: &str, tuple: Tuple, is_overwrite: bool, ) -> Result<(), StorageError>; fn delete(&mut self, table_name: &String, tuple_id: TupleId) -> Result<(), StorageError>; - + fn alter_table(&mut self, op: &AlterTableOperator) -> Result<(), StorageError>; fn create_table( &mut self, table_name: TableName, columns: Vec, + if_not_exists: bool, ) -> Result; - fn drop_table(&mut self, table_name: &String) -> Result<(), StorageError>; - fn drop_data(&mut self, table_name: &String) -> Result<(), StorageError>; - fn table(&self, table_name: &String) -> Option<&TableCatalog>; + fn drop_table(&mut self, table_name: &str, if_exists: bool) -> Result<(), StorageError>; + fn drop_data(&mut self, table_name: &str) -> Result<(), StorageError>; + fn table(&self, table_name: TableName) -> Option<&TableCatalog>; fn show_tables(&self) -> Result, StorageError>; @@ -81,31 +86,182 @@ pub trait Transaction: Sync + Send + 'static { async fn commit(self) -> Result<(), StorageError>; } +enum IndexValue { + PrimaryKey(Tuple), + Normal(TupleId), +} + // TODO: Table return optimization pub struct IndexIter<'a> { + offset: usize, + limit: Option, projections: Projections, + + index_meta: IndexMetaRef, table: &'a TableCatalog, - tuple_ids: VecDeque, tx: &'a mvcc::Transaction, + + // for buffering data + index_values: VecDeque, + binaries: VecDeque, + scope_iter: Option>, +} + +impl IndexIter<'_> { + fn offset_move(offset: &mut usize) -> bool { + if *offset > 0 { + offset.sub_assign(1); + + true + } else { + false + } + } + + fn val_to_key(&self, val: ValueRef) -> Result, TypeError> { + if self.index_meta.is_unique { + let index = Index::new(self.index_meta.id, vec![val]); + + TableCodec::encode_index_key(&self.table.name, &index) + } else { + TableCodec::encode_tuple_key(&self.table.name, &val) + } + } + + fn get_tuple_by_id(&mut self, tuple_id: &TupleId) -> Result, StorageError> { + let key = TableCodec::encode_tuple_key(&self.table.name, &tuple_id)?; + + self.tx + .get(&key)? + .map(|bytes| { + let tuple = TableCodec::decode_tuple(self.table.all_columns(), &bytes); + + tuple_projection(&mut self.limit, &self.projections, tuple) + }) + .transpose() + } + + fn is_empty(&self) -> bool { + self.scope_iter.is_none() && self.index_values.is_empty() && self.binaries.is_empty() + } } impl Iter for IndexIter<'_> { fn next_tuple(&mut self) -> Result, StorageError> { - if let Some(tuple_id) = self.tuple_ids.pop_front() { - let key = TableCodec::encode_tuple_key(&self.table.name, &tuple_id)?; - - Ok(self - .tx - .get(&key)? - .map(|bytes| { - let tuple = TableCodec::decode_tuple(self.table.all_columns(), &bytes); - - tuple_projection(&mut None, &self.projections, tuple) - }) - .transpose()?) - } else { - Ok(None) + // 1. check limit + if matches!(self.limit, Some(0)) || self.is_empty() { + self.scope_iter = None; + self.binaries.clear(); + + return Ok(None); + } + // 2. try get tuple on index_values and until it empty + loop { + if let Some(value) = self.index_values.pop_front() { + if Self::offset_move(&mut self.offset) { + continue; + } + match value { + IndexValue::PrimaryKey(tuple) => { + let tuple = tuple_projection(&mut self.limit, &self.projections, tuple)?; + + return Ok(Some(tuple)); + } + IndexValue::Normal(tuple_id) => { + if let Some(tuple) = self.get_tuple_by_id(&tuple_id)? { + return Ok(Some(tuple)); + } + } + } + } else { + break; + } + } + assert!(self.index_values.is_empty()); + + // 3. If the current expression is a Scope, + // an iterator will be generated for reading the IndexValues of the Scope. + if let Some(iter) = &mut self.scope_iter { + let mut has_next = false; + while let Some((_, value_option)) = iter.try_next()? { + if let Some(value) = value_option { + if self.index_meta.is_primary { + let tuple = TableCodec::decode_tuple(self.table.all_columns(), &value); + + self.index_values.push_back(IndexValue::PrimaryKey(tuple)); + } else { + for tuple_id in TableCodec::decode_index(&value)? { + self.index_values.push_back(IndexValue::Normal(tuple_id)); + } + } + has_next = true; + break; + } + } + if !has_next { + self.scope_iter = None; + } + return self.next_tuple(); } + + // 4. When `scope_iter` and `index_values` do not have a value, use the next expression to iterate + if let Some(binary) = self.binaries.pop_front() { + match binary { + ConstantBinary::Scope { min, max } => { + let table_name = &self.table.name; + let index_meta = &self.index_meta; + + let bound_encode = |bound: Bound| -> Result<_, StorageError> { + match bound { + Bound::Included(val) => Ok(Bound::Included(self.val_to_key(val)?)), + Bound::Excluded(val) => Ok(Bound::Excluded(self.val_to_key(val)?)), + Bound::Unbounded => Ok(Bound::Unbounded), + } + }; + let check_bound = |value: &mut Bound>, bound: Vec| { + if matches!(value, Bound::Unbounded) { + let _ = mem::replace(value, Bound::Included(bound)); + } + }; + let (bound_min, bound_max) = if index_meta.is_unique { + TableCodec::index_bound(table_name, &index_meta.id) + } else { + TableCodec::tuple_bound(table_name) + }; + + let mut encode_min = bound_encode(min)?; + check_bound(&mut encode_min, bound_min); + + let mut encode_max = bound_encode(max)?; + check_bound(&mut encode_max, bound_max); + + let iter = self.tx.iter( + encode_min.as_ref().map(Vec::as_slice), + encode_max.as_ref().map(Vec::as_slice), + )?; + self.scope_iter = Some(iter); + } + ConstantBinary::Eq(val) => { + let key = self.val_to_key(val)?; + if let Some(bytes) = self.tx.get(&key)? { + if self.index_meta.is_unique { + for tuple_id in TableCodec::decode_index(&bytes)? { + self.index_values.push_back(IndexValue::Normal(tuple_id)); + } + } else if self.index_meta.is_primary { + let tuple = TableCodec::decode_tuple(self.table.all_columns(), &bytes); + + self.index_values.push_back(IndexValue::PrimaryKey(tuple)); + } else { + todo!() + } + } + self.scope_iter = None; + } + _ => (), + } + } + self.next_tuple() } } @@ -157,6 +313,15 @@ pub enum StorageError { #[error("The table not found")] TableNotFound, + + #[error("The some column already exists")] + DuplicateColumn, + + #[error("Add column need nullable")] + NeedNullAble, + + #[error("The table already exists")] + TableExists, } impl From for StorageError { diff --git a/src/storage/table_codec.rs b/src/storage/table_codec.rs index 483eab2e..6ce89422 100644 --- a/src/storage/table_codec.rs +++ b/src/storage/table_codec.rs @@ -1,4 +1,4 @@ -use crate::catalog::{ColumnCatalog, ColumnRef, TableName}; +use crate::catalog::{ColumnCatalog, ColumnRef}; use crate::types::errors::TypeError; use crate::types::index::{Index, IndexId, IndexMeta}; use crate::types::tuple::{Tuple, TupleId}; @@ -27,8 +27,8 @@ impl TableCodec { /// TableName + Type /// /// Tips: Root full key = key_prefix - fn key_prefix(ty: CodecType, table_name: &String) -> Vec { - let mut table_bytes = table_name.clone().into_bytes(); + fn key_prefix(ty: CodecType, table_name: &str) -> Vec { + let mut table_bytes = table_name.to_string().into_bytes(); match ty { CodecType::Column => { @@ -55,9 +55,9 @@ impl TableCodec { table_bytes } - pub fn tuple_bound(name: &String) -> (Vec, Vec) { + pub fn tuple_bound(table_name: &str) -> (Vec, Vec) { let op = |bound_id| { - let mut key_prefix = Self::key_prefix(CodecType::Tuple, name); + let mut key_prefix = Self::key_prefix(CodecType::Tuple, table_name); key_prefix.push(bound_id); key_prefix @@ -66,9 +66,9 @@ impl TableCodec { (op(BOUND_MIN_TAG), op(BOUND_MAX_TAG)) } - pub fn index_meta_bound(name: &String) -> (Vec, Vec) { + pub fn index_meta_bound(table_name: &str) -> (Vec, Vec) { let op = |bound_id| { - let mut key_prefix = Self::key_prefix(CodecType::IndexMeta, name); + let mut key_prefix = Self::key_prefix(CodecType::IndexMeta, table_name); key_prefix.push(bound_id); key_prefix @@ -77,9 +77,9 @@ impl TableCodec { (op(BOUND_MIN_TAG), op(BOUND_MAX_TAG)) } - pub fn index_bound(name: &String, index_id: &IndexId) -> (Vec, Vec) { + pub fn index_bound(table_name: &str, index_id: &IndexId) -> (Vec, Vec) { let op = |bound_id| { - let mut key_prefix = Self::key_prefix(CodecType::Index, name); + let mut key_prefix = Self::key_prefix(CodecType::Index, table_name); key_prefix.push(BOUND_MIN_TAG); key_prefix.append(&mut index_id.to_be_bytes().to_vec()); @@ -90,9 +90,9 @@ impl TableCodec { (op(BOUND_MIN_TAG), op(BOUND_MAX_TAG)) } - pub fn all_index_bound(name: &String) -> (Vec, Vec) { + pub fn all_index_bound(table_name: &str) -> (Vec, Vec) { let op = |bound_id| { - let mut key_prefix = Self::key_prefix(CodecType::Index, name); + let mut key_prefix = Self::key_prefix(CodecType::Index, table_name); key_prefix.push(bound_id); key_prefix @@ -112,9 +112,9 @@ impl TableCodec { (op(BOUND_MIN_TAG), op(BOUND_MAX_TAG)) } - pub fn columns_bound(name: &String) -> (Vec, Vec) { + pub fn columns_bound(table_name: &str) -> (Vec, Vec) { let op = |bound_id| { - let mut key_prefix = Self::key_prefix(CodecType::Column, name); + let mut key_prefix = Self::key_prefix(CodecType::Column, table_name); key_prefix.push(bound_id); key_prefix @@ -125,15 +125,15 @@ impl TableCodec { /// Key: TableName_Tuple_0_RowID(Sorted) /// Value: Tuple - pub fn encode_tuple(name: &String, tuple: &Tuple) -> Result<(Bytes, Bytes), TypeError> { + pub fn encode_tuple(table_name: &str, tuple: &Tuple) -> Result<(Bytes, Bytes), TypeError> { let tuple_id = tuple.id.clone().ok_or(TypeError::PrimaryKeyNotFound)?; - let key = Self::encode_tuple_key(name, &tuple_id)?; + let key = Self::encode_tuple_key(table_name, &tuple_id)?; Ok((Bytes::from(key), Bytes::from(tuple.serialize_to()))) } - pub fn encode_tuple_key(name: &String, tuple_id: &TupleId) -> Result, TypeError> { - let mut key_prefix = Self::key_prefix(CodecType::Tuple, name); + pub fn encode_tuple_key(table_name: &str, tuple_id: &TupleId) -> Result, TypeError> { + let mut key_prefix = Self::key_prefix(CodecType::Tuple, table_name); key_prefix.push(BOUND_MIN_TAG); tuple_id.to_primary_key(&mut key_prefix)?; @@ -148,10 +148,10 @@ impl TableCodec { /// Key: TableName_IndexMeta_0_IndexID /// Value: IndexMeta pub fn encode_index_meta( - name: &String, + table_name: &str, index_meta: &IndexMeta, ) -> Result<(Bytes, Bytes), TypeError> { - let mut key_prefix = Self::key_prefix(CodecType::IndexMeta, &name); + let mut key_prefix = Self::key_prefix(CodecType::IndexMeta, table_name); key_prefix.push(BOUND_MIN_TAG); key_prefix.append(&mut index_meta.id.to_be_bytes().to_vec()); @@ -176,7 +176,7 @@ impl TableCodec { /// Tips: The unique index has only one ColumnID and one corresponding DataValue, /// so it can be positioned directly. pub fn encode_index( - name: &String, + name: &str, index: &Index, tuple_ids: &[TupleId], ) -> Result<(Bytes, Bytes), TypeError> { @@ -188,7 +188,7 @@ impl TableCodec { )) } - pub fn encode_index_key(name: &String, index: &Index) -> Result, TypeError> { + pub fn encode_index_key(name: &str, index: &Index) -> Result, TypeError> { let mut key_prefix = Self::key_prefix(CodecType::Index, name); key_prefix.push(BOUND_MIN_TAG); key_prefix.append(&mut index.id.to_be_bytes().to_vec()); @@ -209,9 +209,12 @@ impl TableCodec { /// Value: ColumnCatalog /// /// Tips: the `0` for bound range - pub fn encode_column(col: &ColumnCatalog) -> Result<(Bytes, Bytes), TypeError> { + pub fn encode_column( + table_name: &str, + col: &ColumnCatalog, + ) -> Result<(Bytes, Bytes), TypeError> { let bytes = bincode::serialize(col)?; - let mut key_prefix = Self::key_prefix(CodecType::Column, &col.table_name().unwrap()); + let mut key_prefix = Self::key_prefix(CodecType::Column, table_name); key_prefix.push(BOUND_MIN_TAG); key_prefix.append(&mut col.id().unwrap().to_be_bytes().to_vec()); @@ -219,25 +222,23 @@ impl TableCodec { Ok((Bytes::from(key_prefix), Bytes::from(bytes))) } - pub fn decode_column(bytes: &[u8]) -> Result<(TableName, ColumnCatalog), TypeError> { - let column = bincode::deserialize::(bytes)?; - - Ok((column.table_name().unwrap(), column)) + pub fn decode_column(bytes: &[u8]) -> Result { + Ok(bincode::deserialize::(bytes)?) } /// Key: RootCatalog_0_TableName /// Value: TableName - pub fn encode_root_table(table_name: &String) -> Result<(Bytes, Bytes), TypeError> { + pub fn encode_root_table(table_name: &str) -> Result<(Bytes, Bytes), TypeError> { let key = Self::encode_root_table_key(table_name); Ok(( Bytes::from(key), - Bytes::from(table_name.clone().into_bytes()), + Bytes::from(table_name.to_owned().into_bytes()), )) } - pub fn encode_root_table_key(table_name: &String) -> Vec { - Self::key_prefix(CodecType::Root, &table_name) + pub fn encode_root_table_key(table_name: &str) -> Vec { + Self::key_prefix(CodecType::Root, table_name) } pub fn decode_root_table(bytes: &[u8]) -> Result { @@ -266,13 +267,13 @@ mod tests { ColumnCatalog::new( "c1".into(), false, - ColumnDesc::new(LogicalType::Integer, true, false), + ColumnDesc::new(LogicalType::Integer, true, false, None), None, ), ColumnCatalog::new( "c2".into(), false, - ColumnDesc::new(LogicalType::Decimal(None, None), false, false), + ColumnDesc::new(LogicalType::Decimal(None, None), false, false, None), None, ), ]; @@ -318,6 +319,7 @@ mod tests { column_ids: vec![0], name: "index_1".to_string(), is_unique: false, + is_primary: false, }; let (_, bytes) = TableCodec::encode_index_meta(&"T1".to_string(), &index_meta)?; @@ -346,12 +348,11 @@ mod tests { fn test_table_codec_column() { let table_catalog = build_table_codec(); let col = table_catalog.all_columns()[0].clone(); - let (_, bytes) = TableCodec::encode_column(&col).unwrap(); - let (table_name, decode_col) = TableCodec::decode_column(&bytes).unwrap(); + let (_, bytes) = TableCodec::encode_column(&table_catalog.name, &col).unwrap(); + let decode_col = TableCodec::decode_column(&bytes).unwrap(); assert_eq!(&decode_col, col.as_ref()); - assert_eq!(table_name, table_catalog.name); } #[test] @@ -365,14 +366,14 @@ mod tests { column_datatype: LogicalType::Invalid, is_primary: false, is_unique: false, + default: None, }, None, ); - col.summary.table_name = Some(Arc::new(table_name.to_string())); col.summary.id = Some(col_id as u32); - let (key, _) = TableCodec::encode_column(&col).unwrap(); + let (key, _) = TableCodec::encode_column(&table_name.to_string(), &col).unwrap(); key }; @@ -413,6 +414,7 @@ mod tests { column_ids: vec![], name: "".to_string(), is_unique: false, + is_primary: false, }; let (key, _) = @@ -451,7 +453,13 @@ mod tests { #[test] fn test_table_codec_index_bound() { let mut set = BTreeSet::new(); - let table_catalog = TableCatalog::new(Arc::new("T0".to_string()), vec![]).unwrap(); + let column = ColumnCatalog::new( + "".to_string(), + false, + ColumnDesc::new(LogicalType::Boolean, false, false, None), + None, + ); + let table_catalog = TableCatalog::new(Arc::new("T0".to_string()), vec![column]).unwrap(); let op = |value: DataValue, index_id: usize, table_name: &String| { let index = Index { diff --git a/src/types/index.rs b/src/types/index.rs index bd0f7e3d..c06df10d 100644 --- a/src/types/index.rs +++ b/src/types/index.rs @@ -12,6 +12,7 @@ pub struct IndexMeta { pub column_ids: Vec, pub name: String, pub is_unique: bool, + pub is_primary: bool, } pub struct Index { diff --git a/src/types/mod.rs b/src/types/mod.rs index de2d29f4..b5ed965c 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -162,12 +162,12 @@ impl LogicalType { right: &LogicalType, ) -> Result { if left == right { - return Ok(left.clone()); + return Ok(*left); } match (left, right) { // SqlNull type can be cast to anything - (LogicalType::SqlNull, _) => return Ok(right.clone()), - (_, LogicalType::SqlNull) => return Ok(left.clone()), + (LogicalType::SqlNull, _) => return Ok(*right), + (_, LogicalType::SqlNull) => return Ok(*left), _ => {} } if left.is_numeric() && right.is_numeric() { @@ -204,7 +204,7 @@ impl LogicalType { right: &LogicalType, ) -> Result { if left == right { - return Ok(left.clone()); + return Ok(*left); } if left.is_signed_numeric() && right.is_unsigned_numeric() { // this method is symmetric @@ -214,10 +214,10 @@ impl LogicalType { } if LogicalType::can_implicit_cast(left, right) { - return Ok(right.clone()); + return Ok(*right); } if LogicalType::can_implicit_cast(right, left) { - return Ok(left.clone()); + return Ok(*left); } // we can't cast implicitly either way and types are not equal // this happens when left is signed and right is unsigned @@ -341,6 +341,6 @@ impl TryFrom for LogicalType { impl std::fmt::Display for LogicalType { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.as_ref()) + write!(f, "{}", self.as_ref().to_uppercase()) } } diff --git a/src/types/tuple.rs b/src/types/tuple.rs index 0e75aaf3..02522463 100644 --- a/src/types/tuple.rs +++ b/src/types/tuple.rs @@ -124,6 +124,7 @@ mod tests { use crate::types::tuple::Tuple; use crate::types::value::DataValue; use crate::types::LogicalType; + use rust_decimal::Decimal; use std::sync::Arc; #[test] @@ -132,73 +133,79 @@ mod tests { Arc::new(ColumnCatalog::new( "c1".to_string(), false, - ColumnDesc::new(LogicalType::Integer, true, false), + ColumnDesc::new(LogicalType::Integer, true, false, None), None, )), Arc::new(ColumnCatalog::new( "c2".to_string(), false, - ColumnDesc::new(LogicalType::UInteger, false, false), + ColumnDesc::new(LogicalType::UInteger, false, false, None), None, )), Arc::new(ColumnCatalog::new( "c3".to_string(), false, - ColumnDesc::new(LogicalType::Varchar(Some(2)), false, false), + ColumnDesc::new(LogicalType::Varchar(Some(2)), false, false, None), None, )), Arc::new(ColumnCatalog::new( "c4".to_string(), false, - ColumnDesc::new(LogicalType::Smallint, false, false), + ColumnDesc::new(LogicalType::Smallint, false, false, None), None, )), Arc::new(ColumnCatalog::new( "c5".to_string(), false, - ColumnDesc::new(LogicalType::USmallint, false, false), + ColumnDesc::new(LogicalType::USmallint, false, false, None), None, )), Arc::new(ColumnCatalog::new( "c6".to_string(), false, - ColumnDesc::new(LogicalType::Float, false, false), + ColumnDesc::new(LogicalType::Float, false, false, None), None, )), Arc::new(ColumnCatalog::new( "c7".to_string(), false, - ColumnDesc::new(LogicalType::Double, false, false), + ColumnDesc::new(LogicalType::Double, false, false, None), None, )), Arc::new(ColumnCatalog::new( "c8".to_string(), false, - ColumnDesc::new(LogicalType::Tinyint, false, false), + ColumnDesc::new(LogicalType::Tinyint, false, false, None), None, )), Arc::new(ColumnCatalog::new( "c9".to_string(), false, - ColumnDesc::new(LogicalType::UTinyint, false, false), + ColumnDesc::new(LogicalType::UTinyint, false, false, None), None, )), Arc::new(ColumnCatalog::new( "c10".to_string(), false, - ColumnDesc::new(LogicalType::Boolean, false, false), + ColumnDesc::new(LogicalType::Boolean, false, false, None), None, )), Arc::new(ColumnCatalog::new( "c11".to_string(), false, - ColumnDesc::new(LogicalType::DateTime, false, false), + ColumnDesc::new(LogicalType::DateTime, false, false, None), None, )), Arc::new(ColumnCatalog::new( "c12".to_string(), false, - ColumnDesc::new(LogicalType::Date, false, false), + ColumnDesc::new(LogicalType::Date, false, false, None), + None, + )), + Arc::new(ColumnCatalog::new( + "c13".to_string(), + false, + ColumnDesc::new(LogicalType::Decimal(None, None), false, false, None), None, )), ]; @@ -220,6 +227,7 @@ mod tests { Arc::new(DataValue::Boolean(Some(true))), Arc::new(DataValue::Date64(Some(0))), Arc::new(DataValue::Date32(Some(0))), + Arc::new(DataValue::Decimal(Some(Decimal::new(0, 3)))), ], }, Tuple { @@ -238,6 +246,7 @@ mod tests { Arc::new(DataValue::Boolean(None)), Arc::new(DataValue::Date64(None)), Arc::new(DataValue::Date32(None)), + Arc::new(DataValue::Decimal(None)), ], }, ]; diff --git a/src/types/value.rs b/src/types/value.rs index af3ddd73..49e99207 100644 --- a/src/types/value.rs +++ b/src/types/value.rs @@ -286,10 +286,7 @@ impl DataValue { } pub fn is_variable(&self) -> bool { - match self { - DataValue::Utf8(_) => true, - _ => false, - } + matches!(self, DataValue::Utf8(_)) } pub fn is_null(&self) -> bool { @@ -374,7 +371,7 @@ impl DataValue { DataValue::Utf8(v) => v.clone().map(|v| v.into_bytes()), DataValue::Date32(v) => v.map(|v| v.encode_fixed_vec()), DataValue::Date64(v) => v.map(|v| v.encode_fixed_vec()), - DataValue::Decimal(v) => v.clone().map(|v| v.serialize().to_vec()), + DataValue::Decimal(v) => v.map(|v| v.serialize().to_vec()), } .unwrap_or(vec![]) } @@ -383,7 +380,7 @@ impl DataValue { match ty { LogicalType::Invalid => panic!("invalid logical type"), LogicalType::SqlNull => DataValue::Null, - LogicalType::Boolean => DataValue::Boolean(bytes.get(0).map(|v| *v != 0)), + LogicalType::Boolean => DataValue::Boolean(bytes.first().map(|v| *v != 0)), LogicalType::Tinyint => { DataValue::Int8((!bytes.is_empty()).then(|| i8::decode_fixed(bytes))) } @@ -650,18 +647,16 @@ impl DataValue { DataValue::Int8(value) => match to { LogicalType::SqlNull => Ok(DataValue::Null), LogicalType::Tinyint => Ok(DataValue::Int8(value)), - LogicalType::UTinyint => Ok(DataValue::UInt8( - value.map(|v| u8::try_from(v)).transpose()?, - )), - LogicalType::USmallint => Ok(DataValue::UInt16( - value.map(|v| u16::try_from(v)).transpose()?, - )), - LogicalType::UInteger => Ok(DataValue::UInt32( - value.map(|v| u32::try_from(v)).transpose()?, - )), - LogicalType::UBigint => Ok(DataValue::UInt64( - value.map(|v| u64::try_from(v)).transpose()?, - )), + LogicalType::UTinyint => Ok(DataValue::UInt8(value.map(u8::try_from).transpose()?)), + LogicalType::USmallint => { + Ok(DataValue::UInt16(value.map(u16::try_from).transpose()?)) + } + LogicalType::UInteger => { + Ok(DataValue::UInt32(value.map(u32::try_from).transpose()?)) + } + LogicalType::UBigint => { + Ok(DataValue::UInt64(value.map(u64::try_from).transpose()?)) + } LogicalType::Smallint => Ok(DataValue::Int16(value.map(|v| v.into()))), LogicalType::Integer => Ok(DataValue::Int32(value.map(|v| v.into()))), LogicalType::Bigint => Ok(DataValue::Int64(value.map(|v| v.into()))), @@ -678,19 +673,17 @@ impl DataValue { }, DataValue::Int16(value) => match to { LogicalType::SqlNull => Ok(DataValue::Null), - LogicalType::UTinyint => Ok(DataValue::UInt8( - value.map(|v| u8::try_from(v)).transpose()?, - )), - LogicalType::USmallint => Ok(DataValue::UInt16( - value.map(|v| u16::try_from(v)).transpose()?, - )), - LogicalType::UInteger => Ok(DataValue::UInt32( - value.map(|v| u32::try_from(v)).transpose()?, - )), - LogicalType::UBigint => Ok(DataValue::UInt64( - value.map(|v| u64::try_from(v)).transpose()?, - )), - LogicalType::Smallint => Ok(DataValue::Int16(value.map(|v| v.into()))), + LogicalType::UTinyint => Ok(DataValue::UInt8(value.map(u8::try_from).transpose()?)), + LogicalType::USmallint => { + Ok(DataValue::UInt16(value.map(u16::try_from).transpose()?)) + } + LogicalType::UInteger => { + Ok(DataValue::UInt32(value.map(u32::try_from).transpose()?)) + } + LogicalType::UBigint => { + Ok(DataValue::UInt64(value.map(u64::try_from).transpose()?)) + } + LogicalType::Smallint => Ok(DataValue::Int16(value)), LogicalType::Integer => Ok(DataValue::Int32(value.map(|v| v.into()))), LogicalType::Bigint => Ok(DataValue::Int64(value.map(|v| v.into()))), LogicalType::Float => Ok(DataValue::Float32(value.map(|v| v.into()))), @@ -706,19 +699,17 @@ impl DataValue { }, DataValue::Int32(value) => match to { LogicalType::SqlNull => Ok(DataValue::Null), - LogicalType::UTinyint => Ok(DataValue::UInt8( - value.map(|v| u8::try_from(v)).transpose()?, - )), - LogicalType::USmallint => Ok(DataValue::UInt16( - value.map(|v| u16::try_from(v)).transpose()?, - )), - LogicalType::UInteger => Ok(DataValue::UInt32( - value.map(|v| u32::try_from(v)).transpose()?, - )), - LogicalType::UBigint => Ok(DataValue::UInt64( - value.map(|v| u64::try_from(v)).transpose()?, - )), - LogicalType::Integer => Ok(DataValue::Int32(value.map(|v| v.into()))), + LogicalType::UTinyint => Ok(DataValue::UInt8(value.map(u8::try_from).transpose()?)), + LogicalType::USmallint => { + Ok(DataValue::UInt16(value.map(u16::try_from).transpose()?)) + } + LogicalType::UInteger => { + Ok(DataValue::UInt32(value.map(u32::try_from).transpose()?)) + } + LogicalType::UBigint => { + Ok(DataValue::UInt64(value.map(u64::try_from).transpose()?)) + } + LogicalType::Integer => Ok(DataValue::Int32(value)), LogicalType::Bigint => Ok(DataValue::Int64(value.map(|v| v.into()))), LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))), LogicalType::Varchar(len) => varchar_cast!(value, len), @@ -732,19 +723,17 @@ impl DataValue { }, DataValue::Int64(value) => match to { LogicalType::SqlNull => Ok(DataValue::Null), - LogicalType::UTinyint => Ok(DataValue::UInt8( - value.map(|v| u8::try_from(v)).transpose()?, - )), - LogicalType::USmallint => Ok(DataValue::UInt16( - value.map(|v| u16::try_from(v)).transpose()?, - )), - LogicalType::UInteger => Ok(DataValue::UInt32( - value.map(|v| u32::try_from(v)).transpose()?, - )), - LogicalType::UBigint => Ok(DataValue::UInt64( - value.map(|v| u64::try_from(v)).transpose()?, - )), - LogicalType::Bigint => Ok(DataValue::Int64(value.map(|v| v.into()))), + LogicalType::UTinyint => Ok(DataValue::UInt8(value.map(u8::try_from).transpose()?)), + LogicalType::USmallint => { + Ok(DataValue::UInt16(value.map(u16::try_from).transpose()?)) + } + LogicalType::UInteger => { + Ok(DataValue::UInt32(value.map(u32::try_from).transpose()?)) + } + LogicalType::UBigint => { + Ok(DataValue::UInt64(value.map(u64::try_from).transpose()?)) + } + LogicalType::Bigint => Ok(DataValue::Int64(value)), LogicalType::Varchar(len) => varchar_cast!(value, len), LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| { let mut decimal = Decimal::from(v); @@ -776,7 +765,7 @@ impl DataValue { }, DataValue::UInt16(value) => match to { LogicalType::SqlNull => Ok(DataValue::Null), - LogicalType::USmallint => Ok(DataValue::UInt16(value.map(|v| v.into()))), + LogicalType::USmallint => Ok(DataValue::UInt16(value)), LogicalType::Integer => Ok(DataValue::Int32(value.map(|v| v.into()))), LogicalType::UInteger => Ok(DataValue::UInt32(value.map(|v| v.into()))), LogicalType::Bigint => Ok(DataValue::Int64(value.map(|v| v.into()))), @@ -794,7 +783,7 @@ impl DataValue { }, DataValue::UInt32(value) => match to { LogicalType::SqlNull => Ok(DataValue::Null), - LogicalType::UInteger => Ok(DataValue::UInt32(value.map(|v| v.into()))), + LogicalType::UInteger => Ok(DataValue::UInt32(value)), LogicalType::Bigint => Ok(DataValue::Int64(value.map(|v| v.into()))), LogicalType::UBigint => Ok(DataValue::UInt64(value.map(|v| v.into()))), LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))), @@ -809,7 +798,7 @@ impl DataValue { }, DataValue::UInt64(value) => match to { LogicalType::SqlNull => Ok(DataValue::Null), - LogicalType::UBigint => Ok(DataValue::UInt64(value.map(|v| v.into()))), + LogicalType::UBigint => Ok(DataValue::UInt64(value)), LogicalType::Varchar(len) => varchar_cast!(value, len), LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| { let mut decimal = Decimal::from(v); @@ -1032,13 +1021,9 @@ impl fmt::Display for DataValue { DataValue::UInt64(e) => format_option!(f, e)?, DataValue::Utf8(e) => format_option!(f, e)?, DataValue::Null => write!(f, "null")?, - DataValue::Date32(e) => format_option!(f, e.and_then(|s| DataValue::date_format(s)))?, - DataValue::Date64(e) => { - format_option!(f, e.and_then(|s| DataValue::date_time_format(s)))? - } - DataValue::Decimal(e) => { - format_option!(f, e.as_ref().map(|s| DataValue::decimal_format(s)))? - } + DataValue::Date32(e) => format_option!(f, e.and_then(DataValue::date_format))?, + DataValue::Date64(e) => format_option!(f, e.and_then(DataValue::date_time_format))?, + DataValue::Decimal(e) => format_option!(f, e.as_ref().map(DataValue::decimal_format))?, }; Ok(()) } diff --git a/tests/slt/alter_table.slt b/tests/slt/alter_table.slt new file mode 100644 index 00000000..ded8ff9c --- /dev/null +++ b/tests/slt/alter_table.slt @@ -0,0 +1,19 @@ +statement ok +create table alter_table(id int primary key, v1 int) + +statement ok +insert into alter_table values (1,1), (2,2), (3,3), (4,4) + +statement ok +alter table alter_table add column da int null + +query IIII rowsort +select * from alter_table +---- +1 1 null +2 2 null +3 3 null +4 4 null + +statement ok +drop table alter_table diff --git a/tests/slt/basic_test.slt b/tests/slt/basic_test.slt index e7473788..63d18f1c 100644 --- a/tests/slt/basic_test.slt +++ b/tests/slt/basic_test.slt @@ -80,5 +80,11 @@ select * from t 0 text1 1 text2 +statement error +select CAST(name AS BIGINT) from t + +statement ok +select CAST(id AS VARCHAR) from t + statement ok drop table t \ No newline at end of file diff --git a/tests/slt/create.slt b/tests/slt/create.slt index 741acc10..22707f63 100644 --- a/tests/slt/create.slt +++ b/tests/slt/create.slt @@ -1,2 +1,11 @@ statement ok -create table t(id int primary key, v1 int, v2 int, v3 int) \ No newline at end of file +create table t(id int primary key, v1 int, v2 int, v3 int) + +statement error +create table t(id int primary key, v1 int, v2 int, v3 int) + +statement ok +create table if not exists t(id int primary key, v1 int, v2 int, v3 int) + +statement ok +create table if not exists t(id int primary key, v1 int, v2 int, v3 int) \ No newline at end of file diff --git a/tests/slt/insert.slt b/tests/slt/insert.slt index accbc7ab..5abf9de2 100644 --- a/tests/slt/insert.slt +++ b/tests/slt/insert.slt @@ -19,7 +19,7 @@ insert into t(id, v2, v1) values (7,1,10) statement ok insert into t values (8,NULL,NULL,NULL) -query III rowsort +query IIII rowsort select * from t ---- 0 1 10 100 @@ -30,4 +30,23 @@ select * from t 5 1 10 100 6 1 10 null 7 10 1 null -8 null null null \ No newline at end of file +8 null null null + +statement ok +create table t1(id int primary key, v1 bigint default 233) + +statement ok +insert into t1 values (0) + +statement ok +insert into t1 values (1) + +statement ok +insert into t1 values (2) + +query III rowsort +select * from t1 +---- +0 233 +1 233 +2 233 \ No newline at end of file