Skip to content

Commit

Permalink
feat: support if()/ifnull()/nullif()/coalesce()/`Case ... Whe…
Browse files Browse the repository at this point in the history
…n ...`
  • Loading branch information
KKould committed Feb 18, 2024
1 parent 3907fcd commit fff8a35
Show file tree
Hide file tree
Showing 19 changed files with 926 additions and 100 deletions.
92 changes: 90 additions & 2 deletions src/binder/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,52 @@ impl<'a, T: Transaction> Binder<'a, T> {
ScalarExpression::Constant(_) | ScalarExpression::ColumnRef { .. } => (),
ScalarExpression::Reference { .. } | ScalarExpression::Empty => unreachable!(),
ScalarExpression::Tuple(args)
| ScalarExpression::Function(ScalarFunction { args, .. }) => {
| ScalarExpression::Function(ScalarFunction { args, .. })
| ScalarExpression::Coalesce { exprs: args, .. } => {
for expr in args {
self.visit_column_agg_expr(expr)?;
}
}
ScalarExpression::If {
condition,
left_expr,
right_expr,
..
} => {
self.visit_column_agg_expr(condition)?;
self.visit_column_agg_expr(left_expr)?;
self.visit_column_agg_expr(right_expr)?;
}
ScalarExpression::IfNull {
left_expr,
right_expr,
..
}
| ScalarExpression::NullIf {
left_expr,
right_expr,
..
} => {
self.visit_column_agg_expr(left_expr)?;
self.visit_column_agg_expr(right_expr)?;
}
ScalarExpression::CaseWhen {
operand_expr,
expr_pairs,
else_expr,
..
} => {
if let Some(expr) = operand_expr {
self.visit_column_agg_expr(expr)?;
}
for (expr_1, expr_2) in expr_pairs {
self.visit_column_agg_expr(expr_1)?;
self.visit_column_agg_expr(expr_2)?;
}
if let Some(expr) = else_expr {
self.visit_column_agg_expr(expr)?;
}
}
}

Ok(())
Expand Down Expand Up @@ -318,12 +359,59 @@ impl<'a, T: Transaction> Binder<'a, T> {
ScalarExpression::Constant(_) => Ok(()),
ScalarExpression::Reference { .. } | ScalarExpression::Empty => unreachable!(),
ScalarExpression::Tuple(args)
| ScalarExpression::Function(ScalarFunction { args, .. }) => {
| ScalarExpression::Function(ScalarFunction { args, .. })
| ScalarExpression::Coalesce { exprs: args, .. } => {
for expr in args {
self.validate_having_orderby(expr)?;
}
Ok(())
}
ScalarExpression::If {
condition,
left_expr,
right_expr,
..
} => {
self.validate_having_orderby(condition)?;
self.validate_having_orderby(left_expr)?;
self.validate_having_orderby(right_expr)?;

Ok(())
}
ScalarExpression::IfNull {
left_expr,
right_expr,
..
}
| ScalarExpression::NullIf {
left_expr,
right_expr,
..
} => {
self.validate_having_orderby(left_expr)?;
self.validate_having_orderby(right_expr)?;

Ok(())
}
ScalarExpression::CaseWhen {
operand_expr,
expr_pairs,
else_expr,
..
} => {
if let Some(expr) = operand_expr {
self.validate_having_orderby(expr)?;
}
for (expr_1, expr_2) in expr_pairs {
self.validate_having_orderby(expr_1)?;
self.validate_having_orderby(expr_2)?;
}
if let Some(expr) = else_expr {
self.validate_having_orderby(expr)?;
}

Ok(())
}
}
}
}
140 changes: 139 additions & 1 deletion src/binder/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,44 @@ impl<'a, T: Transaction> Binder<'a, T> {
}
Ok(ScalarExpression::Tuple(bond_exprs))
}
Expr::Case {
operand,
conditions,
results,
else_result,
} => {
let mut operand_expr = None;
let mut ty = LogicalType::SqlNull;
if let Some(expr) = operand {
operand_expr = Some(Box::new(self.bind_expr(expr)?));
}
let mut expr_pairs = Vec::with_capacity(conditions.len());
for i in 0..conditions.len() {
let result = self.bind_expr(&results[i])?;
let result_ty = result.return_type();

if result_ty != LogicalType::SqlNull {
if ty == LogicalType::SqlNull {
ty = result_ty;
} else if ty != result_ty {
return Err(DatabaseError::Incomparable(ty, result_ty));
}
}
expr_pairs.push((self.bind_expr(&conditions[i])?, result))
}

let mut else_expr = None;
if let Some(expr) = else_result {
else_expr = Some(Box::new(self.bind_expr(expr)?));
}

Ok(ScalarExpression::CaseWhen {
operand_expr,
expr_pairs,
else_expr,
ty,
})
}
_ => {
todo!()
}
Expand Down Expand Up @@ -272,14 +310,20 @@ impl<'a, T: Transaction> Binder<'a, T> {

match function_name.as_str() {
"count" => {
if args.len() != 1 {
return Err(DatabaseError::MisMatch("number of count() parameters", "1"));
}
return Ok(ScalarExpression::AggCall {
distinct: func.distinct,
kind: AggKind::Count,
args,
ty: LogicalType::Integer,
})
});
}
"sum" => {
if args.len() != 1 {
return Err(DatabaseError::MisMatch("number of sum() parameters", "1"));
}
let ty = args[0].return_type();

return Ok(ScalarExpression::AggCall {
Expand All @@ -290,6 +334,9 @@ impl<'a, T: Transaction> Binder<'a, T> {
});
}
"min" => {
if args.len() != 1 {
return Err(DatabaseError::MisMatch("number of min() parameters", "1"));
}
let ty = args[0].return_type();

return Ok(ScalarExpression::AggCall {
Expand All @@ -300,6 +347,9 @@ impl<'a, T: Transaction> Binder<'a, T> {
});
}
"max" => {
if args.len() != 1 {
return Err(DatabaseError::MisMatch("number of max() parameters", "1"));
}
let ty = args[0].return_type();

return Ok(ScalarExpression::AggCall {
Expand All @@ -310,6 +360,9 @@ impl<'a, T: Transaction> Binder<'a, T> {
});
}
"avg" => {
if args.len() != 1 {
return Err(DatabaseError::MisMatch("number of avg() parameters", "1"));
}
let ty = args[0].return_type();

return Ok(ScalarExpression::AggCall {
Expand All @@ -319,6 +372,77 @@ impl<'a, T: Transaction> Binder<'a, T> {
ty,
});
}
"if" => {
if args.len() != 3 {
return Err(DatabaseError::MisMatch("number of if() parameters", "3"));
}
let ty = Self::return_type(&args[1], &args[2])?;
let right_expr = Box::new(args.pop().unwrap());
let left_expr = Box::new(args.pop().unwrap());
let condition = Box::new(args.pop().unwrap());

return Ok(ScalarExpression::If {
condition,
left_expr,
right_expr,
ty,
});
}
"nullif" => {
if args.len() != 2 {
return Err(DatabaseError::MisMatch(
"number of nullif() parameters",
"3",
));
}
let ty = Self::return_type(&args[0], &args[1])?;
let right_expr = Box::new(args.pop().unwrap());
let left_expr = Box::new(args.pop().unwrap());

return Ok(ScalarExpression::NullIf {
left_expr,
right_expr,
ty,
});
}
"ifnull" => {
if args.len() != 2 {
return Err(DatabaseError::MisMatch(
"number of ifnull() parameters",
"3",
));
}
let ty = Self::return_type(&args[0], &args[1])?;
let right_expr = Box::new(args.pop().unwrap());
let left_expr = Box::new(args.pop().unwrap());

return Ok(ScalarExpression::IfNull {
left_expr,
right_expr,
ty,
});
}
"coalesce" => {
let mut ty = LogicalType::SqlNull;

if !args.is_empty() {
ty = args[0].return_type();

for arg in args.iter() {
let temp_ty = arg.return_type();

if temp_ty == LogicalType::SqlNull {
continue;
}
if ty == LogicalType::SqlNull && temp_ty != LogicalType::SqlNull {
ty = temp_ty;
} else if ty != temp_ty {
return Err(DatabaseError::Incomparable(ty, temp_ty));
}
}
}
return Ok(ScalarExpression::Coalesce { exprs: args, ty });
}
_ => (),
}
let arg_types = args.iter().map(ScalarExpression::return_type).collect_vec();
Expand All @@ -336,6 +460,20 @@ impl<'a, T: Transaction> Binder<'a, T> {
Err(DatabaseError::NotFound("function", summary.name))
}

fn return_type(
expr_1: &ScalarExpression,
expr_2: &ScalarExpression,
) -> Result<LogicalType, DatabaseError> {
let temp_ty_1 = expr_1.return_type();
let temp_ty_2 = expr_2.return_type();

match (temp_ty_1, temp_ty_2) {
(LogicalType::SqlNull, LogicalType::SqlNull) => Ok(LogicalType::SqlNull),
(ty, LogicalType::SqlNull) | (LogicalType::SqlNull, ty) => Ok(ty),
(ty_1, ty_2) => LogicalType::max_logical_type(&ty_1, &ty_2),
}
}

fn bind_is_null(
&mut self,
expr: &Expr,
Expand Down
10 changes: 7 additions & 3 deletions src/binder/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,12 +273,16 @@ impl<'a, T: Transaction> Binder<'a, T> {
columns: alias_column,
}) = alias
{
let table_alias = Arc::new(name.value.to_lowercase());

if tables.len() > 1 {
todo!("Implement virtual tables for multiple table aliases");
}
self.register_alias(alias_column, table_alias.to_string(), tables.remove(0))?;
let table_alias = Arc::new(name.value.to_lowercase());

self.register_alias(
alias_column,
table_alias.to_string(),
tables.pop().unwrap(),
)?;

(Some(table_alias), plan)
} else {
Expand Down
12 changes: 5 additions & 7 deletions src/db.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use ahash::HashMap;
use sqlparser::ast::Statement;
use std::path::PathBuf;
use std::sync::Arc;

Expand Down Expand Up @@ -101,8 +100,7 @@ impl<S: Storage> Database<S> {
/// Run SQL queries.
pub async fn run<T: AsRef<str>>(&self, sql: T) -> Result<Vec<Tuple>, DatabaseError> {
let transaction = self.storage.transaction().await?;
let (plan, _) =
Self::build_plan::<T, S::TransactionType>(sql, &transaction, &self.functions)?;
let plan = Self::build_plan::<T, S::TransactionType>(sql, &transaction, &self.functions)?;

Self::run_volcano(transaction, plan).await
}
Expand Down Expand Up @@ -133,9 +131,9 @@ impl<S: Storage> Database<S> {
sql: V,
transaction: &<S as Storage>::TransactionType,
functions: &Functions,
) -> Result<(LogicalPlan, Statement), DatabaseError> {
) -> Result<LogicalPlan, DatabaseError> {
// parse
let mut stmts = parse_sql(sql)?;
let stmts = parse_sql(sql)?;
if stmts.is_empty() {
return Err(DatabaseError::EmptyStatement);
}
Expand All @@ -154,7 +152,7 @@ impl<S: Storage> Database<S> {
Self::default_optimizer(source_plan).find_best(Some(&transaction.meta_loader()))?;
// println!("best_plan plan: {:#?}", best_plan);

Ok((best_plan, stmts.remove(0)))
Ok(best_plan)
}

pub(crate) fn default_optimizer(source_plan: LogicalPlan) -> HepOptimizer {
Expand Down Expand Up @@ -241,7 +239,7 @@ pub struct DBTransaction<S: Storage> {

impl<S: Storage> DBTransaction<S> {
pub async fn run<T: AsRef<str>>(&mut self, sql: T) -> Result<Vec<Tuple>, DatabaseError> {
let (plan, _) =
let plan =
Database::<S>::build_plan::<T, S::TransactionType>(sql, &self.inner, &self.functions)?;
let mut stream = build_write(plan, &mut self.inner);

Expand Down
Loading

0 comments on commit fff8a35

Please sign in to comment.