From fff8a351375c1e3474fb5b033da9dc97c8f62c4c Mon Sep 17 00:00:00 2001 From: Kould <2435992353@qq.com> Date: Mon, 19 Feb 2024 01:40:19 +0800 Subject: [PATCH] feat: support `if()`/`ifnull()`/`nullif()`/`coalesce()`/`Case ... When ...` --- src/binder/aggregate.rs | 92 ++++++++- src/binder/expr.rs | 140 +++++++++++++- src/binder/select.rs | 10 +- src/db.rs | 12 +- src/execution/volcano/dql/filter.rs | 12 +- src/execution/volcano/mod.rs | 34 ++-- src/expression/evaluator.rs | 84 +++++++- src/expression/mod.rs | 284 +++++++++++++++++++++++++--- src/expression/simplify.rs | 61 +++++- src/expression/value_compute.rs | 32 ++-- src/marcos/mod.rs | 6 +- src/types/value.rs | 11 ++ tests/slt/crdb/condition.slt | 76 ++++++++ tests/slt/dummy.slt | 120 ++++++++++++ tests/slt/filter_null.slt | 14 +- tests/slt/having.slt | 1 + tests/slt/subquery.slt | 2 - tests/sqllogictest/src/lib.rs | 6 +- tests/sqllogictest/src/main.rs | 29 ++- 19 files changed, 926 insertions(+), 100 deletions(-) create mode 100644 tests/slt/crdb/condition.slt diff --git a/src/binder/aggregate.rs b/src/binder/aggregate.rs index 95631d96..6c51fd44 100644 --- a/src/binder/aggregate.rs +++ b/src/binder/aggregate.rs @@ -139,11 +139,52 @@ impl<'a, T: Transaction> Binder<'a, T> { ScalarExpression::Constant(_) | ScalarExpression::ColumnRef { .. } => (), ScalarExpression::Reference { .. } | ScalarExpression::Empty => unreachable!(), ScalarExpression::Tuple(args) - | ScalarExpression::Function(ScalarFunction { args, .. }) => { + | ScalarExpression::Function(ScalarFunction { args, .. }) + | ScalarExpression::Coalesce { exprs: args, .. } => { for expr in args { self.visit_column_agg_expr(expr)?; } } + ScalarExpression::If { + condition, + left_expr, + right_expr, + .. + } => { + self.visit_column_agg_expr(condition)?; + self.visit_column_agg_expr(left_expr)?; + self.visit_column_agg_expr(right_expr)?; + } + ScalarExpression::IfNull { + left_expr, + right_expr, + .. + } + | ScalarExpression::NullIf { + left_expr, + right_expr, + .. + } => { + self.visit_column_agg_expr(left_expr)?; + self.visit_column_agg_expr(right_expr)?; + } + ScalarExpression::CaseWhen { + operand_expr, + expr_pairs, + else_expr, + .. + } => { + if let Some(expr) = operand_expr { + self.visit_column_agg_expr(expr)?; + } + for (expr_1, expr_2) in expr_pairs { + self.visit_column_agg_expr(expr_1)?; + self.visit_column_agg_expr(expr_2)?; + } + if let Some(expr) = else_expr { + self.visit_column_agg_expr(expr)?; + } + } } Ok(()) @@ -318,12 +359,59 @@ impl<'a, T: Transaction> Binder<'a, T> { ScalarExpression::Constant(_) => Ok(()), ScalarExpression::Reference { .. } | ScalarExpression::Empty => unreachable!(), ScalarExpression::Tuple(args) - | ScalarExpression::Function(ScalarFunction { args, .. }) => { + | ScalarExpression::Function(ScalarFunction { args, .. }) + | ScalarExpression::Coalesce { exprs: args, .. } => { for expr in args { self.validate_having_orderby(expr)?; } Ok(()) } + ScalarExpression::If { + condition, + left_expr, + right_expr, + .. + } => { + self.validate_having_orderby(condition)?; + self.validate_having_orderby(left_expr)?; + self.validate_having_orderby(right_expr)?; + + Ok(()) + } + ScalarExpression::IfNull { + left_expr, + right_expr, + .. + } + | ScalarExpression::NullIf { + left_expr, + right_expr, + .. + } => { + self.validate_having_orderby(left_expr)?; + self.validate_having_orderby(right_expr)?; + + Ok(()) + } + ScalarExpression::CaseWhen { + operand_expr, + expr_pairs, + else_expr, + .. + } => { + if let Some(expr) = operand_expr { + self.validate_having_orderby(expr)?; + } + for (expr_1, expr_2) in expr_pairs { + self.validate_having_orderby(expr_1)?; + self.validate_having_orderby(expr_2)?; + } + if let Some(expr) = else_expr { + self.validate_having_orderby(expr)?; + } + + Ok(()) + } } } } diff --git a/src/binder/expr.rs b/src/binder/expr.rs index 2ab2566c..970d3bf5 100644 --- a/src/binder/expr.rs +++ b/src/binder/expr.rs @@ -122,6 +122,44 @@ impl<'a, T: Transaction> Binder<'a, T> { } Ok(ScalarExpression::Tuple(bond_exprs)) } + Expr::Case { + operand, + conditions, + results, + else_result, + } => { + let mut operand_expr = None; + let mut ty = LogicalType::SqlNull; + if let Some(expr) = operand { + operand_expr = Some(Box::new(self.bind_expr(expr)?)); + } + let mut expr_pairs = Vec::with_capacity(conditions.len()); + for i in 0..conditions.len() { + let result = self.bind_expr(&results[i])?; + let result_ty = result.return_type(); + + if result_ty != LogicalType::SqlNull { + if ty == LogicalType::SqlNull { + ty = result_ty; + } else if ty != result_ty { + return Err(DatabaseError::Incomparable(ty, result_ty)); + } + } + expr_pairs.push((self.bind_expr(&conditions[i])?, result)) + } + + let mut else_expr = None; + if let Some(expr) = else_result { + else_expr = Some(Box::new(self.bind_expr(expr)?)); + } + + Ok(ScalarExpression::CaseWhen { + operand_expr, + expr_pairs, + else_expr, + ty, + }) + } _ => { todo!() } @@ -272,14 +310,20 @@ impl<'a, T: Transaction> Binder<'a, T> { match function_name.as_str() { "count" => { + if args.len() != 1 { + return Err(DatabaseError::MisMatch("number of count() parameters", "1")); + } return Ok(ScalarExpression::AggCall { distinct: func.distinct, kind: AggKind::Count, args, ty: LogicalType::Integer, - }) + }); } "sum" => { + if args.len() != 1 { + return Err(DatabaseError::MisMatch("number of sum() parameters", "1")); + } let ty = args[0].return_type(); return Ok(ScalarExpression::AggCall { @@ -290,6 +334,9 @@ impl<'a, T: Transaction> Binder<'a, T> { }); } "min" => { + if args.len() != 1 { + return Err(DatabaseError::MisMatch("number of min() parameters", "1")); + } let ty = args[0].return_type(); return Ok(ScalarExpression::AggCall { @@ -300,6 +347,9 @@ impl<'a, T: Transaction> Binder<'a, T> { }); } "max" => { + if args.len() != 1 { + return Err(DatabaseError::MisMatch("number of max() parameters", "1")); + } let ty = args[0].return_type(); return Ok(ScalarExpression::AggCall { @@ -310,6 +360,9 @@ impl<'a, T: Transaction> Binder<'a, T> { }); } "avg" => { + if args.len() != 1 { + return Err(DatabaseError::MisMatch("number of avg() parameters", "1")); + } let ty = args[0].return_type(); return Ok(ScalarExpression::AggCall { @@ -319,6 +372,77 @@ impl<'a, T: Transaction> Binder<'a, T> { ty, }); } + "if" => { + if args.len() != 3 { + return Err(DatabaseError::MisMatch("number of if() parameters", "3")); + } + let ty = Self::return_type(&args[1], &args[2])?; + let right_expr = Box::new(args.pop().unwrap()); + let left_expr = Box::new(args.pop().unwrap()); + let condition = Box::new(args.pop().unwrap()); + + return Ok(ScalarExpression::If { + condition, + left_expr, + right_expr, + ty, + }); + } + "nullif" => { + if args.len() != 2 { + return Err(DatabaseError::MisMatch( + "number of nullif() parameters", + "3", + )); + } + let ty = Self::return_type(&args[0], &args[1])?; + let right_expr = Box::new(args.pop().unwrap()); + let left_expr = Box::new(args.pop().unwrap()); + + return Ok(ScalarExpression::NullIf { + left_expr, + right_expr, + ty, + }); + } + "ifnull" => { + if args.len() != 2 { + return Err(DatabaseError::MisMatch( + "number of ifnull() parameters", + "3", + )); + } + let ty = Self::return_type(&args[0], &args[1])?; + let right_expr = Box::new(args.pop().unwrap()); + let left_expr = Box::new(args.pop().unwrap()); + + return Ok(ScalarExpression::IfNull { + left_expr, + right_expr, + ty, + }); + } + "coalesce" => { + let mut ty = LogicalType::SqlNull; + + if !args.is_empty() { + ty = args[0].return_type(); + + for arg in args.iter() { + let temp_ty = arg.return_type(); + + if temp_ty == LogicalType::SqlNull { + continue; + } + if ty == LogicalType::SqlNull && temp_ty != LogicalType::SqlNull { + ty = temp_ty; + } else if ty != temp_ty { + return Err(DatabaseError::Incomparable(ty, temp_ty)); + } + } + } + return Ok(ScalarExpression::Coalesce { exprs: args, ty }); + } _ => (), } let arg_types = args.iter().map(ScalarExpression::return_type).collect_vec(); @@ -336,6 +460,20 @@ impl<'a, T: Transaction> Binder<'a, T> { Err(DatabaseError::NotFound("function", summary.name)) } + fn return_type( + expr_1: &ScalarExpression, + expr_2: &ScalarExpression, + ) -> Result { + let temp_ty_1 = expr_1.return_type(); + let temp_ty_2 = expr_2.return_type(); + + match (temp_ty_1, temp_ty_2) { + (LogicalType::SqlNull, LogicalType::SqlNull) => Ok(LogicalType::SqlNull), + (ty, LogicalType::SqlNull) | (LogicalType::SqlNull, ty) => Ok(ty), + (ty_1, ty_2) => LogicalType::max_logical_type(&ty_1, &ty_2), + } + } + fn bind_is_null( &mut self, expr: &Expr, diff --git a/src/binder/select.rs b/src/binder/select.rs index bfb6e887..a2a7c5c8 100644 --- a/src/binder/select.rs +++ b/src/binder/select.rs @@ -273,12 +273,16 @@ impl<'a, T: Transaction> Binder<'a, T> { columns: alias_column, }) = alias { - let table_alias = Arc::new(name.value.to_lowercase()); - if tables.len() > 1 { todo!("Implement virtual tables for multiple table aliases"); } - self.register_alias(alias_column, table_alias.to_string(), tables.remove(0))?; + let table_alias = Arc::new(name.value.to_lowercase()); + + self.register_alias( + alias_column, + table_alias.to_string(), + tables.pop().unwrap(), + )?; (Some(table_alias), plan) } else { diff --git a/src/db.rs b/src/db.rs index 8bca72ec..729b409e 100644 --- a/src/db.rs +++ b/src/db.rs @@ -1,5 +1,4 @@ use ahash::HashMap; -use sqlparser::ast::Statement; use std::path::PathBuf; use std::sync::Arc; @@ -101,8 +100,7 @@ impl Database { /// Run SQL queries. pub async fn run>(&self, sql: T) -> Result, DatabaseError> { let transaction = self.storage.transaction().await?; - let (plan, _) = - Self::build_plan::(sql, &transaction, &self.functions)?; + let plan = Self::build_plan::(sql, &transaction, &self.functions)?; Self::run_volcano(transaction, plan).await } @@ -133,9 +131,9 @@ impl Database { sql: V, transaction: &::TransactionType, functions: &Functions, - ) -> Result<(LogicalPlan, Statement), DatabaseError> { + ) -> Result { // parse - let mut stmts = parse_sql(sql)?; + let stmts = parse_sql(sql)?; if stmts.is_empty() { return Err(DatabaseError::EmptyStatement); } @@ -154,7 +152,7 @@ impl Database { Self::default_optimizer(source_plan).find_best(Some(&transaction.meta_loader()))?; // println!("best_plan plan: {:#?}", best_plan); - Ok((best_plan, stmts.remove(0))) + Ok(best_plan) } pub(crate) fn default_optimizer(source_plan: LogicalPlan) -> HepOptimizer { @@ -241,7 +239,7 @@ pub struct DBTransaction { impl DBTransaction { pub async fn run>(&mut self, sql: T) -> Result, DatabaseError> { - let (plan, _) = + let plan = Database::::build_plan::(sql, &self.inner, &self.functions)?; let mut stream = build_write(plan, &mut self.inner); diff --git a/src/execution/volcano/dql/filter.rs b/src/execution/volcano/dql/filter.rs index 84bae47a..34eb1bc1 100644 --- a/src/execution/volcano/dql/filter.rs +++ b/src/execution/volcano/dql/filter.rs @@ -5,7 +5,6 @@ use crate::planner::operator::filter::FilterOperator; use crate::planner::LogicalPlan; use crate::storage::Transaction; use crate::types::tuple::Tuple; -use crate::types::value::DataValue; use futures_async_stream::try_stream; pub struct Filter { @@ -33,14 +32,9 @@ impl Filter { #[for_await] for tuple in build_read(input, transaction) { let tuple = tuple?; - if let DataValue::Boolean(option) = predicate.eval(&tuple)?.as_ref() { - if let Some(true) = option { - yield tuple; - } else { - continue; - } - } else { - unreachable!("only bool"); + + if predicate.eval(&tuple)?.is_true()? { + yield tuple; } } } diff --git a/src/execution/volcano/mod.rs b/src/execution/volcano/mod.rs index c585bf37..35263f5a 100644 --- a/src/execution/volcano/mod.rs +++ b/src/execution/volcano/mod.rs @@ -57,7 +57,7 @@ pub fn build_read(plan: LogicalPlan, transaction: &T) -> BoxedEx match operator { Operator::Dummy => Dummy {}.execute(transaction), Operator::Aggregate(op) => { - let input = childrens.remove(0); + let input = childrens.pop().unwrap(); if op.groupby_exprs.is_empty() { SimpleAggExecutor::from((op, input)).execute(transaction) @@ -66,18 +66,18 @@ pub fn build_read(plan: LogicalPlan, transaction: &T) -> BoxedEx } } Operator::Filter(op) => { - let input = childrens.remove(0); + let input = childrens.pop().unwrap(); Filter::from((op, input)).execute(transaction) } Operator::Join(op) => { - let left_input = childrens.remove(0); - let right_input = childrens.remove(0); + let right_input = childrens.pop().unwrap(); + let left_input = childrens.pop().unwrap(); HashJoin::from((op, left_input, right_input)).execute(transaction) } Operator::Project(op) => { - let input = childrens.remove(0); + let input = childrens.pop().unwrap(); Projection::from((op, input)).execute(transaction) } @@ -93,26 +93,26 @@ pub fn build_read(plan: LogicalPlan, transaction: &T) -> BoxedEx } } Operator::Sort(op) => { - let input = childrens.remove(0); + let input = childrens.pop().unwrap(); Sort::from((op, input)).execute(transaction) } Operator::Limit(op) => { - let input = childrens.remove(0); + let input = childrens.pop().unwrap(); Limit::from((op, input)).execute(transaction) } Operator::Values(op) => Values::from(op).execute(transaction), Operator::Show => ShowTables.execute(transaction), Operator::Explain => { - let input = childrens.remove(0); + let input = childrens.pop().unwrap(); Explain::from(input).execute(transaction) } Operator::Describe(op) => Describe::from(op).execute(transaction), Operator::Union(_) => { - let left_input = childrens.remove(0); - let right_input = childrens.remove(0); + let right_input = childrens.pop().unwrap(); + let left_input = childrens.pop().unwrap(); Union::from((left_input, right_input)).execute(transaction) } @@ -130,27 +130,27 @@ pub fn build_write(plan: LogicalPlan, transaction: &mut T) -> Bo match operator { Operator::Insert(op) => { - let input = childrens.remove(0); + let input = childrens.pop().unwrap(); Insert::from((op, input)).execute_mut(transaction) } Operator::Update(op) => { - let input = childrens.remove(0); - let values = childrens.remove(0); + let values = childrens.pop().unwrap(); + let input = childrens.pop().unwrap(); Update::from((op, input, values)).execute_mut(transaction) } Operator::Delete(op) => { - let input = childrens.remove(0); + let input = childrens.pop().unwrap(); Delete::from((op, input)).execute_mut(transaction) } Operator::AddColumn(op) => { - let input = childrens.remove(0); + let input = childrens.pop().unwrap(); AddColumn::from((op, input)).execute_mut(transaction) } Operator::DropColumn(op) => { - let input = childrens.remove(0); + let input = childrens.pop().unwrap(); DropColumn::from((op, input)).execute_mut(transaction) } Operator::CreateTable(op) => CreateTable::from(op).execute_mut(transaction), @@ -162,7 +162,7 @@ pub fn build_write(plan: LogicalPlan, transaction: &mut T) -> Bo todo!() } Operator::Analyze(op) => { - let input = childrens.remove(0); + let input = childrens.pop().unwrap(); Analyze::from((op, input)).execute_mut(transaction) } diff --git a/src/expression/evaluator.rs b/src/expression/evaluator.rs index 8f2ebbd7..f722c963 100644 --- a/src/expression/evaluator.rs +++ b/src/expression/evaluator.rs @@ -1,6 +1,6 @@ use crate::errors::DatabaseError; use crate::expression::function::ScalarFunction; -use crate::expression::{AliasType, ScalarExpression}; +use crate::expression::{AliasType, BinaryOperator, ScalarExpression}; use crate::types::tuple::Tuple; use crate::types::value::{DataValue, ValueRef}; use crate::types::LogicalType; @@ -182,6 +182,88 @@ impl ScalarExpression { inner.eval(args, tuple)?.cast(inner.return_type())?, )), ScalarExpression::Empty => unreachable!(), + ScalarExpression::If { + condition, + left_expr, + right_expr, + .. + } => { + if condition.eval(tuple)?.is_true()? { + left_expr.eval(tuple) + } else { + right_expr.eval(tuple) + } + } + ScalarExpression::IfNull { + left_expr, + right_expr, + .. + } => { + let value = left_expr.eval(tuple)?; + + if value.is_null() { + return right_expr.eval(tuple); + } + Ok(value) + } + ScalarExpression::NullIf { + left_expr, + right_expr, + .. + } => { + let value = left_expr.eval(tuple)?; + + if right_expr.eval(tuple)? == value { + return Ok(NULL_VALUE.clone()); + } + Ok(value) + } + ScalarExpression::Coalesce { exprs, .. } => { + let mut value = None; + + for expr in exprs { + let temp = expr.eval(tuple)?; + + if !temp.is_null() { + value = Some(temp); + break; + } + } + Ok(value.unwrap_or_else(|| NULL_VALUE.clone())) + } + ScalarExpression::CaseWhen { + operand_expr, + expr_pairs, + else_expr, + .. + } => { + let mut operand_value = None; + let mut result = None; + + if let Some(expr) = operand_expr { + operand_value = Some(expr.eval(tuple)?); + } + for (when_expr, result_expr) in expr_pairs { + let when_value = when_expr.eval(tuple)?; + let is_true = if let Some(operand_value) = &operand_value { + operand_value + .binary_op(&when_value, &BinaryOperator::Eq)? + .is_true()? + } else { + when_value.is_true()? + }; + if is_true { + result = Some(result_expr.eval(tuple)?); + break; + } + } + if result.is_none() { + if let Some(expr) = else_expr { + result = Some(expr.eval(tuple)?); + } + } + Ok(result.unwrap_or_else(|| NULL_VALUE.clone())) + } } } } diff --git a/src/expression/mod.rs b/src/expression/mod.rs index bc452f1a..a1823dc6 100644 --- a/src/expression/mod.rs +++ b/src/expression/mod.rs @@ -86,6 +86,32 @@ pub enum ScalarExpression { }, Tuple(Vec), Function(ScalarFunction), + If { + condition: Box, + left_expr: Box, + right_expr: Box, + ty: LogicalType, + }, + IfNull { + left_expr: Box, + right_expr: Box, + ty: LogicalType, + }, + NullIf { + left_expr: Box, + right_expr: Box, + ty: LogicalType, + }, + Coalesce { + exprs: Vec, + ty: LogicalType, + }, + CaseWhen { + operand_expr: Option>, + expr_pairs: Vec<(ScalarExpression, ScalarExpression)>, + else_expr: Option>, + ty: LogicalType, + }, } impl ScalarExpression { @@ -130,7 +156,9 @@ impl ScalarExpression { left_expr.try_reference(output_exprs); right_expr.try_reference(output_exprs); } - ScalarExpression::AggCall { args, .. } => { + ScalarExpression::AggCall { args, .. } + | ScalarExpression::Coalesce { exprs: args, .. } + | ScalarExpression::Tuple(args) => { for arg in args { arg.try_reference(output_exprs); } @@ -168,13 +196,48 @@ impl ScalarExpression { ScalarExpression::Constant(_) | ScalarExpression::ColumnRef(_) | ScalarExpression::Reference { .. } => (), - ScalarExpression::Tuple(exprs) => { - for expr in exprs { + ScalarExpression::Function(function) => { + for expr in function.args.iter_mut() { expr.try_reference(output_exprs); } } - ScalarExpression::Function(function) => { - for expr in function.args.iter_mut() { + ScalarExpression::If { + condition, + left_expr, + right_expr, + .. + } => { + condition.try_reference(output_exprs); + left_expr.try_reference(output_exprs); + right_expr.try_reference(output_exprs); + } + ScalarExpression::IfNull { + left_expr, + right_expr, + .. + } + | ScalarExpression::NullIf { + left_expr, + right_expr, + .. + } => { + left_expr.try_reference(output_exprs); + right_expr.try_reference(output_exprs); + } + ScalarExpression::CaseWhen { + operand_expr, + expr_pairs, + else_expr, + .. + } => { + if let Some(expr) = operand_expr { + expr.try_reference(output_exprs); + } + for (expr_1, expr_2) in expr_pairs { + expr_1.try_reference(output_exprs); + expr_2.try_reference(output_exprs); + } + if let Some(expr) = else_expr { expr.try_reference(output_exprs); } } @@ -193,7 +256,8 @@ impl ScalarExpression { .. } => left_expr.has_count_star() || right_expr.has_count_star(), ScalarExpression::AggCall { args, .. } - | ScalarExpression::Function(ScalarFunction { args, .. }) => { + | ScalarExpression::Function(ScalarFunction { args, .. }) + | ScalarExpression::Coalesce { exprs: args, .. } => { args.iter().any(Self::has_count_star) } ScalarExpression::Constant(_) | ScalarExpression::ColumnRef(_) => false, @@ -224,30 +288,82 @@ impl ScalarExpression { ScalarExpression::Empty => unreachable!(), ScalarExpression::Reference { expr, .. } => expr.has_count_star(), ScalarExpression::Tuple(args) => args.iter().any(Self::has_count_star), + ScalarExpression::If { + condition, + left_expr, + right_expr, + .. + } => { + condition.has_count_star() + || left_expr.has_count_star() + || right_expr.has_count_star() + } + ScalarExpression::IfNull { + left_expr, + right_expr, + .. + } + | ScalarExpression::NullIf { + left_expr, + right_expr, + .. + } => left_expr.has_count_star() || right_expr.has_count_star(), + ScalarExpression::CaseWhen { + operand_expr, + expr_pairs, + else_expr, + .. + } => { + matches!( + operand_expr.as_ref().map(|expr| expr.has_count_star()), + Some(true) + ) || expr_pairs + .iter() + .any(|(expr_1, expr_2)| expr_1.has_count_star() || expr_2.has_count_star()) + || matches!( + else_expr.as_ref().map(|expr| expr.has_count_star()), + Some(true) + ) + } } } pub fn return_type(&self) -> LogicalType { match self { - Self::Constant(v) => v.logical_type(), - Self::ColumnRef(col) => *col.datatype(), - Self::Binary { + ScalarExpression::Constant(v) => v.logical_type(), + ScalarExpression::ColumnRef(col) => *col.datatype(), + ScalarExpression::Binary { ty: return_type, .. - } => *return_type, - Self::Unary { + } + | ScalarExpression::Unary { ty: return_type, .. - } => *return_type, - Self::TypeCast { + } + | ScalarExpression::TypeCast { ty: return_type, .. - } => *return_type, - Self::AggCall { + } + | ScalarExpression::AggCall { + ty: return_type, .. + } + | ScalarExpression::If { ty: return_type, .. - } => *return_type, - Self::IsNull { .. } | Self::In { .. } | ScalarExpression::Between { .. } => { - LogicalType::Boolean } - Self::SubString { .. } => LogicalType::Varchar(None), - Self::Alias { expr, .. } | ScalarExpression::Reference { expr, .. } => { + | ScalarExpression::IfNull { + ty: return_type, .. + } + | ScalarExpression::NullIf { + ty: return_type, .. + } + | ScalarExpression::Coalesce { + ty: return_type, .. + } + | ScalarExpression::CaseWhen { + ty: return_type, .. + } => *return_type, + ScalarExpression::IsNull { .. } + | ScalarExpression::In { .. } + | ScalarExpression::Between { .. } => LogicalType::Boolean, + ScalarExpression::SubString { .. } => LogicalType::Varchar(None), + ScalarExpression::Alias { expr, .. } | ScalarExpression::Reference { expr, .. } => { expr.return_type() } ScalarExpression::Empty => unreachable!(), @@ -288,7 +404,8 @@ impl ScalarExpression { } ScalarExpression::AggCall { args, .. } | ScalarExpression::Function(ScalarFunction { args, .. }) - | ScalarExpression::Tuple(args) => { + | ScalarExpression::Tuple(args) + | ScalarExpression::Coalesce { exprs: args, .. } => { for expr in args { columns_collect(expr, vec, only_column_ref) } @@ -324,6 +441,46 @@ impl ScalarExpression { } ScalarExpression::Constant(_) => (), ScalarExpression::Reference { .. } | ScalarExpression::Empty => unreachable!(), + ScalarExpression::If { + condition, + left_expr, + right_expr, + .. + } => { + columns_collect(condition, vec, only_column_ref); + columns_collect(left_expr, vec, only_column_ref); + columns_collect(right_expr, vec, only_column_ref); + } + ScalarExpression::IfNull { + left_expr, + right_expr, + .. + } + | ScalarExpression::NullIf { + left_expr, + right_expr, + .. + } => { + columns_collect(left_expr, vec, only_column_ref); + columns_collect(right_expr, vec, only_column_ref); + } + ScalarExpression::CaseWhen { + operand_expr, + expr_pairs, + else_expr, + .. + } => { + if let Some(expr) = operand_expr { + columns_collect(expr, vec, only_column_ref); + } + for (expr_1, expr_2) in expr_pairs { + columns_collect(expr_1, vec, only_column_ref); + columns_collect(expr_2, vec, only_column_ref); + } + if let Some(expr) = else_expr { + columns_collect(expr, vec, only_column_ref); + } + } } } let mut exprs = Vec::new(); @@ -373,8 +530,40 @@ impl ScalarExpression { } ScalarExpression::Reference { .. } | ScalarExpression::Empty => unreachable!(), ScalarExpression::Tuple(args) - | ScalarExpression::Function(ScalarFunction { args, .. }) => { - args.iter().any(Self::has_agg_call) + | ScalarExpression::Function(ScalarFunction { args, .. }) + | ScalarExpression::Coalesce { exprs: args, .. } => args.iter().any(Self::has_agg_call), + ScalarExpression::If { + condition, + left_expr, + right_expr, + .. + } => condition.has_agg_call() || left_expr.has_agg_call() || right_expr.has_agg_call(), + ScalarExpression::IfNull { + left_expr, + right_expr, + .. + } + | ScalarExpression::NullIf { + left_expr, + right_expr, + .. + } => left_expr.has_agg_call() || right_expr.has_agg_call(), + ScalarExpression::CaseWhen { + operand_expr, + expr_pairs, + else_expr, + .. + } => { + matches!( + operand_expr.as_ref().map(|expr| expr.has_agg_call()), + Some(true) + ) || expr_pairs + .iter() + .any(|(expr_1, expr_2)| expr_1.has_agg_call() || expr_2.has_agg_call()) + || matches!( + else_expr.as_ref().map(|expr| expr.has_agg_call()), + Some(true) + ) } } } @@ -483,6 +672,55 @@ impl ScalarExpression { let args_str = args.iter().map(|expr| expr.output_name()).join(", "); format!("{}({})", inner.summary().name, args_str) } + ScalarExpression::If { + condition, + left_expr, + right_expr, + .. + } => { + format!("if {} ({}, {})", condition, left_expr, right_expr) + } + ScalarExpression::IfNull { + left_expr, + right_expr, + .. + } => { + format!("ifnull({}, {})", left_expr, right_expr) + } + ScalarExpression::NullIf { + left_expr, + right_expr, + .. + } => { + format!("ifnull({}, {})", left_expr, right_expr) + } + ScalarExpression::Coalesce { exprs, .. } => { + let exprs_str = exprs.iter().map(|expr| expr.output_name()).join(", "); + format!("coalesce({})", exprs_str) + } + ScalarExpression::CaseWhen { + operand_expr, + expr_pairs, + else_expr, + .. + } => { + let op = |tag: &str, expr: &Option>| { + expr.as_ref() + .map(|expr| format!("{}{} ", tag, expr.output_name())) + .unwrap_or_default() + }; + let expr_pairs_str = expr_pairs + .iter() + .map(|(when_expr, then_expr)| format!("when {} then {}", when_expr, then_expr)) + .join(" "); + + format!( + "case {}{} {}end", + op("", operand_expr), + expr_pairs_str, + op("else ", else_expr) + ) + } } } diff --git a/src/expression/simplify.rs b/src/expression/simplify.rs index 83f3d7c7..680c3277 100644 --- a/src/expression/simplify.rs +++ b/src/expression/simplify.rs @@ -428,7 +428,8 @@ impl ScalarExpression { } ScalarExpression::AggCall { args, .. } | ScalarExpression::Tuple(args) - | ScalarExpression::Function(ScalarFunction { args, .. }) => args + | ScalarExpression::Function(ScalarFunction { args, .. }) + | ScalarExpression::Coalesce { exprs: args, .. } => args .iter() .any(|expr| expr.exist_column(table_name, col_id)), ScalarExpression::In { expr, args, .. } => { @@ -464,6 +465,50 @@ impl ScalarExpression { } ScalarExpression::Constant(_) => false, ScalarExpression::Reference { .. } | ScalarExpression::Empty => unreachable!(), + ScalarExpression::If { + condition, + left_expr, + right_expr, + .. + } => { + condition.exist_column(table_name, col_id) + || left_expr.exist_column(table_name, col_id) + || right_expr.exist_column(table_name, col_id) + } + ScalarExpression::IfNull { + left_expr, + right_expr, + .. + } + | ScalarExpression::NullIf { + left_expr, + right_expr, + .. + } => { + left_expr.exist_column(table_name, col_id) + || right_expr.exist_column(table_name, col_id) + } + ScalarExpression::CaseWhen { + operand_expr, + expr_pairs, + else_expr, + .. + } => { + matches!( + operand_expr + .as_ref() + .map(|expr| expr.exist_column(table_name, col_id)), + Some(true) + ) || expr_pairs.iter().any(|(expr_1, expr_2)| { + expr_1.exist_column(table_name, col_id) + || expr_2.exist_column(table_name, col_id) + }) || matches!( + else_expr + .as_ref() + .map(|expr| expr.exist_column(table_name, col_id)), + Some(true) + ) + } } } @@ -995,7 +1040,12 @@ impl ScalarExpression { | ScalarExpression::In { .. } | ScalarExpression::Between { .. } | ScalarExpression::SubString { .. } - | ScalarExpression::Function(_) => expr.convert_binary(table_name, id), + | ScalarExpression::Function(_) + | ScalarExpression::If { .. } + | ScalarExpression::IfNull { .. } + | ScalarExpression::NullIf { .. } + | ScalarExpression::Coalesce { .. } + | ScalarExpression::CaseWhen { .. } => expr.convert_binary(table_name, id), ScalarExpression::Tuple(_) | ScalarExpression::Reference { .. } | ScalarExpression::Empty => unreachable!(), @@ -1004,7 +1054,12 @@ impl ScalarExpression { // FIXME: support `convert_binary` ScalarExpression::Tuple(_) | ScalarExpression::AggCall { .. } - | ScalarExpression::Function(_) => Ok(None), + | ScalarExpression::Function(_) + | ScalarExpression::If { .. } + | ScalarExpression::IfNull { .. } + | ScalarExpression::NullIf { .. } + | ScalarExpression::Coalesce { .. } + | ScalarExpression::CaseWhen { .. } => Ok(None), ScalarExpression::Reference { .. } | ScalarExpression::Empty => unreachable!(), } } diff --git a/src/expression/value_compute.rs b/src/expression/value_compute.rs index 6037aed1..0f29e3a9 100644 --- a/src/expression/value_compute.rs +++ b/src/expression/value_compute.rs @@ -120,7 +120,6 @@ macro_rules! numeric_binary_compute { BinaryOperator::Eq => { let value = match ($left.cast($unified_type)?, $right.cast($unified_type)?) { ($compute_type(Some(v1)), $compute_type(Some(v2))) => Some(v1 == v2), - ($compute_type(None), $compute_type(None)) => Some(true), (_, _) => None, }; @@ -129,7 +128,6 @@ macro_rules! numeric_binary_compute { BinaryOperator::NotEq => { let value = match ($left.cast($unified_type)?, $right.cast($unified_type)?) { ($compute_type(Some(v1)), $compute_type(Some(v2))) => Some(v1 != v2), - ($compute_type(None), $compute_type(None)) => Some(false), (_, _) => None, }; @@ -443,7 +441,6 @@ impl DataValue { (DataValue::Decimal(Some(v1)), DataValue::Decimal(Some(v2))) => { Some(v1 == v2) } - (DataValue::Decimal(None), DataValue::Decimal(None)) => Some(true), (_, _) => None, }; @@ -454,7 +451,6 @@ impl DataValue { (DataValue::Decimal(Some(v1)), DataValue::Decimal(Some(v2))) => { Some(v1 != v2) } - (DataValue::Decimal(None), DataValue::Decimal(None)) => Some(false), (_, _) => None, }; @@ -486,6 +482,22 @@ impl DataValue { DataValue::Boolean(value) } + BinaryOperator::Eq => { + let value = match (left_value, right_value) { + (Some(v1), Some(v2)) => Some(v1 == v2), + (_, _) => None, + }; + + DataValue::Boolean(value) + } + BinaryOperator::NotEq => { + let value = match (left_value, right_value) { + (Some(v1), Some(v2)) => Some(v1 != v2), + (_, _) => None, + }; + + DataValue::Boolean(value) + } _ => return Err(DatabaseError::UnsupportedBinaryOperator(unified_type, *op)), } } @@ -533,7 +545,6 @@ impl DataValue { BinaryOperator::Eq => { let value = match (left_value, right_value) { (Some(v1), Some(v2)) => Some(v1 == v2), - (None, None) => Some(true), (_, _) => None, }; @@ -542,7 +553,6 @@ impl DataValue { BinaryOperator::NotEq => { let value = match (left_value, right_value) { (Some(v1), Some(v2)) => Some(v1 != v2), - (None, None) => Some(false), (_, _) => None, }; @@ -569,7 +579,6 @@ impl DataValue { BinaryOperator::Eq => { let value = match (left_value, right_value) { (Some(v1), Some(v2)) => Some(v1 == v2), - (None, None) => Some(true), (_, _) => None, }; @@ -578,7 +587,6 @@ impl DataValue { BinaryOperator::NotEq => { let value = match (left_value, right_value) { (Some(v1), Some(v2)) => Some(v1 != v2), - (None, None) => Some(false), (_, _) => None, }; @@ -1069,7 +1077,7 @@ mod test { &DataValue::Int32(None), &BinaryOperator::Eq )?, - DataValue::Boolean(Some(true)) + DataValue::Boolean(None) ); Ok(()) @@ -1181,7 +1189,7 @@ mod test { &DataValue::Int64(None), &BinaryOperator::Eq )?, - DataValue::Boolean(Some(true)) + DataValue::Boolean(None) ); Ok(()) @@ -1293,7 +1301,7 @@ mod test { &DataValue::Float64(None), &BinaryOperator::Eq )?, - DataValue::Boolean(Some(true)) + DataValue::Boolean(None) ); Ok(()) @@ -1405,7 +1413,7 @@ mod test { &DataValue::Float32(None), &BinaryOperator::Eq )?, - DataValue::Boolean(Some(true)) + DataValue::Boolean(None) ); Ok(()) diff --git a/src/marcos/mod.rs b/src/marcos/mod.rs index 20d7b5cf..ca9b3447 100644 --- a/src/marcos/mod.rs +++ b/src/marcos/mod.rs @@ -69,13 +69,13 @@ macro_rules! implement_from_tuple { macro_rules! function { ($struct_name:ident::$function_name:ident($($arg_ty:expr),*) -> $return_ty:expr => $closure:expr) => { #[derive(Debug)] - struct $struct_name { + pub(crate) struct $struct_name { summary: FunctionSummary } impl $struct_name { - fn new() -> Arc { - let function_name = stringify!($function_name); + pub(crate) fn new() -> Arc { + let function_name = stringify!($function_name).to_lowercase(); let mut arg_types = Vec::new(); $({ diff --git a/src/types/value.rs b/src/types/value.rs index f05397a3..8e579b7c 100644 --- a/src/types/value.rs +++ b/src/types/value.rs @@ -581,6 +581,17 @@ impl DataValue { Ok(()) } + pub fn is_true(&self) -> Result { + if self.is_null() { + return Ok(false); + } + if let DataValue::Boolean(option) = self { + Ok(matches!(option, Some(true))) + } else { + Err(DatabaseError::InvalidType) + } + } + pub fn cast(self, to: &LogicalType) -> Result { match self { DataValue::Null => match to { diff --git a/tests/slt/crdb/condition.slt b/tests/slt/crdb/condition.slt new file mode 100644 index 00000000..4b104c91 --- /dev/null +++ b/tests/slt/crdb/condition.slt @@ -0,0 +1,76 @@ +query IT +SELECT IF(1 = 2, NULL, 1), IF(2 = 2, NULL, 2) +---- +1 null + +query ITT +SELECT NULLIF(1, 2), NULLIF(2, 2), NULLIF(NULL, NULL) +---- +1 null null + +query IIII +SELECT IFNULL(1, 2), IFNULL(NULL, 2), COALESCE(1, 2), COALESCE(NULL, 2) +---- +1 2 1 2 + +statement ok +DROP TABLE IF EXISTS t + +statement ok +CREATE TABLE t (a INT PRIMARY KEY) + +statement ok +INSERT INTO t VALUES (1), (2), (3) + +query IT +SELECT a, CASE WHEN a = 1 THEN 'one' WHEN a = 2 THEN 'two' ELSE 'other' END FROM t ORDER BY a +---- +1 one +2 two +3 other + +query IT +SELECT a, CASE a WHEN 1 THEN 'one' WHEN 2 THEN 'two' ELSE 'other' END FROM t ORDER BY a +---- +1 one +2 two +3 other + +query III +SELECT a, NULLIF(a, 2), IF(a = 2, NULL, a) FROM t ORDER BY a +---- +1 1 1 +2 null null +3 3 3 + +query TTTT +SELECT CASE WHEN false THEN 'one' WHEN true THEN 'two' ELSE 'three' END, CASE 1 WHEN 2 THEN 'two' WHEN 1 THEN 'one' ELSE 'three' END, CASE WHEN false THEN 'one' ELSE 'three' END, CASE WHEN false THEN 'one' END +---- +two one three null + +query TTTTT +SELECT CASE WHEN 1 = 1 THEN 'one' END, CASE false WHEN 0 = 1 THEN 'one' END, CASE 1 WHEN 2 THEN 'one' ELSE 'three' END, CASE NULL WHEN true THEN 'one' WHEN false THEN 'two' WHEN NULL THEN 'three' ELSE 'four' END, CASE WHEN false THEN 'one' WHEN true THEN 'two' END +---- +one one three four two + +statement ok +DROP TABLE IF EXISTS tt1; + +statement ok +create table tt1(id int primary key, a boolean, b int); + +statement ok +insert into tt1 values(0, true, 1),(1, false, 2); + +# sqlparser-rs unsupported +# query T +# SELECT CASE WHEN tt1.a THEN last_value([{}, {}, {}]) OVER (PARTITION BY FALSE) WHEN tt1.a THEN [{}] END FROM tt1; +# ---- +# [{},{},{}] +# null + +statement ok +DROP TABLE tt1 + +statement ok +DROP TABLE t \ No newline at end of file diff --git a/tests/slt/dummy.slt b/tests/slt/dummy.slt index 7b0f189d..7367c309 100644 --- a/tests/slt/dummy.slt +++ b/tests/slt/dummy.slt @@ -11,6 +11,11 @@ SELECT 'a' ---- a +query B +SELECT NULL=NULL +---- +null + query B SELECT NOT(1=1) ---- @@ -36,6 +41,121 @@ SELECT NOT(TRUE) ---- false +query I +SELECT IF(TRUE, 1, 2) +---- +1 + +query I +SELECT IF(FALSE, 1, 2) +---- +2 + +query I +SELECT IF(NULL, 1, 2) +---- +2 + +query I +SELECT NULLIF(1, 2) +---- +1 + +query I +SELECT NULLIF(NULL, 2) +---- +null + +query I +SELECT NULLIF(NULL, NULL) +---- +null + +query I +SELECT IFNULL(1, NULL) +---- +1 + +query I +SELECT IFNULL(NULL, 1) +---- +1 + +query I +SELECT IFNULL(NULL, NULL) +---- +null + +query I +SELECT COALESCE(1) +---- +1 + +query I +SELECT COALESCE(1, 2, 3) +---- +1 + +query I +SELECT COALESCE(NULL, 2, 3) +---- +2 + +query I +SELECT COALESCE(NULL, 2, NULL) +---- +2 + +query I +SELECT COALESCE(NULL, NULL) +---- +null + +query I +SELECT COALESCE() +---- +null + +query I +SELECT CASE 1 WHEN 0 THEN 0 WHEN 1 THEN 1 WHEN NULL THEN 9 ELSE 2 END +---- +1 + +query I +SELECT CASE 3 WHEN 0 THEN 0 WHEN 1 THEN 1 WHEN NULL THEN 9 ELSE 2 END +---- +2 + +query I +SELECT CASE 3 WHEN 0 THEN 0 WHEN 1 THEN 1 WHEN NULL THEN 9 END +---- +null + +query I +SELECT CASE 3 WHEN 0 THEN 0 WHEN 1 THEN 1 WHEN NULL THEN 9 ELSE 2 END +---- +2 + +query I +SELECT CASE WHEN TRUE THEN 0 WHEN FALSE THEN 1 WHEN FALSE THEN 9 ELSE 2 END +---- +0 + +query I +SELECT CASE WHEN FALSE THEN 0 WHEN FALSE THEN 1 WHEN FALSE THEN 9 ELSE 2 END +---- +2 + +query I +SELECT CASE WHEN FALSE THEN 0 WHEN FALSE THEN 1 WHEN FALSE THEN 9 END +---- +null + +query I +SELECT CASE WHEN FALSE THEN 0 WHEN 1=1 THEN 1 WHEN FALSE THEN 9 ELSE 2 END +---- +1 + # issue: https://github.com/sqlparser-rs/sqlparser-rs/issues/362 # query T # SELECT 'That\'s good.' diff --git a/tests/slt/filter_null.slt b/tests/slt/filter_null.slt index 1bfd0ad3..01d3bde9 100644 --- a/tests/slt/filter_null.slt +++ b/tests/slt/filter_null.slt @@ -4,14 +4,14 @@ create table t(id int primary key, v1 int, v2 int not null) statement ok insert into t values (0, 2, 4), (1, 1, 3), (2, 3, 4), (3, 4, 3); -query II +query III select * from t where v1 > 1 ---- 0 2 4 2 3 4 3 4 3 -query II +query III select * from t where v1 < 2 ---- 1 1 3 @@ -25,24 +25,28 @@ create table t(id int primary key, v1 int null, v2 int) statement ok insert into t values (0, 2, 4), (1, null, 3), (2, 3, 4), (3, 4, 3) -query II +query III select * from t where v1 > 1 ---- 0 2 4 2 3 4 3 4 3 -query II +query III select * from t where v1 is null ---- 1 null 3 -query II +query III select * from t where v1 is not null ---- 0 2 4 2 3 4 3 4 3 +query III +select * from t where null +---- + statement ok drop table t \ No newline at end of file diff --git a/tests/slt/having.slt b/tests/slt/having.slt index 5454867c..dae1aaff 100644 --- a/tests/slt/having.slt +++ b/tests/slt/having.slt @@ -24,6 +24,7 @@ select x from test group by x having max(y) = 22 ---- 11 +# FIXME # query II # select y + 1 as i from test group by y + 1 having count(x) > 1 and y + 1 = 3 or y + 1 = 23 order by i; # ---- diff --git a/tests/slt/subquery.slt b/tests/slt/subquery.slt index 215cc6bb..5df37dbc 100644 --- a/tests/slt/subquery.slt +++ b/tests/slt/subquery.slt @@ -1,5 +1,3 @@ -# Test subquery - statement ok create table t1(id int primary key, a int not null, b int not null); diff --git a/tests/sqllogictest/src/lib.rs b/tests/sqllogictest/src/lib.rs index 432ef4f2..cd392a8e 100644 --- a/tests/sqllogictest/src/lib.rs +++ b/tests/sqllogictest/src/lib.rs @@ -4,12 +4,12 @@ use fnck_sql::storage::kip::KipStorage; use sqllogictest::{AsyncDB, DBOutput, DefaultColumnType}; use std::time::Instant; -pub struct KipSQL { +pub struct SQLBase { pub db: Database, } #[async_trait::async_trait] -impl AsyncDB for KipSQL { +impl AsyncDB for SQLBase { type Error = DatabaseError; type ColumnType = DefaultColumnType; @@ -17,7 +17,7 @@ impl AsyncDB for KipSQL { let start = Instant::now(); let tuples = self.db.run(sql).await?; println!("|— Input SQL: {}", sql); - println!(" |— Time consuming: {:?}", start.elapsed()); + println!(" |— time spent: {:?}", start.elapsed()); if tuples.is_empty() { return Ok(DBOutput::StatementComplete(0)); diff --git a/tests/sqllogictest/src/main.rs b/tests/sqllogictest/src/main.rs index 5b4f8cee..ecb9fb8c 100644 --- a/tests/sqllogictest/src/main.rs +++ b/tests/sqllogictest/src/main.rs @@ -1,10 +1,11 @@ use fnck_sql::db::DataBaseBuilder; use sqllogictest::Runner; -use sqllogictest_test::KipSQL; +use sqllogictest_test::SQLBase; use std::fs::File; use std::io; use std::io::Write; use std::path::Path; +use std::time::Instant; use tempfile::TempDir; #[tokio::main] @@ -16,26 +17,36 @@ async fn main() { println!("FnckSQL Test Start!\n"); init_20000_row_csv().expect("failed to init csv"); + let mut file_num = 0; + let start = Instant::now(); for slt_file in glob::glob(SLT_PATTERN).expect("failed to find slt files") { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); - let filepath = slt_file - .expect("failed to read slt file") - .to_str() - .unwrap() - .to_string(); - println!("-> Now the test file is: {}", filepath); + let filepath = slt_file.expect("failed to read slt file"); + println!( + "-> Now the test file is: {}, num: {}", + filepath.display(), + file_num + ); let db = DataBaseBuilder::path(temp_dir.path()) .build() .await .expect("init db error"); - let mut tester = Runner::new(KipSQL { db }); + let mut tester = Runner::new(SQLBase { db }); if let Err(err) = tester.run_file_async(filepath).await { panic!("test error: {}", err); } - println!("-> Pass!\n\n") + println!("-> Pass!\n"); + file_num += 1; + } + println!("Passed all tests for a total of {} files!!!", file_num + 1); + println!("|- Total time spent: {:?}", start.elapsed()); + if cfg!(debug_assertions) { + println!("|- Debug mode"); + } else { + println!("|- Release mode"); } }