diff --git a/src/binder/aggregate.rs b/src/binder/aggregate.rs index 303e4ca3..36e51f77 100644 --- a/src/binder/aggregate.rs +++ b/src/binder/aggregate.rs @@ -30,7 +30,7 @@ impl Binder { select_items: &mut [ScalarExpression], ) -> Result<(), BindError> { for column in select_items { - self.visit_column_agg_expr(column); + self.visit_column_agg_expr(column, true)?; } Ok(()) } @@ -57,7 +57,8 @@ impl Binder { // Extract having expression. let return_having = if let Some(having) = having { let mut having = self.bind_expr(having).await?; - self.visit_column_agg_expr(&mut having); + self.visit_column_agg_expr(&mut having, false)?; + Some(having) } else { None @@ -73,11 +74,11 @@ impl Binder { nulls_first, } = orderby; let mut expr = self.bind_expr(expr).await?; - self.visit_column_agg_expr(&mut expr); + self.visit_column_agg_expr(&mut expr, false)?; return_orderby.push(SortField::new( expr, - asc.map_or(true, |asc| !asc), + asc.map_or(true, |asc| asc), nulls_first.map_or(false, |first| first), )); } @@ -88,50 +89,67 @@ impl Binder { Ok((return_having, return_orderby)) } - fn visit_column_agg_expr(&mut self, expr: &mut ScalarExpression) { + fn visit_column_agg_expr(&mut self, expr: &mut ScalarExpression, is_select: bool) -> Result<(), BindError> { match expr { ScalarExpression::AggCall { ty: return_type, .. } => { - let index = self.context.input_ref_index(InputRefType::AggCall); - let input_ref = ScalarExpression::InputRef { - index, - ty: return_type.clone(), - }; - match std::mem::replace(expr, input_ref) { - ScalarExpression::AggCall { - kind, - args, + let ty = return_type.clone(); + if is_select { + let index = self.context.input_ref_index(InputRefType::AggCall); + let input_ref = ScalarExpression::InputRef { + index, ty, - distinct - } => { - self.context.agg_calls.push(ScalarExpression::AggCall { - distinct, + }; + match std::mem::replace(expr, input_ref) { + ScalarExpression::AggCall { kind, args, ty, - }); + distinct + } => { + self.context.agg_calls.push(ScalarExpression::AggCall { + distinct, + kind, + args, + ty, + }); + } + _ => unreachable!(), } - _ => unreachable!(), + } else { + let (index, _) = self + .context + .agg_calls + .iter() + .find_position(|agg_expr| agg_expr == &expr) + .ok_or_else(|| BindError::AggMiss(format!("{:?}", expr)))?; + + let _ = std::mem::replace(expr, ScalarExpression::InputRef { + index, + ty, + }); } } - ScalarExpression::TypeCast { expr, .. } => self.visit_column_agg_expr(expr), - ScalarExpression::IsNull { expr } => self.visit_column_agg_expr(expr), - ScalarExpression::Unary { expr, .. } => self.visit_column_agg_expr(expr), - ScalarExpression::Alias { expr, .. } => self.visit_column_agg_expr(expr), + ScalarExpression::TypeCast { expr, .. } => self.visit_column_agg_expr(expr, is_select)?, + ScalarExpression::IsNull { expr } => self.visit_column_agg_expr(expr, is_select)?, + ScalarExpression::Unary { expr, .. } => self.visit_column_agg_expr(expr, is_select)?, + ScalarExpression::Alias { expr, .. } => self.visit_column_agg_expr(expr, is_select)?, ScalarExpression::Binary { left_expr, right_expr, .. } => { - self.visit_column_agg_expr(left_expr); - self.visit_column_agg_expr(right_expr); + self.visit_column_agg_expr(left_expr, is_select)?; + self.visit_column_agg_expr(right_expr, is_select)?; } ScalarExpression::Constant(_) | ScalarExpression::ColumnRef { .. } | ScalarExpression::InputRef { .. } => {} } + + Ok(()) } /// Validate select exprs must appear in the GROUP BY clause or be used in @@ -173,6 +191,7 @@ impl Binder { if expr.has_agg_call(&self.context) { continue; } + group_raw_set.remove(expr); if !group_raw_exprs.iter().contains(expr) { @@ -271,6 +290,9 @@ impl Binder { if self.context.group_by_exprs.contains(expr) { return Ok(()); } + if matches!(expr, ScalarExpression::Alias { .. }) { + return self.validate_having_orderby(expr.unpack_alias()); + } Err(BindError::AggMiss( format!( diff --git a/src/binder/create_table.rs b/src/binder/create_table.rs index 315b32f0..6d3bb383 100644 --- a/src/binder/create_table.rs +++ b/src/binder/create_table.rs @@ -1,7 +1,7 @@ use std::collections::HashSet; use std::sync::Arc; use itertools::Itertools; -use sqlparser::ast::{ColumnDef, ObjectName}; +use sqlparser::ast::{ColumnDef, ObjectName, TableConstraint}; use super::Binder; use crate::binder::{BindError, lower_case_name, split_name}; @@ -12,10 +12,12 @@ use crate::planner::operator::Operator; use crate::storage::Storage; impl Binder { + // TODO: TableConstraint pub(crate) fn bind_create_table( &mut self, name: &ObjectName, columns: &[ColumnDef], + constraints: &[TableConstraint] ) -> Result { let name = lower_case_name(&name); let (_, name) = split_name(&name)?; diff --git a/src/binder/expr.rs b/src/binder/expr.rs index d84c7c57..441eea74 100644 --- a/src/binder/expr.rs +++ b/src/binder/expr.rs @@ -85,7 +85,7 @@ impl Binder { } if got_column.is_none() { if let Some(expr) = self.context.aliases.get(column_name) { - return Ok(expr.clone()); + return Ok(ScalarExpression::Alias { expr: Box::new(expr.clone()), alias: column_name.clone() }); } } let column_catalog = @@ -167,7 +167,7 @@ impl Binder { distinct: func.distinct, kind: AggKind::Count, args, - ty: LogicalType::UInteger, + ty: LogicalType::Integer, }, "sum" => ScalarExpression::AggCall{ distinct: func.distinct, diff --git a/src/binder/mod.rs b/src/binder/mod.rs index 6ac8e187..fd2f6a87 100644 --- a/src/binder/mod.rs +++ b/src/binder/mod.rs @@ -82,7 +82,9 @@ impl Binder { pub async fn bind(mut self, stmt: &Statement) -> Result { let plan = match stmt { Statement::Query(query) => self.bind_query(query).await?, - Statement::CreateTable { name, columns, .. } => self.bind_create_table(name, &columns)?, + Statement::CreateTable { name, columns, constraints, .. } => { + self.bind_create_table(name, &columns, &constraints)? + }, Statement::Drop { object_type, names, .. } => { match object_type { ObjectType::Table => { @@ -168,9 +170,9 @@ pub enum BindError { SubqueryMustHaveAlias, #[error("agg miss: {0}")] AggMiss(String), - #[error("catalog error")] + #[error("catalog error: {0}")] CatalogError(#[from] CatalogError), - #[error("type error")] + #[error("type error: {0}")] TypeError(#[from] TypeError) } @@ -193,16 +195,16 @@ pub mod test { let _ = storage.create_table( Arc::new("t1".to_string()), vec![ - ColumnCatalog::new("c1".to_string(), false, ColumnDesc::new(Integer, true, false)), - ColumnCatalog::new("c2".to_string(), false, ColumnDesc::new(Integer, false, true)), + ColumnCatalog::new("c1".to_string(), false, ColumnDesc::new(Integer, true, false), None), + ColumnCatalog::new("c2".to_string(), false, ColumnDesc::new(Integer, false, true), None), ] ).await?; let _ = storage.create_table( Arc::new("t2".to_string()), vec![ - ColumnCatalog::new("c3".to_string(), false, ColumnDesc::new(Integer, true, false)), - ColumnCatalog::new("c4".to_string(), false, ColumnDesc::new(Integer, false, false)), + ColumnCatalog::new("c3".to_string(), false, ColumnDesc::new(Integer, true, false), None), + ColumnCatalog::new("c4".to_string(), false, ColumnDesc::new(Integer, false, false), None), ] ).await?; diff --git a/src/catalog/column.rs b/src/catalog/column.rs index 4dcea4be..95adc35d 100644 --- a/src/catalog/column.rs +++ b/src/catalog/column.rs @@ -2,6 +2,7 @@ use std::sync::Arc; use serde::{Deserialize, Serialize}; use sqlparser::ast::{ColumnDef, ColumnOption}; use crate::catalog::TableName; +use crate::expression::ScalarExpression; use crate::types::{ColumnId, LogicalType}; @@ -14,16 +15,23 @@ pub struct ColumnCatalog { pub table_name: Option, pub nullable: bool, pub desc: ColumnDesc, + pub ref_expr: Option, } impl ColumnCatalog { - pub(crate) fn new(column_name: String, nullable: bool, column_desc: ColumnDesc) -> ColumnCatalog { + pub(crate) fn new( + column_name: String, + nullable: bool, + column_desc: ColumnDesc, + ref_expr: Option + ) -> ColumnCatalog { ColumnCatalog { id: None, name: column_name, table_name: None, nullable, desc: column_desc, + ref_expr, } } @@ -34,6 +42,7 @@ impl ColumnCatalog { table_name: None, nullable: false, desc: ColumnDesc::new(LogicalType::Varchar(None), false, false), + ref_expr: None, } } @@ -75,7 +84,7 @@ impl From for ColumnCatalog { } } - ColumnCatalog::new(column_name, nullable, column_desc) + ColumnCatalog::new(column_name, nullable, column_desc, None) } } diff --git a/src/catalog/root.rs b/src/catalog/root.rs index b047b0dd..d1738591 100644 --- a/src/catalog/root.rs +++ b/src/catalog/root.rs @@ -68,11 +68,13 @@ mod tests { "a".to_string(), false, ColumnDesc::new(LogicalType::Integer, false, false), + None ); let col1 = ColumnCatalog::new( "b".to_string(), false, ColumnDesc::new(LogicalType::Boolean, false, false), + None ); let col_catalogs = vec![col0, col1]; diff --git a/src/catalog/table.rs b/src/catalog/table.rs index 76a19a24..391496f6 100644 --- a/src/catalog/table.rs +++ b/src/catalog/table.rs @@ -23,10 +23,12 @@ impl TableCatalog { .find(|meta| meta.is_unique && &meta.column_ids[0] == col_id) } + #[allow(dead_code)] pub(crate) fn get_column_by_id(&self, id: &ColumnId) -> Option<&ColumnRef> { self.columns.get(id) } + #[allow(dead_code)] pub(crate) fn get_column_id_by_name(&self, name: &String) -> Option { self.column_idxs.get(name).cloned() } @@ -123,8 +125,8 @@ mod tests { // | 1 | true | // | 2 | false | fn test_table_catalog() { - let col0 = ColumnCatalog::new("a".into(), false, ColumnDesc::new(LogicalType::Integer, false, false)); - let col1 = ColumnCatalog::new("b".into(), false, ColumnDesc::new(LogicalType::Boolean, false, false)); + let col0 = ColumnCatalog::new("a".into(), false, ColumnDesc::new(LogicalType::Integer, false, false), None); + let col1 = ColumnCatalog::new("b".into(), false, ColumnDesc::new(LogicalType::Boolean, false, false), None); let col_catalogs = vec![col0, col1]; let table_catalog = TableCatalog::new(Arc::new("test".to_string()), col_catalogs).unwrap(); diff --git a/src/db.rs b/src/db.rs index ae68e98d..e31e2660 100644 --- a/src/db.rs +++ b/src/db.rs @@ -168,12 +168,14 @@ mod test { ColumnCatalog::new( "c1".to_string(), false, - ColumnDesc::new(LogicalType::Integer, true, false) + ColumnDesc::new(LogicalType::Integer, true, false), + None ), ColumnCatalog::new( "c2".to_string(), false, - ColumnDesc::new(LogicalType::Boolean, false, false) + ColumnDesc::new(LogicalType::Boolean, false, false), + None ), ]; diff --git a/src/execution/executor/dql/aggregate/count.rs b/src/execution/executor/dql/aggregate/count.rs index 2268b221..22f2f766 100644 --- a/src/execution/executor/dql/aggregate/count.rs +++ b/src/execution/executor/dql/aggregate/count.rs @@ -6,7 +6,7 @@ use crate::execution::ExecutorError; use crate::types::value::{DataValue, ValueRef}; pub struct CountAccumulator { - result: u32, + result: i32, } impl CountAccumulator { @@ -25,7 +25,7 @@ impl Accumulator for CountAccumulator { } fn evaluate(&self) -> Result { - Ok(Arc::new(DataValue::UInt32(Some(self.result)))) + Ok(Arc::new(DataValue::Int32(Some(self.result)))) } } @@ -51,6 +51,6 @@ impl Accumulator for DistinctCountAccumulator { } fn evaluate(&self) -> Result { - Ok(Arc::new(DataValue::UInt32(Some(self.distinct_values.len() as u32)))) + Ok(Arc::new(DataValue::Int32(Some(self.distinct_values.len() as i32)))) } } diff --git a/src/execution/executor/dql/aggregate/hash_agg.rs b/src/execution/executor/dql/aggregate/hash_agg.rs index bc42fa3c..08348182 100644 --- a/src/execution/executor/dql/aggregate/hash_agg.rs +++ b/src/execution/executor/dql/aggregate/hash_agg.rs @@ -125,9 +125,9 @@ mod test { let desc = ColumnDesc::new(LogicalType::Integer, false, false); let t1_columns = vec![ - Arc::new(ColumnCatalog::new("c1".to_string(), true, desc.clone())), - Arc::new(ColumnCatalog::new("c2".to_string(), true, desc.clone())), - Arc::new(ColumnCatalog::new("c3".to_string(), true, desc.clone())), + Arc::new(ColumnCatalog::new("c1".to_string(), true, desc.clone(), None)), + Arc::new(ColumnCatalog::new("c2".to_string(), true, desc.clone(), None)), + Arc::new(ColumnCatalog::new("c3".to_string(), true, desc.clone(), None)), ]; let operator = AggregateOperator { diff --git a/src/execution/executor/dql/aggregate/sum.rs b/src/execution/executor/dql/aggregate/sum.rs index c6217d7a..cca28c70 100644 --- a/src/execution/executor/dql/aggregate/sum.rs +++ b/src/execution/executor/dql/aggregate/sum.rs @@ -23,7 +23,7 @@ impl SumAccumulator { impl Accumulator for SumAccumulator { fn update_value(&mut self, value: &ValueRef) -> Result<(), ExecutorError> { if !value.is_null() { - self.result = binary_op( + self.result = binary_op( &self.result, value, &BinaryOperator::Plus diff --git a/src/execution/executor/dql/filter.rs b/src/execution/executor/dql/filter.rs index dbeb731c..a9ada4fa 100644 --- a/src/execution/executor/dql/filter.rs +++ b/src/execution/executor/dql/filter.rs @@ -36,7 +36,7 @@ impl Filter { for tuple in input { let tuple = tuple?; if let DataValue::Boolean(option) = predicate.eval_column(&tuple)?.as_ref() { - if let Some(true) = option{ + if let Some(true) = option { yield tuple; } else { continue diff --git a/src/execution/executor/dql/join/hash_join.rs b/src/execution/executor/dql/join/hash_join.rs index e5b401c6..0d3868e4 100644 --- a/src/execution/executor/dql/join/hash_join.rs +++ b/src/execution/executor/dql/join/hash_join.rs @@ -234,15 +234,15 @@ mod test { let desc = ColumnDesc::new(LogicalType::Integer, false, false); let t1_columns = vec![ - Arc::new(ColumnCatalog::new("c1".to_string(), true, desc.clone())), - Arc::new(ColumnCatalog::new("c2".to_string(), true, desc.clone())), - Arc::new(ColumnCatalog::new("c3".to_string(), true, desc.clone())), + Arc::new(ColumnCatalog::new("c1".to_string(), true, desc.clone(), None)), + Arc::new(ColumnCatalog::new("c2".to_string(), true, desc.clone(), None)), + Arc::new(ColumnCatalog::new("c3".to_string(), true, desc.clone(), None)), ]; let t2_columns = vec![ - Arc::new(ColumnCatalog::new("c4".to_string(), true, desc.clone())), - Arc::new(ColumnCatalog::new("c5".to_string(), true, desc.clone())), - Arc::new(ColumnCatalog::new("c6".to_string(), true, desc.clone())), + Arc::new(ColumnCatalog::new("c4".to_string(), true, desc.clone(), None)), + Arc::new(ColumnCatalog::new("c5".to_string(), true, desc.clone(), None)), + Arc::new(ColumnCatalog::new("c6".to_string(), true, desc.clone(), None)), ]; let on_keys = vec![ diff --git a/src/execution/executor/dql/sort.rs b/src/execution/executor/dql/sort.rs index 3d1b64c8..22efcd2c 100644 --- a/src/execution/executor/dql/sort.rs +++ b/src/execution/executor/dql/sort.rs @@ -42,7 +42,7 @@ impl Sort { tuples.sort_by(|tuple_1, tuple_2| { let mut ordering = Ordering::Equal; - for SortField { expr, desc, nulls_first } in &sort_fields { + for SortField { expr, asc, nulls_first } in &sort_fields { let value_1 = expr.eval_column(tuple_1).unwrap(); let value_2 = expr.eval_column(tuple_2).unwrap(); @@ -53,7 +53,7 @@ impl Sort { _ => Ordering::Equal, }); - if *desc { + if !*asc { ordering = ordering.reverse(); } diff --git a/src/expression/agg.rs b/src/expression/agg.rs index 1a78ab46..c2cdd39a 100644 --- a/src/expression/agg.rs +++ b/src/expression/agg.rs @@ -1,4 +1,6 @@ -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] pub enum AggKind { Avg, Max, diff --git a/src/expression/evaluator.rs b/src/expression/evaluator.rs index 51aa2346..457f7982 100644 --- a/src/expression/evaluator.rs +++ b/src/expression/evaluator.rs @@ -1,26 +1,37 @@ use std::sync::Arc; use itertools::Itertools; +use lazy_static::lazy_static; use crate::expression::value_compute::{binary_op, unary_op}; use crate::expression::ScalarExpression; use crate::types::errors::TypeError; use crate::types::tuple::Tuple; use crate::types::value::{DataValue, ValueRef}; +lazy_static! { + static ref NULL_VALUE: ValueRef = { + Arc::new(DataValue::Null) + }; +} + impl ScalarExpression { pub fn eval_column(&self, tuple: &Tuple) -> Result { match &self { ScalarExpression::Constant(val) => Ok(val.clone()), ScalarExpression::ColumnRef(col) => { - let (index, _) = tuple - .columns - .iter() - .find_position(|tul_col| tul_col.name == col.name) - .unwrap(); + let value = Self::eval_with_name(&tuple, &col.name) + .unwrap_or(&NULL_VALUE) + .clone(); - Ok(tuple.values[index].clone()) + Ok(value) }, ScalarExpression::InputRef{ index, .. } => Ok(tuple.values[*index].clone()), - ScalarExpression::Alias{ expr, .. } => expr.eval_column(tuple), + ScalarExpression::Alias{ expr, alias } => { + if let Some(value) = Self::eval_with_name(&tuple, alias) { + return Ok(value.clone()); + } + + expr.eval_column(tuple) + }, ScalarExpression::TypeCast{ expr, ty, .. } => { let value = expr.eval_column(tuple)?; @@ -45,4 +56,12 @@ impl ScalarExpression { ScalarExpression::AggCall{ .. } => todo!() } } + + fn eval_with_name<'a>(tuple: &'a Tuple, name: &String) -> Option<&'a ValueRef> { + tuple + .columns + .iter() + .find_position(|tul_col| &tul_col.name == name) + .map(|(i, _)| &tuple.values[i]) + } } \ No newline at end of file diff --git a/src/expression/mod.rs b/src/expression/mod.rs index 35f55e5f..31daf75c 100644 --- a/src/expression/mod.rs +++ b/src/expression/mod.rs @@ -2,6 +2,7 @@ use std::fmt; use std::fmt::{Debug, Formatter}; use std::sync::Arc; use itertools::Itertools; +use serde::{Deserialize, Serialize}; use sqlparser::ast::{BinaryOperator as SqlBinaryOperator, UnaryOperator as SqlUnaryOperator}; use crate::binder::BinderContext; @@ -22,7 +23,7 @@ pub mod simplify; /// SELECT a+1, b FROM t1. /// a+1 -> ScalarExpression::Unary(a + 1) /// b -> ScalarExpression::ColumnRef() -#[derive(Debug, PartialEq, Eq, Clone, Hash)] +#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] pub enum ScalarExpression { Constant(ValueRef), ColumnRef(ColumnRef), @@ -173,14 +174,16 @@ impl ScalarExpression { Arc::new(ColumnCatalog::new( format!("{}", value), true, - ColumnDesc::new(value.logical_type(), false, false) + ColumnDesc::new(value.logical_type(), false, false), + Some(self.clone()) )) } ScalarExpression::Alias { expr, alias } => { Arc::new(ColumnCatalog::new( alias.to_string(), true, - ColumnDesc::new(expr.return_type(), false, false) + ColumnDesc::new(expr.return_type(), false, false), + Some(self.clone()) )) } ScalarExpression::AggCall { kind, args, ty, distinct } => { @@ -204,7 +207,8 @@ impl ScalarExpression { Arc::new(ColumnCatalog::new( column_name, true, - ColumnDesc::new(ty.clone(), false, false) + ColumnDesc::new(ty.clone(), false, false), + Some(self.clone()) )) } ScalarExpression::InputRef { index, .. } => { @@ -226,15 +230,33 @@ impl ScalarExpression { Arc::new(ColumnCatalog::new( column_name, true, - ColumnDesc::new(ty.clone(), false, false) + ColumnDesc::new(ty.clone(), false, false), + Some(self.clone()) )) } + ScalarExpression::Unary { + expr, + op, + ty + } => { + let column_name = format!( + "{} {}", + op, + expr.output_columns(tuple).name, + ); + Arc::new(ColumnCatalog::new( + column_name, + true, + ColumnDesc::new(ty.clone(), false, false), + Some(self.clone()) + )) + }, _ => unreachable!() } } } -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] pub enum UnaryOperator { Plus, Minus, @@ -252,7 +274,7 @@ impl From for UnaryOperator { } } -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] pub enum BinaryOperator { Plus, Minus, @@ -298,6 +320,16 @@ impl fmt::Display for BinaryOperator { } } +impl fmt::Display for UnaryOperator { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + match self { + UnaryOperator::Plus => write!(f, "+"), + UnaryOperator::Minus => write!(f, "-"), + UnaryOperator::Not => write!(f, "not") + } + } +} + impl From for BinaryOperator { fn from(value: SqlBinaryOperator) -> Self { match value { diff --git a/src/expression/simplify.rs b/src/expression/simplify.rs index 7bf2beab..36d1920f 100644 --- a/src/expression/simplify.rs +++ b/src/expression/simplify.rs @@ -233,11 +233,13 @@ impl ConstantBinary { } } +#[derive(Debug)] enum Replace { Binary(ReplaceBinary), Unary(ReplaceUnary), } +#[derive(Debug)] struct ReplaceBinary { column_expr: ScalarExpression, val_expr: ScalarExpression, @@ -246,6 +248,7 @@ struct ReplaceBinary { is_column_left: bool } +#[derive(Debug)] struct ReplaceUnary { child_expr: ScalarExpression, op: UnaryOperator, @@ -300,35 +303,40 @@ impl ScalarExpression { } } - fn unpack_col(&self) -> Option { + fn unpack_col(&self, is_deep: bool) -> Option { match self { ScalarExpression::ColumnRef(col) => Some(col.clone()), - ScalarExpression::Alias { expr, .. } => expr.unpack_col(), - ScalarExpression::Unary { expr, .. } => expr.unpack_col(), + ScalarExpression::Alias { expr, .. } => expr.unpack_col(is_deep), + ScalarExpression::Unary { expr, .. } => expr.unpack_col(is_deep), + ScalarExpression::Binary { left_expr, right_expr, .. } => { + if !is_deep { + return None; + } + + left_expr.unpack_col(true) + .or_else(|| right_expr.unpack_col(true)) + } _ => None } } pub fn simplify(&mut self) -> Result<(), TypeError> { - self._simplify(&mut None) + self._simplify(&mut Vec::new()) } // Tips: Indirect expressions like `ScalarExpression::Alias` will be lost - fn _simplify(&mut self, fix_option: &mut Option) -> Result<(), TypeError> { + fn _simplify(&mut self, replaces: &mut Vec) -> Result<(), TypeError> { match self { ScalarExpression::Binary { left_expr, right_expr, op, ty } => { - Self::fix_expr(fix_option, left_expr, right_expr, op)?; + Self::fix_expr(replaces, left_expr, right_expr, op)?; // `(c1 - 1) and (c1 + 2)` cannot fix! - Self::fix_expr(fix_option, right_expr, left_expr, op)?; + Self::fix_expr(replaces, right_expr, left_expr, op)?; - if matches!(op, BinaryOperator::Plus | BinaryOperator::Divide - | BinaryOperator::Minus | BinaryOperator::Multiply) - { - match (left_expr.unpack_col(), right_expr.unpack_col()) { - (Some(_), Some(_)) => (), + if Self::is_arithmetic(op) { + match (left_expr.unpack_col(false), right_expr.unpack_col(false)) { (Some(col), None) => { - fix_option.replace(Replace::Binary(ReplaceBinary{ + replaces.push(Replace::Binary(ReplaceBinary{ column_expr: ScalarExpression::ColumnRef(col), val_expr: right_expr.as_ref().clone(), op: *op, @@ -337,7 +345,7 @@ impl ScalarExpression { })); } (None, Some(col)) => { - fix_option.replace(Replace::Binary(ReplaceBinary{ + replaces.push(Replace::Binary(ReplaceBinary{ column_expr: ScalarExpression::ColumnRef(col), val_expr: left_expr.as_ref().clone(), op: *op, @@ -345,11 +353,38 @@ impl ScalarExpression { is_column_left: false, })); } + (None, None) => { + if replaces.is_empty() { + return Ok(()); + } + + match (left_expr.unpack_col(true), right_expr.unpack_col(true)) { + (Some(col), None) => { + replaces.push(Replace::Binary(ReplaceBinary{ + column_expr: ScalarExpression::ColumnRef(col), + val_expr: right_expr.as_ref().clone(), + op: *op, + ty: *ty, + is_column_left: true, + })); + } + (None, Some(col)) => { + replaces.push(Replace::Binary(ReplaceBinary{ + column_expr: ScalarExpression::ColumnRef(col), + val_expr: left_expr.as_ref().clone(), + op: *op, + ty: *ty, + is_column_left: false, + })); + } + _ => (), + } + } _ => () } } } - ScalarExpression::Alias { expr, .. } => expr._simplify(fix_option)?, + ScalarExpression::Alias { expr, .. } => expr._simplify(replaces)?, ScalarExpression::TypeCast { expr, .. } => { if let Some(val) = expr.unpack_val() { let _ = mem::replace(self, ScalarExpression::Constant(val)); @@ -369,7 +404,7 @@ impl ScalarExpression { ); let _ = mem::replace(self, new_expr); } else { - let _ = fix_option.replace(Replace::Unary( + let _ = replaces.push(Replace::Unary( ReplaceUnary { child_expr: expr.as_ref().clone(), op: *op, @@ -384,23 +419,37 @@ impl ScalarExpression { Ok(()) } + fn is_arithmetic(op: &mut BinaryOperator) -> bool { + matches!(op, BinaryOperator::Plus + | BinaryOperator::Divide + | BinaryOperator::Minus + | BinaryOperator::Multiply) + } + fn fix_expr( - fix_option: &mut Option, + replaces: &mut Vec, left_expr: &mut Box, right_expr: &mut Box, op: &mut BinaryOperator, ) -> Result<(), TypeError> { - left_expr._simplify(fix_option)?; + left_expr._simplify(replaces)?; - if let Some(replace) = fix_option.take() { + if Self::is_arithmetic(op) { + return Ok(()); + } + + while let Some(replace) = replaces.pop() { match replace { - Replace::Binary(binary) => Self::fix_binary(binary, left_expr, right_expr, op), + Replace::Binary(binary) => { + Self::fix_binary(binary, left_expr, right_expr, op) + }, Replace::Unary(unary) => { Self::fix_unary(unary, left_expr, right_expr, op); - Self::fix_expr(fix_option, left_expr, right_expr, op)?; + Self::fix_expr(replaces, left_expr, right_expr, op)?; }, } } + Ok(()) } @@ -541,12 +590,12 @@ impl ScalarExpression { }, (None, None) => { if let (Some(col), Some(val)) = - (left_expr.unpack_col(), right_expr.unpack_val()) + (left_expr.unpack_col(false), right_expr.unpack_val()) { return Ok(Self::new_binary(col_id, *op, col, val, false)); } if let (Some(val), Some(col)) = - (left_expr.unpack_val(), right_expr.unpack_col()) + (left_expr.unpack_val(), right_expr.unpack_col(false)) { return Ok(Self::new_binary(col_id, *op, col, val, true)); } @@ -654,6 +703,7 @@ mod test { is_primary: false, is_unique: false, }, + ref_expr: None, }); let val_1 = Arc::new(DataValue::Int32(Some(1))); diff --git a/src/optimizer/core/opt_expr.rs b/src/optimizer/core/opt_expr.rs index ed6cbea2..4f0fba78 100644 --- a/src/optimizer/core/opt_expr.rs +++ b/src/optimizer/core/opt_expr.rs @@ -8,6 +8,7 @@ pub type OptExprNodeId = usize; pub enum OptExprNode { /// Raw plan node with dummy children. OperatorRef(Operator), + #[allow(dead_code)] /// Existing OptExprNode in graph. OptExpr(OptExprNodeId), } @@ -31,14 +32,17 @@ pub struct OptExpr { impl OptExpr { + #[allow(dead_code)] pub fn new(root: OptExprNode, childrens: Vec) -> Self { Self { root, childrens } } + #[allow(dead_code)] pub fn new_from_op_ref(plan: &LogicalPlan) -> Self { OptExpr::build_opt_expr_internal(plan) } + #[allow(dead_code)] fn build_opt_expr_internal(input: &LogicalPlan) -> OptExpr { let root = OptExprNode::OperatorRef(input.operator.clone()); let childrens = input @@ -49,6 +53,7 @@ impl OptExpr { OptExpr { root, childrens } } + #[allow(dead_code)] pub fn to_plan_ref(&self) -> LogicalPlan { match &self.root { OptExprNode::OperatorRef(op) => { diff --git a/src/optimizer/heuristic/batch.rs b/src/optimizer/heuristic/batch.rs index 8d02e2af..23cad49c 100644 --- a/src/optimizer/heuristic/batch.rs +++ b/src/optimizer/heuristic/batch.rs @@ -30,6 +30,7 @@ pub struct HepBatchStrategy { } impl HepBatchStrategy { + #[allow(dead_code)] pub fn once_topdown() -> Self { HepBatchStrategy { max_iteration: 1, @@ -52,5 +53,6 @@ pub enum HepMatchOrder { TopDown, /// Match from leaves up. A match attempt at a descendant precedes all match attempts at its /// ancestors. + #[allow(dead_code)] BottomUp, } \ No newline at end of file diff --git a/src/optimizer/heuristic/graph.rs b/src/optimizer/heuristic/graph.rs index 135e3837..7d09e7a3 100644 --- a/src/optimizer/heuristic/graph.rs +++ b/src/optimizer/heuristic/graph.rs @@ -53,6 +53,7 @@ impl HepGraph { .next() } + #[allow(dead_code)] pub fn add_root(&mut self, new_node: OptExprNode) { let old_root_id = mem::replace( &mut self.root_index, @@ -145,6 +146,7 @@ impl HepGraph { } } + #[allow(dead_code)] pub fn node(&self, node_id: HepNodeId) -> Option<&OptExprNode> { self.graph.node_weight(node_id) } diff --git a/src/optimizer/rule/column_pruning.rs b/src/optimizer/rule/column_pruning.rs index ca404ca5..df9328eb 100644 --- a/src/optimizer/rule/column_pruning.rs +++ b/src/optimizer/rule/column_pruning.rs @@ -97,8 +97,12 @@ impl Rule for PushProjectThroughChild { .project_input_refs() .iter() .filter_map(|expr| { + if agg_calls.is_empty() { + return None; + } + if let ScalarExpression::InputRef { index, .. } = expr { - Some(agg_calls[*index].clone()) + agg_calls.get(*index).cloned() } else { None } @@ -142,9 +146,13 @@ impl Rule for PushProjectThroughChild { } } _ => { - let grandson_id = graph.children_at(child_index)[0]; - let mut columns = node_operator.project_input_refs(); + let grandson_ids = graph.children_at(child_index); + if grandson_ids.is_empty() { + return Ok(()) + } + let grandson_id = grandson_ids[0]; + let mut columns = node_operator.project_input_refs(); let mut referenced_columns = node_referenced_columns .into_iter() .chain(child_referenced_columns.into_iter()) diff --git a/src/optimizer/rule/simplification.rs b/src/optimizer/rule/simplification.rs index 796668b7..12267d31 100644 --- a/src/optimizer/rule/simplification.rs +++ b/src/optimizer/rule/simplification.rs @@ -9,7 +9,10 @@ lazy_static! { static ref SIMPLIFY_FILTER_RULE: Pattern = { Pattern { predicate: |op| matches!(op, Operator::Filter(_)), - children: PatternChildrenPredicate::None, + children: PatternChildrenPredicate::Predicate(vec![Pattern { + predicate: |op| !matches!(op, Operator::Aggregate(_)), + children: PatternChildrenPredicate::Recursive, + }]), } }; } @@ -43,7 +46,7 @@ mod test { use crate::binder::test::select_sql_run; use crate::catalog::{ColumnCatalog, ColumnDesc}; use crate::db::DatabaseError; - use crate::expression::{BinaryOperator, ScalarExpression}; + use crate::expression::{BinaryOperator, ScalarExpression, UnaryOperator}; use crate::expression::simplify::ConstantBinary; use crate::optimizer::heuristic::batch::HepBatchStrategy; use crate::optimizer::heuristic::optimizer::HepOptimizer; @@ -74,6 +77,12 @@ mod test { // c1 > 0 let plan_8 = select_sql_run("select * from t1 where 1 < c1 + 1").await?; + // c1 < 24 + let plan_9 = select_sql_run("select * from t1 where (-1 - c1) + 1 > 24").await?; + + // c1 < 24 + let plan_10 = select_sql_run("select * from t1 where 24 < (-1 - c1) + 1").await?; + let op = |plan: LogicalPlan, expr: &str| -> Result, DatabaseError> { let best_plan = HepOptimizer::new(plan.clone()) .batch( @@ -91,10 +100,23 @@ mod test { } }; - assert_eq!(op(plan_1, "-(c1 + 1) > 1")?, op(plan_5, "1 < -(c1 + 1)")?); - assert_eq!(op(plan_2, "-(1 - c1) > 1")?, op(plan_6, "1 < -(1 - c1)")?); - assert_eq!(op(plan_3, "-c1 > 1")?, op(plan_7, "1 < -c1")?); - assert_eq!(op(plan_4, "c1 + 1 > 1")?, op(plan_8, "1 < c1 + 1")?); + let op_1 = op(plan_1, "-(c1 + 1) > 1")?; + let op_2 = op(plan_2, "-(1 - c1) > 1")?; + let op_3 = op(plan_3, "-c1 > 1")?; + let op_4 = op(plan_4, "c1 + 1 > 1")?; + let op_5 = op(plan_9, "(-1 - c1) + 1 > 24")?; + + assert!(op_1.is_some()); + assert!(op_2.is_some()); + assert!(op_3.is_some()); + assert!(op_4.is_some()); + assert!(op_5.is_some()); + + assert_eq!(op_1, op(plan_5, "1 < -(c1 + 1)")?); + assert_eq!(op_2, op(plan_6, "1 < -(1 - c1)")?); + assert_eq!(op_3, op(plan_7, "1 < -c1")?); + assert_eq!(op_4, op(plan_8, "1 < c1 + 1")?); + assert_eq!(op_5, op(plan_10, "24 < (-1 - c1) + 1")?); Ok(()) } @@ -125,6 +147,7 @@ mod test { is_primary: true, is_unique: false, }, + ref_expr: None, }; let c2_col = ColumnCatalog { id: Some( @@ -140,17 +163,22 @@ mod test { is_primary: false, is_unique: true, }, + ref_expr: None, }; // -(c1 + 1) > c2 => c1 < -c2 - 1 assert_eq!( filter_op.predicate, ScalarExpression::Binary { - op: BinaryOperator::Lt, - left_expr: Box::new(ScalarExpression::Binary { - op: BinaryOperator::Minus, - left_expr: Box::new(ScalarExpression::ColumnRef(Arc::new(c1_col))), - right_expr: Box::new(ScalarExpression::Constant(Arc::new(DataValue::Int32(Some(-1))))), + op: BinaryOperator::Gt, + left_expr: Box::new(ScalarExpression::Unary { + op: UnaryOperator::Minus, + expr: Box::new(ScalarExpression::Binary { + op: BinaryOperator::Plus, + left_expr: Box::new(ScalarExpression::ColumnRef(Arc::new(c1_col))), + right_expr: Box::new(ScalarExpression::Constant(Arc::new(DataValue::Int32(Some(1))))), + ty: LogicalType::Integer, + }), ty: LogicalType::Integer, }), right_expr: Box::new(ScalarExpression::ColumnRef(Arc::new(c2_col))), diff --git a/src/planner/operator/mod.rs b/src/planner/operator/mod.rs index 3e5560bb..92c47ab8 100644 --- a/src/planner/operator/mod.rs +++ b/src/planner/operator/mod.rs @@ -64,6 +64,10 @@ impl Operator { .iter() .map(ScalarExpression::unpack_alias) .filter(|expr| matches!(expr, ScalarExpression::InputRef { .. })) + .sorted_by_key(|expr| match expr { + ScalarExpression::InputRef { index, .. } => index, + _ => unreachable!() + }) .cloned() .collect_vec() } diff --git a/src/planner/operator/sort.rs b/src/planner/operator/sort.rs index 6d7bc2cd..0006ab20 100644 --- a/src/planner/operator/sort.rs +++ b/src/planner/operator/sort.rs @@ -3,15 +3,15 @@ use crate::expression::ScalarExpression; #[derive(Debug, PartialEq, Clone)] pub struct SortField { pub expr: ScalarExpression, - pub desc: bool, + pub asc: bool, pub nulls_first: bool, } impl SortField { - pub fn new(expr: ScalarExpression, desc: bool, nulls_first: bool) -> Self { + pub fn new(expr: ScalarExpression, asc: bool, nulls_first: bool) -> Self { SortField { expr, - desc, + asc, nulls_first, } } diff --git a/src/storage/kip.rs b/src/storage/kip.rs index bb08d65c..50d6008d 100644 --- a/src/storage/kip.rs +++ b/src/storage/kip.rs @@ -498,12 +498,14 @@ mod test { Arc::new(ColumnCatalog::new( "c1".to_string(), false, - ColumnDesc::new(LogicalType::Integer, true, false) + ColumnDesc::new(LogicalType::Integer, true, false), + None )), Arc::new(ColumnCatalog::new( "c2".to_string(), false, - ColumnDesc::new(LogicalType::Boolean, false, false) + ColumnDesc::new(LogicalType::Boolean, false, false), + None )), ]; diff --git a/src/storage/memory.rs b/src/storage/memory.rs index aaf241a8..76af970a 100644 --- a/src/storage/memory.rs +++ b/src/storage/memory.rs @@ -162,10 +162,12 @@ impl Transaction for MemTable { } } + #[allow(unused_variables)] fn read_by_index(&self, bounds: Bounds, projection: Projections, index_meta: IndexMetaRef, binaries: Vec) -> Result, StorageError> { todo!() } + #[allow(unused_variables)] fn add_index(&mut self, index: Index, tuple_ids: Vec, is_unique: bool) -> Result<(), StorageError> { todo!() } @@ -278,12 +280,14 @@ pub(crate) mod test { Arc::new(ColumnCatalog::new( "c1".to_string(), false, - ColumnDesc::new(LogicalType::Integer, true, false) + ColumnDesc::new(LogicalType::Integer, true, false), + None )), Arc::new(ColumnCatalog::new( "c2".to_string(), false, - ColumnDesc::new(LogicalType::Boolean, false, false) + ColumnDesc::new(LogicalType::Boolean, false, false), + None )), ]; diff --git a/src/storage/table_codec.rs b/src/storage/table_codec.rs index e8afa7fa..a038cb8f 100644 --- a/src/storage/table_codec.rs +++ b/src/storage/table_codec.rs @@ -260,12 +260,14 @@ mod tests { ColumnCatalog::new( "c1".into(), false, - ColumnDesc::new(LogicalType::Integer, true, false) + ColumnDesc::new(LogicalType::Integer, true, false), + None ), ColumnCatalog::new( "c2".into(), false, - ColumnDesc::new(LogicalType::Decimal(None,None), false, false) + ColumnDesc::new(LogicalType::Decimal(None,None), false, false), + None ), ]; let table_catalog = TableCatalog::new(Arc::new("t1".to_string()), columns).unwrap(); @@ -356,7 +358,8 @@ mod tests { column_datatype: LogicalType::Invalid, is_primary: false, is_unique: false, - } + }, + None ); col.table_name = Some(Arc::new(table_name.to_string())); diff --git a/src/types/tuple.rs b/src/types/tuple.rs index dc6db97c..267fd5cd 100644 --- a/src/types/tuple.rs +++ b/src/types/tuple.rs @@ -125,62 +125,74 @@ mod tests { Arc::new(ColumnCatalog::new( "c1".to_string(), false, - ColumnDesc::new(LogicalType::Integer, true, false) + ColumnDesc::new(LogicalType::Integer, true, false), + None )), Arc::new(ColumnCatalog::new( "c2".to_string(), false, - ColumnDesc::new(LogicalType::UInteger, false, false) + ColumnDesc::new(LogicalType::UInteger, false, false), + None )), Arc::new(ColumnCatalog::new( "c3".to_string(), false, - ColumnDesc::new(LogicalType::Varchar(Some(2)), false, false) + ColumnDesc::new(LogicalType::Varchar(Some(2)), false, false), + None )), Arc::new(ColumnCatalog::new( "c4".to_string(), false, - ColumnDesc::new(LogicalType::Smallint, false, false) + ColumnDesc::new(LogicalType::Smallint, false, false), + None )), Arc::new(ColumnCatalog::new( "c5".to_string(), false, - ColumnDesc::new(LogicalType::USmallint, false, false) + ColumnDesc::new(LogicalType::USmallint, false, false), + None )), Arc::new(ColumnCatalog::new( "c6".to_string(), false, - ColumnDesc::new(LogicalType::Float, false, false) + ColumnDesc::new(LogicalType::Float, false, false), + None )), Arc::new(ColumnCatalog::new( "c7".to_string(), false, - ColumnDesc::new(LogicalType::Double, false, false) + ColumnDesc::new(LogicalType::Double, false, false), + None )), Arc::new(ColumnCatalog::new( "c8".to_string(), false, - ColumnDesc::new(LogicalType::Tinyint, false, false) + ColumnDesc::new(LogicalType::Tinyint, false, false), + None )), Arc::new(ColumnCatalog::new( "c9".to_string(), false, - ColumnDesc::new(LogicalType::UTinyint, false, false) + ColumnDesc::new(LogicalType::UTinyint, false, false), + None )), Arc::new(ColumnCatalog::new( "c10".to_string(), false, - ColumnDesc::new(LogicalType::Boolean, false, false) + ColumnDesc::new(LogicalType::Boolean, false, false), + None )), Arc::new(ColumnCatalog::new( "c11".to_string(), false, - ColumnDesc::new(LogicalType::DateTime, false, false) + ColumnDesc::new(LogicalType::DateTime, false, false), + None )), Arc::new(ColumnCatalog::new( "c12".to_string(), false, - ColumnDesc::new(LogicalType::Date, false, false) + ColumnDesc::new(LogicalType::Date, false, false), + None )), ]; diff --git a/tests/slt/aggregation.slt b/tests/slt/aggregation.slt new file mode 100644 index 00000000..0c92a170 --- /dev/null +++ b/tests/slt/aggregation.slt @@ -0,0 +1,77 @@ +statement ok +create table t(id int primary key,v1 int not null, v2 int not null, v3 double not null) + +statement ok +insert into t values(0,1,4,2.5), (1,2,3,3.2), (2,3,4,4.7), (3,4,3,5.1) + +query I +select sum(v1) + sum(v2) from t +---- +24 + +query I +select sum(v1) as a from t +---- +10 + +query IR +select sum(v1), sum(v3) from t +---- +10 15.499999761581421 + +# query IR +# select sum(v1+v2),sum(v1+v3) from t +# ---- +# 24 25.5 + +# SimpleMinTest + +query I +select min(v1) from t +---- +1 + +# SimpleMaxTest + +query I +select max(v1) from t +---- +4 + +# SimpleMaxTest1 + +query I +select max(v1) from t where v2>3 +---- +3 + +# SimpleCountTest + +query I +select count(v1) from t +---- +4 + +# SimpleAvgTest + +# query R +# select avg(v2) from t +# ---- +# 3.5 + +# SumGroupTest + +query II rowsort +select sum(v1), v2 from t group by v2 +---- +4 4 +6 3 + +query II +select sum(v1) as a, v2 from t group by v2 order by a +---- +4 4 +6 3 + +statement ok +drop table t \ No newline at end of file diff --git a/tests/slt/basic_test.slt b/tests/slt/basic_test.slt new file mode 100644 index 00000000..e7473788 --- /dev/null +++ b/tests/slt/basic_test.slt @@ -0,0 +1,84 @@ +# query I +# select 1 +# ---- +# 1 + +# query R +# select 10000.00::FLOAT + 234.567::FLOAT +# ---- +# 10234.567 + +# query R +# select 100.0::DOUBLE/8.0::DOUBLE +# ---- +# 12.5 + +# query B +# select 2>1 +# ---- +# true + +# query B +# select 3>4 +# ---- +# false + +# query T +# select DATE '2001-02-16' +# ---- +# 2001-02-16 + +subtest NullType + +statement ok +create table t(id int primary key,v1 int null) + +statement ok +insert into t values(0, null) + +query T +select * from t +---- +0 null + +statement ok +drop table t + +subtest MultiRowsMultiColumn + +statement ok +create table t(id int primary key, v1 int not null, v2 int not null, v3 int not null) + +statement ok +insert into t values(0,1,4,2), (1,2,3,3), (2,3,4,4), (3,4,3,5) + +query II rowsort +select v1,v3 from t where v2 > 3 +---- +1 2 +3 4 + +statement ok +drop table t + +subtest SyntaxError + +statement error +SELECT * FORM dish + +subtest CharArray + +statement ok +create table t (id int primary key, name VARCHAR NOT NULL) + +statement ok +insert into t values (0,'text1'), (1,'text2') + +query T rowsort +select * from t +---- +0 text1 +1 text2 + +statement ok +drop table t \ No newline at end of file diff --git a/tests/slt/count.slt b/tests/slt/count.slt new file mode 100644 index 00000000..f8930bb9 --- /dev/null +++ b/tests/slt/count.slt @@ -0,0 +1,49 @@ +statement ok +create table t(id int primary key, v int) + +statement ok +insert into t values (0,1), (1,2), (2,3), (3,4), (4,5), (5,6), (6,7), (7,8) + +query I +select count(*) from t +---- +8 + +query I +select count(*)+1 from t; +---- +9 + +query I +select 2*count(*) from t; +---- +16 + +query I +select -count(*) from t; +---- +-8 + +query I +select count(*)+min(v) from t; +---- +9 + +query I +select count(*) as 'total' from t where v > 5 +---- +3 + +statement ok +delete from t where v = 7 + +query I +select count(*) from t where v > 5 +---- +2 + +# FIXME +# query I +# select count(*) from t where 0 = 1 +# ---- +# 0 \ No newline at end of file diff --git a/tests/slt/create.slt b/tests/slt/create.slt new file mode 100644 index 00000000..741acc10 --- /dev/null +++ b/tests/slt/create.slt @@ -0,0 +1,2 @@ +statement ok +create table t(id int primary key, v1 int, v2 int, v3 int) \ No newline at end of file diff --git a/tests/slt/create_table.slt b/tests/slt/create_table.slt deleted file mode 100644 index eddd85e4..00000000 --- a/tests/slt/create_table.slt +++ /dev/null @@ -1,39 +0,0 @@ -statement ok -create table t1(v1 varchar, v2 varchar, v3 varchar); - -statement ok -insert into t1 values('a', 'b', 'c'); - -statement error -create table t1(v1 int); - - -statement ok -create table t2(v1 boolean, v2 tinyint, v3 smallint, v4 int, v5 bigint, v6 float, v7 double, v8 varchar); - -statement ok -insert into t2 values(true, 1, 2, 3, 4, 5.1, 6.2, '7'); - - -statement ok -create table t3(v1 boolean, v2 tinyint unsigned, v3 smallint unsigned, v4 int unsigned, v5 bigint unsigned, v6 float, v7 double, v8 varchar); - -statement ok -insert into t3 values(true, 1, 2, 3, 4, 5.1, 6.2, '7'); - - -statement ok -create table t4(v1 int); - -statement ok -select v1 from t4; - - -statement ok -create table read_csv_table as select * from read_csv('tests/csv/t2.csv'); - - -query I -select a from read_csv_table limit 1; ----- -10 diff --git a/tests/slt/decimal b/tests/slt/decimal deleted file mode 100644 index 4d566b0a..00000000 --- a/tests/slt/decimal +++ /dev/null @@ -1,6 +0,0 @@ - -statement ok -CREATE TABLE mytable ( title varchar(256) primary key, cost decimal(4,2)); - -statement ok -INSERT INTO mytable (title, cost) VALUES ('A', 1.00); \ No newline at end of file diff --git a/tests/slt/delete.slt b/tests/slt/delete.slt new file mode 100644 index 00000000..2e66b16c --- /dev/null +++ b/tests/slt/delete.slt @@ -0,0 +1,25 @@ +statement ok +create table t(id int primary key, v1 int, v2 int, v3 int) + +statement ok +insert into t values (0,1,10,100) + +statement ok +insert into t values (1,1,10,100), (2,2,20,200), (3,3,30,300), (4,4,40,400) + +statement ok +delete from t where v1 = 1 + +query III rowsort +select * from t; +---- +2 2 20 200 +3 3 30 300 +4 4 40 400 + +statement ok +delete from t + +query III rowsort +select * from t +---- \ No newline at end of file diff --git a/tests/slt/distinct.slt b/tests/slt/distinct.slt new file mode 100644 index 00000000..9a33bac5 --- /dev/null +++ b/tests/slt/distinct.slt @@ -0,0 +1,24 @@ +statement ok +CREATE TABLE test (id int primary key, x int, y int); + +statement ok +INSERT INTO test VALUES (0, 1, 1), (1, 2, 2), (2, 1, 1), (3, 3, 3); + +query II +SELECT DISTINCT x FROM test ORDER BY x, id; +---- +1 +2 +3 + + +query I +SELECT DISTINCT sum(x) FROM test ORDER BY sum(x); +---- +7 + + +# ORDER BY items must appear in the select list +# if SELECT DISTINCT is specified +statement error +SELECT DISTINCT x FROM test ORDER BY y; \ No newline at end of file diff --git a/tests/slt/filter.slt b/tests/slt/filter.slt new file mode 100644 index 00000000..ccd46821 --- /dev/null +++ b/tests/slt/filter.slt @@ -0,0 +1,71 @@ +statement ok +create table t (v1 int not null primary key, v2 int not null); + +statement ok +insert into t values (1, 1), (4, 6), (3, 2), (2, 1) + +query I rowsort +select v1 from t where v1 > 2 +---- +3 +4 + +query I +select v2 from t where 3 > v1 +---- +1 +1 + +statement ok +drop table t + +statement ok +create table t(v1 int not null primary key, v2 int not null) + +statement ok +insert into t values(-3, -3), (-2, -2), (-1, -1), (0, 0), (1,1), (2, 2) + +statement ok +insert into t values(-8, -8), (-7, -7), (-6, -6), (3, 3), (7, 7), (8, 8), (9, 9) + +query I +select v1 from t where v1 > 2 and v1 < 4 +---- +3 + +query I +select v2 from t where (-7 < v1 or 9 <= v1) and (v1 = 3) +---- +3 + +query I rowsort +select v2 from t where (-8 < v1 and v1 <= -7) or (v1 >= 1 and 2 > v1) +---- +-7 +1 + +query I rowsort +select v2 from t where ((v1 >= -8 and -4 >= v1) or (v1 >= 0 and 5 > v1)) and ((v1 > 0 and v1 <= 1) or (v1 > -8 and v1 < -6)) +---- +-7 +1 + +query I rowsort +select v2 from t where (-7 < v1 or 9 <= v1) and (v2 = 3) +---- +3 + +query I rowsort +select v2 from t where (-8 < v1 and v2 <= -7) or (v1 >= 1 and 2 > v2) +---- +-7 +1 + +query I rowsort +select v2 from t where ((v2 >= -8 and -4 >= v1) or (v1 >= 0 and 5 > v2)) and ((v2 > 0 and v1 <= 1) or (v1 > -8 and v2 < -6)) +---- +-7 +1 + +statement ok +drop table t \ No newline at end of file diff --git a/tests/slt/filter_null.slt b/tests/slt/filter_null.slt new file mode 100644 index 00000000..db45712f --- /dev/null +++ b/tests/slt/filter_null.slt @@ -0,0 +1,36 @@ +statement ok +create table t(id int primary key, v1 int, v2 int not null) + +statement ok +insert into t values (0, 2, 4), (1, 1, 3), (2, 3, 4), (3, 4, 3); + +query II +select * from t where v1 > 1 +---- +0 2 4 +2 3 4 +3 4 3 + +query II +select * from t where v1 < 2 +---- +1 1 3 + +statement ok +drop table t + +statement ok +create table t(id int primary key, v1 int null, v2 int) + +statement ok +insert into t values (0, 2, 4), (1, null, 3), (2, 3, 4), (3, 4, 3) + +query II +select * from t where v1 > 1 +---- +0 2 4 +2 3 4 +3 4 3 + +statement ok +drop table t \ No newline at end of file diff --git a/tests/slt/group_by.slt b/tests/slt/group_by.slt new file mode 100644 index 00000000..7036f3b7 --- /dev/null +++ b/tests/slt/group_by.slt @@ -0,0 +1,62 @@ +statement ok +create table t (id int primary key, v1 int, v2 int) + +statement ok +insert into t values (0,1,1), (1,2,1), (2,3,2), (3,4,2), (4,5,3) + +# TODO: check on binder +# statement error +# select v2 + 1, v1 from t group by v2 + 1 + +# statement error +# select v2 + 1 as a, v1 as b from t group by a + +# statement error +# select v2, v2 + 1, sum(v1) from t group by v2 + 1 + +# statement error +# select v2 + 2 + count(*) from t group by v2 + 1 + +# statement error +# select v2 + count(*) from t group by v2 order by v1; + +query II rowsort +select v2 + 1, sum(v1) from t group by v2 + 1 +---- +2 3 +3 7 +4 5 + +query III rowsort +select sum(v1), v2 + 1 as a, count(*) from t group by a +---- +3 2 2 +5 4 1 +7 3 2 + +query III rowsort +select v2, v2 + 1, sum(v1) from t group by v2 + 1, v2 +---- +1 2 3 +2 3 7 +3 4 5 + +query III rowsort +select v2, v2 + 1, sum(v1) from t group by v2 + 1, v2 order by v2 +---- +1 2 3 +2 3 7 +3 4 5 + +# TODO +# query I rowsort +# select v1 + 1 + count(*) from t group by v1 + 1 +# ---- +# 3 +# 4 +# 5 +# 6 +# 7 + +statement ok +drop table t \ No newline at end of file diff --git a/tests/slt/having.slt b/tests/slt/having.slt new file mode 100644 index 00000000..89d25713 --- /dev/null +++ b/tests/slt/having.slt @@ -0,0 +1,35 @@ +statement ok +CREATE TABLE test (id int primary key, x int, y int); + +statement ok +INSERT INTO test VALUES (0, 1, 2), (1, 2, 2), (2, 11, 22) + +query II +select y as b, sum(x) as sum from test group by b having b = 2 +---- +2 3 + +query II +select count(x) as a, y as b from test group by b having a > 1 +---- +2 2 + +query II +select count(x) as a, y + 1 as b from test group by b having b + 1 = 24; +---- +1 23 + +# TODO: Filter pushed down to Agg +# query II +# select x from test group by x having max(y) = 22 +# ---- +# 11 + +# query II +# select y + 1 as i from test group by y + 1 having count(x) > 1 and y + 1 = 3 or y + 1 = 23 order by i; +# ---- +# 3 +# 23 + +statement error +select count(x) from test group by count(x) \ No newline at end of file diff --git a/tests/slt/insert_table.slt b/tests/slt/insert_table.slt deleted file mode 100644 index 31c6cf3a..00000000 --- a/tests/slt/insert_table.slt +++ /dev/null @@ -1,93 +0,0 @@ -# Test common insert case - -statement ok -create table t1(v1 varchar, v2 varchar, v3 varchar); - - -statement error -insert into t1(v3) values ('0','4'); - - -statement ok -insert into t1(v3, v2) values ('0','4'), ('1','5'); - - -statement ok -insert into t1 values ('2','7','9'); - - -query III -select v1, v3, v2 from t1; ----- -NULL 0 4 -NULL 1 5 -2 9 7 - - -# Test insert value cast type - -statement ok -create table t2(v1 int, v2 int, v3 int); - - -statement ok -insert into t2(v3, v2, v1) values (0, 4, 1), (1, 5, 2); - - -query III -select v3, v2, v1 from t2; ----- -0 4 1 -1 5 2 - - -# Test insert type cast - -statement ok -create table t3(v1 TINYINT UNSIGNED); - - -statement error -insert into t3(v1) values (1481); - - -# Test insert null values - -statement ok -create table t4(v1 varchar, v2 smallint unsigned, v3 bigint unsigned); - - -statement ok -insert into t4 values (NULL, 1, 2), ('', 3, NULL); - - -statement ok -insert into t4 values (NULL, NULL, NULL); - - -query III -select v1, v2, v3 from t4; ----- -NULL 1 2 -(empty) 3 NULL -NULL NULL NULL - - -# Test insert from select - -statement ok -CREATE TABLE integers(i INTEGER); - - -statement ok -INSERT INTO integers SELECT 42; - -statement ok -INSERT INTO integers SELECT null; - - -query I -SELECT * FROM integers ----- -42 -NULL diff --git a/tests/slt/join.slt b/tests/slt/join.slt new file mode 100644 index 00000000..7ca5738f --- /dev/null +++ b/tests/slt/join.slt @@ -0,0 +1,102 @@ +statement ok +create table x(id int primary key, a int, b int); + +statement ok +create table y(id int primary key, c int, d int); + +statement ok +insert into x values (0, 1, 2), (1, 1, 3); + +query IIII +select a, b, c, d from x join y on a = c; +---- + +statement ok +insert into y values (0, 1, 5), (1, 1, 6), (2, 2, 7); + +query IIII +select a, b, c, d from x join y on a = c; +---- +1 2 1 5 +1 3 1 5 +1 2 1 6 +1 3 1 6 + +statement ok +drop table x; + +statement ok +drop table y; + +statement ok +create table a(id int primary key, v1 int, v2 int); + +statement ok +create table b(id int primary key, v3 int, v4 int); + +statement ok +insert into a values (0, 1, 1), (1, 2, 2), (2, 3, 3); + +query IIII rowsort +select v1, v2, v3, v4 from a left join b on v1 = v3; +---- +1 1 null null +2 2 null null +3 3 null null + +statement ok +insert into b values (0, 1, 100), (1, 3, 300), (2, 4, 400); + +query IIII +select v1, v2, v3, v4 from a left join b on v1 = v3; +---- +1 1 1 100 +3 3 3 300 +2 2 null null + +query IIII rowsort +select v1, v2, v3, v4 from a right join b on v1 = v3; +---- +1 1 1 100 +3 3 3 300 +null null 4 400 + +query IIII rowsort +select v1, v2, v3, v4 from a full join b on v1 = v3; +---- +1 1 1 100 +2 2 null null +3 3 3 300 +null null 4 400 + +statement ok +drop table a; + +statement ok +drop table b; + +statement ok +create table a(id int primary key, v1 int, v2 int); + +statement ok +create table b(id int primary key, v3 int, v4 int, v5 int); + +statement ok +insert into a values (0, 1, 1), (1, 2, 2), (2, 3, 3); + +statement ok +insert into b values (0, 1, 1, 1), (1, 2, 2, 2), (2, 3, 3, 4), (3, 1, 1, 5); + +query IIIII rowsort +select v1, v2, v3, v4, v5 from a join b on v1 = v3 and v2 = v4; +---- +1 1 1 1 1 +1 1 1 1 5 +2 2 2 2 2 +3 3 3 3 4 + +query IIIII rowsort +select v1, v2, v3, v4, v5 from a join b on v1 = v3 and v2 = v4 and v1 < v5; +---- +1 1 1 1 5 +3 3 3 3 4 \ No newline at end of file diff --git a/tests/slt/order_by.slt b/tests/slt/order_by.slt new file mode 100644 index 00000000..9f84275b --- /dev/null +++ b/tests/slt/order_by.slt @@ -0,0 +1,83 @@ +statement ok +create table t(id int primary key, v1 int, v2 int) + +statement ok +insert into t values(0, 1, 1), (1, 4, 2), (2, 3, 3), (3, 10, 12), (4, 2, 5) + +query I +select v1 from t order by v1 asc +---- +1 +2 +3 +4 +10 + +query I +select v1 from t order by v1 desc +---- +10 +4 +3 +2 +1 + +statement ok +drop table t + + +statement ok +create table t(id int primary key, v1 int, v2 int) + +statement ok +insert into t values (0, 1, 0), (1, 2, 2), (2, 3, 15), (3, 2, 12), (4, 3, 9), (5, 1, 5) + +query II +select v1, v2 from t order by v1 asc, v2 desc +---- +1 5 +1 0 +2 12 +2 2 +3 15 +3 9 + +statement ok +drop table t + +# sort with null +statement ok +create table t(id int primary key, v1 int null, v2 int null) + +statement ok +insert into t values (0, 1, 0), (1, 2, 2), (2, null, 5), (3, 2, null) + +query II +select v1, v2 from t order by v1 asc, v2 asc +---- +null 5 +1 0 +2 null +2 2 + +statement ok +drop table t + +# sort on alias +statement ok +create table t(id int primary key, v1 int null, v2 int null) + +statement ok +insert into t values(0, 1, 1), (1, 4, 2), (2, 3, 3), (3, 10, 12), (4, 2, 5) + +query I +select v1 as a from t order by a +---- +1 +2 +3 +4 +10 + +statement ok +drop table t \ No newline at end of file diff --git a/tests/slt/select.slt b/tests/slt/select.slt deleted file mode 100644 index 65d036f3..00000000 --- a/tests/slt/select.slt +++ /dev/null @@ -1,46 +0,0 @@ -# test insert projection with cast expression - -statement ok -create table t2(v1 tinyint); - -statement ok -insert into t2(v1) values (1), (5); - - -statement ok -create table t1(v1 int, v2 int, v3 int); - -statement ok -insert into t1(v3, v2, v1) values (0, 4, 1), (1, 5, 2); - - - -query III -select t1.v1, v2 from t1; ----- -1 4 -2 5 - - - -query III -select *, t1.* from t1; ----- -1 4 0 1 4 0 -2 5 1 2 5 1 - - -# TODO: use alias function to verify output column names - -query III -select t.v1 as a, v2 as b from t1 as t; ----- -1 4 -2 5 - - - -query III -select 1, 2.3, 'πŸ˜‡', true, null; ----- -1 2.3 πŸ˜‡ true NULL diff --git a/tests/sqllogictest/Cargo.toml b/tests/sqllogictest/Cargo.toml index 59f5f380..053e4435 100644 --- a/tests/sqllogictest/Cargo.toml +++ b/tests/sqllogictest/Cargo.toml @@ -5,11 +5,9 @@ edition = "2021" [dependencies] "kip-sql" = { path = "../.." } -sqllogictest = "0.6" glob = "0.3" async-trait = "0.1" -libtest-mimic = "0.6" -anyhow = { version = "1.0.71", features = ["std"] } -[[test]] -name = "sqllogictest" -harness = false +tokio = "1.29.1" +sqllogictest = "0.14.0" +tokio-test = "0.4.2" +tempfile = "3.0.7" \ No newline at end of file diff --git a/tests/sqllogictest/src/lib.rs b/tests/sqllogictest/src/lib.rs index 5d9082d5..8e1eaff6 100644 --- a/tests/sqllogictest/src/lib.rs +++ b/tests/sqllogictest/src/lib.rs @@ -1,26 +1,39 @@ -#![feature(iterator_try_collect)] -use std::sync::Arc; - +use std::time::Instant; +use sqllogictest::{AsyncDB, DBOutput, DefaultColumnType}; use kip_sql::db::{Database, DatabaseError}; -use sqllogictest::{AsyncDB, Runner}; -use kip_sql::types::tuple::create_table; - -pub fn test_run(sqlfile: &str) { - let db = Arc::new(Database::new("./test")); - let mut tester = Runner::new(DatabaseWrapper { db }); - tester.run_file(sqlfile).unwrap() -} +use kip_sql::storage::kip::KipStorage; -struct DatabaseWrapper { - db: Arc, +pub struct KipSQL { + pub db: Database, } #[async_trait::async_trait] -impl AsyncDB for DatabaseWrapper { +impl AsyncDB for KipSQL { type Error = DatabaseError; - async fn run(&mut self, sql: &str) -> Result { - let table = create_table(&self.db.run(sql).await?); + type ColumnType = DefaultColumnType; + + async fn run(&mut self, sql: &str) -> Result, Self::Error> { + let start = Instant::now(); + let tuples = self.db.run(sql).await?; + println!("|β€” Input SQL:"); + println!(" |β€” {}", sql); + println!(" |β€” Time consuming: {:?}", start.elapsed()); + + if tuples.is_empty() { + return Ok(DBOutput::StatementComplete(0)); + } - Ok(format!("{}", table)) + let types = vec![DefaultColumnType::Any; tuples[0].columns.len()]; + let rows = tuples + .into_iter() + .map(|tuple| { + tuple + .values + .into_iter() + .map(|value| format!("{}", value)) + .collect() + }) + .collect(); + Ok(DBOutput::Rows { types, rows }) } } diff --git a/tests/sqllogictest/src/main.rs b/tests/sqllogictest/src/main.rs new file mode 100644 index 00000000..7438db19 --- /dev/null +++ b/tests/sqllogictest/src/main.rs @@ -0,0 +1,35 @@ +use std::path::Path; +use sqllogictest::Runner; +use tempfile::TempDir; +use kip_sql::db::Database; +use sqllogictest_test::KipSQL; + +#[tokio::main] +async fn main() { + let path = Path::new(env!("CARGO_MANIFEST_DIR")).join("..").join(".."); + std::env::set_current_dir(path).unwrap(); + + println!("KipSQL Test Start!\n"); + const SLT_PATTERN: &str = "tests/slt/**/*.slt"; + + let slt_files = glob::glob(SLT_PATTERN).expect("failed to find slt files"); + for slt_file in slt_files { + let temp_dir = TempDir::new().expect("unable to create temporary working directory"); + + let filepath = slt_file + .expect("failed to read slt file") + .to_str() + .unwrap() + .to_string(); + println!("-> Now the test file is: {}", filepath); + + let db = Database::with_kipdb(temp_dir.path()).await + .expect("init db error"); + let mut tester = Runner::new(KipSQL { db }); + + if let Err(err) = tester.run_file_async(filepath).await { + panic!("test error: {}", err); + } + println!("-> Pass!\n\n") + } +} \ No newline at end of file diff --git a/tests/sqllogictest/tests/sqllogictest.rs b/tests/sqllogictest/tests/sqllogictest.rs deleted file mode 100644 index f3d48fd4..00000000 --- a/tests/sqllogictest/tests/sqllogictest.rs +++ /dev/null @@ -1,35 +0,0 @@ -use std::path::Path; - -use libtest_mimic::{Arguments, Trial}; -use sqllogictest_test::test_run; - -fn main() { - let path = Path::new(env!("CARGO_MANIFEST_DIR")).join("..").join(".."); - std::env::set_current_dir(path).unwrap(); - - const SLT_PATTERN: &str = "tests/slt/**/*.slt"; - - let args = Arguments::from_args(); - let mut tests = vec![]; - - let slt_files = glob::glob(SLT_PATTERN).expect("failed to find slt files"); - for slt_file in slt_files { - let filepath = slt_file.expect("failed to read slt file"); - let filename = filepath - .file_stem() - .expect("failed to get file name") - .to_str() - .unwrap() - .to_string(); - let filepath = filepath.to_str().unwrap().to_string(); - - let test = Trial::test(filename, move || { - test_run(filepath.as_str()); - Ok(()) - }); - - tests.push(test); - } - - libtest_mimic::run(&args, tests).exit(); -}