diff --git a/.gitignore b/.gitignore index a9fc4802..8f3c9856 100644 --- a/.gitignore +++ b/.gitignore @@ -18,3 +18,4 @@ Cargo.lock /.vscode /.idea /.obsidian +.DS_Store \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index 7cd8c4ba..223881a9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,7 @@ [package] name = "kip-sql" -version = "0.0.1-alpha.0" +version = "0.0.1-alpha.3" edition = "2021" authors = ["Kould ", "Xwg "] description = "build the SQL layer of KipDB database" @@ -37,7 +37,7 @@ ahash = "0.8.3" lazy_static = "1.4.0" comfy-table = "7.0.1" bytes = "1.5.0" -kip_db = "0.1.2-alpha.15" +kip_db = "0.1.2-alpha.16" async-recursion = "1.0.5" rust_decimal = "1" csv = "1" diff --git a/README.md b/README.md index e49e2cb5..fc2947bf 100755 --- a/README.md +++ b/README.md @@ -48,6 +48,33 @@ Storage Support: ![demo](./static/images/demo.png) ### Features +- ORM Mapping +```rust +#[derive(Debug, Clone, Default)] +pub struct Post { + pub post_title: String, + pub post_date: NaiveDateTime, + pub post_body: String, +} + +implement_from_tuple!(Post, ( + post_title: String => |post: &mut Post, value: DataValue| { + if let Some(title) = value.utf8() { + post.post_title = title; + } + }, + post_date: NaiveDateTime => |post: &mut Post, value: DataValue| { + if let Some(date_time) = value.datetime() { + post.post_date = date_time; + } + }, + post_body: String => |post: &mut Post, value: DataValue| { + if let Some(body) = value.utf8() { + post.post_body = body; + } + } +)); +``` - SQL field options - not null - null @@ -119,4 +146,4 @@ Storage Support: ### Thanks For - [Fedomn/sqlrs](https://github.com/Fedomn/sqlrs): 主要参考资料,Optimizer、Executor均参考自sqlrs的设计 -- [systemxlabs/tinysql](https://github.com/systemxlabs/tinysql) +- [systemxlabs/bustubx](https://github.com/systemxlabs/bustubx) diff --git a/rust-toolchain b/rust-toolchain index 63c8d19a..3accb08b 100644 --- a/rust-toolchain +++ b/rust-toolchain @@ -1 +1 @@ -nightly-2023-04-07 \ No newline at end of file +nightly-2023-09-29 \ No newline at end of file diff --git a/src/binder/aggregate.rs b/src/binder/aggregate.rs index 36e51f77..7b21de5a 100644 --- a/src/binder/aggregate.rs +++ b/src/binder/aggregate.rs @@ -1,17 +1,15 @@ -use std::collections::HashSet; use ahash::RandomState; use itertools::Itertools; use sqlparser::ast::{Expr, OrderByExpr}; +use std::collections::HashSet; -use crate::{ - expression::ScalarExpression, - planner::{ - operator::{aggregate::AggregateOperator, sort::SortField}, - }, -}; use crate::binder::{BindError, InputRefType}; use crate::planner::LogicalPlan; use crate::storage::Storage; +use crate::{ + expression::ScalarExpression, + planner::operator::{aggregate::AggregateOperator, sort::SortField}, +}; use super::Binder; @@ -40,7 +38,8 @@ impl Binder { select_list: &mut [ScalarExpression], groupby: &[Expr], ) -> Result<(), BindError> { - self.validate_groupby_illegal_column(select_list, groupby).await?; + self.validate_groupby_illegal_column(select_list, groupby) + .await?; for gb in groupby { let mut expr = self.bind_expr(gb).await?; @@ -89,7 +88,11 @@ impl Binder { Ok((return_having, return_orderby)) } - fn visit_column_agg_expr(&mut self, expr: &mut ScalarExpression, is_select: bool) -> Result<(), BindError> { + fn visit_column_agg_expr( + &mut self, + expr: &mut ScalarExpression, + is_select: bool, + ) -> Result<(), BindError> { match expr { ScalarExpression::AggCall { ty: return_type, .. @@ -97,16 +100,13 @@ impl Binder { let ty = return_type.clone(); if is_select { let index = self.context.input_ref_index(InputRefType::AggCall); - let input_ref = ScalarExpression::InputRef { - index, - ty, - }; + let input_ref = ScalarExpression::InputRef { index, ty }; match std::mem::replace(expr, input_ref) { ScalarExpression::AggCall { kind, args, ty, - distinct + distinct, } => { self.context.agg_calls.push(ScalarExpression::AggCall { distinct, @@ -125,14 +125,13 @@ impl Binder { .find_position(|agg_expr| agg_expr == &expr) .ok_or_else(|| BindError::AggMiss(format!("{:?}", expr)))?; - let _ = std::mem::replace(expr, ScalarExpression::InputRef { - index, - ty, - }); + let _ = std::mem::replace(expr, ScalarExpression::InputRef { index, ty }); } } - ScalarExpression::TypeCast { expr, .. } => self.visit_column_agg_expr(expr, is_select)?, + ScalarExpression::TypeCast { expr, .. } => { + self.visit_column_agg_expr(expr, is_select)? + } ScalarExpression::IsNull { expr } => self.visit_column_agg_expr(expr, is_select)?, ScalarExpression::Unary { expr, .. } => self.visit_column_agg_expr(expr, is_select)?, ScalarExpression::Alias { expr, .. } => self.visit_column_agg_expr(expr, is_select)?, @@ -185,7 +184,8 @@ impl Binder { group_raw_exprs.push(expr); } } - let mut group_raw_set: HashSet<&ScalarExpression, RandomState> = HashSet::from_iter(group_raw_exprs.iter()); + let mut group_raw_set: HashSet<&ScalarExpression, RandomState> = + HashSet::from_iter(group_raw_exprs.iter()); for expr in select_items { if expr.has_agg_call(&self.context) { @@ -195,19 +195,17 @@ impl Binder { group_raw_set.remove(expr); if !group_raw_exprs.iter().contains(expr) { - return Err(BindError::AggMiss( - format!( - "{:?} must appear in the GROUP BY clause or be used in an aggregate function", - expr - ) - )); + return Err(BindError::AggMiss(format!( + "{:?} must appear in the GROUP BY clause or be used in an aggregate function", + expr + ))); } } if !group_raw_set.is_empty() { - return Err(BindError::AggMiss( - format!("In the GROUP BY clause the field must be in the select clause") - )); + return Err(BindError::AggMiss(format!( + "In the GROUP BY clause the field must be in the select clause" + ))); } Ok(()) diff --git a/src/binder/create_table.rs b/src/binder/create_table.rs index 179ea77c..61bcdd7b 100644 --- a/src/binder/create_table.rs +++ b/src/binder/create_table.rs @@ -1,14 +1,14 @@ -use std::collections::HashSet; -use std::sync::Arc; use itertools::Itertools; use sqlparser::ast::{ColumnDef, ObjectName, TableConstraint}; +use std::collections::HashSet; +use std::sync::Arc; use super::Binder; -use crate::binder::{BindError, lower_case_name, split_name}; +use crate::binder::{lower_case_name, split_name, BindError}; use crate::catalog::ColumnCatalog; -use crate::planner::LogicalPlan; use crate::planner::operator::create_table::CreateTableOperator; use crate::planner::operator::Operator; +use crate::planner::LogicalPlan; use crate::storage::Storage; impl Binder { @@ -36,24 +36,19 @@ impl Binder { .map(|col| ColumnCatalog::from(col.clone())) .collect_vec(); - let primary_key_count = columns - .iter() - .filter(|col| col.desc.is_primary) - .count(); + let primary_key_count = columns.iter().filter(|col| col.desc.is_primary).count(); if primary_key_count != 1 { return Err(BindError::InvalidTable( - "The primary key field must exist and have at least one".to_string() + "The primary key field must exist and have at least one".to_string(), )); } let plan = LogicalPlan { - operator: Operator::CreateTable( - CreateTableOperator { - table_name, - columns - } - ), + operator: Operator::CreateTable(CreateTableOperator { + table_name, + columns, + }), childrens: vec![], }; Ok(plan) @@ -62,12 +57,12 @@ impl Binder { #[cfg(test)] mod tests { - use tempfile::TempDir; use super::*; use crate::binder::BinderContext; use crate::catalog::ColumnDesc; use crate::storage::kip::KipStorage; use crate::types::LogicalType; + use tempfile::TempDir; #[tokio::test] async fn test_create_bind() { @@ -84,13 +79,18 @@ mod tests { assert_eq!(op.table_name, Arc::new("t1".to_string())); assert_eq!(op.columns[0].name, "id".to_string()); assert_eq!(op.columns[0].nullable, false); - assert_eq!(op.columns[0].desc, ColumnDesc::new(LogicalType::Integer, true, false)); + assert_eq!( + op.columns[0].desc, + ColumnDesc::new(LogicalType::Integer, true, false) + ); assert_eq!(op.columns[1].name, "name".to_string()); assert_eq!(op.columns[1].nullable, true); - assert_eq!(op.columns[1].desc, ColumnDesc::new(LogicalType::Varchar(Some(10)), false, false)); + assert_eq!( + op.columns[1].desc, + ColumnDesc::new(LogicalType::Varchar(Some(10)), false, false) + ); } - _ => unreachable!() + _ => unreachable!(), } - } } diff --git a/src/binder/delete.rs b/src/binder/delete.rs index 9de4b5d0..247712d9 100644 --- a/src/binder/delete.rs +++ b/src/binder/delete.rs @@ -1,9 +1,9 @@ -use sqlparser::ast::{Expr, TableFactor, TableWithJoins}; -use crate::binder::{Binder, BindError, lower_case_name, split_name}; -use crate::planner::LogicalPlan; +use crate::binder::{lower_case_name, split_name, BindError, Binder}; use crate::planner::operator::delete::DeleteOperator; use crate::planner::operator::Operator; +use crate::planner::LogicalPlan; use crate::storage::Storage; +use sqlparser::ast::{Expr, TableFactor, TableWithJoins}; impl Binder { pub(crate) async fn bind_delete( @@ -21,15 +21,11 @@ impl Binder { } Ok(LogicalPlan { - operator: Operator::Delete( - DeleteOperator { - table_name - } - ), + operator: Operator::Delete(DeleteOperator { table_name }), childrens: vec![plan], }) } else { unreachable!("only table") } } -} \ No newline at end of file +} diff --git a/src/binder/distinct.rs b/src/binder/distinct.rs index 33f95f20..fc0e1d59 100644 --- a/src/binder/distinct.rs +++ b/src/binder/distinct.rs @@ -1,7 +1,7 @@ use crate::binder::Binder; use crate::expression::ScalarExpression; -use crate::planner::LogicalPlan; use crate::planner::operator::aggregate::AggregateOperator; +use crate::planner::LogicalPlan; use crate::storage::Storage; impl Binder { @@ -12,4 +12,4 @@ impl Binder { ) -> LogicalPlan { AggregateOperator::new(children, vec![], select_list) } -} \ No newline at end of file +} diff --git a/src/binder/drop_table.rs b/src/binder/drop_table.rs index a88b2d9f..45e17b2a 100644 --- a/src/binder/drop_table.rs +++ b/src/binder/drop_table.rs @@ -1,28 +1,21 @@ -use std::sync::Arc; -use sqlparser::ast::ObjectName; -use crate::binder::{Binder, BindError, lower_case_name, split_name}; -use crate::planner::LogicalPlan; +use crate::binder::{lower_case_name, split_name, BindError, Binder}; use crate::planner::operator::drop_table::DropTableOperator; use crate::planner::operator::Operator; +use crate::planner::LogicalPlan; use crate::storage::Storage; +use sqlparser::ast::ObjectName; +use std::sync::Arc; impl Binder { - pub(crate) fn bind_drop_table( - &mut self, - name: &ObjectName - ) -> Result { + pub(crate) fn bind_drop_table(&mut self, name: &ObjectName) -> Result { let name = lower_case_name(&name); let (_, name) = split_name(&name)?; let table_name = Arc::new(name.to_string()); let plan = LogicalPlan { - operator: Operator::DropTable( - DropTableOperator { - table_name - } - ), + operator: Operator::DropTable(DropTableOperator { table_name }), childrens: vec![], }; Ok(plan) } -} \ No newline at end of file +} diff --git a/src/binder/expr.rs b/src/binder/expr.rs index 441eea74..da845f2b 100644 --- a/src/binder/expr.rs +++ b/src/binder/expr.rs @@ -1,28 +1,31 @@ use crate::binder::BindError; +use crate::expression::agg::AggKind; +use async_recursion::async_recursion; use itertools::Itertools; -use sqlparser::ast::{BinaryOperator, Expr, Function, FunctionArg, FunctionArgExpr, Ident, UnaryOperator}; +use sqlparser::ast::{ + BinaryOperator, Expr, Function, FunctionArg, FunctionArgExpr, Ident, UnaryOperator, +}; use std::slice; use std::sync::Arc; -use async_recursion::async_recursion; -use crate::expression::agg::AggKind; use super::Binder; use crate::expression::ScalarExpression; use crate::storage::Storage; -use crate::types::LogicalType; use crate::types::value::DataValue; +use crate::types::LogicalType; impl Binder { #[async_recursion] pub(crate) async fn bind_expr(&mut self, expr: &Expr) -> Result { match expr { Expr::Identifier(ident) => { - self.bind_column_ref_from_identifiers(slice::from_ref(ident), None).await + self.bind_column_ref_from_identifiers(slice::from_ref(ident), None) + .await } Expr::CompoundIdentifier(idents) => { self.bind_column_ref_from_identifiers(idents, None).await } - Expr::BinaryOp { left, right, op} => { + Expr::BinaryOp { left, right, op } => { self.bind_binary_op_internal(left, right, op).await } Expr::Value(v) => Ok(ScalarExpression::Constant(Arc::new(v.into()))), @@ -85,7 +88,10 @@ impl Binder { } if got_column.is_none() { if let Some(expr) = self.context.aliases.get(column_name) { - return Ok(ScalarExpression::Alias { expr: Box::new(expr.clone()), alias: column_name.clone() }); + return Ok(ScalarExpression::Alias { + expr: Box::new(expr.clone()), + alias: column_name.clone(), + }); } } let column_catalog = @@ -104,19 +110,23 @@ impl Binder { let right_expr = Box::new(self.bind_expr(right).await?); let ty = match op { - BinaryOperator::Plus | BinaryOperator::Minus | BinaryOperator::Multiply | - BinaryOperator::Divide | BinaryOperator::Modulo => { - LogicalType::max_logical_type( - &left_expr.return_type(), - &right_expr.return_type() - )? + BinaryOperator::Plus + | BinaryOperator::Minus + | BinaryOperator::Multiply + | BinaryOperator::Divide + | BinaryOperator::Modulo => { + LogicalType::max_logical_type(&left_expr.return_type(), &right_expr.return_type())? } - BinaryOperator::Gt | BinaryOperator::Lt | BinaryOperator::GtEq | - BinaryOperator::LtEq | BinaryOperator::Eq | BinaryOperator::NotEq | - BinaryOperator::And | BinaryOperator::Or | BinaryOperator::Xor => { - LogicalType::Boolean - }, - _ => todo!() + BinaryOperator::Gt + | BinaryOperator::Lt + | BinaryOperator::GtEq + | BinaryOperator::LtEq + | BinaryOperator::Eq + | BinaryOperator::NotEq + | BinaryOperator::And + | BinaryOperator::Or + | BinaryOperator::Xor => LogicalType::Boolean, + _ => todo!(), }; Ok(ScalarExpression::Binary { @@ -157,37 +167,37 @@ impl Binder { match arg_expr { FunctionArgExpr::Expr(expr) => args.push(self.bind_expr(expr).await?), FunctionArgExpr::Wildcard => args.push(Self::wildcard_expr()), - _ => todo!() + _ => todo!(), } } let ty = args[0].return_type(); Ok(match func.name.to_string().to_lowercase().as_str() { - "count" => ScalarExpression::AggCall{ + "count" => ScalarExpression::AggCall { distinct: func.distinct, kind: AggKind::Count, args, ty: LogicalType::Integer, }, - "sum" => ScalarExpression::AggCall{ + "sum" => ScalarExpression::AggCall { distinct: func.distinct, kind: AggKind::Sum, args, ty, }, - "min" => ScalarExpression::AggCall{ + "min" => ScalarExpression::AggCall { distinct: func.distinct, kind: AggKind::Min, args, ty, }, - "max" => ScalarExpression::AggCall{ + "max" => ScalarExpression::AggCall { distinct: func.distinct, kind: AggKind::Max, args, ty, }, - "avg" => ScalarExpression::AggCall{ + "avg" => ScalarExpression::AggCall { distinct: func.distinct, kind: AggKind::Avg, args, diff --git a/src/binder/insert.rs b/src/binder/insert.rs index 4b82755c..cdf3ee14 100644 --- a/src/binder/insert.rs +++ b/src/binder/insert.rs @@ -1,16 +1,16 @@ -use std::slice; -use std::sync::Arc; -use sqlparser::ast::{Expr, Ident, ObjectName}; -use crate::binder::{Binder, BindError, lower_case_name, split_name}; +use crate::binder::{lower_case_name, split_name, BindError, Binder}; use crate::catalog::ColumnRef; -use crate::expression::ScalarExpression; use crate::expression::value_compute::unary_op; -use crate::planner::LogicalPlan; +use crate::expression::ScalarExpression; use crate::planner::operator::insert::InsertOperator; -use crate::planner::operator::Operator; use crate::planner::operator::values::ValuesOperator; +use crate::planner::operator::Operator; +use crate::planner::LogicalPlan; use crate::storage::Storage; use crate::types::value::{DataValue, ValueRef}; +use sqlparser::ast::{Expr, Ident, ObjectName}; +use std::slice; +use std::sync::Arc; impl Binder { pub(crate) async fn bind_insert( @@ -18,7 +18,7 @@ impl Binder { name: ObjectName, idents: &[Ident], expr_rows: &Vec>, - is_overwrite: bool + is_overwrite: bool, ) -> Result { let name = lower_case_name(&name); let (_, name) = split_name(&name)?; @@ -32,12 +32,15 @@ impl Binder { } else { let bind_table_name = Some(table_name.to_string()); for ident in idents { - match self.bind_column_ref_from_identifiers( - slice::from_ref(ident), - bind_table_name.as_ref() - ).await? { + match self + .bind_column_ref_from_identifiers( + slice::from_ref(ident), + bind_table_name.as_ref(), + ) + .await? + { ScalarExpression::ColumnRef(catalog) => columns.push(catalog), - _ => unreachable!() + _ => unreachable!(), } } } @@ -51,10 +54,9 @@ impl Binder { ScalarExpression::Constant(value) => { // Check if the value length is too long value.check_len(columns[i].datatype())?; - let cast_value = DataValue::clone(value) - .cast(columns[i].datatype())?; + let cast_value = DataValue::clone(value).cast(columns[i].datatype())?; row.push(Arc::new(cast_value)) - }, + } ScalarExpression::Unary { expr, op, .. } => { if let ScalarExpression::Constant(value) = expr.as_ref() { row.push(Arc::new(unary_op(value, op)?)) @@ -71,30 +73,28 @@ impl Binder { let values_plan = self.bind_values(rows, columns); Ok(LogicalPlan { - operator: Operator::Insert( - InsertOperator { - table_name, - is_overwrite, - } - ), + operator: Operator::Insert(InsertOperator { + table_name, + is_overwrite, + }), childrens: vec![values_plan], }) } else { - Err(BindError::InvalidTable(format!("not found table {}", table_name))) + Err(BindError::InvalidTable(format!( + "not found table {}", + table_name + ))) } } pub(crate) fn bind_values( &mut self, rows: Vec>, - columns: Vec + columns: Vec, ) -> LogicalPlan { LogicalPlan { - operator: Operator::Values(ValuesOperator { - rows, - columns, - }), + operator: Operator::Values(ValuesOperator { rows, columns }), childrens: vec![], } } -} \ No newline at end of file +} diff --git a/src/binder/mod.rs b/src/binder/mod.rs index 45251524..2c23f032 100644 --- a/src/binder/mod.rs +++ b/src/binder/mod.rs @@ -1,23 +1,23 @@ pub mod aggregate; mod create_table; -pub mod expr; -mod select; -mod insert; -mod update; mod delete; -mod drop_table; -mod truncate; mod distinct; +mod drop_table; +pub mod expr; +mod insert; +mod select; mod show; pub mod copy; +mod truncate; +mod update; -use std::collections::BTreeMap; use sqlparser::ast::{Ident, ObjectName, ObjectType, SetExpr, Statement}; +use std::collections::BTreeMap; -use crate::catalog::{DEFAULT_SCHEMA_NAME, CatalogError, TableName, TableCatalog}; +use crate::catalog::{CatalogError, TableCatalog, TableName, DEFAULT_SCHEMA_NAME}; use crate::expression::ScalarExpression; -use crate::planner::LogicalPlan; use crate::planner::operator::join::JoinType; +use crate::planner::LogicalPlan; use crate::storage::Storage; use crate::types::errors::TypeError; @@ -49,12 +49,8 @@ impl BinderContext { // Tips: The order of this index is based on Aggregate being bound first. pub fn input_ref_index(&self, ty: InputRefType) -> usize { match ty { - InputRefType::AggCall => { - self.agg_calls.len() - }, - InputRefType::GroupBy => { - self.agg_calls.len() + self.group_by_exprs.len() - } + InputRefType::AggCall => self.agg_calls.len(), + InputRefType::GroupBy => self.agg_calls.len() + self.group_by_exprs.len(), } } @@ -83,37 +79,47 @@ impl Binder { pub async fn bind(mut self, stmt: &Statement) -> Result { let plan = match stmt { Statement::Query(query) => self.bind_query(query).await?, - Statement::CreateTable { name, columns, constraints, .. } => { - self.bind_create_table(name, &columns, &constraints)? + Statement::CreateTable { + name, + columns, + constraints, + .. + } => self.bind_create_table(name, &columns, &constraints)?, + Statement::Drop { + object_type, names, .. + } => match object_type { + ObjectType::Table => self.bind_drop_table(&names[0])?, + _ => todo!(), }, - Statement::Drop { object_type, names, .. } => { - match object_type { - ObjectType::Table => { - self.bind_drop_table(&names[0])? - } - _ => todo!() - } - } - Statement::Insert { table_name, columns, source, overwrite, .. } => { + Statement::Insert { + table_name, + columns, + source, + overwrite, + .. + } => { if let SetExpr::Values(values) = source.body.as_ref() { - self.bind_insert( - table_name.to_owned(), - columns, - &values.rows, - *overwrite - ).await? + self.bind_insert(table_name.to_owned(), columns, &values.rows, *overwrite) + .await? } else { todo!() } } - Statement::Update { table, selection, assignments, .. } => { + Statement::Update { + table, + selection, + assignments, + .. + } => { if !table.joins.is_empty() { unimplemented!() } else { self.bind_update(table, selection, assignments).await? } } - Statement::Delete { from, selection, .. } => { + Statement::Delete { + from, selection, .. + } => { let table = &from[0]; if !table.joins.is_empty() { @@ -184,35 +190,61 @@ pub enum BindError { #[cfg(test)] pub mod test { - use std::path::PathBuf; - use std::sync::Arc; - use tempfile::TempDir; - use crate::catalog::{ColumnCatalog, ColumnDesc}; - use crate::planner::LogicalPlan; - use crate::types::LogicalType::Integer; use crate::binder::{Binder, BinderContext}; + use crate::catalog::{ColumnCatalog, ColumnDesc}; use crate::execution::ExecutorError; + use crate::planner::LogicalPlan; use crate::storage::kip::KipStorage; use crate::storage::{Storage, StorageError}; + use crate::types::LogicalType::Integer; + use std::path::PathBuf; + use std::sync::Arc; + use tempfile::TempDir; - pub(crate) async fn build_test_catalog(path: impl Into + Send) -> Result { + pub(crate) async fn build_test_catalog( + path: impl Into + Send, + ) -> Result { let storage = KipStorage::new(path).await?; - let _ = storage.create_table( - Arc::new("t1".to_string()), - vec![ - ColumnCatalog::new("c1".to_string(), false, ColumnDesc::new(Integer, true, false), None), - ColumnCatalog::new("c2".to_string(), false, ColumnDesc::new(Integer, false, true), None), - ] - ).await?; - - let _ = storage.create_table( - Arc::new("t2".to_string()), - vec![ - ColumnCatalog::new("c3".to_string(), false, ColumnDesc::new(Integer, true, false), None), - ColumnCatalog::new("c4".to_string(), false, ColumnDesc::new(Integer, false, false), None), - ] - ).await?; + let _ = storage + .create_table( + Arc::new("t1".to_string()), + vec![ + ColumnCatalog::new( + "c1".to_string(), + false, + ColumnDesc::new(Integer, true, false), + None, + ), + ColumnCatalog::new( + "c2".to_string(), + false, + ColumnDesc::new(Integer, false, true), + None, + ), + ], + ) + .await?; + + let _ = storage + .create_table( + Arc::new("t2".to_string()), + vec![ + ColumnCatalog::new( + "c3".to_string(), + false, + ColumnDesc::new(Integer, true, false), + None, + ), + ColumnCatalog::new( + "c4".to_string(), + false, + ColumnDesc::new(Integer, false, false), + None, + ), + ], + ) + .await?; Ok(storage) } diff --git a/src/binder/select.rs b/src/binder/select.rs index 8e37a775..585aadff 100644 --- a/src/binder/select.rs +++ b/src/binder/select.rs @@ -1,7 +1,7 @@ +use async_recursion::async_recursion; use std::borrow::Borrow; use std::collections::HashMap; use std::sync::Arc; -use async_recursion::async_recursion; use crate::{ expression::ScalarExpression, @@ -17,18 +17,23 @@ use crate::{ use super::Binder; -use crate::catalog::{ColumnCatalog, DEFAULT_DATABASE_NAME, DEFAULT_SCHEMA_NAME, TableCatalog, TableName}; -use itertools::Itertools; -use sqlparser::ast; -use sqlparser::ast::{Distinct, Expr, Ident, Join, JoinConstraint, JoinOperator, Offset, OrderByExpr, Query, Select, SelectItem, SetExpr, TableFactor, TableWithJoins}; use crate::binder::BindError; +use crate::catalog::{ + ColumnCatalog, TableCatalog, TableName, DEFAULT_DATABASE_NAME, DEFAULT_SCHEMA_NAME, +}; use crate::execution::executor::dql::join::joins_nullable; use crate::expression::BinaryOperator; -use crate::planner::LogicalPlan; use crate::planner::operator::join::JoinCondition; use crate::planner::operator::sort::{SortField, SortOperator}; +use crate::planner::LogicalPlan; use crate::storage::Storage; use crate::types::LogicalType; +use itertools::Itertools; +use sqlparser::ast; +use sqlparser::ast::{ + Distinct, Expr, Ident, Join, JoinConstraint, JoinOperator, Offset, OrderByExpr, Query, Select, + SelectItem, SetExpr, TableFactor, TableWithJoins, +}; impl Binder { #[async_recursion] @@ -74,13 +79,16 @@ impl Binder { self.extract_select_aggregate(&mut select_list)?; if !select.group_by.is_empty() { - self.extract_group_by_aggregate(&mut select_list, &select.group_by).await?; + self.extract_group_by_aggregate(&mut select_list, &select.group_by) + .await?; } let mut having_orderby = (None, None); if select.having.is_some() || !orderby.is_empty() { - having_orderby = self.extract_having_orderby_aggregate(&select.having, orderby).await?; + having_orderby = self + .extract_having_orderby_aggregate(&select.having, orderby) + .await?; } if !self.context.agg_calls.is_empty() || !self.context.group_by_exprs.is_empty() { @@ -108,7 +116,10 @@ impl Binder { Ok(plan) } - pub(crate) async fn bind_table_ref(&mut self, from: &[TableWithJoins]) -> Result { + pub(crate) async fn bind_table_ref( + &mut self, + from: &[TableWithJoins], + ) -> Result { assert!(from.len() < 2, "not support yet."); if from.is_empty() { return Ok(LogicalPlan { @@ -129,7 +140,11 @@ impl Binder { Ok(plan) } - async fn bind_single_table_ref(&mut self, table: &TableFactor, joint_type: Option) -> Result<(TableName, LogicalPlan), BindError> { + async fn bind_single_table_ref( + &mut self, + table: &TableFactor, + joint_type: Option, + ) -> Result<(TableName, LogicalPlan), BindError> { let plan_with_name = match table { TableFactor::Table { name, alias, .. } => { let obj_name = name @@ -157,7 +172,11 @@ impl Binder { Ok(plan_with_name) } - pub(crate) async fn _bind_single_table_ref(&mut self, joint_type: Option, table: &str) -> Result<(Arc, LogicalPlan), BindError> { + pub(crate) async fn _bind_single_table_ref( + &mut self, + joint_type: Option, + table: &str, + ) -> Result<(Arc, LogicalPlan), BindError> { let table_name = Arc::new(table.to_string()); if self.context.bind_table.contains_key(&table_name) { @@ -171,9 +190,14 @@ impl Binder { .await .ok_or_else(|| BindError::InvalidTable(format!("bind table {}", table)))?; - self.context.bind_table.insert(table_name.clone(), (table_catalog.clone(), joint_type)); + self.context + .bind_table + .insert(table_name.clone(), (table_catalog.clone(), joint_type)); - Ok((table_name.clone(), ScanOperator::new(table_name, &table_catalog))) + Ok(( + table_name.clone(), + ScanOperator::new(table_name, &table_catalog), + )) } /// Normalize select item. @@ -182,7 +206,10 @@ impl Binder { /// - Qualified name with wildcard, e.g. `SELECT t.* FROM t,t1` /// - Scalar expression or aggregate expression, e.g. `SELECT COUNT(*) + 1 AS count FROM t` /// - async fn normalize_select_item(&mut self, items: &[SelectItem]) -> Result, BindError> { + async fn normalize_select_item( + &mut self, + items: &[SelectItem], + ) -> Result, BindError> { let mut select_items = vec![]; for item in items.iter().enumerate() { @@ -213,7 +240,8 @@ impl Binder { async fn bind_all_column_refs(&mut self) -> Result, BindError> { let mut exprs = vec![]; for table_name in self.context.bind_table.keys().cloned() { - let table = self.context + let table = self + .context .storage .table(&table_name) .await @@ -226,7 +254,12 @@ impl Binder { Ok(exprs) } - async fn bind_join(&mut self, left_table: TableName, left: LogicalPlan, join: &Join) -> Result { + async fn bind_join( + &mut self, + left_table: TableName, + left: LogicalPlan, + join: &Join, + ) -> Result { let Join { relation, join_operator, @@ -241,25 +274,30 @@ impl Binder { _ => unimplemented!(), }; - let (right_table, right) = self.bind_single_table_ref(relation, Some(join_type)).await?; + let (right_table, right) = self + .bind_single_table_ref(relation, Some(join_type)) + .await?; - let left_table = self.context.storage + let left_table = self + .context + .storage .table(&left_table) .await .cloned() .ok_or_else(|| BindError::InvalidTable(format!("Left: {} not found", left_table)))?; - let right_table = self.context.storage + let right_table = self + .context + .storage .table(&right_table) .await .cloned() .ok_or_else(|| BindError::InvalidTable(format!("Right: {} not found", right_table)))?; let on = match joint_condition { - Some(constraint) => self.bind_join_constraint( - &left_table, - &right_table, - constraint - ).await?, + Some(constraint) => { + self.bind_join_constraint(&left_table, &right_table, constraint) + .await? + } None => JoinCondition::None, }; @@ -300,11 +338,7 @@ impl Binder { } } - fn bind_sort( - &mut self, - children: LogicalPlan, - sort_fields: Vec, - ) -> LogicalPlan { + fn bind_sort(&mut self, children: LogicalPlan, sort_fields: Vec) -> LogicalPlan { LogicalPlan { operator: Operator::Sort(SortOperator { sort_fields, @@ -328,9 +362,17 @@ impl Binder { ScalarExpression::Constant(dv) => match dv.as_ref() { DataValue::Int32(Some(v)) if *v > 0 => limit = *v as usize, DataValue::Int64(Some(v)) if *v > 0 => limit = *v as usize, - _ => return Err(BindError::InvalidColumn("invalid limit expression.".to_owned())), + _ => { + return Err(BindError::InvalidColumn( + "invalid limit expression.".to_owned(), + )) + } }, - _ => return Err(BindError::InvalidColumn("invalid limit expression.".to_owned())), + _ => { + return Err(BindError::InvalidColumn( + "invalid limit expression.".to_owned(), + )) + } } } @@ -340,9 +382,17 @@ impl Binder { ScalarExpression::Constant(dv) => match dv.as_ref() { DataValue::Int32(Some(v)) if *v > 0 => offset = *v as usize, DataValue::Int64(Some(v)) if *v > 0 => offset = *v as usize, - _ => return Err(BindError::InvalidColumn("invalid limit expression.".to_owned())), + _ => { + return Err(BindError::InvalidColumn( + "invalid limit expression.".to_owned(), + )) + } }, - _ => return Err(BindError::InvalidColumn("invalid limit expression.".to_owned())), + _ => { + return Err(BindError::InvalidColumn( + "invalid limit expression.".to_owned(), + )) + } } } @@ -351,10 +401,7 @@ impl Binder { Ok(LimitOperator::new(offset, limit, children)) } - pub fn extract_select_join( - &mut self, - select_items: &mut [ScalarExpression], - ) { + pub fn extract_select_join(&mut self, select_items: &mut [ScalarExpression]) { let bind_tables = &self.context.bind_table; if bind_tables.len() < 2 { return; @@ -403,7 +450,8 @@ impl Binder { // expression that didn't match equi-join pattern let mut filter = vec![]; - self.extract_join_keys(expr, &mut on_keys, &mut filter, left_table, right_table).await?; + self.extract_join_keys(expr, &mut on_keys, &mut filter, left_table, right_table) + .await?; // combine multiple filter exprs into one BinaryExpr let join_filter = filter @@ -481,14 +529,16 @@ impl Binder { accum_filter, left_schema, right_schema, - ).await?; + ) + .await?; self.extract_join_keys( right, accum, accum_filter, left_schema, right_schema, - ).await?; + ) + .await?; } } _other => { @@ -513,54 +563,29 @@ mod tests { #[tokio::test] async fn test_select_bind() -> Result<(), ExecutorError> { let plan_1 = select_sql_run("select * from t1").await?; - println!( - "just_col:\n {:#?}", - plan_1 - ); + println!("just_col:\n {:#?}", plan_1); let plan_2 = select_sql_run("select t1.c1, t1.c2 from t1").await?; - println!( - "table_with_col:\n {:#?}", - plan_2 - ); + println!("table_with_col:\n {:#?}", plan_2); let plan_3 = select_sql_run("select t1.c1, t1.c2 from t1 where c1 > 2").await?; - println!( - "table_with_col_and_c1_compare_constant:\n {:#?}", - plan_3 - ); + println!("table_with_col_and_c1_compare_constant:\n {:#?}", plan_3); let plan_4 = select_sql_run("select t1.c1, t1.c2 from t1 where c1 > c2").await?; - println!( - "table_with_col_and_c1_compare_c2:\n {:#?}", - plan_4 - ); + println!("table_with_col_and_c1_compare_c2:\n {:#?}", plan_4); let plan_5 = select_sql_run("select avg(t1.c1) from t1").await?; - println!( - "table_with_col_and_c1_avg:\n {:#?}", - plan_5 - ); - let plan_6 = select_sql_run("select t1.c1, t1.c2 from t1 where (t1.c1 - t1.c2) > 1").await?; - println!( - "table_with_col_nested:\n {:#?}", - plan_6 - ); + println!("table_with_col_and_c1_avg:\n {:#?}", plan_5); + let plan_6 = + select_sql_run("select t1.c1, t1.c2 from t1 where (t1.c1 - t1.c2) > 1").await?; + println!("table_with_col_nested:\n {:#?}", plan_6); let plan_7 = select_sql_run("select * from t1 limit 1").await?; - println!( - "limit:\n {:#?}", - plan_7 - ); + println!("limit:\n {:#?}", plan_7); let plan_8 = select_sql_run("select * from t1 offset 2").await?; - println!( - "offset:\n {:#?}", - plan_8 - ); + println!("offset:\n {:#?}", plan_8); - let plan_9 = select_sql_run("select c1, c3 from t1 inner join t2 on c1 = c3 and c1 > 1").await?; - println!( - "join:\n {:#?}", - plan_9 - ); + let plan_9 = + select_sql_run("select c1, c3 from t1 inner join t2 on c1 = c3 and c1 > 1").await?; + println!("join:\n {:#?}", plan_9); Ok(()) } -} \ No newline at end of file +} diff --git a/src/binder/show.rs b/src/binder/show.rs index 04ca6657..31cf4935 100644 --- a/src/binder/show.rs +++ b/src/binder/show.rs @@ -1,19 +1,15 @@ -use crate::binder::{Binder, BindError}; -use crate::planner::LogicalPlan; -use crate::planner::operator::Operator; +use crate::binder::{BindError, Binder}; use crate::planner::operator::show::ShowTablesOperator; +use crate::planner::operator::Operator; +use crate::planner::LogicalPlan; use crate::storage::Storage; impl Binder { - pub(crate) fn bind_show_tables( - &mut self, - ) -> Result { + pub(crate) fn bind_show_tables(&mut self) -> Result { let plan = LogicalPlan { - operator: Operator::Show( - ShowTablesOperator {} - ), + operator: Operator::Show(ShowTablesOperator {}), childrens: vec![], }; Ok(plan) } -} \ No newline at end of file +} diff --git a/src/binder/truncate.rs b/src/binder/truncate.rs index 6de1df72..b670ad0b 100644 --- a/src/binder/truncate.rs +++ b/src/binder/truncate.rs @@ -1,28 +1,24 @@ -use std::sync::Arc; -use sqlparser::ast::ObjectName; -use crate::binder::{Binder, BindError, lower_case_name, split_name}; -use crate::planner::LogicalPlan; -use crate::planner::operator::Operator; +use crate::binder::{lower_case_name, split_name, BindError, Binder}; use crate::planner::operator::truncate::TruncateOperator; +use crate::planner::operator::Operator; +use crate::planner::LogicalPlan; use crate::storage::Storage; +use sqlparser::ast::ObjectName; +use std::sync::Arc; impl Binder { pub(crate) async fn bind_truncate( &mut self, - name: &ObjectName + name: &ObjectName, ) -> Result { let name = lower_case_name(&name); let (_, name) = split_name(&name)?; let table_name = Arc::new(name.to_string()); let plan = LogicalPlan { - operator: Operator::Truncate( - TruncateOperator { - table_name - } - ), + operator: Operator::Truncate(TruncateOperator { table_name }), childrens: vec![], }; Ok(plan) } -} \ No newline at end of file +} diff --git a/src/binder/update.rs b/src/binder/update.rs index ea99b4cc..d6504562 100644 --- a/src/binder/update.rs +++ b/src/binder/update.rs @@ -1,20 +1,20 @@ -use std::slice; -use std::sync::Arc; -use sqlparser::ast::{Assignment, Expr, TableFactor, TableWithJoins}; -use crate::binder::{Binder, BindError, lower_case_name, split_name}; +use crate::binder::{lower_case_name, split_name, BindError, Binder}; use crate::expression::ScalarExpression; -use crate::planner::LogicalPlan; -use crate::planner::operator::Operator; use crate::planner::operator::update::UpdateOperator; +use crate::planner::operator::Operator; +use crate::planner::LogicalPlan; use crate::storage::Storage; use crate::types::value::ValueRef; +use sqlparser::ast::{Assignment, Expr, TableFactor, TableWithJoins}; +use std::slice; +use std::sync::Arc; impl Binder { pub(crate) async fn bind_update( &mut self, to: &TableWithJoins, selection: &Option, - assignments: &[Assignment] + assignments: &[Assignment], ) -> Result { if let TableFactor::Table { name, .. } = &to.relation { let name = lower_case_name(&name); @@ -39,16 +39,19 @@ impl Binder { }?; for ident in &assignment.id { - match self.bind_column_ref_from_identifiers( - slice::from_ref(&ident), - bind_table_name.as_ref() - ).await? { + match self + .bind_column_ref_from_identifiers( + slice::from_ref(&ident), + bind_table_name.as_ref(), + ) + .await? + { ScalarExpression::ColumnRef(catalog) => { value.check_len(catalog.datatype())?; columns.push(catalog); row.push(value.clone()); - }, - _ => unreachable!() + } + _ => unreachable!(), } } } @@ -56,11 +59,7 @@ impl Binder { let values_plan = self.bind_values(vec![row], columns); Ok(LogicalPlan { - operator: Operator::Update( - UpdateOperator { - table_name - } - ), + operator: Operator::Update(UpdateOperator { table_name }), childrens: vec![plan, values_plan], }) } else { diff --git a/src/catalog/column.rs b/src/catalog/column.rs index 95adc35d..c3206160 100644 --- a/src/catalog/column.rs +++ b/src/catalog/column.rs @@ -1,8 +1,8 @@ -use std::sync::Arc; -use serde::{Deserialize, Serialize}; -use sqlparser::ast::{ColumnDef, ColumnOption}; use crate::catalog::TableName; use crate::expression::ScalarExpression; +use serde::{Deserialize, Serialize}; +use sqlparser::ast::{ColumnDef, ColumnOption}; +use std::sync::Arc; use crate::types::{ColumnId, LogicalType}; @@ -23,7 +23,7 @@ impl ColumnCatalog { column_name: String, nullable: bool, column_desc: ColumnDesc, - ref_expr: Option + ref_expr: Option, ) -> ColumnCatalog { ColumnCatalog { id: None, @@ -35,7 +35,7 @@ impl ColumnCatalog { } } - pub(crate) fn new_dummy(column_name: String)-> ColumnCatalog { + pub(crate) fn new_dummy(column_name: String) -> ColumnCatalog { ColumnCatalog { id: Some(0), name: column_name, @@ -61,7 +61,7 @@ impl From for ColumnCatalog { let mut column_desc = ColumnDesc::new( LogicalType::try_from(column_def.data_type).unwrap(), false, - false + false, ); let mut nullable = false; @@ -79,8 +79,8 @@ impl From for ColumnCatalog { } else { column_desc.is_unique = true; } - }, - _ => todo!() + } + _ => todo!(), } } @@ -97,7 +97,11 @@ pub struct ColumnDesc { } impl ColumnDesc { - pub(crate) const fn new(column_datatype: LogicalType, is_primary: bool, is_unique: bool) -> ColumnDesc { + pub(crate) const fn new( + column_datatype: LogicalType, + is_primary: bool, + is_unique: bool, + ) -> ColumnDesc { ColumnDesc { column_datatype, is_primary, diff --git a/src/catalog/root.rs b/src/catalog/root.rs index d1738591..9112dfb3 100644 --- a/src/catalog/root.rs +++ b/src/catalog/root.rs @@ -33,21 +33,16 @@ impl RootCatalog { if self.table_idxs.contains_key(&table_name) { return Err(CatalogError::Duplicated("column", table_name.to_string())); } - let table = TableCatalog::new( - table_name.clone(), - columns - )?; + let table = TableCatalog::new(table_name.clone(), columns)?; self.table_idxs.insert(table_name.clone(), table); Ok(table_name) } - pub(crate) fn drop_table( - &mut self, - table_name: &String, - ) -> Result<(), CatalogError> { - self.table_idxs.retain(|name, _| name.as_str() != table_name); + pub(crate) fn drop_table(&mut self, table_name: &String) -> Result<(), CatalogError> { + self.table_idxs + .retain(|name, _| name.as_str() != table_name); Ok(()) } @@ -55,10 +50,10 @@ impl RootCatalog { #[cfg(test)] mod tests { - use std::sync::Arc; use super::*; use crate::catalog::{ColumnCatalog, ColumnDesc}; use crate::types::LogicalType; + use std::sync::Arc; #[test] fn test_root_catalog() { @@ -68,13 +63,13 @@ mod tests { "a".to_string(), false, ColumnDesc::new(LogicalType::Integer, false, false), - None + None, ); let col1 = ColumnCatalog::new( "b".to_string(), false, ColumnDesc::new(LogicalType::Boolean, false, false), - None + None, ); let col_catalogs = vec![col0, col1]; diff --git a/src/catalog/table.rs b/src/catalog/table.rs index 391496f6..1d075791 100644 --- a/src/catalog/table.rs +++ b/src/catalog/table.rs @@ -2,8 +2,8 @@ use std::collections::BTreeMap; use std::sync::Arc; use crate::catalog::{CatalogError, ColumnCatalog, ColumnRef}; -use crate::types::ColumnId; use crate::types::index::{IndexMeta, IndexMetaRef}; +use crate::types::ColumnId; pub type TableName = Arc; @@ -13,7 +13,7 @@ pub struct TableCatalog { /// Mapping from column names to column ids column_idxs: BTreeMap, pub(crate) columns: BTreeMap, - pub indexes: Vec + pub indexes: Vec, } impl TableCatalog { @@ -43,9 +43,7 @@ impl TableCatalog { } pub(crate) fn all_columns_with_id(&self) -> Vec<(&ColumnId, &ColumnRef)> { - self.columns - .iter() - .collect() + self.columns.iter().collect() } pub(crate) fn all_columns(&self) -> Vec { @@ -56,10 +54,7 @@ impl TableCatalog { } /// Add a column to the table catalog. - pub(crate) fn add_column( - &mut self, - mut col: ColumnCatalog, - ) -> Result { + pub(crate) fn add_column(&mut self, mut col: ColumnCatalog) -> Result { if self.column_idxs.contains_key(&col.name) { return Err(CatalogError::Duplicated("column", col.name.clone())); } @@ -85,7 +80,7 @@ impl TableCatalog { pub(crate) fn new( name: TableName, - columns: Vec + columns: Vec, ) -> Result { let mut table_catalog = TableCatalog { name, @@ -104,7 +99,7 @@ impl TableCatalog { pub(crate) fn new_with_indexes( name: TableName, columns: Vec, - indexes: Vec + indexes: Vec, ) -> Result { let mut catalog = TableCatalog::new(name, columns)?; catalog.indexes = indexes; @@ -125,8 +120,18 @@ mod tests { // | 1 | true | // | 2 | false | fn test_table_catalog() { - let col0 = ColumnCatalog::new("a".into(), false, ColumnDesc::new(LogicalType::Integer, false, false), None); - let col1 = ColumnCatalog::new("b".into(), false, ColumnDesc::new(LogicalType::Boolean, false, false), None); + let col0 = ColumnCatalog::new( + "a".into(), + false, + ColumnDesc::new(LogicalType::Integer, false, false), + None, + ); + let col1 = ColumnCatalog::new( + "b".into(), + false, + ColumnDesc::new(LogicalType::Boolean, false, false), + None, + ); let col_catalogs = vec![col0, col1]; let table_catalog = TableCatalog::new(Arc::new("test".to_string()), col_catalogs).unwrap(); @@ -134,8 +139,12 @@ mod tests { assert_eq!(table_catalog.contains_column(&"b".to_string()), true); assert_eq!(table_catalog.contains_column(&"c".to_string()), false); - let col_a_id = table_catalog.get_column_id_by_name(&"a".to_string()).unwrap(); - let col_b_id = table_catalog.get_column_id_by_name(&"b".to_string()).unwrap(); + let col_a_id = table_catalog + .get_column_id_by_name(&"a".to_string()) + .unwrap(); + let col_b_id = table_catalog + .get_column_id_by_name(&"b".to_string()) + .unwrap(); assert!(col_a_id < col_b_id); let column_catalog = table_catalog.get_column_by_id(&col_a_id).unwrap(); diff --git a/src/db.rs b/src/db.rs index e31e2660..1ee0d0cd 100644 --- a/src/db.rs +++ b/src/db.rs @@ -1,18 +1,18 @@ -use std::path::PathBuf; use sqlparser::parser::ParserError; +use std::path::PathBuf; use crate::binder::{BindError, Binder, BinderContext}; -use crate::execution::ExecutorError; use crate::execution::executor::{build, try_collect}; +use crate::execution::ExecutorError; use crate::optimizer::heuristic::batch::HepBatchStrategy; use crate::optimizer::heuristic::optimizer::HepOptimizer; -use crate::optimizer::OptimizerError; use crate::optimizer::rule::RuleImpl; +use crate::optimizer::OptimizerError; use crate::parser::parse_sql; use crate::planner::LogicalPlan; -use crate::storage::{Storage, StorageError}; use crate::storage::kip::KipStorage; use crate::storage::memory::MemStorage; +use crate::storage::{Storage, StorageError}; use crate::types::tuple::Tuple; pub struct Database { @@ -64,8 +64,7 @@ impl Database { let source_plan = binder.bind(&stmts[0]).await?; // println!("source_plan plan: {:#?}", source_plan); - let best_plan = Self::default_optimizer(source_plan) - .find_best()?; + let best_plan = Self::default_optimizer(source_plan).find_best()?; // println!("best_plan plan: {:#?}", best_plan); let mut stream = build(best_plan, &self.storage); @@ -78,23 +77,23 @@ impl Database { .batch( "Simplify Filter".to_string(), HepBatchStrategy::fix_point_topdown(10), - vec![RuleImpl::SimplifyFilter] + vec![RuleImpl::SimplifyFilter], ) .batch( "Predicate Pushdown".to_string(), HepBatchStrategy::fix_point_topdown(10), vec![ RuleImpl::PushPredicateThroughJoin, - RuleImpl::PushPredicateIntoScan - ] + RuleImpl::PushPredicateIntoScan, + ], ) .batch( "Column Pruning".to_string(), HepBatchStrategy::fix_point_topdown(10), vec![ RuleImpl::PushProjectThroughChild, - RuleImpl::PushProjectIntoScan - ] + RuleImpl::PushProjectIntoScan, + ], ) .batch( "Limit Pushdown".to_string(), @@ -109,10 +108,7 @@ impl Database { .batch( "Combine Operators".to_string(), HepBatchStrategy::fix_point_topdown(10), - vec![ - RuleImpl::CollapseProject, - RuleImpl::CombineFilter - ] + vec![RuleImpl::CollapseProject, RuleImpl::CombineFilter], ) } } @@ -141,7 +137,7 @@ pub enum DatabaseError { ExecutorError( #[source] #[from] - ExecutorError + ExecutorError, ), #[error("Internal error: {0}")] InternalError(String), @@ -149,19 +145,19 @@ pub enum DatabaseError { OptimizerError( #[source] #[from] - OptimizerError - ) + OptimizerError, + ), } #[cfg(test)] mod test { - use std::sync::Arc; - use tempfile::TempDir; use crate::catalog::{ColumnCatalog, ColumnDesc, TableName}; use crate::db::{Database, DatabaseError}; use crate::storage::{Storage, StorageError}; - use crate::types::LogicalType; use crate::types::tuple::create_table; + use crate::types::LogicalType; + use std::sync::Arc; + use tempfile::TempDir; async fn build_table(storage: &impl Storage) -> Result { let columns = vec![ @@ -169,17 +165,19 @@ mod test { "c1".to_string(), false, ColumnDesc::new(LogicalType::Integer, true, false), - None + None, ), ColumnCatalog::new( "c2".to_string(), false, ColumnDesc::new(LogicalType::Boolean, false, false), - None + None, ), ]; - Ok(storage.create_table(Arc::new("t1".to_string()), columns).await?) + Ok(storage + .create_table(Arc::new("t1".to_string()), columns) + .await?) } #[tokio::test] @@ -199,12 +197,20 @@ mod test { let kipsql = Database::with_kipdb(temp_dir.path()).await?; let _ = kipsql.run("create table t1 (a int primary key, b int unique null, k int, z varchar unique null)").await?; - let _ = kipsql.run("create table t2 (c int primary key, d int unsigned null, e datetime)").await?; + let _ = kipsql + .run("create table t2 (c int primary key, d int unsigned null, e datetime)") + .await?; let _ = kipsql.run("insert into t1 (a, b, k, z) values (-99, 1, 1, 'k'), (-1, 2, 2, 'i'), (5, 3, 2, 'p')").await?; let _ = kipsql.run("insert into t2 (d, c, e) values (2, 1, '2021-05-20 21:00:00'), (3, 4, '2023-09-10 00:00:00')").await?; - let _ = kipsql.run("create table t3 (a int primary key, b decimal(4,2))").await?; - let _ = kipsql.run("insert into t3 (a, b) values (1, 1111), (2, 2.01), (3, 3.00)").await?; - let _ = kipsql.run("insert into t3 (a, b) values (4, 4444), (5, 5222), (6, 1.00)").await?; + let _ = kipsql + .run("create table t3 (a int primary key, b decimal(4,2))") + .await?; + let _ = kipsql + .run("insert into t3 (a, b) values (1, 1111), (2, 2.01), (3, 3.00)") + .await?; + let _ = kipsql + .run("insert into t3 (a, b) values (4, 4444), (5, 5222), (6, 1.00)") + .await?; println!("show tables:"); let tuples_show_tables = kipsql.run("show tables").await?; @@ -231,7 +237,9 @@ mod test { println!("{}", create_table(&tuples_limit)); println!("inner join:"); - let tuples_inner_join = kipsql.run("select * from t1 inner join t2 on a = c").await?; + let tuples_inner_join = kipsql + .run("select * from t1 inner join t2 on a = c") + .await?; println!("{}", create_table(&tuples_inner_join)); println!("left join:"); @@ -239,7 +247,9 @@ mod test { println!("{}", create_table(&tuples_left_join)); println!("right join:"); - let tuples_right_join = kipsql.run("select * from t1 right join t2 on a = c").await?; + let tuples_right_join = kipsql + .run("select * from t1 right join t2 on a = c") + .await?; println!("{}", create_table(&tuples_right_join)); println!("full join:"); @@ -275,7 +285,9 @@ mod test { println!("{}", create_table(&tuples_min_max_agg)); println!("group agg:"); - let tuples_group_agg = kipsql.run("select c, max(d) from t2 group by c having c = 1").await?; + let tuples_group_agg = kipsql + .run("select c, max(d) from t2 group by c having c = 1") + .await?; println!("{}", create_table(&tuples_group_agg)); println!("alias:"); @@ -283,7 +295,9 @@ mod test { println!("{}", create_table(&tuples_group_agg)); println!("alias agg:"); - let tuples_group_agg = kipsql.run("select c, max(d) as max_d from t2 group by c having c = 1").await?; + let tuples_group_agg = kipsql + .run("select c, max(d) as max_d from t2 group by c having c = 1") + .await?; println!("{}", create_table(&tuples_group_agg)); println!("time max:"); @@ -291,10 +305,15 @@ mod test { println!("{}", create_table(&tuples_time_max)); println!("time where:"); - let tuples_time_where_t2 = kipsql.run("select (c + 1) from t2 where e > '2021-05-20'").await?; + let tuples_time_where_t2 = kipsql + .run("select (c + 1) from t2 where e > '2021-05-20'") + .await?; println!("{}", create_table(&tuples_time_where_t2)); - assert!(kipsql.run("select max(d) from t2 group by c").await.is_err()); + assert!(kipsql + .run("select max(d) from t2 group by c") + .await + .is_err()); println!("distinct t1:"); let tuples_distinct_t1 = kipsql.run("select distinct b, k from t1").await?; @@ -307,12 +326,17 @@ mod test { println!("{}", create_table(&update_after_full_t1)); println!("insert overwrite t1:"); - let _ = kipsql.run("insert overwrite t1 (a, b, k) values (-99, 1, 0)").await?; + let _ = kipsql + .run("insert overwrite t1 (a, b, k) values (-99, 1, 0)") + .await?; println!("after t1:"); let insert_overwrite_after_full_t1 = kipsql.run("select * from t1").await?; println!("{}", create_table(&insert_overwrite_after_full_t1)); - assert!(kipsql.run("insert overwrite t1 (a, b, k) values (-1, 1, 0)").await.is_err()); + assert!(kipsql + .run("insert overwrite t1 (a, b, k) values (-1, 1, 0)") + .await + .is_err()); println!("delete t1 with filter:"); let _ = kipsql.run("delete from t1 where b = 0").await?; diff --git a/src/execution/executor/ddl/create_table.rs b/src/execution/executor/ddl/create_table.rs index 89f3c6a4..708a0625 100644 --- a/src/execution/executor/ddl/create_table.rs +++ b/src/execution/executor/ddl/create_table.rs @@ -1,20 +1,18 @@ -use futures_async_stream::try_stream; use crate::execution::executor::{BoxedExecutor, Executor}; use crate::execution::ExecutorError; use crate::planner::operator::create_table::CreateTableOperator; use crate::storage::Storage; use crate::types::tuple::Tuple; use crate::types::tuple_builder::TupleBuilder; +use futures_async_stream::try_stream; pub struct CreateTable { - op: CreateTableOperator + op: CreateTableOperator, } impl From for CreateTable { fn from(op: CreateTableOperator) -> Self { - CreateTable { - op - } + CreateTable { op } } } @@ -33,4 +31,4 @@ impl CreateTable { let tuple = tuple_builder.push_result("CREATE TABLE SUCCESS", format!("{}", table_name).as_str())?; yield tuple; } -} \ No newline at end of file +} diff --git a/src/execution/executor/ddl/drop_table.rs b/src/execution/executor/ddl/drop_table.rs index df666593..3541f57b 100644 --- a/src/execution/executor/ddl/drop_table.rs +++ b/src/execution/executor/ddl/drop_table.rs @@ -1,19 +1,17 @@ -use futures_async_stream::try_stream; use crate::execution::executor::{BoxedExecutor, Executor}; use crate::execution::ExecutorError; use crate::planner::operator::drop_table::DropTableOperator; use crate::storage::Storage; use crate::types::tuple::Tuple; +use futures_async_stream::try_stream; pub struct DropTable { - op: DropTableOperator + op: DropTableOperator, } impl From for DropTable { fn from(op: DropTableOperator) -> Self { - DropTable { - op - } + DropTable { op } } } @@ -30,4 +28,4 @@ impl DropTable { storage.drop_table(&table_name).await?; } -} \ No newline at end of file +} diff --git a/src/execution/executor/ddl/mod.rs b/src/execution/executor/ddl/mod.rs index 56d6c5ba..9c5a45a1 100644 --- a/src/execution/executor/ddl/mod.rs +++ b/src/execution/executor/ddl/mod.rs @@ -1,3 +1,3 @@ pub(crate) mod create_table; pub(crate) mod drop_table; -pub(crate) mod truncate; \ No newline at end of file +pub(crate) mod truncate; diff --git a/src/execution/executor/ddl/truncate.rs b/src/execution/executor/ddl/truncate.rs index 287893c0..5be57612 100644 --- a/src/execution/executor/ddl/truncate.rs +++ b/src/execution/executor/ddl/truncate.rs @@ -1,19 +1,17 @@ -use futures_async_stream::try_stream; use crate::execution::executor::{BoxedExecutor, Executor}; use crate::execution::ExecutorError; use crate::planner::operator::truncate::TruncateOperator; use crate::storage::Storage; use crate::types::tuple::Tuple; +use futures_async_stream::try_stream; pub struct Truncate { - op: TruncateOperator + op: TruncateOperator, } impl From for Truncate { fn from(op: TruncateOperator) -> Self { - Truncate { - op - } + Truncate { op } } } @@ -30,4 +28,4 @@ impl Truncate { storage.drop_data(&table_name).await?; } -} \ No newline at end of file +} diff --git a/src/execution/executor/dml/delete.rs b/src/execution/executor/dml/delete.rs index e6ef8cf8..32f8a974 100644 --- a/src/execution/executor/dml/delete.rs +++ b/src/execution/executor/dml/delete.rs @@ -1,5 +1,3 @@ -use futures_async_stream::try_stream; -use itertools::Itertools; use crate::catalog::TableName; use crate::execution::executor::{BoxedExecutor, Executor}; use crate::execution::ExecutorError; @@ -7,6 +5,8 @@ use crate::planner::operator::delete::DeleteOperator; use crate::storage::{Storage, Transaction}; use crate::types::index::Index; use crate::types::tuple::Tuple; +use futures_async_stream::try_stream; +use itertools::Itertools; pub struct Delete { table_name: TableName, @@ -15,10 +15,7 @@ pub struct Delete { impl From<(DeleteOperator, BoxedExecutor)> for Delete { fn from((DeleteOperator { table_name }, input): (DeleteOperator, BoxedExecutor)) -> Self { - Delete { - table_name, - input, - } + Delete { table_name, input } } } @@ -40,15 +37,20 @@ impl Delete { .all_columns() .into_iter() .enumerate() - .filter_map(|(i, col)| col.desc.is_unique - .then(|| col.id.and_then(|col_id| { - table_catalog.get_unique_index(&col_id) - .map(|index_meta| (i, index_meta)) - })) - .flatten()) + .filter_map(|(i, col)| { + col.desc + .is_unique + .then(|| { + col.id.and_then(|col_id| { + table_catalog + .get_unique_index(&col_id) + .map(|index_meta| (i, index_meta)) + }) + }) + .flatten() + }) .collect_vec(); - #[for_await] for tuple in input { let tuple: Tuple = tuple?; @@ -73,4 +75,4 @@ impl Delete { transaction.commit().await?; } } -} \ No newline at end of file +} diff --git a/src/execution/executor/dml/insert.rs b/src/execution/executor/dml/insert.rs index ae6a876b..c7ec4f9a 100644 --- a/src/execution/executor/dml/insert.rs +++ b/src/execution/executor/dml/insert.rs @@ -1,6 +1,3 @@ -use std::collections::HashMap; -use std::sync::Arc; -use futures_async_stream::try_stream; use crate::catalog::TableName; use crate::execution::executor::{BoxedExecutor, Executor}; use crate::execution::ExecutorError; @@ -9,15 +6,26 @@ use crate::storage::{Storage, Transaction}; use crate::types::index::Index; use crate::types::tuple::Tuple; use crate::types::value::DataValue; +use futures_async_stream::try_stream; +use std::collections::HashMap; +use std::sync::Arc; pub struct Insert { table_name: TableName, input: BoxedExecutor, - is_overwrite: bool + is_overwrite: bool, } impl From<(InsertOperator, BoxedExecutor)> for Insert { - fn from((InsertOperator { table_name, is_overwrite }, input): (InsertOperator, BoxedExecutor)) -> Self { + fn from( + ( + InsertOperator { + table_name, + is_overwrite, + }, + input, + ): (InsertOperator, BoxedExecutor), + ) -> Self { Insert { table_name, input, @@ -35,16 +43,23 @@ impl Executor for Insert { impl Insert { #[try_stream(boxed, ok = Tuple, error = ExecutorError)] pub async fn _execute(self, storage: S) { - let Insert { table_name, input, is_overwrite } = self; + let Insert { + table_name, + input, + is_overwrite, + } = self; let mut primary_key_index = None; let mut unique_values = HashMap::new(); - if let (Some(table_catalog), Some(mut transaction)) = - (storage.table(&table_name).await, storage.transaction(&table_name).await) - { + if let (Some(table_catalog), Some(mut transaction)) = ( + storage.table(&table_name).await, + storage.transaction(&table_name).await, + ) { #[for_await] for tuple in input { - let Tuple { columns, values, .. } = tuple?; + let Tuple { + columns, values, .. + } = tuple?; let mut tuple_map = HashMap::new(); for (i, value) in values.into_iter().enumerate() { let col = &columns[i]; @@ -55,22 +70,22 @@ impl Insert { } } let primary_col_id = primary_key_index.get_or_insert_with(|| { - columns.iter() + columns + .iter() .find(|col| col.desc.is_primary) .map(|col| col.id.unwrap()) .unwrap() }); let all_columns = table_catalog.all_columns_with_id(); - let tuple_id = tuple_map.get(primary_col_id) - .cloned() - .unwrap(); + let tuple_id = tuple_map.get(primary_col_id).cloned().unwrap(); let mut tuple = Tuple { id: Some(tuple_id.clone()), columns: Vec::with_capacity(all_columns.len()), values: Vec::with_capacity(all_columns.len()), }; for (col_id, col) in all_columns { - let value = tuple_map.remove(col_id) + let value = tuple_map + .remove(col_id) .unwrap_or_else(|| Arc::new(DataValue::none(col.datatype()))); if col.desc.is_unique && !value.is_null() { @@ -80,7 +95,10 @@ impl Insert { .push((tuple_id.clone(), value.clone())) } if value.is_null() && !col.nullable { - return Err(ExecutorError::InternalError(format!("Non-null fields do not allow null values to be passed in: {:?}", col))); + return Err(ExecutorError::InternalError(format!( + "Non-null fields do not allow null values to be passed in: {:?}", + col + ))); } tuple.columns.push(col.clone()); @@ -106,4 +124,4 @@ impl Insert { transaction.commit().await?; } } -} \ No newline at end of file +} diff --git a/src/execution/executor/dml/mod.rs b/src/execution/executor/dml/mod.rs index 2044b636..fcc4031e 100644 --- a/src/execution/executor/dml/mod.rs +++ b/src/execution/executor/dml/mod.rs @@ -1,5 +1,5 @@ +pub(crate) mod delete; pub(crate) mod insert; pub(crate) mod update; -pub(crate) mod delete; pub(crate) mod copy_from_file; pub(crate) mod copy_to_file; diff --git a/src/execution/executor/dml/update.rs b/src/execution/executor/dml/update.rs index 0c1320a2..9fbc0e56 100644 --- a/src/execution/executor/dml/update.rs +++ b/src/execution/executor/dml/update.rs @@ -1,5 +1,3 @@ -use std::collections::HashMap; -use futures_async_stream::try_stream; use crate::catalog::TableName; use crate::execution::executor::{BoxedExecutor, Executor}; use crate::execution::ExecutorError; @@ -7,19 +5,27 @@ use crate::planner::operator::update::UpdateOperator; use crate::storage::{Storage, Transaction}; use crate::types::index::Index; use crate::types::tuple::Tuple; +use futures_async_stream::try_stream; +use std::collections::HashMap; pub struct Update { table_name: TableName, input: BoxedExecutor, - values: BoxedExecutor + values: BoxedExecutor, } impl From<(UpdateOperator, BoxedExecutor, BoxedExecutor)> for Update { - fn from((UpdateOperator { table_name }, input, values): (UpdateOperator, BoxedExecutor, BoxedExecutor)) -> Self { + fn from( + (UpdateOperator { table_name }, input, values): ( + UpdateOperator, + BoxedExecutor, + BoxedExecutor, + ), + ) -> Self { Update { table_name, input, - values + values, } } } @@ -33,7 +39,11 @@ impl Executor for Update { impl Update { #[try_stream(boxed, ok = Tuple, error = ExecutorError)] pub async fn _execute(self, storage: S) { - let Update { table_name, input, values } = self; + let Update { + table_name, + input, + values, + } = self; if let Some(mut transaction) = storage.transaction(&table_name).await { let table_catalog = storage.table(&table_name).await.unwrap(); @@ -42,7 +52,9 @@ impl Update { // only once #[for_await] for tuple in values { - let Tuple { columns, values, .. } = tuple?; + let Tuple { + columns, values, .. + } = tuple?; for i in 0..columns.len() { value_map.insert(columns[i].id, values[i].clone()); } @@ -61,7 +73,9 @@ impl Update { is_overwrite = false; } if column.desc.is_unique && value != &tuple.values[i] { - if let Some(index_meta) = table_catalog.get_unique_index(&column.id.unwrap()) { + if let Some(index_meta) = + table_catalog.get_unique_index(&column.id.unwrap()) + { let mut index = Index { id: index_meta.id, column_values: vec![tuple.values[i].clone()], @@ -70,7 +84,11 @@ impl Update { if !value.is_null() { index.column_values[0] = value.clone(); - transaction.add_index(index, vec![tuple.id.clone().unwrap()], true)?; + transaction.add_index( + index, + vec![tuple.id.clone().unwrap()], + true, + )?; } } } @@ -85,4 +103,4 @@ impl Update { transaction.commit().await?; } } -} \ No newline at end of file +} diff --git a/src/execution/executor/dql/aggregate/avg.rs b/src/execution/executor/dql/aggregate/avg.rs index 205a1e0c..599a1295 100644 --- a/src/execution/executor/dql/aggregate/avg.rs +++ b/src/execution/executor/dql/aggregate/avg.rs @@ -1,15 +1,15 @@ -use std::sync::Arc; -use crate::execution::executor::dql::aggregate::Accumulator; use crate::execution::executor::dql::aggregate::sum::SumAccumulator; +use crate::execution::executor::dql::aggregate::Accumulator; use crate::execution::ExecutorError; -use crate::expression::BinaryOperator; use crate::expression::value_compute::binary_op; -use crate::types::LogicalType; +use crate::expression::BinaryOperator; use crate::types::value::{DataValue, ValueRef}; +use crate::types::LogicalType; +use std::sync::Arc; pub struct AvgAccumulator { inner: SumAccumulator, - count: usize + count: usize, } impl AvgAccumulator { @@ -32,8 +32,7 @@ impl Accumulator for AvgAccumulator { } fn evaluate(&self) -> Result { - let value = self.inner - .evaluate()?; + let value = self.inner.evaluate()?; let quantity = if value.logical_type().is_signed_numeric() { DataValue::Int64(Some(self.count as i64)) @@ -41,12 +40,10 @@ impl Accumulator for AvgAccumulator { DataValue::UInt32(Some(self.count as u32)) }; - Ok(Arc::new( - binary_op( - &value, - &quantity, - &BinaryOperator::Divide - )? - )) + Ok(Arc::new(binary_op( + &value, + &quantity, + &BinaryOperator::Divide, + )?)) } -} \ No newline at end of file +} diff --git a/src/execution/executor/dql/aggregate/count.rs b/src/execution/executor/dql/aggregate/count.rs index 22f2f766..d649b022 100644 --- a/src/execution/executor/dql/aggregate/count.rs +++ b/src/execution/executor/dql/aggregate/count.rs @@ -1,9 +1,9 @@ -use std::collections::HashSet; -use std::sync::Arc; -use ahash::RandomState; use crate::execution::executor::dql::aggregate::Accumulator; use crate::execution::ExecutorError; use crate::types::value::{DataValue, ValueRef}; +use ahash::RandomState; +use std::collections::HashSet; +use std::sync::Arc; pub struct CountAccumulator { result: i32, @@ -51,6 +51,8 @@ impl Accumulator for DistinctCountAccumulator { } fn evaluate(&self) -> Result { - Ok(Arc::new(DataValue::Int32(Some(self.distinct_values.len() as i32)))) + Ok(Arc::new(DataValue::Int32(Some( + self.distinct_values.len() as i32 + )))) } } diff --git a/src/execution/executor/dql/aggregate/hash_agg.rs b/src/execution/executor/dql/aggregate/hash_agg.rs index 08348182..8cfc5779 100644 --- a/src/execution/executor/dql/aggregate/hash_agg.rs +++ b/src/execution/executor/dql/aggregate/hash_agg.rs @@ -1,14 +1,14 @@ -use ahash::{HashMap, HashMapExt}; -use futures_async_stream::try_stream; -use itertools::Itertools; -use crate::execution::executor::{BoxedExecutor, Executor}; use crate::execution::executor::dql::aggregate::create_accumulators; +use crate::execution::executor::{BoxedExecutor, Executor}; use crate::execution::ExecutorError; use crate::expression::ScalarExpression; use crate::planner::operator::aggregate::AggregateOperator; use crate::storage::Storage; use crate::types::tuple::Tuple; use crate::types::value::ValueRef; +use ahash::{HashMap, HashMapExt}; +use futures_async_stream::try_stream; +use itertools::Itertools; pub struct HashAggExecutor { pub agg_calls: Vec, @@ -17,7 +17,15 @@ pub struct HashAggExecutor { } impl From<(AggregateOperator, BoxedExecutor)> for HashAggExecutor { - fn from((AggregateOperator { agg_calls, groupby_exprs }, input): (AggregateOperator, BoxedExecutor)) -> Self { + fn from( + ( + AggregateOperator { + agg_calls, + groupby_exprs, + }, + input, + ): (AggregateOperator, BoxedExecutor), + ) -> Self { HashAggExecutor { agg_calls, groupby_exprs, @@ -53,7 +61,8 @@ impl HashAggExecutor { }); // 2.1 evaluate agg exprs and collect the result values for later accumulators. - let values: Vec = self.agg_calls + let values: Vec = self + .agg_calls .iter() .map(|expr| { if let ScalarExpression::AggCall { args, .. } = expr { @@ -64,7 +73,8 @@ impl HashAggExecutor { }) .try_collect()?; - let group_keys: Vec = self.groupby_exprs + let group_keys: Vec = self + .groupby_exprs .iter() .map(|expr| expr.eval_column(&tuple)) .try_collect()?; @@ -82,11 +92,10 @@ impl HashAggExecutor { if let Some(group_and_agg_columns) = group_and_agg_columns_option { for (group_keys, accs) in group_hash_accs { // Tips: Accumulator First - let values: Vec = accs.iter() + let values: Vec = accs + .iter() .map(|acc| acc.evaluate()) - .chain(group_keys - .into_iter() - .map(|key| Ok(key))) + .chain(group_keys.into_iter().map(|key| Ok(key))) .try_collect()?; yield Tuple { @@ -101,23 +110,22 @@ impl HashAggExecutor { #[cfg(test)] mod test { - use std::sync::Arc; - use itertools::Itertools; use crate::catalog::{ColumnCatalog, ColumnDesc}; use crate::execution::executor::dql::aggregate::hash_agg::HashAggExecutor; - use crate::execution::executor::dql::values::Values; - use crate::execution::executor::{Executor, try_collect}; use crate::execution::executor::dql::test::build_integers; + use crate::execution::executor::dql::values::Values; + use crate::execution::executor::{try_collect, Executor}; use crate::execution::ExecutorError; use crate::expression::agg::AggKind; use crate::expression::ScalarExpression; use crate::planner::operator::aggregate::AggregateOperator; use crate::planner::operator::values::ValuesOperator; use crate::storage::memory::MemStorage; - use crate::types::LogicalType; use crate::types::tuple::create_table; use crate::types::value::DataValue; - + use crate::types::LogicalType; + use itertools::Itertools; + use std::sync::Arc; #[tokio::test] async fn test_hash_agg() -> Result<(), ExecutorError> { @@ -125,25 +133,34 @@ mod test { let desc = ColumnDesc::new(LogicalType::Integer, false, false); let t1_columns = vec![ - Arc::new(ColumnCatalog::new("c1".to_string(), true, desc.clone(), None)), - Arc::new(ColumnCatalog::new("c2".to_string(), true, desc.clone(), None)), - Arc::new(ColumnCatalog::new("c3".to_string(), true, desc.clone(), None)), + Arc::new(ColumnCatalog::new( + "c1".to_string(), + true, + desc.clone(), + None, + )), + Arc::new(ColumnCatalog::new( + "c2".to_string(), + true, + desc.clone(), + None, + )), + Arc::new(ColumnCatalog::new( + "c3".to_string(), + true, + desc.clone(), + None, + )), ]; let operator = AggregateOperator { - groupby_exprs: vec![ - ScalarExpression::ColumnRef(t1_columns[0].clone()) - ], - agg_calls: vec![ - ScalarExpression::AggCall { - distinct: false, - kind: AggKind::Sum, - args: vec![ - ScalarExpression::ColumnRef(t1_columns[1].clone()) - ], - ty: LogicalType::Integer, - } - ], + groupby_exprs: vec![ScalarExpression::ColumnRef(t1_columns[0].clone())], + agg_calls: vec![ScalarExpression::AggCall { + distinct: false, + kind: AggKind::Sum, + args: vec![ScalarExpression::ColumnRef(t1_columns[1].clone())], + ty: LogicalType::Integer, + }], }; let input = Values::from(ValuesOperator { @@ -167,25 +184,25 @@ mod test { Arc::new(DataValue::Int32(Some(1))), Arc::new(DataValue::Int32(Some(2))), Arc::new(DataValue::Int32(Some(3))), - ] + ], ], columns: t1_columns, - }).execute(&mem_storage); + }) + .execute(&mem_storage); - let tuples = try_collect(&mut HashAggExecutor::from((operator, input)).execute(&mem_storage)).await?; + let tuples = + try_collect(&mut HashAggExecutor::from((operator, input)).execute(&mem_storage)) + .await?; println!("hash_agg_test: \n{}", create_table(&tuples)); assert_eq!(tuples.len(), 2); - let vec_values = tuples - .into_iter() - .map(|tuple| tuple.values) - .collect_vec(); + let vec_values = tuples.into_iter().map(|tuple| tuple.values).collect_vec(); assert!(vec_values.contains(&build_integers(vec![Some(3), Some(0)]))); assert!(vec_values.contains(&build_integers(vec![Some(5), Some(1)]))); Ok(()) } -} \ No newline at end of file +} diff --git a/src/execution/executor/dql/aggregate/min_max.rs b/src/execution/executor/dql/aggregate/min_max.rs index d06b3d3b..8c45cc82 100644 --- a/src/execution/executor/dql/aggregate/min_max.rs +++ b/src/execution/executor/dql/aggregate/min_max.rs @@ -1,15 +1,15 @@ -use std::sync::Arc; use crate::execution::executor::dql::aggregate::Accumulator; use crate::execution::ExecutorError; -use crate::expression::BinaryOperator; use crate::expression::value_compute::binary_op; -use crate::types::LogicalType; +use crate::expression::BinaryOperator; use crate::types::value::{DataValue, ValueRef}; +use crate::types::LogicalType; +use std::sync::Arc; pub struct MinMaxAccumulator { inner: Option, op: BinaryOperator, - ty: LogicalType + ty: LogicalType, } impl MinMaxAccumulator { @@ -32,22 +32,25 @@ impl Accumulator for MinMaxAccumulator { fn update_value(&mut self, value: &ValueRef) -> Result<(), ExecutorError> { if !value.is_null() { if let Some(inner_value) = &self.inner { - if let DataValue::Boolean(Some(result)) = binary_op(&inner_value, value, &self.op)? { + if let DataValue::Boolean(Some(result)) = binary_op(&inner_value, value, &self.op)? + { result } else { unreachable!() } } else { true - }.then(|| self.inner = Some(value.clone())); + } + .then(|| self.inner = Some(value.clone())); } Ok(()) } fn evaluate(&self) -> Result { - Ok(self.inner + Ok(self + .inner .clone() .unwrap_or_else(|| Arc::new(DataValue::none(&self.ty)))) } -} \ No newline at end of file +} diff --git a/src/execution/executor/dql/aggregate/mod.rs b/src/execution/executor/dql/aggregate/mod.rs index 889123d1..c40ee2c4 100644 --- a/src/execution/executor/dql/aggregate/mod.rs +++ b/src/execution/executor/dql/aggregate/mod.rs @@ -1,12 +1,14 @@ +mod avg; mod count; +pub mod hash_agg; +mod min_max; pub mod simple_agg; mod sum; -mod min_max; -mod avg; -pub mod hash_agg; use crate::execution::executor::dql::aggregate::avg::AvgAccumulator; -use crate::execution::executor::dql::aggregate::count::{CountAccumulator, DistinctCountAccumulator}; +use crate::execution::executor::dql::aggregate::count::{ + CountAccumulator, DistinctCountAccumulator, +}; use crate::execution::executor::dql::aggregate::min_max::MinMaxAccumulator; use crate::execution::executor::dql::aggregate::sum::{DistinctSumAccumulator, SumAccumulator}; use crate::execution::ExecutorError; @@ -26,7 +28,10 @@ pub trait Accumulator: Send + Sync { } fn create_accumulator(expr: &ScalarExpression) -> Box { - if let ScalarExpression::AggCall { kind, ty, distinct, .. } = expr { + if let ScalarExpression::AggCall { + kind, ty, distinct, .. + } = expr + { match (kind, distinct) { (AggKind::Count, false) => Box::new(CountAccumulator::new()), (AggKind::Count, true) => Box::new(DistinctCountAccumulator::new()), @@ -46,4 +51,4 @@ fn create_accumulator(expr: &ScalarExpression) -> Box { fn create_accumulators(exprs: &[ScalarExpression]) -> Vec> { exprs.iter().map(create_accumulator).collect() -} \ No newline at end of file +} diff --git a/src/execution/executor/dql/aggregate/simple_agg.rs b/src/execution/executor/dql/aggregate/simple_agg.rs index 0934f05a..8a00a045 100644 --- a/src/execution/executor/dql/aggregate/simple_agg.rs +++ b/src/execution/executor/dql/aggregate/simple_agg.rs @@ -1,13 +1,13 @@ -use futures_async_stream::try_stream; -use itertools::Itertools; -use crate::execution::executor::{BoxedExecutor, Executor}; use crate::execution::executor::dql::aggregate::create_accumulators; +use crate::execution::executor::{BoxedExecutor, Executor}; use crate::execution::ExecutorError; use crate::expression::ScalarExpression; use crate::planner::operator::aggregate::AggregateOperator; use crate::storage::Storage; use crate::types::tuple::Tuple; use crate::types::value::ValueRef; +use futures_async_stream::try_stream; +use itertools::Itertools; pub struct SimpleAggExecutor { pub agg_calls: Vec, @@ -15,11 +15,10 @@ pub struct SimpleAggExecutor { } impl From<(AggregateOperator, BoxedExecutor)> for SimpleAggExecutor { - fn from((AggregateOperator { agg_calls, .. }, input): (AggregateOperator, BoxedExecutor)) -> Self { - SimpleAggExecutor { - agg_calls, - input, - } + fn from( + (AggregateOperator { agg_calls, .. }, input): (AggregateOperator, BoxedExecutor), + ) -> Self { + SimpleAggExecutor { agg_calls, input } } } @@ -46,13 +45,12 @@ impl SimpleAggExecutor { .collect_vec() }); - let values: Vec = self.agg_calls + let values: Vec = self + .agg_calls .iter() .map(|expr| match expr { - ScalarExpression::AggCall { args, .. } => { - args[0].eval_column(&tuple) - } - _ => unreachable!() + ScalarExpression::AggCall { args, .. } => args[0].eval_column(&tuple), + _ => unreachable!(), }) .try_collect()?; @@ -62,10 +60,7 @@ impl SimpleAggExecutor { } if let Some(columns) = columns_option { - let values: Vec = accs - .into_iter() - .map(|acc| acc.evaluate()) - .try_collect()?; + let values: Vec = accs.into_iter().map(|acc| acc.evaluate()).try_collect()?; yield Tuple { id: None, @@ -74,4 +69,4 @@ impl SimpleAggExecutor { }; } } -} \ No newline at end of file +} diff --git a/src/execution/executor/dql/aggregate/sum.rs b/src/execution/executor/dql/aggregate/sum.rs index cca28c70..ef3ca7b5 100644 --- a/src/execution/executor/dql/aggregate/sum.rs +++ b/src/execution/executor/dql/aggregate/sum.rs @@ -1,12 +1,12 @@ -use std::collections::HashSet; -use std::sync::Arc; -use ahash::RandomState; use crate::execution::executor::dql::aggregate::Accumulator; use crate::execution::ExecutorError; -use crate::expression::BinaryOperator; use crate::expression::value_compute::binary_op; -use crate::types::LogicalType; +use crate::expression::BinaryOperator; use crate::types::value::{DataValue, ValueRef}; +use crate::types::LogicalType; +use ahash::RandomState; +use std::collections::HashSet; +use std::sync::Arc; pub struct SumAccumulator { result: DataValue, @@ -16,18 +16,16 @@ impl SumAccumulator { pub fn new(ty: &LogicalType) -> Self { assert!(ty.is_numeric()); - Self { result: DataValue::init(&ty) } + Self { + result: DataValue::init(&ty), + } } } impl Accumulator for SumAccumulator { fn update_value(&mut self, value: &ValueRef) -> Result<(), ExecutorError> { if !value.is_null() { - self.result = binary_op( - &self.result, - value, - &BinaryOperator::Plus - )?; + self.result = binary_op(&self.result, value, &BinaryOperator::Plus)?; } Ok(()) @@ -40,7 +38,7 @@ impl Accumulator for SumAccumulator { pub struct DistinctSumAccumulator { distinct_values: HashSet, - inner: SumAccumulator + inner: SumAccumulator, } impl DistinctSumAccumulator { @@ -65,4 +63,4 @@ impl Accumulator for DistinctSumAccumulator { fn evaluate(&self) -> Result { self.inner.evaluate() } -} \ No newline at end of file +} diff --git a/src/execution/executor/dql/dummy.rs b/src/execution/executor/dql/dummy.rs index a6c342f4..9b9c7e5e 100644 --- a/src/execution/executor/dql/dummy.rs +++ b/src/execution/executor/dql/dummy.rs @@ -1,8 +1,8 @@ -use futures_async_stream::try_stream; use crate::execution::executor::{BoxedExecutor, Executor}; use crate::execution::ExecutorError; use crate::storage::Storage; use crate::types::tuple::Tuple; +use futures_async_stream::try_stream; pub struct Dummy {} @@ -14,5 +14,5 @@ impl Executor for Dummy { impl Dummy { #[try_stream(boxed, ok = Tuple, error = ExecutorError)] - pub async fn _execute(self) { } -} \ No newline at end of file + pub async fn _execute(self) {} +} diff --git a/src/execution/executor/dql/filter.rs b/src/execution/executor/dql/filter.rs index a9ada4fa..a3090548 100644 --- a/src/execution/executor/dql/filter.rs +++ b/src/execution/executor/dql/filter.rs @@ -1,4 +1,3 @@ -use futures_async_stream::try_stream; use crate::execution::executor::{BoxedExecutor, Executor}; use crate::execution::ExecutorError; use crate::expression::ScalarExpression; @@ -6,18 +5,16 @@ use crate::planner::operator::filter::FilterOperator; use crate::storage::Storage; use crate::types::tuple::Tuple; use crate::types::value::DataValue; +use futures_async_stream::try_stream; pub struct Filter { predicate: ScalarExpression, - input: BoxedExecutor + input: BoxedExecutor, } impl From<(FilterOperator, BoxedExecutor)> for Filter { fn from((FilterOperator { predicate, .. }, input): (FilterOperator, BoxedExecutor)) -> Self { - Filter { - predicate, - input - } + Filter { predicate, input } } } @@ -39,11 +36,11 @@ impl Filter { if let Some(true) = option { yield tuple; } else { - continue + continue; } } else { unreachable!("only bool"); } } } -} \ No newline at end of file +} diff --git a/src/execution/executor/dql/index_scan.rs b/src/execution/executor/dql/index_scan.rs index cf826cd4..1fdc77f8 100644 --- a/src/execution/executor/dql/index_scan.rs +++ b/src/execution/executor/dql/index_scan.rs @@ -1,20 +1,18 @@ -use futures_async_stream::try_stream; use crate::execution::executor::{BoxedExecutor, Executor}; use crate::execution::ExecutorError; use crate::planner::operator::scan::ScanOperator; use crate::storage::{Iter, Storage, Transaction}; use crate::types::errors::TypeError; use crate::types::tuple::Tuple; +use futures_async_stream::try_stream; pub(crate) struct IndexScan { - op: ScanOperator + op: ScanOperator, } impl From for IndexScan { fn from(op: ScanOperator) -> Self { - IndexScan { - op - } + IndexScan { op } } } @@ -27,20 +25,21 @@ impl Executor for IndexScan { impl IndexScan { #[try_stream(boxed, ok = Tuple, error = ExecutorError)] pub async fn _execute(self, storage: S) { - let ScanOperator { table_name, columns, limit, index_by, .. } = self.op; + let ScanOperator { + table_name, + columns, + limit, + index_by, + .. + } = self.op; let (index_meta, binaries) = index_by.ok_or(TypeError::InvalidType)?; if let Some(transaction) = storage.transaction(&table_name).await { - let mut iter = transaction.read_by_index( - limit, - columns, - index_meta, - binaries - )?; + let mut iter = transaction.read_by_index(limit, columns, index_meta, binaries)?; - while let Some(tuple) = iter.next_tuple()? { + while let Some(tuple) = iter.next_tuple()? { yield tuple; } } } -} \ No newline at end of file +} diff --git a/src/execution/executor/dql/join/hash_join.rs b/src/execution/executor/dql/join/hash_join.rs index 0d3868e4..ed8b968f 100644 --- a/src/execution/executor/dql/join/hash_join.rs +++ b/src/execution/executor/dql/join/hash_join.rs @@ -1,9 +1,5 @@ -use std::sync::Arc; -use ahash::{HashMap, HashMapExt, HashSet, HashSetExt, RandomState}; -use futures_async_stream::try_stream; -use itertools::Itertools; -use crate::execution::executor::dql::join::joins_nullable; use crate::catalog::{ColumnCatalog, ColumnRef}; +use crate::execution::executor::dql::join::joins_nullable; use crate::execution::executor::{BoxedExecutor, Executor}; use crate::execution::ExecutorError; use crate::expression::ScalarExpression; @@ -12,16 +8,26 @@ use crate::storage::Storage; use crate::types::errors::TypeError; use crate::types::tuple::Tuple; use crate::types::value::DataValue; +use ahash::{HashMap, HashMapExt, HashSet, HashSetExt, RandomState}; +use futures_async_stream::try_stream; +use itertools::Itertools; +use std::sync::Arc; pub struct HashJoin { on: JoinCondition, ty: JoinType, left_input: BoxedExecutor, - right_input: BoxedExecutor + right_input: BoxedExecutor, } impl From<(JoinOperator, BoxedExecutor, BoxedExecutor)> for HashJoin { - fn from((JoinOperator { on, join_type }, left_input, right_input): (JoinOperator, BoxedExecutor, BoxedExecutor)) -> Self { + fn from( + (JoinOperator { on, join_type }, left_input, right_input): ( + JoinOperator, + BoxedExecutor, + BoxedExecutor, + ), + ) -> Self { HashJoin { on, ty: join_type, @@ -40,14 +46,22 @@ impl Executor for HashJoin { impl HashJoin { #[try_stream(boxed, ok = Tuple, error = ExecutorError)] pub async fn _execute(self) { - let HashJoin { on, ty, left_input, right_input } = self; + let HashJoin { + on, + ty, + left_input, + right_input, + } = self; if ty == JoinType::Cross { unreachable!("Cross join should not be in HashJoinExecutor"); } - let ((on_left_keys, on_right_keys), filter): ((Vec, Vec), _) = match on { + let ((on_left_keys, on_right_keys), filter): ( + (Vec, Vec), + _, + ) = match on { JoinCondition::On { on, filter } => (on.into_iter().unzip(), filter), - JoinCondition::None => unreachable!("HashJoin must has on condition") + JoinCondition::None => unreachable!("HashJoin must has on condition"), }; let mut join_columns = Vec::new(); @@ -71,10 +85,7 @@ impl HashJoin { left_init_flag = true; } - left_map - .entry(hash) - .or_insert(Vec::new()) - .push(tuple); + left_map.entry(hash).or_insert(Vec::new()).push(tuple); } // probe phase @@ -102,7 +113,11 @@ impl HashJoin { .chain(tuple.values.clone()) .collect_vec(); - Tuple { id: None, columns: join_columns.clone(), values: full_values } + Tuple { + id: None, + columns: join_columns.clone(), + values: full_values, + } }) .collect_vec() } else if matches!(ty, JoinType::Right | JoinType::Full) { @@ -113,13 +128,20 @@ impl HashJoin { .chain(tuple.values) .collect_vec(); - vec![Tuple { id: None, columns: join_columns.clone(), values }] + vec![Tuple { + id: None, + columns: join_columns.clone(), + values, + }] } else { vec![] }; // on filter - if let (Some(expr), false) = (&filter, join_tuples.is_empty() || matches!(ty, JoinType::Full | JoinType::Cross)) { + if let (Some(expr), false) = ( + &filter, + join_tuples.is_empty() || matches!(ty, JoinType::Full | JoinType::Cross), + ) { let mut filter_tuples = Vec::with_capacity(join_tuples.len()); for mut tuple in join_tuples { @@ -145,7 +167,7 @@ impl HashJoin { } filter_tuples.push(tuple) } - _ => () + _ => (), } } else { filter_tuples.push(tuple) @@ -166,10 +188,15 @@ impl HashJoin { if matches!(ty, JoinType::Left | JoinType::Full) { for (hash, tuples) in left_map { if used_set.contains(&hash) { - continue + continue; } - for Tuple { mut values, columns, ..} in tuples { + for Tuple { + mut values, + columns, + .. + } in tuples + { let mut right_empties = join_columns[columns.len()..] .iter() .map(|col| Arc::new(DataValue::none(col.datatype()))) @@ -177,14 +204,20 @@ impl HashJoin { values.append(&mut right_empties); - yield Tuple { id: None, columns: join_columns.clone(), values } + yield Tuple { + id: None, + columns: join_columns.clone(), + values, + } } } } } fn columns_filling(tuple: &Tuple, join_columns: &mut Vec, force_nullable: bool) { - let mut new_columns = tuple.columns.iter() + let mut new_columns = tuple + .columns + .iter() .cloned() .map(|col| { let mut new_catalog = ColumnCatalog::clone(&col); @@ -200,7 +233,7 @@ impl HashJoin { fn hash_row( on_keys: &[ScalarExpression], hash_random_state: &RandomState, - tuple: &Tuple + tuple: &Tuple, ) -> Result { let mut values = Vec::with_capacity(on_keys.len()); @@ -214,40 +247,77 @@ impl HashJoin { #[cfg(test)] mod test { - use std::sync::Arc; use crate::catalog::{ColumnCatalog, ColumnDesc}; - use crate::execution::executor::{BoxedExecutor, Executor, try_collect}; use crate::execution::executor::dql::join::hash_join::HashJoin; use crate::execution::executor::dql::test::build_integers; use crate::execution::executor::dql::values::Values; + use crate::execution::executor::{try_collect, BoxedExecutor, Executor}; use crate::execution::ExecutorError; use crate::expression::ScalarExpression; use crate::planner::operator::join::{JoinCondition, JoinOperator, JoinType}; use crate::planner::operator::values::ValuesOperator; use crate::storage::memory::MemStorage; use crate::storage::Storage; - use crate::types::LogicalType; use crate::types::tuple::create_table; use crate::types::value::DataValue; + use crate::types::LogicalType; + use std::sync::Arc; - fn build_join_values(_s: &S) -> (Vec<(ScalarExpression, ScalarExpression)>, BoxedExecutor, BoxedExecutor) { + fn build_join_values( + _s: &S, + ) -> ( + Vec<(ScalarExpression, ScalarExpression)>, + BoxedExecutor, + BoxedExecutor, + ) { let desc = ColumnDesc::new(LogicalType::Integer, false, false); let t1_columns = vec![ - Arc::new(ColumnCatalog::new("c1".to_string(), true, desc.clone(), None)), - Arc::new(ColumnCatalog::new("c2".to_string(), true, desc.clone(), None)), - Arc::new(ColumnCatalog::new("c3".to_string(), true, desc.clone(), None)), + Arc::new(ColumnCatalog::new( + "c1".to_string(), + true, + desc.clone(), + None, + )), + Arc::new(ColumnCatalog::new( + "c2".to_string(), + true, + desc.clone(), + None, + )), + Arc::new(ColumnCatalog::new( + "c3".to_string(), + true, + desc.clone(), + None, + )), ]; let t2_columns = vec![ - Arc::new(ColumnCatalog::new("c4".to_string(), true, desc.clone(), None)), - Arc::new(ColumnCatalog::new("c5".to_string(), true, desc.clone(), None)), - Arc::new(ColumnCatalog::new("c6".to_string(), true, desc.clone(), None)), + Arc::new(ColumnCatalog::new( + "c4".to_string(), + true, + desc.clone(), + None, + )), + Arc::new(ColumnCatalog::new( + "c5".to_string(), + true, + desc.clone(), + None, + )), + Arc::new(ColumnCatalog::new( + "c6".to_string(), + true, + desc.clone(), + None, + )), ]; - let on_keys = vec![ - (ScalarExpression::ColumnRef(t1_columns[0].clone()), ScalarExpression::ColumnRef(t2_columns[0].clone())) - ]; + let on_keys = vec![( + ScalarExpression::ColumnRef(t1_columns[0].clone()), + ScalarExpression::ColumnRef(t2_columns[0].clone()), + )]; let values_t1 = Values::from(ValuesOperator { rows: vec![ @@ -265,7 +335,7 @@ mod test { Arc::new(DataValue::Int32(Some(3))), Arc::new(DataValue::Int32(Some(5))), Arc::new(DataValue::Int32(Some(7))), - ] + ], ], columns: t1_columns, }); @@ -296,8 +366,6 @@ mod test { columns: t2_columns, }); - - (on_keys, values_t1.execute(_s), values_t2.execute(_s)) } @@ -307,7 +375,10 @@ mod test { let (keys, left, right) = build_join_values(&mem_storage); let op = JoinOperator { - on: JoinCondition::On { on: keys, filter: None }, + on: JoinCondition::On { + on: keys, + filter: None, + }, join_type: JoinType::Inner, }; let mut executor = HashJoin::from((op, left, right)).execute(&mem_storage); @@ -317,9 +388,18 @@ mod test { assert_eq!(tuples.len(), 3); - assert_eq!(tuples[0].values, build_integers(vec![Some(0), Some(2), Some(4), Some(0), Some(2), Some(4)])); - assert_eq!(tuples[1].values, build_integers(vec![Some(1), Some(3), Some(5), Some(1), Some(3), Some(5)])); - assert_eq!(tuples[2].values, build_integers(vec![Some(1), Some(3), Some(5), Some(1), Some(1), Some(1)])); + assert_eq!( + tuples[0].values, + build_integers(vec![Some(0), Some(2), Some(4), Some(0), Some(2), Some(4)]) + ); + assert_eq!( + tuples[1].values, + build_integers(vec![Some(1), Some(3), Some(5), Some(1), Some(3), Some(5)]) + ); + assert_eq!( + tuples[2].values, + build_integers(vec![Some(1), Some(3), Some(5), Some(1), Some(1), Some(1)]) + ); Ok(()) } @@ -330,7 +410,10 @@ mod test { let (keys, left, right) = build_join_values(&mem_storage); let op = JoinOperator { - on: JoinCondition::On { on: keys, filter: None }, + on: JoinCondition::On { + on: keys, + filter: None, + }, join_type: JoinType::Left, }; let mut executor = HashJoin::from((op, left, right)).execute(&mem_storage); @@ -340,10 +423,22 @@ mod test { assert_eq!(tuples.len(), 4); - assert_eq!(tuples[0].values, build_integers(vec![Some(0), Some(2), Some(4), Some(0), Some(2), Some(4)])); - assert_eq!(tuples[1].values, build_integers(vec![Some(1), Some(3), Some(5), Some(1), Some(3), Some(5)])); - assert_eq!(tuples[2].values, build_integers(vec![Some(1), Some(3), Some(5), Some(1), Some(1), Some(1)])); - assert_eq!(tuples[3].values, build_integers(vec![Some(3), Some(5), Some(7), None, None, None])); + assert_eq!( + tuples[0].values, + build_integers(vec![Some(0), Some(2), Some(4), Some(0), Some(2), Some(4)]) + ); + assert_eq!( + tuples[1].values, + build_integers(vec![Some(1), Some(3), Some(5), Some(1), Some(3), Some(5)]) + ); + assert_eq!( + tuples[2].values, + build_integers(vec![Some(1), Some(3), Some(5), Some(1), Some(1), Some(1)]) + ); + assert_eq!( + tuples[3].values, + build_integers(vec![Some(3), Some(5), Some(7), None, None, None]) + ); Ok(()) } @@ -354,7 +449,10 @@ mod test { let (keys, left, right) = build_join_values(&mem_storage); let op = JoinOperator { - on: JoinCondition::On { on: keys, filter: None }, + on: JoinCondition::On { + on: keys, + filter: None, + }, join_type: JoinType::Right, }; let mut executor = HashJoin::from((op, left, right)).execute(&mem_storage); @@ -364,10 +462,22 @@ mod test { assert_eq!(tuples.len(), 4); - assert_eq!(tuples[0].values, build_integers(vec![Some(0), Some(2), Some(4), Some(0), Some(2), Some(4)])); - assert_eq!(tuples[1].values, build_integers(vec![Some(1), Some(3), Some(5), Some(1), Some(3), Some(5)])); - assert_eq!(tuples[2].values, build_integers(vec![None, None, None, Some(4), Some(6), Some(8)])); - assert_eq!(tuples[3].values, build_integers(vec![Some(1), Some(3), Some(5), Some(1), Some(1), Some(1)])); + assert_eq!( + tuples[0].values, + build_integers(vec![Some(0), Some(2), Some(4), Some(0), Some(2), Some(4)]) + ); + assert_eq!( + tuples[1].values, + build_integers(vec![Some(1), Some(3), Some(5), Some(1), Some(3), Some(5)]) + ); + assert_eq!( + tuples[2].values, + build_integers(vec![None, None, None, Some(4), Some(6), Some(8)]) + ); + assert_eq!( + tuples[3].values, + build_integers(vec![Some(1), Some(3), Some(5), Some(1), Some(1), Some(1)]) + ); Ok(()) } @@ -378,7 +488,10 @@ mod test { let (keys, left, right) = build_join_values(&mem_storage); let op = JoinOperator { - on: JoinCondition::On { on: keys, filter: None }, + on: JoinCondition::On { + on: keys, + filter: None, + }, join_type: JoinType::Full, }; let mut executor = HashJoin::from((op, left, right)).execute(&mem_storage); @@ -388,12 +501,27 @@ mod test { assert_eq!(tuples.len(), 5); - assert_eq!(tuples[0].values, build_integers(vec![Some(0), Some(2), Some(4), Some(0), Some(2), Some(4)])); - assert_eq!(tuples[1].values, build_integers(vec![Some(1), Some(3), Some(5), Some(1), Some(3), Some(5)])); - assert_eq!(tuples[2].values, build_integers(vec![None, None, None, Some(4), Some(6), Some(8)])); - assert_eq!(tuples[3].values, build_integers(vec![Some(1), Some(3), Some(5), Some(1), Some(1), Some(1)])); - assert_eq!(tuples[4].values, build_integers(vec![Some(3), Some(5), Some(7), None, None, None])); + assert_eq!( + tuples[0].values, + build_integers(vec![Some(0), Some(2), Some(4), Some(0), Some(2), Some(4)]) + ); + assert_eq!( + tuples[1].values, + build_integers(vec![Some(1), Some(3), Some(5), Some(1), Some(3), Some(5)]) + ); + assert_eq!( + tuples[2].values, + build_integers(vec![None, None, None, Some(4), Some(6), Some(8)]) + ); + assert_eq!( + tuples[3].values, + build_integers(vec![Some(1), Some(3), Some(5), Some(1), Some(1), Some(1)]) + ); + assert_eq!( + tuples[4].values, + build_integers(vec![Some(3), Some(5), Some(7), None, None, None]) + ); Ok(()) } -} \ No newline at end of file +} diff --git a/src/execution/executor/dql/join/mod.rs b/src/execution/executor/dql/join/mod.rs index 9a4ddaee..12036197 100644 --- a/src/execution/executor/dql/join/mod.rs +++ b/src/execution/executor/dql/join/mod.rs @@ -10,4 +10,4 @@ pub fn joins_nullable(join_type: &JoinType) -> (bool, bool) { JoinType::Full => (true, true), JoinType::Cross => (true, true), } -} \ No newline at end of file +} diff --git a/src/execution/executor/dql/limit.rs b/src/execution/executor/dql/limit.rs index eedd02e9..ab253fd4 100644 --- a/src/execution/executor/dql/limit.rs +++ b/src/execution/executor/dql/limit.rs @@ -1,15 +1,15 @@ -use futures::StreamExt; -use futures_async_stream::try_stream; use crate::execution::executor::{BoxedExecutor, Executor}; use crate::execution::ExecutorError; use crate::planner::operator::limit::LimitOperator; use crate::storage::Storage; use crate::types::tuple::Tuple; +use futures::StreamExt; +use futures_async_stream::try_stream; pub struct Limit { offset: Option, limit: Option, - input: BoxedExecutor + input: BoxedExecutor, } impl From<(LimitOperator, BoxedExecutor)> for Limit { @@ -31,7 +31,11 @@ impl Executor for Limit { impl Limit { #[try_stream(boxed, ok = Tuple, error = ExecutorError)] pub async fn _execute(self) { - let Limit { offset, limit, input } = self; + let Limit { + offset, + limit, + input, + } = self; if limit.is_some() && limit.unwrap() == 0 { return Ok(()); @@ -43,12 +47,12 @@ impl Limit { #[for_await] for (i, tuple) in input.enumerate() { if i < offset_val { - continue + continue; } else if i > offset_limit { - break + break; } yield tuple?; } } -} \ No newline at end of file +} diff --git a/src/execution/executor/dql/mod.rs b/src/execution/executor/dql/mod.rs index b42e5b74..464255ba 100644 --- a/src/execution/executor/dql/mod.rs +++ b/src/execution/executor/dql/mod.rs @@ -1,23 +1,23 @@ -pub(crate) mod seq_scan; -pub(crate) mod projection; -pub(crate) mod values; -pub(crate) mod filter; -pub(crate) mod sort; -pub(crate) mod limit; -pub(crate) mod join; -pub(crate) mod dummy; pub(crate) mod aggregate; +pub(crate) mod dummy; +pub(crate) mod filter; pub(crate) mod index_scan; +pub(crate) mod join; +pub(crate) mod limit; +pub(crate) mod projection; +pub(crate) mod seq_scan; +pub(crate) mod sort; +pub(crate) mod values; #[cfg(test)] pub(crate) mod test { - use std::sync::Arc; - use itertools::Itertools; use crate::types::value::{DataValue, ValueRef}; + use itertools::Itertools; + use std::sync::Arc; pub(crate) fn build_integers(ints: Vec>) -> Vec { ints.into_iter() .map(|i| Arc::new(DataValue::Int32(i))) .collect_vec() } -} \ No newline at end of file +} diff --git a/src/execution/executor/dql/projection.rs b/src/execution/executor/dql/projection.rs index e9ea87ac..debbf6ad 100644 --- a/src/execution/executor/dql/projection.rs +++ b/src/execution/executor/dql/projection.rs @@ -1,14 +1,14 @@ -use futures_async_stream::try_stream; use crate::execution::executor::{BoxedExecutor, Executor}; use crate::execution::ExecutorError; use crate::expression::ScalarExpression; use crate::planner::operator::project::ProjectOperator; use crate::storage::Storage; use crate::types::tuple::Tuple; +use futures_async_stream::try_stream; pub struct Projection { exprs: Vec, - input: BoxedExecutor + input: BoxedExecutor, } impl From<(ProjectOperator, BoxedExecutor)> for Projection { @@ -43,7 +43,11 @@ impl Projection { columns.push(expr.output_columns(&tuple)); } - yield Tuple { id: None, columns, values, }; + yield Tuple { + id: None, + columns, + values, + }; } } -} \ No newline at end of file +} diff --git a/src/execution/executor/dql/seq_scan.rs b/src/execution/executor/dql/seq_scan.rs index 5df79669..bb9f7daf 100644 --- a/src/execution/executor/dql/seq_scan.rs +++ b/src/execution/executor/dql/seq_scan.rs @@ -1,19 +1,17 @@ -use futures_async_stream::try_stream; use crate::execution::executor::{BoxedExecutor, Executor}; use crate::execution::ExecutorError; use crate::planner::operator::scan::ScanOperator; -use crate::storage::{Transaction, Iter, Storage}; +use crate::storage::{Iter, Storage, Transaction}; use crate::types::tuple::Tuple; +use futures_async_stream::try_stream; pub(crate) struct SeqScan { - op: ScanOperator + op: ScanOperator, } impl From for SeqScan { fn from(op: ScanOperator) -> Self { - SeqScan { - op, - } + SeqScan { op } } } @@ -26,17 +24,19 @@ impl Executor for SeqScan { impl SeqScan { #[try_stream(boxed, ok = Tuple, error = ExecutorError)] pub async fn _execute(self, storage: S) { - let ScanOperator { table_name, columns, limit, .. } = self.op; + let ScanOperator { + table_name, + columns, + limit, + .. + } = self.op; if let Some(transaction) = storage.transaction(&table_name).await { - let mut iter = transaction.read( - limit, - columns - )?; + let mut iter = transaction.read(limit, columns)?; - while let Some(tuple) = iter.next_tuple()? { + while let Some(tuple) = iter.next_tuple()? { yield tuple; } } } -} \ No newline at end of file +} diff --git a/src/execution/executor/dql/sort.rs b/src/execution/executor/dql/sort.rs index 22efcd2c..355edb83 100644 --- a/src/execution/executor/dql/sort.rs +++ b/src/execution/executor/dql/sort.rs @@ -1,15 +1,15 @@ -use std::cmp::Ordering; -use futures_async_stream::try_stream; use crate::execution::executor::{BoxedExecutor, Executor}; use crate::execution::ExecutorError; use crate::planner::operator::sort::{SortField, SortOperator}; use crate::storage::Storage; use crate::types::tuple::Tuple; +use futures_async_stream::try_stream; +use std::cmp::Ordering; pub struct Sort { sort_fields: Vec, limit: Option, - input: BoxedExecutor + input: BoxedExecutor, } impl From<(SortOperator, BoxedExecutor)> for Sort { @@ -31,7 +31,11 @@ impl Executor for Sort { impl Sort { #[try_stream(boxed, ok = Tuple, error = ExecutorError)] pub async fn _execute(self) { - let Sort { sort_fields, limit, input } = self; + let Sort { + sort_fields, + limit, + input, + } = self; let mut tuples: Vec = vec![]; #[for_await] @@ -42,23 +46,41 @@ impl Sort { tuples.sort_by(|tuple_1, tuple_2| { let mut ordering = Ordering::Equal; - for SortField { expr, asc, nulls_first } in &sort_fields { + for SortField { + expr, + asc, + nulls_first, + } in &sort_fields + { let value_1 = expr.eval_column(tuple_1).unwrap(); let value_2 = expr.eval_column(tuple_2).unwrap(); - ordering = value_1.partial_cmp(&value_2) - .unwrap_or_else(|| match (value_1.is_null(), value_2.is_null()) { - (false, true) => if *nulls_first { Ordering::Less } else { Ordering::Greater }, - (true, false) => if *nulls_first { Ordering::Greater } else { Ordering::Less }, + ordering = value_1.partial_cmp(&value_2).unwrap_or_else(|| { + match (value_1.is_null(), value_2.is_null()) { + (false, true) => { + if *nulls_first { + Ordering::Less + } else { + Ordering::Greater + } + } + (true, false) => { + if *nulls_first { + Ordering::Greater + } else { + Ordering::Less + } + } _ => Ordering::Equal, - }); + } + }); if !*asc { ordering = ordering.reverse(); } if ordering != Ordering::Equal { - break + break; } } diff --git a/src/execution/executor/dql/values.rs b/src/execution/executor/dql/values.rs index c7841785..c94b5f37 100644 --- a/src/execution/executor/dql/values.rs +++ b/src/execution/executor/dql/values.rs @@ -1,19 +1,17 @@ -use futures_async_stream::try_stream; use crate::execution::executor::{BoxedExecutor, Executor}; use crate::execution::ExecutorError; use crate::planner::operator::values::ValuesOperator; use crate::storage::Storage; use crate::types::tuple::Tuple; +use futures_async_stream::try_stream; pub struct Values { - op: ValuesOperator + op: ValuesOperator, } impl From for Values { fn from(op: ValuesOperator) -> Self { - Values { - op - } + Values { op } } } @@ -34,6 +32,6 @@ impl Values { columns: columns.clone(), values, }; - }; + } } -} \ No newline at end of file +} diff --git a/src/execution/executor/mod.rs b/src/execution/executor/mod.rs index f586aa37..0f4dc86a 100644 --- a/src/execution/executor/mod.rs +++ b/src/execution/executor/mod.rs @@ -1,10 +1,8 @@ +pub(crate) mod ddl; +pub(crate) mod dml; pub(crate) mod dql; -pub(crate)mod ddl; -pub(crate)mod dml; pub(crate) mod show; -use futures::stream::BoxStream; -use futures::TryStreamExt; use crate::execution::executor::ddl::create_table::CreateTable; use crate::execution::executor::ddl::drop_table::DropTable; use crate::execution::executor::ddl::truncate::Truncate; @@ -25,10 +23,12 @@ use crate::execution::executor::dql::sort::Sort; use crate::execution::executor::dql::values::Values; use crate::execution::executor::show::show_table::ShowTables; use crate::execution::ExecutorError; -use crate::planner::LogicalPlan; use crate::planner::operator::Operator; +use crate::planner::LogicalPlan; use crate::storage::Storage; use crate::types::tuple::Tuple; +use futures::stream::BoxStream; +use futures::TryStreamExt; pub type BoxedExecutor = BoxStream<'static, Result>; @@ -37,10 +37,13 @@ pub trait Executor { } pub fn build(plan: LogicalPlan, storage: &S) -> BoxedExecutor { - let LogicalPlan { operator, mut childrens } = plan; + let LogicalPlan { + operator, + mut childrens, + } = plan; match operator { - Operator::Dummy => Dummy{ }.execute(storage), + Operator::Dummy => Dummy {}.execute(storage), Operator::Aggregate(op) => { let input = build(childrens.remove(0), storage); @@ -131,4 +134,4 @@ pub async fn try_collect(executor: &mut BoxedExecutor) -> Result, Exe output.push(tuple); } Ok(output) -} \ No newline at end of file +} diff --git a/src/execution/executor/show/mod.rs b/src/execution/executor/show/mod.rs index 50ed8e8a..edc17c14 100644 --- a/src/execution/executor/show/mod.rs +++ b/src/execution/executor/show/mod.rs @@ -1,4 +1 @@ pub(crate) mod show_table; - - - diff --git a/src/execution/executor/show/show_table.rs b/src/execution/executor/show/show_table.rs index 8641cc0f..bff78db4 100644 --- a/src/execution/executor/show/show_table.rs +++ b/src/execution/executor/show/show_table.rs @@ -1,13 +1,13 @@ -use futures_async_stream::try_stream; +use crate::catalog::ColumnCatalog; +use crate::catalog::ColumnRef; use crate::execution::executor::{BoxedExecutor, Executor}; use crate::execution::ExecutorError; use crate::planner::operator::show::ShowTablesOperator; use crate::storage::Storage; use crate::types::tuple::Tuple; -use crate::catalog::ColumnCatalog; -use crate::catalog::ColumnRef; -use std::sync::Arc; use crate::types::value::{DataValue, ValueRef}; +use futures_async_stream::try_stream; +use std::sync::Arc; pub struct ShowTables { _op: ShowTablesOperator, @@ -15,9 +15,7 @@ pub struct ShowTables { impl From for ShowTables { fn from(op: ShowTablesOperator) -> Self { - ShowTables { - _op: op - } + ShowTables { _op: op } } } @@ -33,12 +31,9 @@ impl ShowTables { let tables = storage.show_tables().await?; for table in tables { - let columns: Vec = vec![ - Arc::new(ColumnCatalog::new_dummy("TABLES".to_string())), - ]; - let values: Vec = vec![ - Arc::new(DataValue::Utf8(Some(table))), - ]; + let columns: Vec = + vec![Arc::new(ColumnCatalog::new_dummy("TABLES".to_string()))]; + let values: Vec = vec![Arc::new(DataValue::Utf8(Some(table)))]; yield Tuple { id: None, @@ -47,4 +42,4 @@ impl ShowTables { }; } } -} \ No newline at end of file +} diff --git a/src/execution/mod.rs b/src/execution/mod.rs index 49bc412c..77764311 100644 --- a/src/execution/mod.rs +++ b/src/execution/mod.rs @@ -1,10 +1,10 @@ pub mod executor; -use sqlparser::parser::ParserError; use crate::binder::BindError; use crate::catalog::CatalogError; use crate::storage::StorageError; use crate::types::errors::TypeError; +use sqlparser::parser::ParserError; #[derive(thiserror::Error, Debug)] pub enum ExecutorError { @@ -24,19 +24,19 @@ pub enum ExecutorError { StorageError( #[source] #[from] - StorageError + StorageError, ), #[error("bind error: {0}")] BindError( #[source] #[from] - BindError + BindError, ), #[error("parser error: {0}")] ParserError( #[source] #[from] - ParserError + ParserError, ), #[error("Internal error: {0}")] InternalError(String), diff --git a/src/expression/evaluator.rs b/src/expression/evaluator.rs index 457f7982..f30eba4f 100644 --- a/src/expression/evaluator.rs +++ b/src/expression/evaluator.rs @@ -1,16 +1,14 @@ -use std::sync::Arc; -use itertools::Itertools; -use lazy_static::lazy_static; use crate::expression::value_compute::{binary_op, unary_op}; use crate::expression::ScalarExpression; use crate::types::errors::TypeError; use crate::types::tuple::Tuple; use crate::types::value::{DataValue, ValueRef}; +use itertools::Itertools; +use lazy_static::lazy_static; +use std::sync::Arc; lazy_static! { - static ref NULL_VALUE: ValueRef = { - Arc::new(DataValue::Null) - }; + static ref NULL_VALUE: ValueRef = { Arc::new(DataValue::Null) }; } impl ScalarExpression { @@ -23,37 +21,42 @@ impl ScalarExpression { .clone(); Ok(value) - }, - ScalarExpression::InputRef{ index, .. } => Ok(tuple.values[*index].clone()), - ScalarExpression::Alias{ expr, alias } => { + } + ScalarExpression::InputRef { index, .. } => Ok(tuple.values[*index].clone()), + ScalarExpression::Alias { expr, alias } => { if let Some(value) = Self::eval_with_name(&tuple, alias) { return Ok(value.clone()); } expr.eval_column(tuple) - }, - ScalarExpression::TypeCast{ expr, ty, .. } => { + } + ScalarExpression::TypeCast { expr, ty, .. } => { let value = expr.eval_column(tuple)?; Ok(Arc::new(DataValue::clone(&value).cast(ty)?)) } - ScalarExpression::Binary{ left_expr, right_expr, op, .. } => { + ScalarExpression::Binary { + left_expr, + right_expr, + op, + .. + } => { let left = left_expr.eval_column(tuple)?; let right = right_expr.eval_column(tuple)?; Ok(Arc::new(binary_op(&left, &right, op)?)) } - ScalarExpression::IsNull{ expr } => { + ScalarExpression::IsNull { expr } => { let value = expr.eval_column(tuple)?; Ok(Arc::new(DataValue::Boolean(Some(value.is_null())))) - }, - ScalarExpression::Unary{ expr, op, .. } => { + } + ScalarExpression::Unary { expr, op, .. } => { let value = expr.eval_column(tuple)?; Ok(Arc::new(unary_op(&value, op)?)) - }, - ScalarExpression::AggCall{ .. } => todo!() + } + ScalarExpression::AggCall { .. } => todo!(), } } @@ -64,4 +67,4 @@ impl ScalarExpression { .find_position(|tul_col| &tul_col.name == name) .map(|(i, _)| &tuple.values[i]) } -} \ No newline at end of file +} diff --git a/src/expression/mod.rs b/src/expression/mod.rs index 31daf75c..bd6cf28d 100644 --- a/src/expression/mod.rs +++ b/src/expression/mod.rs @@ -1,23 +1,23 @@ +use itertools::Itertools; +use serde::{Deserialize, Serialize}; use std::fmt; use std::fmt::{Debug, Formatter}; use std::sync::Arc; -use itertools::Itertools; -use serde::{Deserialize, Serialize}; -use sqlparser::ast::{BinaryOperator as SqlBinaryOperator, UnaryOperator as SqlUnaryOperator}; use crate::binder::BinderContext; +use sqlparser::ast::{BinaryOperator as SqlBinaryOperator, UnaryOperator as SqlUnaryOperator}; use self::agg::AggKind; use crate::catalog::{ColumnCatalog, ColumnDesc, ColumnRef}; use crate::storage::Storage; +use crate::types::tuple::Tuple; use crate::types::value::ValueRef; use crate::types::LogicalType; -use crate::types::tuple::Tuple; pub mod agg; mod evaluator; -pub mod value_compute; pub mod simplify; +pub mod value_compute; /// ScalarExpression represnet all scalar expression in SQL. /// SELECT a+1, b FROM t1. @@ -79,9 +79,12 @@ impl ScalarExpression { ScalarExpression::TypeCast { expr, .. } => expr.nullable(), ScalarExpression::IsNull { expr } => expr.nullable(), ScalarExpression::Unary { expr, .. } => expr.nullable(), - ScalarExpression::Binary { left_expr, right_expr, .. } => - left_expr.nullable() && right_expr.nullable(), - ScalarExpression::AggCall { args, .. } => args[0].nullable() + ScalarExpression::Binary { + left_expr, + right_expr, + .. + } => left_expr.nullable() && right_expr.nullable(), + ScalarExpression::AggCall { args, .. } => args[0].nullable(), } } @@ -115,19 +118,15 @@ impl ScalarExpression { ScalarExpression::ColumnRef(col) => { vec.push(col.clone()); } - ScalarExpression::Alias { expr, .. } => { - columns_collect(&expr, vec) - } - ScalarExpression::TypeCast { expr, .. } => { - columns_collect(&expr, vec) - } - ScalarExpression::IsNull { expr, .. } => { - columns_collect(&expr, vec) - } - ScalarExpression::Unary { expr, .. } => { - columns_collect(&expr, vec) - } - ScalarExpression::Binary { left_expr, right_expr, .. } => { + ScalarExpression::Alias { expr, .. } => columns_collect(&expr, vec), + ScalarExpression::TypeCast { expr, .. } => columns_collect(&expr, vec), + ScalarExpression::IsNull { expr, .. } => columns_collect(&expr, vec), + ScalarExpression::Unary { expr, .. } => columns_collect(&expr, vec), + ScalarExpression::Binary { + left_expr, + right_expr, + .. + } => { columns_collect(left_expr, vec); columns_collect(right_expr, vec); } @@ -149,9 +148,7 @@ impl ScalarExpression { pub fn has_agg_call(&self, context: &BinderContext) -> bool { match self { - ScalarExpression::InputRef { index, .. } => { - context.agg_calls.get(*index).is_some() - }, + ScalarExpression::InputRef { index, .. } => context.agg_calls.get(*index).is_some(), ScalarExpression::AggCall { .. } => unreachable!(), ScalarExpression::Constant(_) => false, ScalarExpression::ColumnRef(_) => false, @@ -159,35 +156,37 @@ impl ScalarExpression { ScalarExpression::TypeCast { expr, .. } => expr.has_agg_call(context), ScalarExpression::IsNull { expr, .. } => expr.has_agg_call(context), ScalarExpression::Unary { expr, .. } => expr.has_agg_call(context), - ScalarExpression::Binary { left_expr, right_expr, .. } => { - left_expr.has_agg_call(context) || right_expr.has_agg_call(context) - } + ScalarExpression::Binary { + left_expr, + right_expr, + .. + } => left_expr.has_agg_call(context) || right_expr.has_agg_call(context), } } pub fn output_columns(&self, tuple: &Tuple) -> ColumnRef { match self { - ScalarExpression::ColumnRef(col) => { - col.clone() - } - ScalarExpression::Constant(value) => { - Arc::new(ColumnCatalog::new( - format!("{}", value), - true, - ColumnDesc::new(value.logical_type(), false, false), - Some(self.clone()) - )) - } - ScalarExpression::Alias { expr, alias } => { - Arc::new(ColumnCatalog::new( - alias.to_string(), - true, - ColumnDesc::new(expr.return_type(), false, false), - Some(self.clone()) - )) - } - ScalarExpression::AggCall { kind, args, ty, distinct } => { - let args_str = args.iter() + ScalarExpression::ColumnRef(col) => col.clone(), + ScalarExpression::Constant(value) => Arc::new(ColumnCatalog::new( + format!("{}", value), + true, + ColumnDesc::new(value.logical_type(), false, false), + Some(self.clone()), + )), + ScalarExpression::Alias { expr, alias } => Arc::new(ColumnCatalog::new( + alias.to_string(), + true, + ColumnDesc::new(expr.return_type(), false, false), + Some(self.clone()), + )), + ScalarExpression::AggCall { + kind, + args, + ty, + distinct, + } => { + let args_str = args + .iter() .map(|expr| expr.output_columns(tuple).name.clone()) .join(", "); let op = |allow_distinct, distinct| { @@ -208,17 +207,15 @@ impl ScalarExpression { column_name, true, ColumnDesc::new(ty.clone(), false, false), - Some(self.clone()) + Some(self.clone()), )) } - ScalarExpression::InputRef { index, .. } => { - tuple.columns[*index].clone() - } + ScalarExpression::InputRef { index, .. } => tuple.columns[*index].clone(), ScalarExpression::Binary { left_expr, right_expr, op, - ty + ty, } => { let column_name = format!( "({} {} {})", @@ -231,27 +228,19 @@ impl ScalarExpression { column_name, true, ColumnDesc::new(ty.clone(), false, false), - Some(self.clone()) + Some(self.clone()), )) } - ScalarExpression::Unary { - expr, - op, - ty - } => { - let column_name = format!( - "{} {}", - op, - expr.output_columns(tuple).name, - ); + ScalarExpression::Unary { expr, op, ty } => { + let column_name = format!("{} {}", op, expr.output_columns(tuple).name,); Arc::new(ColumnCatalog::new( column_name, true, ColumnDesc::new(ty.clone(), false, false), - Some(self.clone()) + Some(self.clone()), )) - }, - _ => unreachable!() + } + _ => unreachable!(), } } } @@ -266,10 +255,10 @@ pub enum UnaryOperator { impl From for UnaryOperator { fn from(value: SqlUnaryOperator) -> Self { match value { - SqlUnaryOperator::Plus => UnaryOperator::Plus, + SqlUnaryOperator::Plus => UnaryOperator::Plus, SqlUnaryOperator::Minus => UnaryOperator::Minus, SqlUnaryOperator::Not => UnaryOperator::Not, - _ => unimplemented!("not support!") + _ => unimplemented!("not support!"), } } } @@ -325,7 +314,7 @@ impl fmt::Display for UnaryOperator { match self { UnaryOperator::Plus => write!(f, "+"), UnaryOperator::Minus => write!(f, "-"), - UnaryOperator::Not => write!(f, "not") + UnaryOperator::Not => write!(f, "not"), } } } @@ -349,7 +338,7 @@ impl From for BinaryOperator { SqlBinaryOperator::And => BinaryOperator::And, SqlBinaryOperator::Or => BinaryOperator::Or, SqlBinaryOperator::Xor => BinaryOperator::Xor, - _ => unimplemented!("not support!") + _ => unimplemented!("not support!"), } } } diff --git a/src/expression/simplify.rs b/src/expression/simplify.rs index 36d1920f..25a7be94 100644 --- a/src/expression/simplify.rs +++ b/src/expression/simplify.rs @@ -1,21 +1,21 @@ -use std::cmp::Ordering; -use std::collections::{Bound, HashSet}; -use std::mem; -use std::sync::Arc; -use ahash::RandomState; -use itertools::Itertools; use crate::catalog::ColumnRef; -use crate::expression::{BinaryOperator, ScalarExpression, UnaryOperator}; use crate::expression::value_compute::{binary_op, unary_op}; -use crate::types::{ColumnId, LogicalType}; +use crate::expression::{BinaryOperator, ScalarExpression, UnaryOperator}; use crate::types::errors::TypeError; use crate::types::value::{DataValue, ValueRef}; +use crate::types::{ColumnId, LogicalType}; +use ahash::RandomState; +use itertools::Itertools; +use std::cmp::Ordering; +use std::collections::{Bound, HashSet}; +use std::mem; +use std::sync::Arc; #[derive(Debug, PartialEq, Clone)] pub enum ConstantBinary { Scope { min: Bound, - max: Bound + max: Bound, }, Eq(ValueRef), NotEq(ValueRef), @@ -23,7 +23,7 @@ pub enum ConstantBinary { // ConstantBinary in And can only be Scope\Eq\NotEq And(Vec), // ConstantBinary in Or can only be Scope\Eq\NotEq\And - Or(Vec) + Or(Vec), } impl ConstantBinary { @@ -43,7 +43,7 @@ impl ConstantBinary { } Ok(matches!((min, max), (Bound::Unbounded, Bound::Unbounded))) - }, + } ConstantBinary::Eq(val) | ConstantBinary::NotEq(val) => Ok(val.is_null()), _ => Err(TypeError::InvalidType), } @@ -60,35 +60,35 @@ impl ConstantBinary { ConstantBinary::And(mut and_binaries) => { condition_binaries.append(&mut and_binaries); } - ConstantBinary::Scope { min: Bound::Unbounded, max: Bound::Unbounded } => (), + ConstantBinary::Scope { + min: Bound::Unbounded, + max: Bound::Unbounded, + } => (), source => condition_binaries.push(source), } } // Sort condition_binaries.sort_by(|a, b| { - let op = |binary: &ConstantBinary| { - match binary { - ConstantBinary::Scope { min, .. } => min.clone(), - ConstantBinary::Eq(val) => Bound::Included(val.clone()), - ConstantBinary::NotEq(val) => Bound::Excluded(val.clone()), - _ => unreachable!() - } + let op = |binary: &ConstantBinary| match binary { + ConstantBinary::Scope { min, .. } => min.clone(), + ConstantBinary::Eq(val) => Bound::Included(val.clone()), + ConstantBinary::NotEq(val) => Bound::Excluded(val.clone()), + _ => unreachable!(), }; - Self::bound_compared(&op(a), &op(b), true) - .unwrap_or(Ordering::Equal) + Self::bound_compared(&op(a), &op(b), true).unwrap_or(Ordering::Equal) }); let mut merged_binaries: Vec = Vec::new(); for condition in condition_binaries { - let op = |binary: &ConstantBinary| { - match binary { - ConstantBinary::Scope { min, max } => (min.clone(), max.clone()), - ConstantBinary::Eq(val) => (Bound::Unbounded, Bound::Included(val.clone())), - ConstantBinary::NotEq(val) => (Bound::Unbounded, Bound::Excluded(val.clone())), - _ => unreachable!() + let op = |binary: &ConstantBinary| match binary { + ConstantBinary::Scope { min, max } => (min.clone(), max.clone()), + ConstantBinary::Eq(val) => (Bound::Unbounded, Bound::Included(val.clone())), + ConstantBinary::NotEq(val) => { + (Bound::Unbounded, Bound::Excluded(val.clone())) } + _ => unreachable!(), }; let mut is_push = merged_binaries.is_empty(); @@ -104,13 +104,13 @@ impl ConstantBinary { if !is_lt_min && is_lt_max { let _ = mem::replace(max, condition_max); - } else if !matches!(condition, ConstantBinary::Scope {..}) { + } else if !matches!(condition, ConstantBinary::Scope { .. }) { is_push = is_lt_max; } else if is_lt_min && is_lt_max { is_push = true } - break + break; } } @@ -120,7 +120,7 @@ impl ConstantBinary { } Ok(merged_binaries) - }, + } ConstantBinary::And(binaries) => Ok(binaries), source => Ok(vec![source]), } @@ -133,13 +133,17 @@ impl ConstantBinary { binary.scope_aggregation()? } } - binary => binary._scope_aggregation()? + binary => binary._scope_aggregation()?, } Ok(()) } - fn bound_compared(left_bound: &Bound, right_bound: &Bound, is_min: bool) -> Option { + fn bound_compared( + left_bound: &Bound, + right_bound: &Bound, + is_min: bool, + ) -> Option { let op = |is_min, order: Ordering| { if is_min { order @@ -153,15 +157,13 @@ impl ConstantBinary { (Bound::Unbounded, _) => Some(op(is_min, Ordering::Less)), (_, Bound::Unbounded) => Some(op(is_min, Ordering::Greater)), (Bound::Included(left), Bound::Included(right)) => left.partial_cmp(right), - (Bound::Included(left), Bound::Excluded(right)) => { - left.partial_cmp(right) - .map(|order| order.then(op(is_min, Ordering::Less))) - }, + (Bound::Included(left), Bound::Excluded(right)) => left + .partial_cmp(right) + .map(|order| order.then(op(is_min, Ordering::Less))), (Bound::Excluded(left), Bound::Excluded(right)) => left.partial_cmp(right), - (Bound::Excluded(left), Bound::Included(right)) => { - left.partial_cmp(right) - .map(|order| order.then(op(is_min, Ordering::Greater))) - }, + (Bound::Excluded(left), Bound::Included(right)) => left + .partial_cmp(right) + .map(|order| order.then(op(is_min, Ordering::Greater))), } } @@ -172,13 +174,11 @@ impl ConstantBinary { let mut scope_max = Bound::Unbounded; let mut eq_set = HashSet::with_hasher(RandomState::new()); - let sort_op = |binary: &&ConstantBinary| { - match binary { - ConstantBinary::Scope { .. } => 3, - ConstantBinary::NotEq(_) => 2, - ConstantBinary::Eq(_) => 1, - ConstantBinary::And(_) | ConstantBinary::Or(_) => 0 - } + let sort_op = |binary: &&ConstantBinary| match binary { + ConstantBinary::Scope { .. } => 3, + ConstantBinary::NotEq(_) => 2, + ConstantBinary::Eq(_) => 1, + ConstantBinary::And(_) | ConstantBinary::Or(_) => 0, }; // Aggregate various ranges to get the minimum range @@ -186,7 +186,9 @@ impl ConstantBinary { match binary { ConstantBinary::Scope { min, max } => { // Skip if eq or noteq exists - if !eq_set.is_empty() { continue } + if !eq_set.is_empty() { + continue; + } if let Some(order) = Self::bound_compared(&scope_min, &min, true) { if order.is_lt() { @@ -202,22 +204,28 @@ impl ConstantBinary { } ConstantBinary::Eq(val) => { let _ = eq_set.insert(val.clone()); - }, + } ConstantBinary::NotEq(val) => { let _ = eq_set.remove(val); - }, - ConstantBinary::Or(_) | ConstantBinary::And(_) => return Err(TypeError::InvalidType) + } + ConstantBinary::Or(_) | ConstantBinary::And(_) => { + return Err(TypeError::InvalidType) + } } } - let eq_option = eq_set.into_iter() + let eq_option = eq_set + .into_iter() .sorted_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal)) .next() .map(|val| ConstantBinary::Eq(val)); if let Some(eq) = eq_option { let _ = mem::replace(self, eq); - } else if !matches!((&scope_min, &scope_max), (Bound::Unbounded, Bound::Unbounded)) { + } else if !matches!( + (&scope_min, &scope_max), + (Bound::Unbounded, Bound::Unbounded) + ) { let scope_binary = ConstantBinary::Scope { min: scope_min, max: scope_max, @@ -245,7 +253,7 @@ struct ReplaceBinary { val_expr: ScalarExpression, op: BinaryOperator, ty: LogicalType, - is_column_left: bool + is_column_left: bool, } #[derive(Debug)] @@ -263,10 +271,12 @@ impl ScalarExpression { ScalarExpression::TypeCast { expr, .. } => expr.exist_column(col_id), ScalarExpression::IsNull { expr } => expr.exist_column(col_id), ScalarExpression::Unary { expr, .. } => expr.exist_column(col_id), - ScalarExpression::Binary { left_expr, right_expr, .. } => { - left_expr.exist_column(col_id) || right_expr.exist_column(col_id) - } - _ => false + ScalarExpression::Binary { + left_expr, + right_expr, + .. + } => left_expr.exist_column(col_id) || right_expr.exist_column(col_id), + _ => false, } } @@ -274,32 +284,31 @@ impl ScalarExpression { match self { ScalarExpression::Constant(val) => Some(val.clone()), ScalarExpression::Alias { expr, .. } => expr.unpack_val(), - ScalarExpression::TypeCast { expr, ty, .. } => { - expr.unpack_val() - .and_then(|val| DataValue::clone(&val) - .cast(ty).ok() - .map(Arc::new)) - } + ScalarExpression::TypeCast { expr, ty, .. } => expr + .unpack_val() + .and_then(|val| DataValue::clone(&val).cast(ty).ok().map(Arc::new)), ScalarExpression::IsNull { expr } => { let is_null = expr.unpack_val().map(|val| val.is_null()); Some(Arc::new(DataValue::Boolean(is_null))) - }, + } ScalarExpression::Unary { expr, op, .. } => { let val = expr.unpack_val()?; - unary_op(&val, op).ok() - .map(Arc::new) - + unary_op(&val, op).ok().map(Arc::new) } - ScalarExpression::Binary { left_expr, right_expr, op, .. } => { + ScalarExpression::Binary { + left_expr, + right_expr, + op, + .. + } => { let left = left_expr.unpack_val()?; let right = right_expr.unpack_val()?; - binary_op(&left, &right, op).ok() - .map(Arc::new) + binary_op(&left, &right, op).ok().map(Arc::new) } - _ => None + _ => None, } } @@ -308,15 +317,20 @@ impl ScalarExpression { ScalarExpression::ColumnRef(col) => Some(col.clone()), ScalarExpression::Alias { expr, .. } => expr.unpack_col(is_deep), ScalarExpression::Unary { expr, .. } => expr.unpack_col(is_deep), - ScalarExpression::Binary { left_expr, right_expr, .. } => { + ScalarExpression::Binary { + left_expr, + right_expr, + .. + } => { if !is_deep { return None; } - left_expr.unpack_col(true) + left_expr + .unpack_col(true) .or_else(|| right_expr.unpack_col(true)) } - _ => None + _ => None, } } @@ -327,7 +341,12 @@ impl ScalarExpression { // Tips: Indirect expressions like `ScalarExpression::Alias` will be lost fn _simplify(&mut self, replaces: &mut Vec) -> Result<(), TypeError> { match self { - ScalarExpression::Binary { left_expr, right_expr, op, ty } => { + ScalarExpression::Binary { + left_expr, + right_expr, + op, + ty, + } => { Self::fix_expr(replaces, left_expr, right_expr, op)?; // `(c1 - 1) and (c1 + 2)` cannot fix! @@ -336,7 +355,7 @@ impl ScalarExpression { if Self::is_arithmetic(op) { match (left_expr.unpack_col(false), right_expr.unpack_col(false)) { (Some(col), None) => { - replaces.push(Replace::Binary(ReplaceBinary{ + replaces.push(Replace::Binary(ReplaceBinary { column_expr: ScalarExpression::ColumnRef(col), val_expr: right_expr.as_ref().clone(), op: *op, @@ -345,7 +364,7 @@ impl ScalarExpression { })); } (None, Some(col)) => { - replaces.push(Replace::Binary(ReplaceBinary{ + replaces.push(Replace::Binary(ReplaceBinary { column_expr: ScalarExpression::ColumnRef(col), val_expr: left_expr.as_ref().clone(), op: *op, @@ -355,12 +374,12 @@ impl ScalarExpression { } (None, None) => { if replaces.is_empty() { - return Ok(()); + return Ok(()); } match (left_expr.unpack_col(true), right_expr.unpack_col(true)) { (Some(col), None) => { - replaces.push(Replace::Binary(ReplaceBinary{ + replaces.push(Replace::Binary(ReplaceBinary { column_expr: ScalarExpression::ColumnRef(col), val_expr: right_expr.as_ref().clone(), op: *op, @@ -369,7 +388,7 @@ impl ScalarExpression { })); } (None, Some(col)) => { - replaces.push(Replace::Binary(ReplaceBinary{ + replaces.push(Replace::Binary(ReplaceBinary { column_expr: ScalarExpression::ColumnRef(col), val_expr: left_expr.as_ref().clone(), op: *op, @@ -380,7 +399,7 @@ impl ScalarExpression { _ => (), } } - _ => () + _ => (), } } } @@ -389,41 +408,43 @@ impl ScalarExpression { if let Some(val) = expr.unpack_val() { let _ = mem::replace(self, ScalarExpression::Constant(val)); } - }, + } ScalarExpression::IsNull { expr, .. } => { if let Some(val) = expr.unpack_val() { - let _ = mem::replace(self, ScalarExpression::Constant( - Arc::new(DataValue::Boolean(Some(val.is_null()))) - )); + let _ = mem::replace( + self, + ScalarExpression::Constant(Arc::new(DataValue::Boolean(Some( + val.is_null(), + )))), + ); } - }, + } ScalarExpression::Unary { expr, op, ty } => { if let Some(val) = expr.unpack_val() { - let new_expr = ScalarExpression::Constant( - Arc::new(unary_op(&val, op)?) - ); + let new_expr = ScalarExpression::Constant(Arc::new(unary_op(&val, op)?)); let _ = mem::replace(self, new_expr); } else { - let _ = replaces.push(Replace::Unary( - ReplaceUnary { - child_expr: expr.as_ref().clone(), - op: *op, - ty: *ty, - } - )); + let _ = replaces.push(Replace::Unary(ReplaceUnary { + child_expr: expr.as_ref().clone(), + op: *op, + ty: *ty, + })); } - }, - _ => () + } + _ => (), } Ok(()) } fn is_arithmetic(op: &mut BinaryOperator) -> bool { - matches!(op, BinaryOperator::Plus - | BinaryOperator::Divide - | BinaryOperator::Minus - | BinaryOperator::Multiply) + matches!( + op, + BinaryOperator::Plus + | BinaryOperator::Divide + | BinaryOperator::Minus + | BinaryOperator::Multiply + ) } fn fix_expr( @@ -440,13 +461,11 @@ impl ScalarExpression { while let Some(replace) = replaces.pop() { match replace { - Replace::Binary(binary) => { - Self::fix_binary(binary, left_expr, right_expr, op) - }, + Replace::Binary(binary) => Self::fix_binary(binary, left_expr, right_expr, op), Replace::Unary(unary) => { Self::fix_unary(unary, left_expr, right_expr, op); Self::fix_expr(replaces, left_expr, right_expr, op)?; - }, + } } } @@ -457,19 +476,27 @@ impl ScalarExpression { replace_unary: ReplaceUnary, col_expr: &mut Box, val_expr: &mut Box, - op: &mut BinaryOperator + op: &mut BinaryOperator, ) { - let ReplaceUnary { child_expr, op: fix_op, ty: fix_ty } = replace_unary; - let _ = mem::replace(col_expr, Box::new(child_expr)); - let _ = mem::replace(val_expr, Box::new(ScalarExpression::Unary { + let ReplaceUnary { + child_expr, op: fix_op, - expr: val_expr.clone(), ty: fix_ty, - })); - let _ = mem::replace(op, match fix_op { - UnaryOperator::Plus => *op, - UnaryOperator::Minus => { - match *op { + } = replace_unary; + let _ = mem::replace(col_expr, Box::new(child_expr)); + let _ = mem::replace( + val_expr, + Box::new(ScalarExpression::Unary { + op: fix_op, + expr: val_expr.clone(), + ty: fix_ty, + }), + ); + let _ = mem::replace( + op, + match fix_op { + UnaryOperator::Plus => *op, + UnaryOperator::Minus => match *op { BinaryOperator::Plus => BinaryOperator::Minus, BinaryOperator::Minus => BinaryOperator::Plus, BinaryOperator::Multiply => BinaryOperator::Divide, @@ -478,45 +505,45 @@ impl ScalarExpression { BinaryOperator::Lt => BinaryOperator::Gt, BinaryOperator::GtEq => BinaryOperator::LtEq, BinaryOperator::LtEq => BinaryOperator::GtEq, - source_op => source_op - } - } - UnaryOperator::Not => { - match *op { + source_op => source_op, + }, + UnaryOperator::Not => match *op { BinaryOperator::Gt => BinaryOperator::Lt, BinaryOperator::Lt => BinaryOperator::Gt, BinaryOperator::GtEq => BinaryOperator::LtEq, BinaryOperator::LtEq => BinaryOperator::GtEq, - source_op => source_op - } - } - }); + source_op => source_op, + }, + }, + ); } fn fix_binary( replace_binary: ReplaceBinary, left_expr: &mut Box, right_expr: &mut Box, - op: &mut BinaryOperator + op: &mut BinaryOperator, ) { - let ReplaceBinary { column_expr, val_expr, op: fix_op, ty: fix_ty, is_column_left } = replace_binary; - let op_flip = |op: BinaryOperator| { - match op { - BinaryOperator::Plus => BinaryOperator::Minus, - BinaryOperator::Minus => BinaryOperator::Plus, - BinaryOperator::Multiply => BinaryOperator::Divide, - BinaryOperator::Divide => BinaryOperator::Multiply, - _ => unreachable!() - } + let ReplaceBinary { + column_expr, + val_expr, + op: fix_op, + ty: fix_ty, + is_column_left, + } = replace_binary; + let op_flip = |op: BinaryOperator| match op { + BinaryOperator::Plus => BinaryOperator::Minus, + BinaryOperator::Minus => BinaryOperator::Plus, + BinaryOperator::Multiply => BinaryOperator::Divide, + BinaryOperator::Divide => BinaryOperator::Multiply, + _ => unreachable!(), }; - let comparison_flip = |op: BinaryOperator| { - match op { - BinaryOperator::Gt => BinaryOperator::Lt, - BinaryOperator::GtEq => BinaryOperator::LtEq, - BinaryOperator::Lt => BinaryOperator::Gt, - BinaryOperator::LtEq => BinaryOperator::GtEq, - source_op => source_op - } + let comparison_flip = |op: BinaryOperator| match op { + BinaryOperator::Gt => BinaryOperator::Lt, + BinaryOperator::GtEq => BinaryOperator::LtEq, + BinaryOperator::Lt => BinaryOperator::Gt, + BinaryOperator::LtEq => BinaryOperator::GtEq, + source_op => source_op, }; let (fixed_op, fixed_left_expr, fixed_right_expr) = if is_column_left { (op_flip(fix_op), right_expr.clone(), Box::new(val_expr)) @@ -528,12 +555,15 @@ impl ScalarExpression { }; let _ = mem::replace(left_expr, Box::new(column_expr)); - let _ = mem::replace(right_expr, Box::new(ScalarExpression::Binary { - op: fixed_op, - left_expr: fixed_left_expr, - right_expr: fixed_right_expr, - ty: fix_ty, - })); + let _ = mem::replace( + right_expr, + Box::new(ScalarExpression::Binary { + op: fixed_op, + left_expr: fixed_left_expr, + right_expr: fixed_right_expr, + ty: fix_ty, + }), + ); } /// The definition of Or is not the Or in the Where condition. @@ -542,51 +572,51 @@ impl ScalarExpression { /// - `ConstantBinary::Or`: Rearrange and sort the range of each OR data pub fn convert_binary(&self, col_id: &ColumnId) -> Result, TypeError> { match self { - ScalarExpression::Binary { left_expr, right_expr, op, .. } => { - match (left_expr.convert_binary(col_id)?, right_expr.convert_binary(col_id)?) { - (Some(left_binary), Some(right_binary)) => { - match (left_binary, right_binary) { - (ConstantBinary::And(mut left), ConstantBinary::And(mut right)) - | (ConstantBinary::Or(mut left), ConstantBinary::Or(mut right)) => { - left.append(&mut right); - - Ok(Some(ConstantBinary::And(left))) - } - (ConstantBinary::And(mut left), ConstantBinary::Or(mut right)) => { - right.append(&mut left); + ScalarExpression::Binary { + left_expr, + right_expr, + op, + .. + } => { + match ( + left_expr.convert_binary(col_id)?, + right_expr.convert_binary(col_id)?, + ) { + (Some(left_binary), Some(right_binary)) => match (left_binary, right_binary) { + (ConstantBinary::And(mut left), ConstantBinary::And(mut right)) + | (ConstantBinary::Or(mut left), ConstantBinary::Or(mut right)) => { + left.append(&mut right); + + Ok(Some(ConstantBinary::And(left))) + } + (ConstantBinary::And(mut left), ConstantBinary::Or(mut right)) => { + right.append(&mut left); - Ok(Some(ConstantBinary::Or(right))) - } - (ConstantBinary::Or(mut left), ConstantBinary::And(mut right)) => { - left.append(&mut right); + Ok(Some(ConstantBinary::Or(right))) + } + (ConstantBinary::Or(mut left), ConstantBinary::And(mut right)) => { + left.append(&mut right); - Ok(Some(ConstantBinary::Or(left))) - } - (ConstantBinary::And(mut binaries), binary) - | (binary, ConstantBinary::And(mut binaries)) => { - binaries.push(binary); + Ok(Some(ConstantBinary::Or(left))) + } + (ConstantBinary::And(mut binaries), binary) + | (binary, ConstantBinary::And(mut binaries)) => { + binaries.push(binary); - Ok(Some(ConstantBinary::And(binaries))) - } - (ConstantBinary::Or(mut binaries), binary) - | (binary, ConstantBinary::Or(mut binaries)) => { - binaries.push(binary); + Ok(Some(ConstantBinary::And(binaries))) + } + (ConstantBinary::Or(mut binaries), binary) + | (binary, ConstantBinary::Or(mut binaries)) => { + binaries.push(binary); - Ok(Some(ConstantBinary::Or(binaries))) - } - (left, right) => { - match op { - BinaryOperator::And => { - Ok(Some(ConstantBinary::And(vec![left, right]))) - } - BinaryOperator::Or => { - Ok(Some(ConstantBinary::Or(vec![left, right]))) - } - BinaryOperator::Xor => todo!(), - _ => Ok(None) - } - } + Ok(Some(ConstantBinary::Or(binaries))) } + (left, right) => match op { + BinaryOperator::And => Ok(Some(ConstantBinary::And(vec![left, right]))), + BinaryOperator::Or => Ok(Some(ConstantBinary::Or(vec![left, right]))), + BinaryOperator::Xor => todo!(), + _ => Ok(None), + }, }, (None, None) => { if let (Some(col), Some(val)) = @@ -605,7 +635,7 @@ impl ScalarExpression { (Some(binary), None) => Ok(Self::check_or(col_id, right_expr, op, binary)), (None, Some(binary)) => Ok(Self::check_or(col_id, left_expr, op, binary)), } - }, + } ScalarExpression::Alias { expr, .. } => expr.convert_binary(col_id), ScalarExpression::TypeCast { expr, .. } => expr.convert_binary(col_id), ScalarExpression::IsNull { expr } => expr.convert_binary(col_id), @@ -620,16 +650,22 @@ impl ScalarExpression { col_id: &ColumnId, right_expr: &Box, op: &BinaryOperator, - binary: ConstantBinary + binary: ConstantBinary, ) -> Option { if matches!(op, BinaryOperator::Or) && right_expr.exist_column(col_id) { - return None + return None; } Some(binary) } - fn new_binary(col_id: &ColumnId, mut op: BinaryOperator, col: ColumnRef, val: ValueRef, is_flip: bool) -> Option { + fn new_binary( + col_id: &ColumnId, + mut op: BinaryOperator, + col: ColumnRef, + val: ValueRef, + is_flip: bool, + ) -> Option { if col.id.unwrap() != *col_id { return None; } @@ -640,56 +676,44 @@ impl ScalarExpression { BinaryOperator::Lt => BinaryOperator::Gt, BinaryOperator::GtEq => BinaryOperator::LtEq, BinaryOperator::LtEq => BinaryOperator::GtEq, - source_op => source_op + source_op => source_op, }; } match op { - BinaryOperator::Gt => { - Some(ConstantBinary::Scope { - min: Bound::Excluded(val.clone()), - max: Bound::Unbounded - }) - } - BinaryOperator::Lt => { - Some(ConstantBinary::Scope { - min: Bound::Unbounded, - max: Bound::Excluded(val.clone()), - }) - } - BinaryOperator::GtEq => { - Some(ConstantBinary::Scope { - min: Bound::Included(val.clone()), - max: Bound::Unbounded - }) - } - BinaryOperator::LtEq => { - Some(ConstantBinary::Scope { - min: Bound::Unbounded, - max: Bound::Included(val.clone()), - }) - } - BinaryOperator::Eq | BinaryOperator::Spaceship => { - Some(ConstantBinary::Eq(val.clone())) - }, - BinaryOperator::NotEq => { - Some(ConstantBinary::NotEq(val.clone())) - }, - _ => None + BinaryOperator::Gt => Some(ConstantBinary::Scope { + min: Bound::Excluded(val.clone()), + max: Bound::Unbounded, + }), + BinaryOperator::Lt => Some(ConstantBinary::Scope { + min: Bound::Unbounded, + max: Bound::Excluded(val.clone()), + }), + BinaryOperator::GtEq => Some(ConstantBinary::Scope { + min: Bound::Included(val.clone()), + max: Bound::Unbounded, + }), + BinaryOperator::LtEq => Some(ConstantBinary::Scope { + min: Bound::Unbounded, + max: Bound::Included(val.clone()), + }), + BinaryOperator::Eq | BinaryOperator::Spaceship => Some(ConstantBinary::Eq(val.clone())), + BinaryOperator::NotEq => Some(ConstantBinary::NotEq(val.clone())), + _ => None, } } } #[cfg(test)] mod test { - use std::collections::Bound; - use std::sync::Arc; use crate::catalog::{ColumnCatalog, ColumnDesc}; - use crate::expression::{BinaryOperator, ScalarExpression}; use crate::expression::simplify::ConstantBinary; + use crate::expression::{BinaryOperator, ScalarExpression}; use crate::types::errors::TypeError; - use crate::types::LogicalType; use crate::types::value::DataValue; + use crate::types::LogicalType; + use std::collections::Bound; + use std::sync::Arc; #[test] fn test_convert_binary_simple() -> Result<(), TypeError> { @@ -712,7 +736,9 @@ mod test { left_expr: Box::new(ScalarExpression::Constant(val_1.clone())), right_expr: Box::new(ScalarExpression::ColumnRef(col_1.clone())), ty: LogicalType::Boolean, - }.convert_binary(&0)?.unwrap(); + } + .convert_binary(&0)? + .unwrap(); assert_eq!(binary_eq, ConstantBinary::Eq(val_1.clone())); @@ -721,7 +747,9 @@ mod test { left_expr: Box::new(ScalarExpression::Constant(val_1.clone())), right_expr: Box::new(ScalarExpression::ColumnRef(col_1.clone())), ty: LogicalType::Boolean, - }.convert_binary(&0)?.unwrap(); + } + .convert_binary(&0)? + .unwrap(); assert_eq!(binary_not_eq, ConstantBinary::NotEq(val_1.clone())); @@ -730,48 +758,68 @@ mod test { left_expr: Box::new(ScalarExpression::ColumnRef(col_1.clone())), right_expr: Box::new(ScalarExpression::Constant(val_1.clone())), ty: LogicalType::Boolean, - }.convert_binary(&0)?.unwrap(); + } + .convert_binary(&0)? + .unwrap(); - assert_eq!(binary_lt, ConstantBinary::Scope { - min: Bound::Unbounded, - max: Bound::Excluded(val_1.clone()) - }); + assert_eq!( + binary_lt, + ConstantBinary::Scope { + min: Bound::Unbounded, + max: Bound::Excluded(val_1.clone()) + } + ); let binary_lteq = ScalarExpression::Binary { op: BinaryOperator::LtEq, left_expr: Box::new(ScalarExpression::ColumnRef(col_1.clone())), right_expr: Box::new(ScalarExpression::Constant(val_1.clone())), ty: LogicalType::Boolean, - }.convert_binary(&0)?.unwrap(); + } + .convert_binary(&0)? + .unwrap(); - assert_eq!(binary_lteq, ConstantBinary::Scope { - min: Bound::Unbounded, - max: Bound::Included(val_1.clone()) - }); + assert_eq!( + binary_lteq, + ConstantBinary::Scope { + min: Bound::Unbounded, + max: Bound::Included(val_1.clone()) + } + ); let binary_gt = ScalarExpression::Binary { op: BinaryOperator::Gt, left_expr: Box::new(ScalarExpression::ColumnRef(col_1.clone())), right_expr: Box::new(ScalarExpression::Constant(val_1.clone())), ty: LogicalType::Boolean, - }.convert_binary(&0)?.unwrap(); + } + .convert_binary(&0)? + .unwrap(); - assert_eq!(binary_gt, ConstantBinary::Scope { - min: Bound::Excluded(val_1.clone()), - max: Bound::Unbounded - }); + assert_eq!( + binary_gt, + ConstantBinary::Scope { + min: Bound::Excluded(val_1.clone()), + max: Bound::Unbounded + } + ); let binary_gteq = ScalarExpression::Binary { op: BinaryOperator::GtEq, left_expr: Box::new(ScalarExpression::ColumnRef(col_1.clone())), right_expr: Box::new(ScalarExpression::Constant(val_1.clone())), ty: LogicalType::Boolean, - }.convert_binary(&0)?.unwrap(); + } + .convert_binary(&0)? + .unwrap(); - assert_eq!(binary_gteq, ConstantBinary::Scope { - min: Bound::Included(val_1.clone()), - max: Bound::Unbounded - }); + assert_eq!( + binary_gteq, + ConstantBinary::Scope { + min: Bound::Included(val_1.clone()), + max: Bound::Unbounded + } + ); Ok(()) } @@ -792,10 +840,7 @@ mod test { binary.scope_aggregation()?; - assert_eq!( - binary, - ConstantBinary::Eq(val_0) - ); + assert_eq!(binary, ConstantBinary::Eq(val_0)); Ok(()) } @@ -812,7 +857,6 @@ mod test { ConstantBinary::NotEq(val_1.clone()), ConstantBinary::Eq(val_2.clone()), ConstantBinary::NotEq(val_3.clone()), - ConstantBinary::NotEq(val_0.clone()), ConstantBinary::NotEq(val_1.clone()), ConstantBinary::NotEq(val_2.clone()), @@ -821,10 +865,7 @@ mod test { binary.scope_aggregation()?; - assert_eq!( - binary, - ConstantBinary::And(vec![]) - ); + assert_eq!(binary, ConstantBinary::And(vec![])); Ok(()) } @@ -839,19 +880,19 @@ mod test { let mut binary = ConstantBinary::And(vec![ ConstantBinary::Scope { min: Bound::Excluded(val_0.clone()), - max: Bound::Included(val_3.clone()) + max: Bound::Included(val_3.clone()), }, ConstantBinary::Scope { min: Bound::Included(val_1.clone()), - max: Bound::Excluded(val_2.clone()) + max: Bound::Excluded(val_2.clone()), }, ConstantBinary::Scope { min: Bound::Excluded(val_1.clone()), - max: Bound::Included(val_2.clone()) + max: Bound::Included(val_2.clone()), }, ConstantBinary::Scope { min: Bound::Included(val_0.clone()), - max: Bound::Excluded(val_3.clone()) + max: Bound::Excluded(val_3.clone()), }, ConstantBinary::Scope { min: Bound::Unbounded, @@ -882,19 +923,19 @@ mod test { let mut binary = ConstantBinary::And(vec![ ConstantBinary::Scope { min: Bound::Excluded(val_0.clone()), - max: Bound::Included(val_3.clone()) + max: Bound::Included(val_3.clone()), }, ConstantBinary::Scope { min: Bound::Included(val_1.clone()), - max: Bound::Excluded(val_2.clone()) + max: Bound::Excluded(val_2.clone()), }, ConstantBinary::Scope { min: Bound::Excluded(val_1.clone()), - max: Bound::Included(val_2.clone()) + max: Bound::Included(val_2.clone()), }, ConstantBinary::Scope { min: Bound::Included(val_0.clone()), - max: Bound::Excluded(val_3.clone()) + max: Bound::Excluded(val_3.clone()), }, ConstantBinary::Scope { min: Bound::Unbounded, @@ -907,10 +948,7 @@ mod test { binary.scope_aggregation()?; - assert_eq!( - binary, - ConstantBinary::Eq(val_0.clone()) - ); + assert_eq!(binary, ConstantBinary::Eq(val_0.clone())); Ok(()) } @@ -933,27 +971,27 @@ mod test { let binary = ConstantBinary::Or(vec![ ConstantBinary::Scope { min: Bound::Excluded(val_6.clone()), - max: Bound::Included(val_10.clone()) + max: Bound::Included(val_10.clone()), }, ConstantBinary::Scope { min: Bound::Excluded(val_0.clone()), - max: Bound::Included(val_3.clone()) + max: Bound::Included(val_3.clone()), }, ConstantBinary::Scope { min: Bound::Included(val_1.clone()), - max: Bound::Excluded(val_2.clone()) + max: Bound::Excluded(val_2.clone()), }, ConstantBinary::Scope { min: Bound::Excluded(val_1.clone()), - max: Bound::Included(val_2.clone()) + max: Bound::Included(val_2.clone()), }, ConstantBinary::Scope { min: Bound::Included(val_0.clone()), - max: Bound::Excluded(val_3.clone()) + max: Bound::Excluded(val_3.clone()), }, ConstantBinary::Scope { min: Bound::Included(val_6.clone()), - max: Bound::Included(val_7.clone()) + max: Bound::Included(val_7.clone()), }, ConstantBinary::Scope { min: Bound::Unbounded, @@ -982,4 +1020,4 @@ mod test { Ok(()) } -} \ No newline at end of file +} diff --git a/src/expression/value_compute.rs b/src/expression/value_compute.rs index 3a2f6dfc..fe4efdc1 100644 --- a/src/expression/value_compute.rs +++ b/src/expression/value_compute.rs @@ -1,75 +1,72 @@ use crate::expression::{BinaryOperator, UnaryOperator}; use crate::types::errors::TypeError; -use crate::types::LogicalType; use crate::types::value::DataValue; +use crate::types::LogicalType; fn unpack_i32(value: DataValue) -> Option { match value { DataValue::Int32(inner) => inner, - _ => None + _ => None, } } fn unpack_i64(value: DataValue) -> Option { match value { DataValue::Int64(inner) => inner, - _ => None + _ => None, } } fn unpack_u32(value: DataValue) -> Option { match value { DataValue::UInt32(inner) => inner, - _ => None + _ => None, } } fn unpack_u64(value: DataValue) -> Option { match value { DataValue::UInt64(inner) => inner, - _ => None + _ => None, } } fn unpack_f64(value: DataValue) -> Option { match value { DataValue::Float64(inner) => inner, - _ => None + _ => None, } } fn unpack_f32(value: DataValue) -> Option { match value { DataValue::Float32(inner) => inner, - _ => None + _ => None, } } fn unpack_bool(value: DataValue) -> Option { match value { DataValue::Boolean(inner) => inner, - _ => None + _ => None, } } fn unpack_date(value: DataValue) -> Option { match value { DataValue::Date64(inner) => inner, - _ => None + _ => None, } } fn unpack_utf8(value: DataValue) -> Option { match value { DataValue::Utf8(inner) => inner, - _ => None + _ => None, } } -pub fn unary_op( - value: &DataValue, - op: &UnaryOperator, -) -> Result { +pub fn unary_op(value: &DataValue, op: &UnaryOperator) -> Result { let mut value_type = value.logical_type(); let mut value = value.clone(); @@ -80,32 +77,30 @@ pub fn unary_op( LogicalType::USmallint => value_type = LogicalType::Smallint, LogicalType::UInteger => value_type = LogicalType::Integer, LogicalType::UBigint => value_type = LogicalType::Bigint, - _ => unreachable!() + _ => unreachable!(), }; value = value.cast(&value_type)?; } let result = match op { UnaryOperator::Plus => value, - UnaryOperator::Minus => { - match value { - DataValue::Float32(option) => DataValue::Float32(option.map(|v| -v)), - DataValue::Float64(option) => DataValue::Float64(option.map(|v| -v)), - DataValue::Int8(option) => DataValue::Int8(option.map(|v| -v)), - DataValue::Int16(option) => DataValue::Int16(option.map(|v| -v)), - DataValue::Int32(option) => DataValue::Int32(option.map(|v| -v)), - DataValue::Int64(option) => DataValue::Int64(option.map(|v| -v)), - _ => unreachable!() - } - } - _ => unreachable!() + UnaryOperator::Minus => match value { + DataValue::Float32(option) => DataValue::Float32(option.map(|v| -v)), + DataValue::Float64(option) => DataValue::Float64(option.map(|v| -v)), + DataValue::Int8(option) => DataValue::Int8(option.map(|v| -v)), + DataValue::Int16(option) => DataValue::Int16(option.map(|v| -v)), + DataValue::Int32(option) => DataValue::Int32(option.map(|v| -v)), + DataValue::Int64(option) => DataValue::Int64(option.map(|v| -v)), + _ => unreachable!(), + }, + _ => unreachable!(), }; Ok(result) } else if matches!((value_type, op), (LogicalType::Boolean, UnaryOperator::Not)) { match value { DataValue::Boolean(option) => Ok(DataValue::Boolean(option.map(|v| !v))), - _ => unreachable!() + _ => unreachable!(), } } else { Err(TypeError::InvalidType) @@ -119,10 +114,7 @@ pub fn binary_op( right: &DataValue, op: &BinaryOperator, ) -> Result { - let unified_type = LogicalType::max_logical_type( - &left.logical_type(), - &right.logical_type(), - )?; + let unified_type = LogicalType::max_logical_type(&left.logical_type(), &right.logical_type())?; let value = match &unified_type { LogicalType::Integer => { @@ -205,15 +197,9 @@ pub fn binary_op( } BinaryOperator::Eq => { let value = match (left_value, right_value) { - (Some(v1), Some(v2)) => { - Some(v1 == v2) - } - (None, None) => { - Some(true) - } - (_, _) => { - None - } + (Some(v1), Some(v2)) => Some(v1 == v2), + (None, None) => Some(true), + (_, _) => None, }; DataValue::Boolean(value) @@ -227,7 +213,7 @@ pub fn binary_op( DataValue::Boolean(value) } - _ => todo!("unsupported operator") + _ => todo!("unsupported operator"), } } LogicalType::Bigint => { @@ -310,15 +296,9 @@ pub fn binary_op( } BinaryOperator::Eq => { let value = match (left_value, right_value) { - (Some(v1), Some(v2)) => { - Some(v1 == v2) - } - (None, None) => { - Some(true) - } - (_, _) => { - None - } + (Some(v1), Some(v2)) => Some(v1 == v2), + (None, None) => Some(true), + (_, _) => None, }; DataValue::Boolean(value) @@ -332,7 +312,7 @@ pub fn binary_op( DataValue::Boolean(value) } - _ => todo!("unsupported operator") + _ => todo!("unsupported operator"), } } LogicalType::UInteger => { @@ -415,15 +395,9 @@ pub fn binary_op( } BinaryOperator::Eq => { let value = match (left_value, right_value) { - (Some(v1), Some(v2)) => { - Some(v1 == v2) - } - (None, None) => { - Some(true) - } - (_, _) => { - None - } + (Some(v1), Some(v2)) => Some(v1 == v2), + (None, None) => Some(true), + (_, _) => None, }; DataValue::Boolean(value) @@ -437,7 +411,7 @@ pub fn binary_op( DataValue::Boolean(value) } - _ => todo!("unsupported operator") + _ => todo!("unsupported operator"), } } LogicalType::UBigint => { @@ -520,15 +494,9 @@ pub fn binary_op( } BinaryOperator::Eq => { let value = match (left_value, right_value) { - (Some(v1), Some(v2)) => { - Some(v1 == v2) - } - (None, None) => { - Some(true) - } - (_, _) => { - None - } + (Some(v1), Some(v2)) => Some(v1 == v2), + (None, None) => Some(true), + (_, _) => None, }; DataValue::Boolean(value) @@ -542,7 +510,7 @@ pub fn binary_op( DataValue::Boolean(value) } - _ => todo!("unsupported operator") + _ => todo!("unsupported operator"), } } LogicalType::Double => { @@ -625,15 +593,9 @@ pub fn binary_op( } BinaryOperator::Eq => { let value = match (left_value, right_value) { - (Some(v1), Some(v2)) => { - Some(v1 == v2) - } - (None, None) => { - Some(true) - } - (_, _) => { - None - } + (Some(v1), Some(v2)) => Some(v1 == v2), + (None, None) => Some(true), + (_, _) => None, }; DataValue::Boolean(value) @@ -647,7 +609,7 @@ pub fn binary_op( DataValue::Boolean(value) } - _ => todo!("unsupported operator") + _ => todo!("unsupported operator"), } } LogicalType::Boolean => { @@ -673,7 +635,7 @@ pub fn binary_op( DataValue::Boolean(value) } - _ => todo!("unsupported operator") + _ => todo!("unsupported operator"), } } LogicalType::Float => { @@ -755,15 +717,9 @@ pub fn binary_op( } BinaryOperator::Eq => { let value = match (left_value, right_value) { - (Some(v1), Some(v2)) => { - Some(v1 == v2) - } - (None, None) => { - Some(true) - } - (_, _) => { - None - } + (Some(v1), Some(v2)) => Some(v1 == v2), + (None, None) => Some(true), + (_, _) => None, }; DataValue::Boolean(value) @@ -777,12 +733,10 @@ pub fn binary_op( DataValue::Boolean(value) } - _ => todo!("unsupported operator") + _ => todo!("unsupported operator"), } } - LogicalType::SqlNull => { - DataValue::Boolean(None) - } + LogicalType::SqlNull => DataValue::Boolean(None), LogicalType::DateTime => { let left_value = unpack_date(left.clone().cast(&unified_type)?); let right_value = unpack_date(right.clone().cast(&unified_type)?); @@ -826,15 +780,9 @@ pub fn binary_op( } BinaryOperator::Eq => { let value = match (left_value, right_value) { - (Some(v1), Some(v2)) => { - Some(v1 == v2) - } - (None, None) => { - Some(true) - } - (_, _) => { - None - } + (Some(v1), Some(v2)) => Some(v1 == v2), + (None, None) => Some(true), + (_, _) => None, }; DataValue::Boolean(value) @@ -848,7 +796,7 @@ pub fn binary_op( DataValue::Boolean(value) } - _ => todo!("unsupported operator") + _ => todo!("unsupported operator"), } } LogicalType::Varchar(None) => { @@ -894,15 +842,9 @@ pub fn binary_op( } BinaryOperator::Eq => { let value = match (left_value, right_value) { - (Some(v1), Some(v2)) => { - Some(v1 == v2) - } - (None, None) => { - Some(true) - } - (_, _) => { - None - } + (Some(v1), Some(v2)) => Some(v1 == v2), + (None, None) => Some(true), + (_, _) => None, }; DataValue::Boolean(value) @@ -916,11 +858,10 @@ pub fn binary_op( DataValue::Boolean(value) } - _ => todo!("unsupported operator") + _ => todo!("unsupported operator"), } } // Utf8 - _ => todo!("unsupported data type"), }; @@ -936,28 +877,76 @@ mod test { #[test] fn test_binary_op_arithmetic_plus() -> Result<(), TypeError> { - let plus_i32_1 = binary_op(&DataValue::Int32(None), &DataValue::Int32(None), &BinaryOperator::Plus)?; - let plus_i32_2 = binary_op(&DataValue::Int32(Some(1)), &DataValue::Int32(None), &BinaryOperator::Plus)?; - let plus_i32_3 = binary_op(&DataValue::Int32(None), &DataValue::Int32(Some(1)), &BinaryOperator::Plus)?; - let plus_i32_4 = binary_op(&DataValue::Int32(Some(1)), &DataValue::Int32(Some(1)), &BinaryOperator::Plus)?; + let plus_i32_1 = binary_op( + &DataValue::Int32(None), + &DataValue::Int32(None), + &BinaryOperator::Plus, + )?; + let plus_i32_2 = binary_op( + &DataValue::Int32(Some(1)), + &DataValue::Int32(None), + &BinaryOperator::Plus, + )?; + let plus_i32_3 = binary_op( + &DataValue::Int32(None), + &DataValue::Int32(Some(1)), + &BinaryOperator::Plus, + )?; + let plus_i32_4 = binary_op( + &DataValue::Int32(Some(1)), + &DataValue::Int32(Some(1)), + &BinaryOperator::Plus, + )?; assert_eq!(plus_i32_1, plus_i32_2); assert_eq!(plus_i32_2, plus_i32_3); assert_eq!(plus_i32_4, DataValue::Int32(Some(2))); - let plus_i64_1 = binary_op(&DataValue::Int64(None), &DataValue::Int64(None), &BinaryOperator::Plus)?; - let plus_i64_2 = binary_op(&DataValue::Int64(Some(1)), &DataValue::Int64(None), &BinaryOperator::Plus)?; - let plus_i64_3 = binary_op(&DataValue::Int64(None), &DataValue::Int64(Some(1)), &BinaryOperator::Plus)?; - let plus_i64_4 = binary_op(&DataValue::Int64(Some(1)), &DataValue::Int64(Some(1)), &BinaryOperator::Plus)?; + let plus_i64_1 = binary_op( + &DataValue::Int64(None), + &DataValue::Int64(None), + &BinaryOperator::Plus, + )?; + let plus_i64_2 = binary_op( + &DataValue::Int64(Some(1)), + &DataValue::Int64(None), + &BinaryOperator::Plus, + )?; + let plus_i64_3 = binary_op( + &DataValue::Int64(None), + &DataValue::Int64(Some(1)), + &BinaryOperator::Plus, + )?; + let plus_i64_4 = binary_op( + &DataValue::Int64(Some(1)), + &DataValue::Int64(Some(1)), + &BinaryOperator::Plus, + )?; assert_eq!(plus_i64_1, plus_i64_2); assert_eq!(plus_i64_2, plus_i64_3); assert_eq!(plus_i64_4, DataValue::Int64(Some(2))); - let plus_f64_1 = binary_op(&DataValue::Float64(None), &DataValue::Float64(None), &BinaryOperator::Plus)?; - let plus_f64_2 = binary_op(&DataValue::Float64(Some(1.0)), &DataValue::Float64(None), &BinaryOperator::Plus)?; - let plus_f64_3 = binary_op(&DataValue::Float64(None), &DataValue::Float64(Some(1.0)), &BinaryOperator::Plus)?; - let plus_f64_4 = binary_op(&DataValue::Float64(Some(1.0)), &DataValue::Float64(Some(1.0)), &BinaryOperator::Plus)?; + let plus_f64_1 = binary_op( + &DataValue::Float64(None), + &DataValue::Float64(None), + &BinaryOperator::Plus, + )?; + let plus_f64_2 = binary_op( + &DataValue::Float64(Some(1.0)), + &DataValue::Float64(None), + &BinaryOperator::Plus, + )?; + let plus_f64_3 = binary_op( + &DataValue::Float64(None), + &DataValue::Float64(Some(1.0)), + &BinaryOperator::Plus, + )?; + let plus_f64_4 = binary_op( + &DataValue::Float64(Some(1.0)), + &DataValue::Float64(Some(1.0)), + &BinaryOperator::Plus, + )?; assert_eq!(plus_f64_1, plus_f64_2); assert_eq!(plus_f64_2, plus_f64_3); @@ -968,28 +957,76 @@ mod test { #[test] fn test_binary_op_arithmetic_minus() -> Result<(), TypeError> { - let minus_i32_1 = binary_op(&DataValue::Int32(None), &DataValue::Int32(None), &BinaryOperator::Minus)?; - let minus_i32_2 = binary_op(&DataValue::Int32(Some(1)), &DataValue::Int32(None), &BinaryOperator::Minus)?; - let minus_i32_3 = binary_op(&DataValue::Int32(None), &DataValue::Int32(Some(1)), &BinaryOperator::Minus)?; - let minus_i32_4 = binary_op(&DataValue::Int32(Some(1)), &DataValue::Int32(Some(1)), &BinaryOperator::Minus)?; + let minus_i32_1 = binary_op( + &DataValue::Int32(None), + &DataValue::Int32(None), + &BinaryOperator::Minus, + )?; + let minus_i32_2 = binary_op( + &DataValue::Int32(Some(1)), + &DataValue::Int32(None), + &BinaryOperator::Minus, + )?; + let minus_i32_3 = binary_op( + &DataValue::Int32(None), + &DataValue::Int32(Some(1)), + &BinaryOperator::Minus, + )?; + let minus_i32_4 = binary_op( + &DataValue::Int32(Some(1)), + &DataValue::Int32(Some(1)), + &BinaryOperator::Minus, + )?; assert_eq!(minus_i32_1, minus_i32_2); assert_eq!(minus_i32_2, minus_i32_3); assert_eq!(minus_i32_4, DataValue::Int32(Some(0))); - let minus_i64_1 = binary_op(&DataValue::Int64(None), &DataValue::Int64(None), &BinaryOperator::Minus)?; - let minus_i64_2 = binary_op(&DataValue::Int64(Some(1)), &DataValue::Int64(None), &BinaryOperator::Minus)?; - let minus_i64_3 = binary_op(&DataValue::Int64(None), &DataValue::Int64(Some(1)), &BinaryOperator::Minus)?; - let minus_i64_4 = binary_op(&DataValue::Int64(Some(1)), &DataValue::Int64(Some(1)), &BinaryOperator::Minus)?; + let minus_i64_1 = binary_op( + &DataValue::Int64(None), + &DataValue::Int64(None), + &BinaryOperator::Minus, + )?; + let minus_i64_2 = binary_op( + &DataValue::Int64(Some(1)), + &DataValue::Int64(None), + &BinaryOperator::Minus, + )?; + let minus_i64_3 = binary_op( + &DataValue::Int64(None), + &DataValue::Int64(Some(1)), + &BinaryOperator::Minus, + )?; + let minus_i64_4 = binary_op( + &DataValue::Int64(Some(1)), + &DataValue::Int64(Some(1)), + &BinaryOperator::Minus, + )?; assert_eq!(minus_i64_1, minus_i64_2); assert_eq!(minus_i64_2, minus_i64_3); assert_eq!(minus_i64_4, DataValue::Int64(Some(0))); - let minus_f64_1 = binary_op(&DataValue::Float64(None), &DataValue::Float64(None), &BinaryOperator::Minus)?; - let minus_f64_2 = binary_op(&DataValue::Float64(Some(1.0)), &DataValue::Float64(None), &BinaryOperator::Minus)?; - let minus_f64_3 = binary_op(&DataValue::Float64(None), &DataValue::Float64(Some(1.0)), &BinaryOperator::Minus)?; - let minus_f64_4 = binary_op(&DataValue::Float64(Some(1.0)), &DataValue::Float64(Some(1.0)), &BinaryOperator::Minus)?; + let minus_f64_1 = binary_op( + &DataValue::Float64(None), + &DataValue::Float64(None), + &BinaryOperator::Minus, + )?; + let minus_f64_2 = binary_op( + &DataValue::Float64(Some(1.0)), + &DataValue::Float64(None), + &BinaryOperator::Minus, + )?; + let minus_f64_3 = binary_op( + &DataValue::Float64(None), + &DataValue::Float64(Some(1.0)), + &BinaryOperator::Minus, + )?; + let minus_f64_4 = binary_op( + &DataValue::Float64(Some(1.0)), + &DataValue::Float64(Some(1.0)), + &BinaryOperator::Minus, + )?; assert_eq!(minus_f64_1, minus_f64_2); assert_eq!(minus_f64_2, minus_f64_3); @@ -1000,28 +1037,76 @@ mod test { #[test] fn test_binary_op_arithmetic_multiply() -> Result<(), TypeError> { - let multiply_i32_1 = binary_op(&DataValue::Int32(None), &DataValue::Int32(None), &BinaryOperator::Multiply)?; - let multiply_i32_2 = binary_op(&DataValue::Int32(Some(1)), &DataValue::Int32(None), &BinaryOperator::Multiply)?; - let multiply_i32_3 = binary_op(&DataValue::Int32(None), &DataValue::Int32(Some(1)), &BinaryOperator::Multiply)?; - let multiply_i32_4 = binary_op(&DataValue::Int32(Some(1)), &DataValue::Int32(Some(1)), &BinaryOperator::Multiply)?; + let multiply_i32_1 = binary_op( + &DataValue::Int32(None), + &DataValue::Int32(None), + &BinaryOperator::Multiply, + )?; + let multiply_i32_2 = binary_op( + &DataValue::Int32(Some(1)), + &DataValue::Int32(None), + &BinaryOperator::Multiply, + )?; + let multiply_i32_3 = binary_op( + &DataValue::Int32(None), + &DataValue::Int32(Some(1)), + &BinaryOperator::Multiply, + )?; + let multiply_i32_4 = binary_op( + &DataValue::Int32(Some(1)), + &DataValue::Int32(Some(1)), + &BinaryOperator::Multiply, + )?; assert_eq!(multiply_i32_1, multiply_i32_2); assert_eq!(multiply_i32_2, multiply_i32_3); assert_eq!(multiply_i32_4, DataValue::Int32(Some(1))); - let multiply_i64_1 = binary_op(&DataValue::Int64(None), &DataValue::Int64(None), &BinaryOperator::Multiply)?; - let multiply_i64_2 = binary_op(&DataValue::Int64(Some(1)), &DataValue::Int64(None), &BinaryOperator::Multiply)?; - let multiply_i64_3 = binary_op(&DataValue::Int64(None), &DataValue::Int64(Some(1)), &BinaryOperator::Multiply)?; - let multiply_i64_4 = binary_op(&DataValue::Int64(Some(1)), &DataValue::Int64(Some(1)), &BinaryOperator::Multiply)?; + let multiply_i64_1 = binary_op( + &DataValue::Int64(None), + &DataValue::Int64(None), + &BinaryOperator::Multiply, + )?; + let multiply_i64_2 = binary_op( + &DataValue::Int64(Some(1)), + &DataValue::Int64(None), + &BinaryOperator::Multiply, + )?; + let multiply_i64_3 = binary_op( + &DataValue::Int64(None), + &DataValue::Int64(Some(1)), + &BinaryOperator::Multiply, + )?; + let multiply_i64_4 = binary_op( + &DataValue::Int64(Some(1)), + &DataValue::Int64(Some(1)), + &BinaryOperator::Multiply, + )?; assert_eq!(multiply_i64_1, multiply_i64_2); assert_eq!(multiply_i64_2, multiply_i64_3); assert_eq!(multiply_i64_4, DataValue::Int64(Some(1))); - let multiply_f64_1 = binary_op(&DataValue::Float64(None), &DataValue::Float64(None), &BinaryOperator::Multiply)?; - let multiply_f64_2 = binary_op(&DataValue::Float64(Some(1.0)), &DataValue::Float64(None), &BinaryOperator::Multiply)?; - let multiply_f64_3 = binary_op(&DataValue::Float64(None), &DataValue::Float64(Some(1.0)), &BinaryOperator::Multiply)?; - let multiply_f64_4 = binary_op(&DataValue::Float64(Some(1.0)), &DataValue::Float64(Some(1.0)), &BinaryOperator::Multiply)?; + let multiply_f64_1 = binary_op( + &DataValue::Float64(None), + &DataValue::Float64(None), + &BinaryOperator::Multiply, + )?; + let multiply_f64_2 = binary_op( + &DataValue::Float64(Some(1.0)), + &DataValue::Float64(None), + &BinaryOperator::Multiply, + )?; + let multiply_f64_3 = binary_op( + &DataValue::Float64(None), + &DataValue::Float64(Some(1.0)), + &BinaryOperator::Multiply, + )?; + let multiply_f64_4 = binary_op( + &DataValue::Float64(Some(1.0)), + &DataValue::Float64(Some(1.0)), + &BinaryOperator::Multiply, + )?; assert_eq!(multiply_f64_1, multiply_f64_2); assert_eq!(multiply_f64_2, multiply_f64_3); @@ -1032,28 +1117,76 @@ mod test { #[test] fn test_binary_op_arithmetic_divide() -> Result<(), TypeError> { - let divide_i32_1 = binary_op(&DataValue::Int32(None), &DataValue::Int32(None), &BinaryOperator::Divide)?; - let divide_i32_2 = binary_op(&DataValue::Int32(Some(1)), &DataValue::Int32(None), &BinaryOperator::Divide)?; - let divide_i32_3 = binary_op(&DataValue::Int32(None), &DataValue::Int32(Some(1)), &BinaryOperator::Divide)?; - let divide_i32_4 = binary_op(&DataValue::Int32(Some(1)), &DataValue::Int32(Some(1)), &BinaryOperator::Divide)?; + let divide_i32_1 = binary_op( + &DataValue::Int32(None), + &DataValue::Int32(None), + &BinaryOperator::Divide, + )?; + let divide_i32_2 = binary_op( + &DataValue::Int32(Some(1)), + &DataValue::Int32(None), + &BinaryOperator::Divide, + )?; + let divide_i32_3 = binary_op( + &DataValue::Int32(None), + &DataValue::Int32(Some(1)), + &BinaryOperator::Divide, + )?; + let divide_i32_4 = binary_op( + &DataValue::Int32(Some(1)), + &DataValue::Int32(Some(1)), + &BinaryOperator::Divide, + )?; assert_eq!(divide_i32_1, divide_i32_2); assert_eq!(divide_i32_2, divide_i32_3); assert_eq!(divide_i32_4, DataValue::Float64(Some(1.0))); - let divide_i64_1 = binary_op(&DataValue::Int64(None), &DataValue::Int64(None), &BinaryOperator::Divide)?; - let divide_i64_2 = binary_op(&DataValue::Int64(Some(1)), &DataValue::Int64(None), &BinaryOperator::Divide)?; - let divide_i64_3 = binary_op(&DataValue::Int64(None), &DataValue::Int64(Some(1)), &BinaryOperator::Divide)?; - let divide_i64_4 = binary_op(&DataValue::Int64(Some(1)), &DataValue::Int64(Some(1)), &BinaryOperator::Divide)?; + let divide_i64_1 = binary_op( + &DataValue::Int64(None), + &DataValue::Int64(None), + &BinaryOperator::Divide, + )?; + let divide_i64_2 = binary_op( + &DataValue::Int64(Some(1)), + &DataValue::Int64(None), + &BinaryOperator::Divide, + )?; + let divide_i64_3 = binary_op( + &DataValue::Int64(None), + &DataValue::Int64(Some(1)), + &BinaryOperator::Divide, + )?; + let divide_i64_4 = binary_op( + &DataValue::Int64(Some(1)), + &DataValue::Int64(Some(1)), + &BinaryOperator::Divide, + )?; assert_eq!(divide_i64_1, divide_i64_2); assert_eq!(divide_i64_2, divide_i64_3); assert_eq!(divide_i64_4, DataValue::Float64(Some(1.0))); - let divide_f64_1 = binary_op(&DataValue::Float64(None), &DataValue::Float64(None), &BinaryOperator::Divide)?; - let divide_f64_2 = binary_op(&DataValue::Float64(Some(1.0)), &DataValue::Float64(None), &BinaryOperator::Divide)?; - let divide_f64_3 = binary_op(&DataValue::Float64(None), &DataValue::Float64(Some(1.0)), &BinaryOperator::Divide)?; - let divide_f64_4 = binary_op(&DataValue::Float64(Some(1.0)), &DataValue::Float64(Some(1.0)), &BinaryOperator::Divide)?; + let divide_f64_1 = binary_op( + &DataValue::Float64(None), + &DataValue::Float64(None), + &BinaryOperator::Divide, + )?; + let divide_f64_2 = binary_op( + &DataValue::Float64(Some(1.0)), + &DataValue::Float64(None), + &BinaryOperator::Divide, + )?; + let divide_f64_3 = binary_op( + &DataValue::Float64(None), + &DataValue::Float64(Some(1.0)), + &BinaryOperator::Divide, + )?; + let divide_f64_4 = binary_op( + &DataValue::Float64(Some(1.0)), + &DataValue::Float64(Some(1.0)), + &BinaryOperator::Divide, + )?; assert_eq!(divide_f64_1, divide_f64_2); assert_eq!(divide_f64_2, divide_f64_3); @@ -1064,19 +1197,43 @@ mod test { #[test] fn test_binary_op_cast() -> Result<(), TypeError> { - let i32_cast_1 = binary_op(&DataValue::Int32(Some(1)), &DataValue::Int8(Some(1)), &BinaryOperator::Plus)?; - let i32_cast_2 = binary_op(&DataValue::Int32(Some(1)), &DataValue::Int16(Some(1)), &BinaryOperator::Plus)?; + let i32_cast_1 = binary_op( + &DataValue::Int32(Some(1)), + &DataValue::Int8(Some(1)), + &BinaryOperator::Plus, + )?; + let i32_cast_2 = binary_op( + &DataValue::Int32(Some(1)), + &DataValue::Int16(Some(1)), + &BinaryOperator::Plus, + )?; assert_eq!(i32_cast_1, i32_cast_2); - let i64_cast_1 = binary_op(&DataValue::Int64(Some(1)), &DataValue::Int8(Some(1)), &BinaryOperator::Plus)?; - let i64_cast_2 = binary_op(&DataValue::Int64(Some(1)), &DataValue::Int16(Some(1)), &BinaryOperator::Plus)?; - let i64_cast_3 = binary_op(&DataValue::Int64(Some(1)), &DataValue::Int32(Some(1)), &BinaryOperator::Plus)?; + let i64_cast_1 = binary_op( + &DataValue::Int64(Some(1)), + &DataValue::Int8(Some(1)), + &BinaryOperator::Plus, + )?; + let i64_cast_2 = binary_op( + &DataValue::Int64(Some(1)), + &DataValue::Int16(Some(1)), + &BinaryOperator::Plus, + )?; + let i64_cast_3 = binary_op( + &DataValue::Int64(Some(1)), + &DataValue::Int32(Some(1)), + &BinaryOperator::Plus, + )?; assert_eq!(i64_cast_1, i64_cast_2); assert_eq!(i64_cast_2, i64_cast_3); - let f64_cast_1 = binary_op(&DataValue::Float64(Some(1.0)), &DataValue::Float32(Some(1.0)), &BinaryOperator::Plus)?; + let f64_cast_1 = binary_op( + &DataValue::Float64(Some(1.0)), + &DataValue::Float32(Some(1.0)), + &BinaryOperator::Plus, + )?; assert_eq!(f64_cast_1, DataValue::Float64(Some(2.0))); Ok(()) @@ -1084,119 +1241,616 @@ mod test { #[test] fn test_binary_op_i32_compare() -> Result<(), TypeError> { - assert_eq!(binary_op(&DataValue::Int32(Some(1)), &DataValue::Int32(Some(0)), &BinaryOperator::Gt)?, DataValue::Boolean(Some(true))); - assert_eq!(binary_op(&DataValue::Int32(Some(1)), &DataValue::Int32(Some(0)), &BinaryOperator::Lt)?, DataValue::Boolean(Some(false))); - assert_eq!(binary_op(&DataValue::Int32(Some(1)), &DataValue::Int32(Some(1)), &BinaryOperator::GtEq)?, DataValue::Boolean(Some(true))); - assert_eq!(binary_op(&DataValue::Int32(Some(1)), &DataValue::Int32(Some(1)), &BinaryOperator::LtEq)?, DataValue::Boolean(Some(true))); - assert_eq!(binary_op(&DataValue::Int32(Some(1)), &DataValue::Int32(Some(1)), &BinaryOperator::NotEq)?, DataValue::Boolean(Some(false))); - assert_eq!(binary_op(&DataValue::Int32(Some(1)), &DataValue::Int32(Some(1)), &BinaryOperator::Eq)?, DataValue::Boolean(Some(true))); - - assert_eq!(binary_op(&DataValue::Int32(None), &DataValue::Int32(Some(0)), &BinaryOperator::Gt)?, DataValue::Boolean(None)); - assert_eq!(binary_op(&DataValue::Int32(None), &DataValue::Int32(Some(0)), &BinaryOperator::Lt)?, DataValue::Boolean(None)); - assert_eq!(binary_op(&DataValue::Int32(None), &DataValue::Int32(Some(1)), &BinaryOperator::GtEq)?, DataValue::Boolean(None)); - assert_eq!(binary_op(&DataValue::Int32(None), &DataValue::Int32(Some(1)), &BinaryOperator::LtEq)?, DataValue::Boolean(None)); - assert_eq!(binary_op(&DataValue::Int32(None), &DataValue::Int32(Some(1)), &BinaryOperator::NotEq)?, DataValue::Boolean(None)); - - assert_eq!(binary_op(&DataValue::Int32(None), &DataValue::Int32(Some(1)), &BinaryOperator::Eq)?, DataValue::Boolean(None)); - assert_eq!(binary_op(&DataValue::Int32(None), &DataValue::Int32(None), &BinaryOperator::Eq)?, DataValue::Boolean(Some(true))); + assert_eq!( + binary_op( + &DataValue::Int32(Some(1)), + &DataValue::Int32(Some(0)), + &BinaryOperator::Gt + )?, + DataValue::Boolean(Some(true)) + ); + assert_eq!( + binary_op( + &DataValue::Int32(Some(1)), + &DataValue::Int32(Some(0)), + &BinaryOperator::Lt + )?, + DataValue::Boolean(Some(false)) + ); + assert_eq!( + binary_op( + &DataValue::Int32(Some(1)), + &DataValue::Int32(Some(1)), + &BinaryOperator::GtEq + )?, + DataValue::Boolean(Some(true)) + ); + assert_eq!( + binary_op( + &DataValue::Int32(Some(1)), + &DataValue::Int32(Some(1)), + &BinaryOperator::LtEq + )?, + DataValue::Boolean(Some(true)) + ); + assert_eq!( + binary_op( + &DataValue::Int32(Some(1)), + &DataValue::Int32(Some(1)), + &BinaryOperator::NotEq + )?, + DataValue::Boolean(Some(false)) + ); + assert_eq!( + binary_op( + &DataValue::Int32(Some(1)), + &DataValue::Int32(Some(1)), + &BinaryOperator::Eq + )?, + DataValue::Boolean(Some(true)) + ); + + assert_eq!( + binary_op( + &DataValue::Int32(None), + &DataValue::Int32(Some(0)), + &BinaryOperator::Gt + )?, + DataValue::Boolean(None) + ); + assert_eq!( + binary_op( + &DataValue::Int32(None), + &DataValue::Int32(Some(0)), + &BinaryOperator::Lt + )?, + DataValue::Boolean(None) + ); + assert_eq!( + binary_op( + &DataValue::Int32(None), + &DataValue::Int32(Some(1)), + &BinaryOperator::GtEq + )?, + DataValue::Boolean(None) + ); + assert_eq!( + binary_op( + &DataValue::Int32(None), + &DataValue::Int32(Some(1)), + &BinaryOperator::LtEq + )?, + DataValue::Boolean(None) + ); + assert_eq!( + binary_op( + &DataValue::Int32(None), + &DataValue::Int32(Some(1)), + &BinaryOperator::NotEq + )?, + DataValue::Boolean(None) + ); + + assert_eq!( + binary_op( + &DataValue::Int32(None), + &DataValue::Int32(Some(1)), + &BinaryOperator::Eq + )?, + DataValue::Boolean(None) + ); + assert_eq!( + binary_op( + &DataValue::Int32(None), + &DataValue::Int32(None), + &BinaryOperator::Eq + )?, + DataValue::Boolean(Some(true)) + ); Ok(()) } #[test] fn test_binary_op_i64_compare() -> Result<(), TypeError> { - assert_eq!(binary_op(&DataValue::Int64(Some(1)), &DataValue::Int64(Some(0)), &BinaryOperator::Gt)?, DataValue::Boolean(Some(true))); - assert_eq!(binary_op(&DataValue::Int64(Some(1)), &DataValue::Int64(Some(0)), &BinaryOperator::Lt)?, DataValue::Boolean(Some(false))); - assert_eq!(binary_op(&DataValue::Int64(Some(1)), &DataValue::Int64(Some(1)), &BinaryOperator::GtEq)?, DataValue::Boolean(Some(true))); - assert_eq!(binary_op(&DataValue::Int64(Some(1)), &DataValue::Int64(Some(1)), &BinaryOperator::LtEq)?, DataValue::Boolean(Some(true))); - assert_eq!(binary_op(&DataValue::Int64(Some(1)), &DataValue::Int64(Some(1)), &BinaryOperator::NotEq)?, DataValue::Boolean(Some(false))); - assert_eq!(binary_op(&DataValue::Int64(Some(1)), &DataValue::Int64(Some(1)), &BinaryOperator::Eq)?, DataValue::Boolean(Some(true))); - - assert_eq!(binary_op(&DataValue::Int64(None), &DataValue::Int64(Some(0)), &BinaryOperator::Gt)?, DataValue::Boolean(None)); - assert_eq!(binary_op(&DataValue::Int64(None), &DataValue::Int64(Some(0)), &BinaryOperator::Lt)?, DataValue::Boolean(None)); - assert_eq!(binary_op(&DataValue::Int64(None), &DataValue::Int64(Some(1)), &BinaryOperator::GtEq)?, DataValue::Boolean(None)); - assert_eq!(binary_op(&DataValue::Int64(None), &DataValue::Int64(Some(1)), &BinaryOperator::LtEq)?, DataValue::Boolean(None)); - assert_eq!(binary_op(&DataValue::Int64(None), &DataValue::Int64(Some(1)), &BinaryOperator::NotEq)?, DataValue::Boolean(None)); - - assert_eq!(binary_op(&DataValue::Int64(None), &DataValue::Int64(Some(1)), &BinaryOperator::Eq)?, DataValue::Boolean(None)); - assert_eq!(binary_op(&DataValue::Int64(None), &DataValue::Int64(None), &BinaryOperator::Eq)?, DataValue::Boolean(Some(true))); + assert_eq!( + binary_op( + &DataValue::Int64(Some(1)), + &DataValue::Int64(Some(0)), + &BinaryOperator::Gt + )?, + DataValue::Boolean(Some(true)) + ); + assert_eq!( + binary_op( + &DataValue::Int64(Some(1)), + &DataValue::Int64(Some(0)), + &BinaryOperator::Lt + )?, + DataValue::Boolean(Some(false)) + ); + assert_eq!( + binary_op( + &DataValue::Int64(Some(1)), + &DataValue::Int64(Some(1)), + &BinaryOperator::GtEq + )?, + DataValue::Boolean(Some(true)) + ); + assert_eq!( + binary_op( + &DataValue::Int64(Some(1)), + &DataValue::Int64(Some(1)), + &BinaryOperator::LtEq + )?, + DataValue::Boolean(Some(true)) + ); + assert_eq!( + binary_op( + &DataValue::Int64(Some(1)), + &DataValue::Int64(Some(1)), + &BinaryOperator::NotEq + )?, + DataValue::Boolean(Some(false)) + ); + assert_eq!( + binary_op( + &DataValue::Int64(Some(1)), + &DataValue::Int64(Some(1)), + &BinaryOperator::Eq + )?, + DataValue::Boolean(Some(true)) + ); + + assert_eq!( + binary_op( + &DataValue::Int64(None), + &DataValue::Int64(Some(0)), + &BinaryOperator::Gt + )?, + DataValue::Boolean(None) + ); + assert_eq!( + binary_op( + &DataValue::Int64(None), + &DataValue::Int64(Some(0)), + &BinaryOperator::Lt + )?, + DataValue::Boolean(None) + ); + assert_eq!( + binary_op( + &DataValue::Int64(None), + &DataValue::Int64(Some(1)), + &BinaryOperator::GtEq + )?, + DataValue::Boolean(None) + ); + assert_eq!( + binary_op( + &DataValue::Int64(None), + &DataValue::Int64(Some(1)), + &BinaryOperator::LtEq + )?, + DataValue::Boolean(None) + ); + assert_eq!( + binary_op( + &DataValue::Int64(None), + &DataValue::Int64(Some(1)), + &BinaryOperator::NotEq + )?, + DataValue::Boolean(None) + ); + + assert_eq!( + binary_op( + &DataValue::Int64(None), + &DataValue::Int64(Some(1)), + &BinaryOperator::Eq + )?, + DataValue::Boolean(None) + ); + assert_eq!( + binary_op( + &DataValue::Int64(None), + &DataValue::Int64(None), + &BinaryOperator::Eq + )?, + DataValue::Boolean(Some(true)) + ); Ok(()) } #[test] fn test_binary_op_f64_compare() -> Result<(), TypeError> { - assert_eq!(binary_op(&DataValue::Float64(Some(1.0)), &DataValue::Float64(Some(0.0)), &BinaryOperator::Gt)?, DataValue::Boolean(Some(true))); - assert_eq!(binary_op(&DataValue::Float64(Some(1.0)), &DataValue::Float64(Some(0.0)), &BinaryOperator::Lt)?, DataValue::Boolean(Some(false))); - assert_eq!(binary_op(&DataValue::Float64(Some(1.0)), &DataValue::Float64(Some(1.0)), &BinaryOperator::GtEq)?, DataValue::Boolean(Some(true))); - assert_eq!(binary_op(&DataValue::Float64(Some(1.0)), &DataValue::Float64(Some(1.0)), &BinaryOperator::LtEq)?, DataValue::Boolean(Some(true))); - assert_eq!(binary_op(&DataValue::Float64(Some(1.0)), &DataValue::Float64(Some(1.0)), &BinaryOperator::NotEq)?, DataValue::Boolean(Some(false))); - assert_eq!(binary_op(&DataValue::Float64(Some(1.0)), &DataValue::Float64(Some(1.0)), &BinaryOperator::Eq)?, DataValue::Boolean(Some(true))); - - assert_eq!(binary_op(&DataValue::Float64(None), &DataValue::Float64(Some(0.0)), &BinaryOperator::Gt)?, DataValue::Boolean(None)); - assert_eq!(binary_op(&DataValue::Float64(None), &DataValue::Float64(Some(0.0)), &BinaryOperator::Lt)?, DataValue::Boolean(None)); - assert_eq!(binary_op(&DataValue::Float64(None), &DataValue::Float64(Some(1.0)), &BinaryOperator::GtEq)?, DataValue::Boolean(None)); - assert_eq!(binary_op(&DataValue::Float64(None), &DataValue::Float64(Some(1.0)), &BinaryOperator::LtEq)?, DataValue::Boolean(None)); - assert_eq!(binary_op(&DataValue::Float64(None), &DataValue::Float64(Some(1.0)), &BinaryOperator::NotEq)?, DataValue::Boolean(None)); - - assert_eq!(binary_op(&DataValue::Float64(None), &DataValue::Float64(Some(1.0)), &BinaryOperator::Eq)?, DataValue::Boolean(None)); - assert_eq!(binary_op(&DataValue::Float64(None), &DataValue::Float64(None), &BinaryOperator::Eq)?, DataValue::Boolean(Some(true))); + assert_eq!( + binary_op( + &DataValue::Float64(Some(1.0)), + &DataValue::Float64(Some(0.0)), + &BinaryOperator::Gt + )?, + DataValue::Boolean(Some(true)) + ); + assert_eq!( + binary_op( + &DataValue::Float64(Some(1.0)), + &DataValue::Float64(Some(0.0)), + &BinaryOperator::Lt + )?, + DataValue::Boolean(Some(false)) + ); + assert_eq!( + binary_op( + &DataValue::Float64(Some(1.0)), + &DataValue::Float64(Some(1.0)), + &BinaryOperator::GtEq + )?, + DataValue::Boolean(Some(true)) + ); + assert_eq!( + binary_op( + &DataValue::Float64(Some(1.0)), + &DataValue::Float64(Some(1.0)), + &BinaryOperator::LtEq + )?, + DataValue::Boolean(Some(true)) + ); + assert_eq!( + binary_op( + &DataValue::Float64(Some(1.0)), + &DataValue::Float64(Some(1.0)), + &BinaryOperator::NotEq + )?, + DataValue::Boolean(Some(false)) + ); + assert_eq!( + binary_op( + &DataValue::Float64(Some(1.0)), + &DataValue::Float64(Some(1.0)), + &BinaryOperator::Eq + )?, + DataValue::Boolean(Some(true)) + ); + + assert_eq!( + binary_op( + &DataValue::Float64(None), + &DataValue::Float64(Some(0.0)), + &BinaryOperator::Gt + )?, + DataValue::Boolean(None) + ); + assert_eq!( + binary_op( + &DataValue::Float64(None), + &DataValue::Float64(Some(0.0)), + &BinaryOperator::Lt + )?, + DataValue::Boolean(None) + ); + assert_eq!( + binary_op( + &DataValue::Float64(None), + &DataValue::Float64(Some(1.0)), + &BinaryOperator::GtEq + )?, + DataValue::Boolean(None) + ); + assert_eq!( + binary_op( + &DataValue::Float64(None), + &DataValue::Float64(Some(1.0)), + &BinaryOperator::LtEq + )?, + DataValue::Boolean(None) + ); + assert_eq!( + binary_op( + &DataValue::Float64(None), + &DataValue::Float64(Some(1.0)), + &BinaryOperator::NotEq + )?, + DataValue::Boolean(None) + ); + + assert_eq!( + binary_op( + &DataValue::Float64(None), + &DataValue::Float64(Some(1.0)), + &BinaryOperator::Eq + )?, + DataValue::Boolean(None) + ); + assert_eq!( + binary_op( + &DataValue::Float64(None), + &DataValue::Float64(None), + &BinaryOperator::Eq + )?, + DataValue::Boolean(Some(true)) + ); Ok(()) } #[test] fn test_binary_op_f32_compare() -> Result<(), TypeError> { - assert_eq!(binary_op(&DataValue::Float32(Some(1.0)), &DataValue::Float32(Some(0.0)), &BinaryOperator::Gt)?, DataValue::Boolean(Some(true))); - assert_eq!(binary_op(&DataValue::Float32(Some(1.0)), &DataValue::Float32(Some(0.0)), &BinaryOperator::Lt)?, DataValue::Boolean(Some(false))); - assert_eq!(binary_op(&DataValue::Float32(Some(1.0)), &DataValue::Float32(Some(1.0)), &BinaryOperator::GtEq)?, DataValue::Boolean(Some(true))); - assert_eq!(binary_op(&DataValue::Float32(Some(1.0)), &DataValue::Float32(Some(1.0)), &BinaryOperator::LtEq)?, DataValue::Boolean(Some(true))); - assert_eq!(binary_op(&DataValue::Float32(Some(1.0)), &DataValue::Float32(Some(1.0)), &BinaryOperator::NotEq)?, DataValue::Boolean(Some(false))); - assert_eq!(binary_op(&DataValue::Float32(Some(1.0)), &DataValue::Float32(Some(1.0)), &BinaryOperator::Eq)?, DataValue::Boolean(Some(true))); - - assert_eq!(binary_op(&DataValue::Float32(None), &DataValue::Float32(Some(0.0)), &BinaryOperator::Gt)?, DataValue::Boolean(None)); - assert_eq!(binary_op(&DataValue::Float32(None), &DataValue::Float32(Some(0.0)), &BinaryOperator::Lt)?, DataValue::Boolean(None)); - assert_eq!(binary_op(&DataValue::Float32(None), &DataValue::Float32(Some(1.0)), &BinaryOperator::GtEq)?, DataValue::Boolean(None)); - assert_eq!(binary_op(&DataValue::Float32(None), &DataValue::Float32(Some(1.0)), &BinaryOperator::LtEq)?, DataValue::Boolean(None)); - assert_eq!(binary_op(&DataValue::Float32(None), &DataValue::Float32(Some(1.0)), &BinaryOperator::NotEq)?, DataValue::Boolean(None)); - - assert_eq!(binary_op(&DataValue::Float32(None), &DataValue::Float32(Some(1.0)), &BinaryOperator::Eq)?, DataValue::Boolean(None)); - assert_eq!(binary_op(&DataValue::Float32(None), &DataValue::Float32(None), &BinaryOperator::Eq)?, DataValue::Boolean(Some(true))); + assert_eq!( + binary_op( + &DataValue::Float32(Some(1.0)), + &DataValue::Float32(Some(0.0)), + &BinaryOperator::Gt + )?, + DataValue::Boolean(Some(true)) + ); + assert_eq!( + binary_op( + &DataValue::Float32(Some(1.0)), + &DataValue::Float32(Some(0.0)), + &BinaryOperator::Lt + )?, + DataValue::Boolean(Some(false)) + ); + assert_eq!( + binary_op( + &DataValue::Float32(Some(1.0)), + &DataValue::Float32(Some(1.0)), + &BinaryOperator::GtEq + )?, + DataValue::Boolean(Some(true)) + ); + assert_eq!( + binary_op( + &DataValue::Float32(Some(1.0)), + &DataValue::Float32(Some(1.0)), + &BinaryOperator::LtEq + )?, + DataValue::Boolean(Some(true)) + ); + assert_eq!( + binary_op( + &DataValue::Float32(Some(1.0)), + &DataValue::Float32(Some(1.0)), + &BinaryOperator::NotEq + )?, + DataValue::Boolean(Some(false)) + ); + assert_eq!( + binary_op( + &DataValue::Float32(Some(1.0)), + &DataValue::Float32(Some(1.0)), + &BinaryOperator::Eq + )?, + DataValue::Boolean(Some(true)) + ); + + assert_eq!( + binary_op( + &DataValue::Float32(None), + &DataValue::Float32(Some(0.0)), + &BinaryOperator::Gt + )?, + DataValue::Boolean(None) + ); + assert_eq!( + binary_op( + &DataValue::Float32(None), + &DataValue::Float32(Some(0.0)), + &BinaryOperator::Lt + )?, + DataValue::Boolean(None) + ); + assert_eq!( + binary_op( + &DataValue::Float32(None), + &DataValue::Float32(Some(1.0)), + &BinaryOperator::GtEq + )?, + DataValue::Boolean(None) + ); + assert_eq!( + binary_op( + &DataValue::Float32(None), + &DataValue::Float32(Some(1.0)), + &BinaryOperator::LtEq + )?, + DataValue::Boolean(None) + ); + assert_eq!( + binary_op( + &DataValue::Float32(None), + &DataValue::Float32(Some(1.0)), + &BinaryOperator::NotEq + )?, + DataValue::Boolean(None) + ); + + assert_eq!( + binary_op( + &DataValue::Float32(None), + &DataValue::Float32(Some(1.0)), + &BinaryOperator::Eq + )?, + DataValue::Boolean(None) + ); + assert_eq!( + binary_op( + &DataValue::Float32(None), + &DataValue::Float32(None), + &BinaryOperator::Eq + )?, + DataValue::Boolean(Some(true)) + ); Ok(()) } #[test] fn test_binary_op_bool_compare() -> Result<(), TypeError> { - assert_eq!(binary_op(&DataValue::Boolean(Some(true)), &DataValue::Boolean(Some(true)), &BinaryOperator::And)?, DataValue::Boolean(Some(true))); - assert_eq!(binary_op(&DataValue::Boolean(Some(false)), &DataValue::Boolean(Some(true)), &BinaryOperator::And)?, DataValue::Boolean(Some(false))); - assert_eq!(binary_op(&DataValue::Boolean(Some(false)), &DataValue::Boolean(Some(false)), &BinaryOperator::And)?, DataValue::Boolean(Some(false))); - - assert_eq!(binary_op(&DataValue::Boolean(None), &DataValue::Boolean(Some(true)), &BinaryOperator::And)?, DataValue::Boolean(None)); - - assert_eq!(binary_op(&DataValue::Boolean(Some(true)), &DataValue::Boolean(Some(true)), &BinaryOperator::Or)?, DataValue::Boolean(Some(true))); - assert_eq!(binary_op(&DataValue::Boolean(Some(false)), &DataValue::Boolean(Some(true)), &BinaryOperator::Or)?, DataValue::Boolean(Some(true))); - assert_eq!(binary_op(&DataValue::Boolean(Some(false)), &DataValue::Boolean(Some(false)), &BinaryOperator::Or)?, DataValue::Boolean(Some(false))); - - assert_eq!(binary_op(&DataValue::Boolean(None), &DataValue::Boolean(Some(true)), &BinaryOperator::Or)?, DataValue::Boolean(None)); + assert_eq!( + binary_op( + &DataValue::Boolean(Some(true)), + &DataValue::Boolean(Some(true)), + &BinaryOperator::And + )?, + DataValue::Boolean(Some(true)) + ); + assert_eq!( + binary_op( + &DataValue::Boolean(Some(false)), + &DataValue::Boolean(Some(true)), + &BinaryOperator::And + )?, + DataValue::Boolean(Some(false)) + ); + assert_eq!( + binary_op( + &DataValue::Boolean(Some(false)), + &DataValue::Boolean(Some(false)), + &BinaryOperator::And + )?, + DataValue::Boolean(Some(false)) + ); + + assert_eq!( + binary_op( + &DataValue::Boolean(None), + &DataValue::Boolean(Some(true)), + &BinaryOperator::And + )?, + DataValue::Boolean(None) + ); + + assert_eq!( + binary_op( + &DataValue::Boolean(Some(true)), + &DataValue::Boolean(Some(true)), + &BinaryOperator::Or + )?, + DataValue::Boolean(Some(true)) + ); + assert_eq!( + binary_op( + &DataValue::Boolean(Some(false)), + &DataValue::Boolean(Some(true)), + &BinaryOperator::Or + )?, + DataValue::Boolean(Some(true)) + ); + assert_eq!( + binary_op( + &DataValue::Boolean(Some(false)), + &DataValue::Boolean(Some(false)), + &BinaryOperator::Or + )?, + DataValue::Boolean(Some(false)) + ); + + assert_eq!( + binary_op( + &DataValue::Boolean(None), + &DataValue::Boolean(Some(true)), + &BinaryOperator::Or + )?, + DataValue::Boolean(None) + ); Ok(()) } - + #[test] - fn test_binary_op_utf8_compare()->Result<(),TypeError>{ - assert_eq!(binary_op(&DataValue::Utf8(Some("a".to_string())), &DataValue::Utf8(Some("b".to_string())), &BinaryOperator::Gt)?, DataValue::Boolean(Some(false))); - assert_eq!(binary_op(&DataValue::Utf8(Some("a".to_string())), &DataValue::Utf8(Some("b".to_string())), &BinaryOperator::Lt)?, DataValue::Boolean(Some(true))); - assert_eq!(binary_op(&DataValue::Utf8(Some("a".to_string())), &DataValue::Utf8(Some("a".to_string())), &BinaryOperator::GtEq)?, DataValue::Boolean(Some(true))); - assert_eq!(binary_op(&DataValue::Utf8(Some("a".to_string())), &DataValue::Utf8(Some("a".to_string())), &BinaryOperator::LtEq)?, DataValue::Boolean(Some(true))); - assert_eq!(binary_op(&DataValue::Utf8(Some("a".to_string())), &DataValue::Utf8(Some("a".to_string())), &BinaryOperator::NotEq)?, DataValue::Boolean(Some(false))); - assert_eq!(binary_op(&DataValue::Utf8(Some("a".to_string())), &DataValue::Utf8(Some("a".to_string())), &BinaryOperator::Eq)?, DataValue::Boolean(Some(true))); - - assert_eq!(binary_op(&DataValue::Utf8(None), &DataValue::Utf8(Some("a".to_string())), &BinaryOperator::Gt)?, DataValue::Boolean(None)); - assert_eq!(binary_op(&DataValue::Utf8(None), &DataValue::Utf8(Some("a".to_string())), &BinaryOperator::Lt)?, DataValue::Boolean(None)); - assert_eq!(binary_op(&DataValue::Utf8(None), &DataValue::Utf8(Some("a".to_string())), &BinaryOperator::GtEq)?, DataValue::Boolean(None)); - assert_eq!(binary_op(&DataValue::Utf8(None), &DataValue::Utf8(Some("a".to_string())), &BinaryOperator::LtEq)?, DataValue::Boolean(None)); - assert_eq!(binary_op(&DataValue::Utf8(None), &DataValue::Utf8(Some("a".to_string())), &BinaryOperator::NotEq)?, DataValue::Boolean(None)); + fn test_binary_op_utf8_compare() -> Result<(), TypeError> { + assert_eq!( + binary_op( + &DataValue::Utf8(Some("a".to_string())), + &DataValue::Utf8(Some("b".to_string())), + &BinaryOperator::Gt + )?, + DataValue::Boolean(Some(false)) + ); + assert_eq!( + binary_op( + &DataValue::Utf8(Some("a".to_string())), + &DataValue::Utf8(Some("b".to_string())), + &BinaryOperator::Lt + )?, + DataValue::Boolean(Some(true)) + ); + assert_eq!( + binary_op( + &DataValue::Utf8(Some("a".to_string())), + &DataValue::Utf8(Some("a".to_string())), + &BinaryOperator::GtEq + )?, + DataValue::Boolean(Some(true)) + ); + assert_eq!( + binary_op( + &DataValue::Utf8(Some("a".to_string())), + &DataValue::Utf8(Some("a".to_string())), + &BinaryOperator::LtEq + )?, + DataValue::Boolean(Some(true)) + ); + assert_eq!( + binary_op( + &DataValue::Utf8(Some("a".to_string())), + &DataValue::Utf8(Some("a".to_string())), + &BinaryOperator::NotEq + )?, + DataValue::Boolean(Some(false)) + ); + assert_eq!( + binary_op( + &DataValue::Utf8(Some("a".to_string())), + &DataValue::Utf8(Some("a".to_string())), + &BinaryOperator::Eq + )?, + DataValue::Boolean(Some(true)) + ); + + assert_eq!( + binary_op( + &DataValue::Utf8(None), + &DataValue::Utf8(Some("a".to_string())), + &BinaryOperator::Gt + )?, + DataValue::Boolean(None) + ); + assert_eq!( + binary_op( + &DataValue::Utf8(None), + &DataValue::Utf8(Some("a".to_string())), + &BinaryOperator::Lt + )?, + DataValue::Boolean(None) + ); + assert_eq!( + binary_op( + &DataValue::Utf8(None), + &DataValue::Utf8(Some("a".to_string())), + &BinaryOperator::GtEq + )?, + DataValue::Boolean(None) + ); + assert_eq!( + binary_op( + &DataValue::Utf8(None), + &DataValue::Utf8(Some("a".to_string())), + &BinaryOperator::LtEq + )?, + DataValue::Boolean(None) + ); + assert_eq!( + binary_op( + &DataValue::Utf8(None), + &DataValue::Utf8(Some("a".to_string())), + &BinaryOperator::NotEq + )?, + DataValue::Boolean(None) + ); Ok(()) } diff --git a/src/lib.rs b/src/lib.rs index 8635388f..ae2e8dd2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,14 +6,14 @@ #![feature(slice_pattern)] #![feature(bound_map)] extern crate core; - pub mod binder; pub mod catalog; pub mod db; +pub mod execution; pub mod expression; +pub mod marco; +mod optimizer; pub mod parser; pub mod planner; -pub mod types; -mod optimizer; -pub mod execution; pub mod storage; +pub mod types; diff --git a/src/main.rs b/src/main.rs index a62c3050..f35f4c9a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -47,7 +47,7 @@ async fn server_run() -> Result<(), Box> { if input.len() >= 4 && input.to_lowercase()[..4].eq("quit") { println!("{}", BLOOM); - break + break; } match db.run(&input).await { @@ -65,4 +65,4 @@ async fn server_run() -> Result<(), Box> { } Ok(()) -} \ No newline at end of file +} diff --git a/src/marco/mod.rs b/src/marco/mod.rs new file mode 100644 index 00000000..f9ca70e9 --- /dev/null +++ b/src/marco/mod.rs @@ -0,0 +1,121 @@ +/// # Examples +/// +/// ``` +///struct MyStruct { +/// c1: i32, +/// c2: String, +///} +/// +///implement_from_tuple!( +/// MyStruct, ( +/// c1: i32 => |inner: &mut MyStruct, value| { +/// if let DataValue::Int32(Some(val)) = value { +/// inner.c1 = val; +/// } +/// }, +/// c2: String => |inner: &mut MyStruct, value| { +/// if let DataValue::Utf8(Some(val)) = value { +/// inner.c2 = val; +/// } +/// } +/// ) +/// ); +/// ``` +#[macro_export] +macro_rules! implement_from_tuple { + ($struct_name:ident, ($($field_name:ident : $field_type:ty => $closure:expr),+)) => { + impl From for $struct_name { + fn from(tuple: Tuple) -> Self { + fn try_get(tuple: &Tuple, field_name: &str) -> Option { + let ty = LogicalType::type_trans::()?; + let (idx, _) = tuple.columns + .iter() + .enumerate() + .find(|(_, col)| &col.name == field_name)?; + + DataValue::clone(&tuple.values[idx]) + .cast(&ty) + .ok() + } + + let mut struct_instance = $struct_name::default(); + $( + if let Some(value) = try_get::<$field_type>(&tuple, stringify!($field_name)) { + $closure( + &mut struct_instance, + value + ); + } + )+ + struct_instance + } + } + }; +} + +#[cfg(test)] +mod test { + use crate::catalog::{ColumnCatalog, ColumnDesc}; + use crate::types::tuple::Tuple; + use crate::types::value::DataValue; + use crate::types::LogicalType; + use std::sync::Arc; + + fn build_tuple() -> Tuple { + let columns = vec![ + Arc::new(ColumnCatalog::new( + "c1".to_string(), + false, + ColumnDesc::new(LogicalType::Integer, true, false), + None, + )), + Arc::new(ColumnCatalog::new( + "c2".to_string(), + false, + ColumnDesc::new(LogicalType::Varchar(None), false, false), + None, + )), + ]; + let values = vec![ + Arc::new(DataValue::Int32(Some(9))), + Arc::new(DataValue::Utf8(Some("LOL".to_string()))), + ]; + + Tuple { + id: None, + columns, + values, + } + } + + #[derive(Default, Debug, PartialEq)] + struct MyStruct { + c1: i32, + c2: String, + } + + implement_from_tuple!( + MyStruct, ( + c1: i32 => |inner: &mut MyStruct, value| { + if let DataValue::Int32(Some(val)) = value { + inner.c1 = val; + } + }, + c2: String => |inner: &mut MyStruct, value| { + if let DataValue::Utf8(Some(val)) = value { + inner.c2 = val; + } + } + ) + ); + + #[test] + fn test_from_tuple() { + let my_struct = MyStruct::from(build_tuple()); + + println!("{:?}", my_struct); + + assert_eq!(my_struct.c1, 9); + assert_eq!(my_struct.c2, "LOL"); + } +} diff --git a/src/optimizer/core/mod.rs b/src/optimizer/core/mod.rs index a91a2e4a..8dd57c35 100644 --- a/src/optimizer/core/mod.rs +++ b/src/optimizer/core/mod.rs @@ -1,3 +1,3 @@ pub(crate) mod opt_expr; pub(crate) mod pattern; -pub(crate) mod rule; \ No newline at end of file +pub(crate) mod rule; diff --git a/src/optimizer/core/opt_expr.rs b/src/optimizer/core/opt_expr.rs index 4f0fba78..0b7136fc 100644 --- a/src/optimizer/core/opt_expr.rs +++ b/src/optimizer/core/opt_expr.rs @@ -1,6 +1,6 @@ -use std::fmt::Debug; -use crate::planner::LogicalPlan; use crate::planner::operator::Operator; +use crate::planner::LogicalPlan; +use std::fmt::Debug; pub type OptExprNodeId = usize; @@ -30,7 +30,6 @@ pub struct OptExpr { pub childrens: Vec, } - impl OptExpr { #[allow(dead_code)] pub fn new(root: OptExprNode, childrens: Vec) -> Self { diff --git a/src/optimizer/core/pattern.rs b/src/optimizer/core/pattern.rs index 21a5fa29..bd49c197 100644 --- a/src/optimizer/core/pattern.rs +++ b/src/optimizer/core/pattern.rs @@ -20,4 +20,4 @@ pub struct Pattern { pub trait PatternMatcher { fn match_opt_expr(&self) -> bool; -} \ No newline at end of file +} diff --git a/src/optimizer/core/rule.rs b/src/optimizer/core/rule.rs index 37bf04cd..9a790ace 100644 --- a/src/optimizer/core/rule.rs +++ b/src/optimizer/core/rule.rs @@ -8,4 +8,4 @@ pub trait Rule { fn pattern(&self) -> &Pattern; fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), OptimizerError>; -} \ No newline at end of file +} diff --git a/src/optimizer/heuristic/batch.rs b/src/optimizer/heuristic/batch.rs index 23cad49c..81f6ed78 100644 --- a/src/optimizer/heuristic/batch.rs +++ b/src/optimizer/heuristic/batch.rs @@ -55,4 +55,4 @@ pub enum HepMatchOrder { /// ancestors. #[allow(dead_code)] BottomUp, -} \ No newline at end of file +} diff --git a/src/optimizer/heuristic/graph.rs b/src/optimizer/heuristic/graph.rs index 7d09e7a3..4f4b9e97 100644 --- a/src/optimizer/heuristic/graph.rs +++ b/src/optimizer/heuristic/graph.rs @@ -1,11 +1,11 @@ -use std::mem; -use itertools::Itertools; -use petgraph::stable_graph::{NodeIndex, StableDiGraph}; -use petgraph::visit::{Bfs, EdgeRef}; use crate::optimizer::core::opt_expr::{OptExprNode, OptExprNodeId}; use crate::optimizer::heuristic::batch::HepMatchOrder; -use crate::planner::LogicalPlan; use crate::planner::operator::Operator; +use crate::planner::LogicalPlan; +use itertools::Itertools; +use petgraph::stable_graph::{NodeIndex, StableDiGraph}; +use petgraph::visit::{Bfs, EdgeRef}; +use std::mem; /// HepNodeId is used in optimizer to identify a node. pub type HepNodeId = NodeIndex; @@ -21,7 +21,10 @@ impl HepGraph { pub fn new(root: LogicalPlan) -> Self { fn graph_filling( graph: &mut StableDiGraph, - LogicalPlan{ operator, childrens }: LogicalPlan, + LogicalPlan { + operator, + childrens, + }: LogicalPlan, ) -> HepNodeId { let index = graph.add_node(OptExprNode::OperatorRef(operator)); @@ -35,10 +38,7 @@ impl HepGraph { let mut graph = StableDiGraph::::default(); - let root_index = graph_filling( - &mut graph, - root, - ); + let root_index = graph_filling(&mut graph, root); HepGraph { graph, @@ -55,28 +55,27 @@ impl HepGraph { #[allow(dead_code)] pub fn add_root(&mut self, new_node: OptExprNode) { - let old_root_id = mem::replace( - &mut self.root_index, - self.graph.add_node(new_node) - ); + let old_root_id = mem::replace(&mut self.root_index, self.graph.add_node(new_node)); self.graph.add_edge(self.root_index, old_root_id, 0); self.version += 1; } - pub fn add_node(&mut self, source_id: HepNodeId, children_option: Option, new_node: OptExprNode) { + pub fn add_node( + &mut self, + source_id: HepNodeId, + children_option: Option, + new_node: OptExprNode, + ) { let new_index = self.graph.add_node(new_node); - let mut order = self.graph - .edges(source_id) - .count(); + let mut order = self.graph.edges(source_id).count(); if let Some(children_id) = children_option { - self.graph.find_edge(source_id, children_id) + self.graph + .find_edge(source_id, children_id) .map(|old_edge_id| { - order = self.graph - .remove_edge(old_edge_id) - .unwrap_or(0); + order = self.graph.remove_edge(old_edge_id).unwrap_or(0); self.graph.add_edge(new_index, children_id, 0); }); @@ -94,24 +93,26 @@ impl HepGraph { pub fn swap_node(&mut self, a: HepNodeId, b: HepNodeId) { let tmp = self.graph[a].clone(); - self.graph[a] = mem::replace( - &mut self.graph[b], - tmp - ); + self.graph[a] = mem::replace(&mut self.graph[b], tmp); self.version += 1; } - pub fn remove_node(&mut self, source_id: HepNodeId, with_childrens: bool) -> Option { + pub fn remove_node( + &mut self, + source_id: HepNodeId, + with_childrens: bool, + ) -> Option { if !with_childrens { - let children_ids = self.graph.edges(source_id) + let children_ids = self + .graph + .edges(source_id) .sorted_by_key(|edge_ref| edge_ref.weight()) .map(|edge_ref| edge_ref.target()) .collect_vec(); if let Some(parent_id) = self.parent_id(source_id) { if let Some(edge) = self.graph.find_edge(parent_id, source_id) { - let weight = *self.graph.edge_weight(edge) - .unwrap_or(&0); + let weight = *self.graph.edge_weight(edge).unwrap_or(&0); for (order, children_id) in children_ids.into_iter().enumerate() { let _ = self.graph.add_edge(parent_id, children_id, weight + order); @@ -138,7 +139,11 @@ impl HepGraph { } /// Use bfs to traverse the graph and return node ids - pub fn nodes_iter(&self, order: HepMatchOrder, start_option: Option) -> Box> { + pub fn nodes_iter( + &self, + order: HepMatchOrder, + start_option: Option, + ) -> Box> { let ids = self.bfs(start_option.unwrap_or(self.root_index)); match order { HepMatchOrder::TopDown => Box::new(ids.into_iter()), @@ -164,8 +169,7 @@ impl HepGraph { /// If input node is join, we use the edge weight to control the join chilren order. pub fn children_at(&self, id: HepNodeId) -> Vec { - self - .graph + self.graph .edges(id) .sorted_by_key(|edge| edge.weight()) .map(|edge| edge.target()) @@ -183,11 +187,7 @@ impl HepGraph { root_plan } - fn build_childrens( - &self, - plan: &mut LogicalPlan, - start: HepNodeId, - ) { + fn build_childrens(&self, plan: &mut LogicalPlan, start: HepNodeId) { for child_id in self.children_at(start) { let mut child_plan = LogicalPlan { operator: self.operator(child_id).clone(), @@ -202,21 +202,27 @@ impl HepGraph { #[cfg(test)] mod tests { - use petgraph::stable_graph::{EdgeIndex, NodeIndex}; use crate::binder::test::select_sql_run; use crate::execution::ExecutorError; use crate::optimizer::core::opt_expr::OptExprNode; use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId}; use crate::planner::operator::Operator; + use petgraph::stable_graph::{EdgeIndex, NodeIndex}; #[tokio::test] async fn test_graph_for_plan() -> Result<(), ExecutorError> { let plan = select_sql_run("select * from t1 left join t2 on c1 = c3").await?; let graph = HepGraph::new(plan); - assert!(graph.graph.contains_edge(NodeIndex::new(1), NodeIndex::new(2))); - assert!(graph.graph.contains_edge(NodeIndex::new(1), NodeIndex::new(3))); - assert!(graph.graph.contains_edge(NodeIndex::new(0), NodeIndex::new(1))); + assert!(graph + .graph + .contains_edge(NodeIndex::new(1), NodeIndex::new(2))); + assert!(graph + .graph + .contains_edge(NodeIndex::new(1), NodeIndex::new(3))); + assert!(graph + .graph + .contains_edge(NodeIndex::new(0), NodeIndex::new(1))); assert_eq!(graph.graph.edge_weight(EdgeIndex::new(0)), Some(&0)); assert_eq!(graph.graph.edge_weight(EdgeIndex::new(1)), Some(&1)); @@ -233,24 +239,30 @@ mod tests { graph.add_node( HepNodeId::new(1), None, - OptExprNode::OperatorRef(Operator::Dummy) + OptExprNode::OperatorRef(Operator::Dummy), ); graph.add_node( HepNodeId::new(1), Some(HepNodeId::new(4)), - OptExprNode::OperatorRef(Operator::Dummy) + OptExprNode::OperatorRef(Operator::Dummy), ); graph.add_node( HepNodeId::new(5), None, - OptExprNode::OperatorRef(Operator::Dummy) + OptExprNode::OperatorRef(Operator::Dummy), ); - assert!(graph.graph.contains_edge(NodeIndex::new(5), NodeIndex::new(4))); - assert!(graph.graph.contains_edge(NodeIndex::new(1), NodeIndex::new(5))); - assert!(graph.graph.contains_edge(NodeIndex::new(5), NodeIndex::new(6))); + assert!(graph + .graph + .contains_edge(NodeIndex::new(5), NodeIndex::new(4))); + assert!(graph + .graph + .contains_edge(NodeIndex::new(1), NodeIndex::new(5))); + assert!(graph + .graph + .contains_edge(NodeIndex::new(5), NodeIndex::new(6))); assert_eq!(graph.graph.edge_weight(EdgeIndex::new(3)), Some(&0)); assert_eq!(graph.graph.edge_weight(EdgeIndex::new(4)), Some(&2)); @@ -280,8 +292,12 @@ mod tests { assert_eq!(graph.graph.edge_count(), 2); - assert!(graph.graph.contains_edge(NodeIndex::new(0), NodeIndex::new(2))); - assert!(graph.graph.contains_edge(NodeIndex::new(0), NodeIndex::new(3))); + assert!(graph + .graph + .contains_edge(NodeIndex::new(0), NodeIndex::new(2))); + assert!(graph + .graph + .contains_edge(NodeIndex::new(0), NodeIndex::new(3))); Ok(()) } @@ -325,7 +341,9 @@ mod tests { graph.add_root(OptExprNode::OperatorRef(Operator::Dummy)); assert_eq!(graph.graph.edge_count(), 4); - assert!(graph.graph.contains_edge(NodeIndex::new(4), NodeIndex::new(0))); + assert!(graph + .graph + .contains_edge(NodeIndex::new(4), NodeIndex::new(0))); assert_eq!(graph.graph.edge_weight(EdgeIndex::new(3)), Some(&0)); Ok(()) @@ -346,4 +364,4 @@ mod tests { Ok(()) } -} \ No newline at end of file +} diff --git a/src/optimizer/heuristic/matcher.rs b/src/optimizer/heuristic/matcher.rs index c88f9bf7..2911a7c7 100644 --- a/src/optimizer/heuristic/matcher.rs +++ b/src/optimizer/heuristic/matcher.rs @@ -30,7 +30,10 @@ impl PatternMatcher for HepMatcher<'_, '_> { match &self.pattern.children { PatternChildrenPredicate::Recursive => { // check - for node_id in self.graph.nodes_iter(HepMatchOrder::TopDown, Some(self.start_id)) { + for node_id in self + .graph + .nodes_iter(HepMatchOrder::TopDown, Some(self.start_id)) + { if !(self.pattern.predicate)(&self.graph.operator(node_id)) { return false; } @@ -45,7 +48,7 @@ impl PatternMatcher for HepMatcher<'_, '_> { } } } - PatternChildrenPredicate::None => () + PatternChildrenPredicate::None => (), } true @@ -59,8 +62,8 @@ mod tests { use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate, PatternMatcher}; use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId}; use crate::optimizer::heuristic::matcher::HepMatcher; - use crate::planner::LogicalPlan; use crate::planner::operator::Operator; + use crate::planner::LogicalPlan; #[tokio::test] async fn test_predicate() -> Result<(), ExecutorError> { @@ -96,17 +99,15 @@ mod tests { childrens: vec![ LogicalPlan { operator: Operator::Dummy, - childrens: vec![ - LogicalPlan { - operator: Operator::Dummy, - childrens: vec![], - } - ], + childrens: vec![LogicalPlan { + operator: Operator::Dummy, + childrens: vec![], + }], }, LogicalPlan { operator: Operator::Dummy, childrens: vec![], - } + }, ], }; let graph = HepGraph::new(all_dummy_plan.clone()); @@ -119,9 +120,6 @@ mod tests { children: PatternChildrenPredicate::Recursive, }; - assert!( - HepMatcher::new(&only_dummy_pattern, HepNodeId::new(0), &graph) - .match_opt_expr() - ); + assert!(HepMatcher::new(&only_dummy_pattern, HepNodeId::new(0), &graph).match_opt_expr()); } -} \ No newline at end of file +} diff --git a/src/optimizer/heuristic/mod.rs b/src/optimizer/heuristic/mod.rs index e23d90d8..b9fd5abb 100644 --- a/src/optimizer/heuristic/mod.rs +++ b/src/optimizer/heuristic/mod.rs @@ -1,4 +1,4 @@ pub(crate) mod batch; pub(crate) mod graph; pub(crate) mod matcher; -pub mod optimizer; \ No newline at end of file +pub mod optimizer; diff --git a/src/optimizer/heuristic/optimizer.rs b/src/optimizer/heuristic/optimizer.rs index 081c313d..ee266b2b 100644 --- a/src/optimizer/heuristic/optimizer.rs +++ b/src/optimizer/heuristic/optimizer.rs @@ -3,8 +3,8 @@ use crate::optimizer::core::rule::Rule; use crate::optimizer::heuristic::batch::{HepBatch, HepBatchStrategy}; use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId}; use crate::optimizer::heuristic::matcher::HepMatcher; -use crate::optimizer::OptimizerError; use crate::optimizer::rule::RuleImpl; +use crate::optimizer::OptimizerError; use crate::planner::LogicalPlan; pub struct HepOptimizer { @@ -44,7 +44,12 @@ impl HepOptimizer { Ok(self.graph.to_plan()) } - fn apply_batch(&mut self, HepBatch{ rules, strategy, .. }: &HepBatch) -> Result { + fn apply_batch( + &mut self, + HepBatch { + rules, strategy, .. + }: &HepBatch, + ) -> Result { let start_ver = self.graph.version; for rule in rules { @@ -67,5 +72,4 @@ impl HepOptimizer { Ok(after_version != self.graph.version) } - -} \ No newline at end of file +} diff --git a/src/optimizer/mod.rs b/src/optimizer/mod.rs index 8f937373..28dcdaa0 100644 --- a/src/optimizer/mod.rs +++ b/src/optimizer/mod.rs @@ -2,7 +2,6 @@ use crate::types::errors::TypeError; /// The architecture and some components, /// such as (/core) are referenced from sqlrs - mod core; pub mod heuristic; pub mod rule; @@ -13,6 +12,6 @@ pub enum OptimizerError { TypeError( #[source] #[from] - TypeError - ) -} \ No newline at end of file + TypeError, + ), +} diff --git a/src/optimizer/rule/column_pruning.rs b/src/optimizer/rule/column_pruning.rs index df9328eb..e07e83e8 100644 --- a/src/optimizer/rule/column_pruning.rs +++ b/src/optimizer/rule/column_pruning.rs @@ -1,5 +1,3 @@ -use itertools::Itertools; -use lazy_static::lazy_static; use crate::catalog::ColumnRef; use crate::expression::ScalarExpression; use crate::optimizer::core::opt_expr::OptExprNode; @@ -8,8 +6,10 @@ use crate::optimizer::core::rule::Rule; use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId}; use crate::optimizer::OptimizerError; use crate::planner::operator::aggregate::AggregateOperator; -use crate::planner::operator::Operator; use crate::planner::operator::project::ProjectOperator; +use crate::planner::operator::Operator; +use itertools::Itertools; +use lazy_static::lazy_static; lazy_static! { static ref PUSH_PROJECT_INTO_SCAN_RULE: Pattern = { @@ -21,7 +21,6 @@ lazy_static! { }]), } }; - static ref PUSH_PROJECT_THROUGH_CHILD_RULE: Pattern = { Pattern { predicate: |op| matches!(op, Operator::Project(_)), @@ -50,16 +49,17 @@ impl Rule for PushProjectIntoScan { if let Operator::Scan(scan_op) = graph.operator(child_index) { let mut new_scan_op = scan_op.clone(); - new_scan_op.columns = project_op.columns + new_scan_op.columns = project_op + .columns .iter() - .filter(|expr| matches!(expr.unpack_alias(),ScalarExpression::ColumnRef(_))) + .filter(|expr| matches!(expr.unpack_alias(), ScalarExpression::ColumnRef(_))) .cloned() .collect_vec(); graph.remove_node(node_id, false); graph.replace_node( child_index, - OptExprNode::OperatorRef(Operator::Scan(new_scan_op)) + OptExprNode::OperatorRef(Operator::Scan(new_scan_op)), ); } } @@ -125,18 +125,16 @@ impl Rule for PushProjectThroughChild { .collect_vec(); for grandson_id in graph.children_at(child_index) { - let grandson_referenced_column = graph - .operator(grandson_id) - .referenced_columns(); + let grandson_referenced_column = + graph.operator(grandson_id).referenced_columns(); // for PushLimitThroughJoin if grandson_referenced_column.is_empty() { - return Ok(()) + return Ok(()); } - let grandson_table_name = grandson_referenced_column[0] - .table_name - .clone(); - let columns = parent_referenced_columns.iter() + let grandson_table_name = grandson_referenced_column[0].table_name.clone(); + let columns = parent_referenced_columns + .iter() .filter(|col| col.table_name == grandson_table_name) .cloned() .map(|col| ScalarExpression::ColumnRef(col)) @@ -149,7 +147,7 @@ impl Rule for PushProjectThroughChild { let grandson_ids = graph.children_at(child_index); if grandson_ids.is_empty() { - return Ok(()) + return Ok(()); } let grandson_id = grandson_ids[0]; let mut columns = node_operator.project_input_refs(); @@ -172,14 +170,17 @@ impl Rule for PushProjectThroughChild { } impl PushProjectThroughChild { - fn add_project_node(graph: &mut HepGraph, child_index: HepNodeId, columns: Vec, grandson_id: HepNodeId) { + fn add_project_node( + graph: &mut HepGraph, + child_index: HepNodeId, + columns: Vec, + grandson_id: HepNodeId, + ) { if !columns.is_empty() { graph.add_node( child_index, Some(grandson_id), - OptExprNode::OperatorRef( - Operator::Project(ProjectOperator { columns }) - ) + OptExprNode::OperatorRef(Operator::Project(ProjectOperator { columns })), ); } } @@ -189,7 +190,7 @@ impl PushProjectThroughChild { mod tests { use crate::binder::test::select_sql_run; use crate::db::DatabaseError; - use crate::optimizer::heuristic::batch::{HepBatchStrategy}; + use crate::optimizer::heuristic::batch::HepBatchStrategy; use crate::optimizer::heuristic::optimizer::HepOptimizer; use crate::optimizer::rule::RuleImpl; use crate::planner::operator::join::JoinCondition; @@ -203,7 +204,7 @@ mod tests { .batch( "test_project_into_table_scan".to_string(), HepBatchStrategy::once_topdown(), - vec![RuleImpl::PushProjectIntoScan] + vec![RuleImpl::PushProjectIntoScan], ) .find_best()?; @@ -211,7 +212,7 @@ mod tests { match best_plan.operator { Operator::Scan(op) => { assert_eq!(op.columns.len(), 2); - }, + } _ => unreachable!("Should be a scan operator"), } @@ -228,26 +229,25 @@ mod tests { HepBatchStrategy::fix_point_topdown(10), vec![ RuleImpl::PushProjectThroughChild, - RuleImpl::PushProjectIntoScan - ] - ).find_best()?; + RuleImpl::PushProjectIntoScan, + ], + ) + .find_best()?; assert_eq!(best_plan.childrens.len(), 1); match best_plan.operator { Operator::Project(op) => { assert_eq!(op.columns.len(), 2); - }, + } _ => unreachable!("Should be a project operator"), } match &best_plan.childrens[0].operator { - Operator::Join(op) => { - match &op.on { - JoinCondition::On { on, filter } => { - assert_eq!(on.len(), 1); - assert!(filter.is_none()); - } - _ => unreachable!("Should be a on condition"), + Operator::Join(op) => match &op.on { + JoinCondition::On { on, filter } => { + assert_eq!(on.len(), 1); + assert!(filter.is_none()); } + _ => unreachable!("Should be a on condition"), }, _ => unreachable!("Should be a join operator"), } @@ -258,11 +258,11 @@ mod tests { match &grandson_plan.operator { Operator::Scan(op) => { assert_eq!(op.columns.len(), 1); - }, + } _ => unreachable!("Should be a scan operator"), } } Ok(()) } -} \ No newline at end of file +} diff --git a/src/optimizer/rule/combine_operators.rs b/src/optimizer/rule/combine_operators.rs index bf22a66c..b2903b3e 100644 --- a/src/optimizer/rule/combine_operators.rs +++ b/src/optimizer/rule/combine_operators.rs @@ -1,14 +1,14 @@ -use lazy_static::lazy_static; use crate::expression::{BinaryOperator, ScalarExpression}; use crate::optimizer::core::opt_expr::OptExprNode; use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate}; use crate::optimizer::core::rule::Rule; use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId}; -use crate::optimizer::OptimizerError; use crate::optimizer::rule::is_subset_exprs; +use crate::optimizer::OptimizerError; use crate::planner::operator::filter::FilterOperator; use crate::planner::operator::Operator; use crate::types::LogicalType; +use lazy_static::lazy_static; lazy_static! { static ref COLLAPSE_PROJECT_RULE: Pattern = { @@ -76,7 +76,7 @@ impl Rule for CombineFilter { }; graph.replace_node( node_id, - OptExprNode::OperatorRef(Operator::Filter(new_filter_op)) + OptExprNode::OperatorRef(Operator::Filter(new_filter_op)), ); graph.remove_node(child_id, false); } @@ -88,34 +88,31 @@ impl Rule for CombineFilter { #[cfg(test)] mod tests { - use std::sync::Arc; use crate::binder::test::select_sql_run; use crate::db::DatabaseError; - use crate::expression::{BinaryOperator, ScalarExpression}; use crate::expression::ScalarExpression::Constant; + use crate::expression::{BinaryOperator, ScalarExpression}; use crate::optimizer::core::opt_expr::OptExprNode; use crate::optimizer::heuristic::batch::HepBatchStrategy; use crate::optimizer::heuristic::graph::HepNodeId; use crate::optimizer::heuristic::optimizer::HepOptimizer; use crate::optimizer::rule::RuleImpl; use crate::planner::operator::Operator; - use crate::types::LogicalType; use crate::types::value::DataValue; + use crate::types::LogicalType; + use std::sync::Arc; #[tokio::test] async fn test_collapse_project() -> Result<(), DatabaseError> { let plan = select_sql_run("select c1, c2 from t1").await?; - let mut optimizer = HepOptimizer::new(plan.clone()) - .batch( - "test_collapse_project".to_string(), - HepBatchStrategy::once_topdown(), - vec![RuleImpl::CollapseProject] - ); + let mut optimizer = HepOptimizer::new(plan.clone()).batch( + "test_collapse_project".to_string(), + HepBatchStrategy::once_topdown(), + vec![RuleImpl::CollapseProject], + ); - let mut new_project_op = optimizer.graph - .operator(HepNodeId::new(0)) - .clone(); + let mut new_project_op = optimizer.graph.operator(HepNodeId::new(0)).clone(); if let Operator::Project(op) = &mut new_project_op { op.columns.remove(0); @@ -123,19 +120,21 @@ mod tests { unreachable!("Should be a project operator") } - optimizer.graph.add_root(OptExprNode::OperatorRef(new_project_op)); + optimizer + .graph + .add_root(OptExprNode::OperatorRef(new_project_op)); let best_plan = optimizer.find_best()?; if let Operator::Project(op) = &best_plan.operator { assert_eq!(op.columns.len(), 1); - } else { + } else { unreachable!("Should be a project operator") } if let Operator::Scan(_) = &best_plan.childrens[0].operator { assert_eq!(best_plan.childrens[0].childrens.len(), 0) - } else { + } else { unreachable!("Should be a scan operator") } @@ -146,16 +145,13 @@ mod tests { async fn test_combine_filter() -> Result<(), DatabaseError> { let plan = select_sql_run("select * from t1 where c1 > 1").await?; - let mut optimizer = HepOptimizer::new(plan.clone()) - .batch( - "test_combine_filter".to_string(), - HepBatchStrategy::once_topdown(), - vec![RuleImpl::CombineFilter] - ); + let mut optimizer = HepOptimizer::new(plan.clone()).batch( + "test_combine_filter".to_string(), + HepBatchStrategy::once_topdown(), + vec![RuleImpl::CombineFilter], + ); - let mut new_filter_op = optimizer.graph - .operator(HepNodeId::new(1)) - .clone(); + let mut new_filter_op = optimizer.graph.operator(HepNodeId::new(1)).clone(); if let Operator::Filter(op) = &mut new_filter_op { op.predicate = ScalarExpression::Binary { @@ -171,7 +167,7 @@ mod tests { optimizer.graph.add_node( HepNodeId::new(0), Some(HepNodeId::new(1)), - OptExprNode::OperatorRef(new_filter_op) + OptExprNode::OperatorRef(new_filter_op), ); let best_plan = optimizer.find_best()?; @@ -188,4 +184,4 @@ mod tests { Ok(()) } -} \ No newline at end of file +} diff --git a/src/optimizer/rule/mod.rs b/src/optimizer/rule/mod.rs index bca3f4c6..846aa27c 100644 --- a/src/optimizer/rule/mod.rs +++ b/src/optimizer/rule/mod.rs @@ -2,13 +2,15 @@ use crate::expression::ScalarExpression; use crate::optimizer::core::pattern::Pattern; use crate::optimizer::core::rule::Rule; use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId}; -use crate::optimizer::OptimizerError; use crate::optimizer::rule::column_pruning::{PushProjectIntoScan, PushProjectThroughChild}; use crate::optimizer::rule::combine_operators::{CollapseProject, CombineFilter}; -use crate::optimizer::rule::pushdown_limit::{LimitProjectTranspose, EliminateLimits, PushLimitThroughJoin, PushLimitIntoScan}; -use crate::optimizer::rule::pushdown_predicates::PushPredicateThroughJoin; +use crate::optimizer::rule::pushdown_limit::{ + EliminateLimits, LimitProjectTranspose, PushLimitIntoScan, PushLimitThroughJoin, +}; use crate::optimizer::rule::pushdown_predicates::PushPredicateIntoScan; +use crate::optimizer::rule::pushdown_predicates::PushPredicateThroughJoin; use crate::optimizer::rule::simplification::SimplifyFilter; +use crate::optimizer::OptimizerError; mod column_pruning; mod combine_operators; @@ -34,7 +36,7 @@ pub enum RuleImpl { // Tips: need to be used with `SimplifyFilter` PushPredicateIntoScan, // Simplification - SimplifyFilter + SimplifyFilter, } impl Rule for RuleImpl { @@ -54,7 +56,7 @@ impl Rule for RuleImpl { } } - fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), OptimizerError>{ + fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), OptimizerError> { match self { RuleImpl::PushProjectIntoScan => PushProjectIntoScan {}.apply(node_id, graph), RuleImpl::PushProjectThroughChild => PushProjectThroughChild {}.apply(node_id, graph), @@ -66,7 +68,7 @@ impl Rule for RuleImpl { RuleImpl::PushLimitIntoTableScan => PushLimitIntoScan {}.apply(node_id, graph), RuleImpl::PushPredicateThroughJoin => PushPredicateThroughJoin {}.apply(node_id, graph), RuleImpl::SimplifyFilter => SimplifyFilter {}.apply(node_id, graph), - RuleImpl::PushPredicateIntoScan => PushPredicateIntoScan {}.apply(node_id, graph) + RuleImpl::PushPredicateIntoScan => PushPredicateIntoScan {}.apply(node_id, graph), } } } @@ -74,4 +76,4 @@ impl Rule for RuleImpl { /// Return true when left is subset of right pub fn is_subset_exprs(left: &[ScalarExpression], right: &[ScalarExpression]) -> bool { left.iter().all(|l| right.contains(l)) -} \ No newline at end of file +} diff --git a/src/optimizer/rule/pushdown_limit.rs b/src/optimizer/rule/pushdown_limit.rs index 8fd161f0..3175c654 100644 --- a/src/optimizer/rule/pushdown_limit.rs +++ b/src/optimizer/rule/pushdown_limit.rs @@ -1,14 +1,14 @@ -use std::cmp; -use lazy_static::lazy_static; use crate::optimizer::core::opt_expr::OptExprNode; -use crate::optimizer::core::pattern::PatternChildrenPredicate; use crate::optimizer::core::pattern::Pattern; +use crate::optimizer::core::pattern::PatternChildrenPredicate; use crate::optimizer::core::rule::Rule; use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId}; use crate::optimizer::OptimizerError; use crate::planner::operator::join::JoinType; use crate::planner::operator::limit::LimitOperator; use crate::planner::operator::Operator; +use lazy_static::lazy_static; +use std::cmp; lazy_static! { static ref LIMIT_PROJECT_TRANSPOSE_RULE: Pattern = { Pattern { @@ -56,10 +56,7 @@ impl Rule for LimitProjectTranspose { } fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), OptimizerError> { - graph.swap_node( - node_id, - graph.children_at(node_id)[0] - ); + graph.swap_node(node_id, graph.children_at(node_id)[0]); Ok(()) } @@ -86,9 +83,7 @@ impl Rule for EliminateLimits { graph.remove_node(child_id, false); graph.replace_node( node_id, - OptExprNode::OperatorRef( - Operator::Limit(new_limit_op) - ) + OptExprNode::OperatorRef(Operator::Limit(new_limit_op)), ); } } @@ -123,12 +118,12 @@ impl Rule for PushLimitThroughJoin { if let Some(grandson_id) = match ty { JoinType::Left => Some(graph.children_at(child_id)[0]), JoinType::Right => Some(graph.children_at(child_id)[1]), - _ => None + _ => None, } { graph.add_node( child_id, Some(grandson_id), - OptExprNode::OperatorRef(Operator::Limit(op.clone())) + OptExprNode::OperatorRef(Operator::Limit(op.clone())), ); } } @@ -157,7 +152,7 @@ impl Rule for PushLimitIntoScan { graph.remove_node(node_id, false); graph.replace_node( child_index, - OptExprNode::OperatorRef(Operator::Scan(new_scan_op)) + OptExprNode::OperatorRef(Operator::Scan(new_scan_op)), ); } } @@ -185,18 +180,16 @@ mod tests { .batch( "test_limit_project_transpose".to_string(), HepBatchStrategy::once_topdown(), - vec![RuleImpl::LimitProjectTranspose] + vec![RuleImpl::LimitProjectTranspose], ) - .find_best()?; + .find_best()?; if let Operator::Project(_) = &best_plan.operator { - } else { unreachable!("Should be a project operator") } if let Operator::Limit(_) = &best_plan.childrens[0].operator { - } else { unreachable!("Should be a limit operator") } @@ -208,21 +201,20 @@ mod tests { async fn test_eliminate_limits() -> Result<(), DatabaseError> { let plan = select_sql_run("select c1, c2 from t1 limit 1 offset 1").await?; - let mut optimizer = HepOptimizer::new(plan.clone()) - .batch( - "test_eliminate_limits".to_string(), - HepBatchStrategy::once_topdown(), - vec![RuleImpl::EliminateLimits] - ); + let mut optimizer = HepOptimizer::new(plan.clone()).batch( + "test_eliminate_limits".to_string(), + HepBatchStrategy::once_topdown(), + vec![RuleImpl::EliminateLimits], + ); let new_limit_op = LimitOperator { offset: 2, limit: 1, }; - optimizer.graph.add_root( - OptExprNode::OperatorRef(Operator::Limit(new_limit_op)) - ); + optimizer + .graph + .add_root(OptExprNode::OperatorRef(Operator::Limit(new_limit_op))); let best_plan = optimizer.find_best()?; @@ -250,8 +242,8 @@ mod tests { HepBatchStrategy::once_topdown(), vec![ RuleImpl::LimitProjectTranspose, - RuleImpl::PushLimitThroughJoin - ] + RuleImpl::PushLimitThroughJoin, + ], ) .find_best()?; @@ -279,8 +271,8 @@ mod tests { HepBatchStrategy::once_topdown(), vec![ RuleImpl::LimitProjectTranspose, - RuleImpl::PushLimitIntoTableScan - ] + RuleImpl::PushLimitIntoTableScan, + ], ) .find_best()?; diff --git a/src/optimizer/rule/pushdown_predicates.rs b/src/optimizer/rule/pushdown_predicates.rs index 1cef28e8..c3268eeb 100644 --- a/src/optimizer/rule/pushdown_predicates.rs +++ b/src/optimizer/rule/pushdown_predicates.rs @@ -1,17 +1,17 @@ -use itertools::Itertools; -use lazy_static::lazy_static; use crate::catalog::ColumnRef; use crate::expression::{BinaryOperator, ScalarExpression}; use crate::optimizer::core::opt_expr::OptExprNode; use crate::optimizer::core::pattern::Pattern; +use crate::optimizer::core::pattern::PatternChildrenPredicate; use crate::optimizer::core::rule::Rule; use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId}; -use crate::optimizer::core::pattern::PatternChildrenPredicate; use crate::optimizer::OptimizerError; use crate::planner::operator::filter::FilterOperator; use crate::planner::operator::join::JoinType; use crate::planner::operator::Operator; use crate::types::LogicalType; +use itertools::Itertools; +use lazy_static::lazy_static; lazy_static! { static ref PUSH_PREDICATE_THROUGH_JOIN: Pattern = { @@ -57,7 +57,7 @@ fn split_conjunctive_predicates(expr: &ScalarExpression) -> Vec vec![expr.clone()] + _ => vec![expr.clone()], } } @@ -66,17 +66,15 @@ fn split_conjunctive_predicates(expr: &ScalarExpression) -> Vec, having: bool) -> Option { filters .into_iter() - .reduce(|a, b| { - ScalarExpression::Binary { - op: BinaryOperator::And, - left_expr: Box::new(a), - right_expr: Box::new(b), - ty: LogicalType::Boolean - } + .reduce(|a, b| ScalarExpression::Binary { + op: BinaryOperator::And, + left_expr: Box::new(a), + right_expr: Box::new(b), + ty: LogicalType::Boolean, }) .map(|f| FilterOperator { predicate: f, - having + having, }) } @@ -106,17 +104,16 @@ impl Rule for PushPredicateThroughJoin { fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), OptimizerError> { let child_id = graph.children_at(node_id)[0]; if let Operator::Join(child_op) = graph.operator(child_id) { - if !matches!(child_op.join_type, JoinType::Inner | JoinType::Left | JoinType::Right) { + if !matches!( + child_op.join_type, + JoinType::Inner | JoinType::Left | JoinType::Right + ) { return Ok(()); } let join_childs = graph.children_at(child_id); - let left_columns = graph - .operator(join_childs[0]) - .referenced_columns(); - let right_columns = graph - .operator(join_childs[1]) - .referenced_columns(); + let left_columns = graph.operator(join_childs[0]).referenced_columns(); + let right_columns = graph.operator(join_childs[1]).referenced_columns(); let mut new_ops = (None, None, None); @@ -139,7 +136,8 @@ impl Rule for PushPredicateThroughJoin { } if !right_filters.is_empty() { - if let Some(right_filter_op) = reduce_filters(right_filters, op.having) { + if let Some(right_filter_op) = reduce_filters(right_filters, op.having) + { new_ops.1 = Some(Operator::Filter(right_filter_op)); } } @@ -160,17 +158,15 @@ impl Rule for PushPredicateThroughJoin { } JoinType::Right => { if !right_filters.is_empty() { - if let Some(right_filter_op) = reduce_filters(right_filters, op.having) { + if let Some(right_filter_op) = reduce_filters(right_filters, op.having) + { new_ops.1 = Some(Operator::Filter(right_filter_op)); } } - common_filters - .into_iter() - .chain(left_filters) - .collect_vec() + common_filters.into_iter().chain(left_filters).collect_vec() } - _ => vec![] + _ => vec![], }; if !replace_filters.is_empty() { @@ -184,7 +180,7 @@ impl Rule for PushPredicateThroughJoin { graph.add_node( child_id, Some(join_childs[0]), - OptExprNode::OperatorRef(left_op) + OptExprNode::OperatorRef(left_op), ); } @@ -192,15 +188,12 @@ impl Rule for PushPredicateThroughJoin { graph.add_node( child_id, Some(join_childs[1]), - OptExprNode::OperatorRef(right_op) + OptExprNode::OperatorRef(right_op), ); } if let Some(common_op) = new_ops.2 { - graph.replace_node( - node_id, - OptExprNode::OperatorRef(common_op) - ); + graph.replace_node(node_id, OptExprNode::OperatorRef(common_op)); } else { graph.remove_node(node_id, false); } @@ -210,9 +203,7 @@ impl Rule for PushPredicateThroughJoin { } } -pub struct PushPredicateIntoScan { - -} +pub struct PushPredicateIntoScan {} impl Rule for PushPredicateIntoScan { fn pattern(&self) -> &Pattern { @@ -224,7 +215,7 @@ impl Rule for PushPredicateIntoScan { let child_id = graph.children_at(node_id)[0]; if let Operator::Scan(child_op) = graph.operator(child_id) { if child_op.index_by.is_some() { - return Ok(()) + return Ok(()); } //FIXME: now only support unique @@ -243,12 +234,10 @@ impl Rule for PushPredicateIntoScan { // reduce the data scanning range and cannot replace the role of Filter. graph.replace_node( child_id, - OptExprNode::OperatorRef( - Operator::Scan(scan_by_index) - ) + OptExprNode::OperatorRef(Operator::Scan(scan_by_index)), ); - return Ok(()) + return Ok(()); } } } @@ -261,18 +250,18 @@ impl Rule for PushPredicateIntoScan { #[cfg(test)] mod tests { - use std::collections::Bound; - use std::sync::Arc; use crate::binder::test::select_sql_run; use crate::db::DatabaseError; - use crate::expression::{BinaryOperator, ScalarExpression}; use crate::expression::simplify::ConstantBinary::Scope; + use crate::expression::{BinaryOperator, ScalarExpression}; use crate::optimizer::heuristic::batch::HepBatchStrategy; use crate::optimizer::heuristic::optimizer::HepOptimizer; use crate::optimizer::rule::RuleImpl; use crate::planner::operator::Operator; - use crate::types::LogicalType; use crate::types::value::DataValue; + use crate::types::LogicalType; + use std::collections::Bound; + use std::sync::Arc; #[tokio::test] async fn test_push_predicate_into_scan() -> Result<(), DatabaseError> { @@ -283,19 +272,19 @@ mod tests { .batch( "simplify_filter".to_string(), HepBatchStrategy::once_topdown(), - vec![RuleImpl::SimplifyFilter] + vec![RuleImpl::SimplifyFilter], ) .batch( "test_push_predicate_into_scan".to_string(), HepBatchStrategy::once_topdown(), - vec![RuleImpl::PushPredicateIntoScan] + vec![RuleImpl::PushPredicateIntoScan], ) .find_best()?; if let Operator::Scan(op) = &best_plan.childrens[0].childrens[0].operator { let mock_binaries = vec![Scope { min: Bound::Excluded(Arc::new(DataValue::Int32(Some(1)))), - max: Bound::Unbounded + max: Bound::Unbounded, }]; assert_eq!(op.index_by.clone().unwrap().1, mock_binaries); @@ -308,13 +297,15 @@ mod tests { #[tokio::test] async fn test_push_predicate_through_join_in_left_join() -> Result<(), DatabaseError> { - let plan = select_sql_run("select * from t1 left join t2 on c1 = c3 where c1 > 1 and c3 < 2").await?; + let plan = + select_sql_run("select * from t1 left join t2 on c1 = c3 where c1 > 1 and c3 < 2") + .await?; let best_plan = HepOptimizer::new(plan) .batch( "test_push_predicate_through_join".to_string(), HepBatchStrategy::once_topdown(), - vec![RuleImpl::PushPredicateThroughJoin] + vec![RuleImpl::PushPredicateThroughJoin], ) .find_best()?; @@ -325,7 +316,7 @@ mod tests { ty: LogicalType::Boolean, .. } => (), - _ => unreachable!() + _ => unreachable!(), } } else { unreachable!("Should be a filter operator") @@ -338,7 +329,7 @@ mod tests { ty: LogicalType::Boolean, .. } => (), - _ => unreachable!() + _ => unreachable!(), } } else { unreachable!("Should be a filter operator") @@ -349,13 +340,15 @@ mod tests { #[tokio::test] async fn test_push_predicate_through_join_in_right_join() -> Result<(), DatabaseError> { - let plan = select_sql_run("select * from t1 right join t2 on c1 = c3 where c1 > 1 and c3 < 2").await?; + let plan = + select_sql_run("select * from t1 right join t2 on c1 = c3 where c1 > 1 and c3 < 2") + .await?; let best_plan = HepOptimizer::new(plan) .batch( "test_push_predicate_through_join".to_string(), HepBatchStrategy::once_topdown(), - vec![RuleImpl::PushPredicateThroughJoin] + vec![RuleImpl::PushPredicateThroughJoin], ) .find_best()?; @@ -366,7 +359,7 @@ mod tests { ty: LogicalType::Boolean, .. } => (), - _ => unreachable!() + _ => unreachable!(), } } else { unreachable!("Should be a filter operator") @@ -379,7 +372,7 @@ mod tests { ty: LogicalType::Boolean, .. } => (), - _ => unreachable!() + _ => unreachable!(), } } else { unreachable!("Should be a filter operator") @@ -390,18 +383,19 @@ mod tests { #[tokio::test] async fn test_push_predicate_through_join_in_inner_join() -> Result<(), DatabaseError> { - let plan = select_sql_run("select * from t1 inner join t2 on c1 = c3 where c1 > 1 and c3 < 2").await?; + let plan = + select_sql_run("select * from t1 inner join t2 on c1 = c3 where c1 > 1 and c3 < 2") + .await?; let best_plan = HepOptimizer::new(plan) .batch( "test_push_predicate_through_join".to_string(), HepBatchStrategy::once_topdown(), - vec![RuleImpl::PushPredicateThroughJoin] + vec![RuleImpl::PushPredicateThroughJoin], ) .find_best()?; if let Operator::Join(_) = &best_plan.childrens[0].operator { - } else { unreachable!("Should be a filter operator") } @@ -413,7 +407,7 @@ mod tests { ty: LogicalType::Boolean, .. } => (), - _ => unreachable!() + _ => unreachable!(), } } else { unreachable!("Should be a filter operator") @@ -426,7 +420,7 @@ mod tests { ty: LogicalType::Boolean, .. } => (), - _ => unreachable!() + _ => unreachable!(), } } else { unreachable!("Should be a filter operator") @@ -434,4 +428,4 @@ mod tests { Ok(()) } -} \ No newline at end of file +} diff --git a/src/optimizer/rule/simplification.rs b/src/optimizer/rule/simplification.rs index 12267d31..fc5e324d 100644 --- a/src/optimizer/rule/simplification.rs +++ b/src/optimizer/rule/simplification.rs @@ -1,10 +1,10 @@ -use lazy_static::lazy_static; use crate::optimizer::core::opt_expr::OptExprNode; use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate}; use crate::optimizer::core::rule::Rule; use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId}; use crate::optimizer::OptimizerError; use crate::planner::operator::Operator; +use lazy_static::lazy_static; lazy_static! { static ref SIMPLIFY_FILTER_RULE: Pattern = { Pattern { @@ -31,7 +31,7 @@ impl Rule for SimplifyFilter { graph.replace_node( node_id, - OptExprNode::OperatorRef(Operator::Filter(filter_op)) + OptExprNode::OperatorRef(Operator::Filter(filter_op)), ) } @@ -41,21 +41,21 @@ impl Rule for SimplifyFilter { #[cfg(test)] mod test { - use std::collections::Bound; - use std::sync::Arc; use crate::binder::test::select_sql_run; use crate::catalog::{ColumnCatalog, ColumnDesc}; use crate::db::DatabaseError; - use crate::expression::{BinaryOperator, ScalarExpression, UnaryOperator}; use crate::expression::simplify::ConstantBinary; + use crate::expression::{BinaryOperator, ScalarExpression, UnaryOperator}; use crate::optimizer::heuristic::batch::HepBatchStrategy; use crate::optimizer::heuristic::optimizer::HepOptimizer; use crate::optimizer::rule::RuleImpl; - use crate::planner::LogicalPlan; use crate::planner::operator::filter::FilterOperator; use crate::planner::operator::Operator; - use crate::types::LogicalType; + use crate::planner::LogicalPlan; use crate::types::value::DataValue; + use crate::types::LogicalType; + use std::collections::Bound; + use std::sync::Arc; #[tokio::test] async fn test_simplify_filter_single_column() -> Result<(), DatabaseError> { @@ -88,11 +88,14 @@ mod test { .batch( "test_simplify_filter".to_string(), HepBatchStrategy::once_topdown(), - vec![RuleImpl::SimplifyFilter] + vec![RuleImpl::SimplifyFilter], ) .find_best()?; if let Operator::Filter(filter_op) = best_plan.childrens[0].clone().operator { - println!("{expr}: {:#?}", filter_op.predicate.convert_binary(&0).unwrap()); + println!( + "{expr}: {:#?}", + filter_op.predicate.convert_binary(&0).unwrap() + ); Ok(filter_op.predicate.convert_binary(&0).unwrap()) } else { @@ -129,18 +132,14 @@ mod test { .batch( "test_simplify_filter".to_string(), HepBatchStrategy::once_topdown(), - vec![RuleImpl::SimplifyFilter] + vec![RuleImpl::SimplifyFilter], ) .find_best()?; if let Operator::Filter(filter_op) = best_plan.childrens[0].clone().operator { let c1_col = ColumnCatalog { - id: Some( - 0, - ), + id: Some(0), name: "c1".to_string(), - table_name: Some( - Arc::new("t1".to_string()), - ), + table_name: Some(Arc::new("t1".to_string())), nullable: false, desc: ColumnDesc { column_datatype: LogicalType::Integer, @@ -150,13 +149,9 @@ mod test { ref_expr: None, }; let c2_col = ColumnCatalog { - id: Some( - 1, - ), + id: Some(1), name: "c2".to_string(), - table_name: Some( - Arc::new("t1".to_string()), - ), + table_name: Some(Arc::new("t1".to_string())), nullable: false, desc: ColumnDesc { column_datatype: LogicalType::Integer, @@ -176,7 +171,9 @@ mod test { expr: Box::new(ScalarExpression::Binary { op: BinaryOperator::Plus, left_expr: Box::new(ScalarExpression::ColumnRef(Arc::new(c1_col))), - right_expr: Box::new(ScalarExpression::Constant(Arc::new(DataValue::Int32(Some(1))))), + right_expr: Box::new(ScalarExpression::Constant(Arc::new( + DataValue::Int32(Some(1)) + ))), ty: LogicalType::Integer, }), ty: LogicalType::Integer, @@ -195,9 +192,11 @@ mod test { #[tokio::test] async fn test_simplify_filter_multiple_column() -> Result<(), DatabaseError> { // c1 + 1 < -1 => c1 < -2 - let plan_1 = select_sql_run("select * from t1 where -(c1 + 1) > 1 and -(1 - c2) > 1").await?; + let plan_1 = + select_sql_run("select * from t1 where -(c1 + 1) > 1 and -(1 - c2) > 1").await?; // 1 - c1 < -1 => c1 > 2 - let plan_2 = select_sql_run("select * from t1 where -(1 - c1) > 1 and -(c2 + 1) > 1").await?; + let plan_2 = + select_sql_run("select * from t1 where -(1 - c1) > 1 and -(c2 + 1) > 1").await?; // c1 < -1 let plan_3 = select_sql_run("select * from t1 where -c1 > 1 and c2 + 1 > 1").await?; // c1 > 0 @@ -208,7 +207,7 @@ mod test { .batch( "test_simplify_filter".to_string(), HepBatchStrategy::once_topdown(), - vec![RuleImpl::SimplifyFilter] + vec![RuleImpl::SimplifyFilter], ) .find_best()?; if let Operator::Filter(filter_op) = best_plan.childrens[0].clone().operator { @@ -227,59 +226,83 @@ mod test { let cb_1_c1 = op_1.predicate.convert_binary(&0).unwrap(); println!("op_1 => c1: {:#?}", cb_1_c1); - assert_eq!(cb_1_c1, Some(ConstantBinary::Scope { - min: Bound::Unbounded, - max: Bound::Excluded(Arc::new(DataValue::Int32(Some(-2)))) - })); + assert_eq!( + cb_1_c1, + Some(ConstantBinary::Scope { + min: Bound::Unbounded, + max: Bound::Excluded(Arc::new(DataValue::Int32(Some(-2)))) + }) + ); let cb_1_c2 = op_1.predicate.convert_binary(&1).unwrap(); println!("op_1 => c2: {:#?}", cb_1_c2); - assert_eq!(cb_1_c2, Some(ConstantBinary::Scope { - min: Bound::Excluded(Arc::new(DataValue::Int32(Some(2)))), - max: Bound::Unbounded - })); + assert_eq!( + cb_1_c2, + Some(ConstantBinary::Scope { + min: Bound::Excluded(Arc::new(DataValue::Int32(Some(2)))), + max: Bound::Unbounded + }) + ); let cb_2_c1 = op_2.predicate.convert_binary(&0).unwrap(); println!("op_2 => c1: {:#?}", cb_2_c1); - assert_eq!(cb_2_c1, Some(ConstantBinary::Scope { - min: Bound::Excluded(Arc::new(DataValue::Int32(Some(2)))), - max: Bound::Unbounded - })); + assert_eq!( + cb_2_c1, + Some(ConstantBinary::Scope { + min: Bound::Excluded(Arc::new(DataValue::Int32(Some(2)))), + max: Bound::Unbounded + }) + ); let cb_2_c2 = op_2.predicate.convert_binary(&1).unwrap(); println!("op_2 => c2: {:#?}", cb_2_c2); - assert_eq!(cb_1_c1, Some(ConstantBinary::Scope { - min: Bound::Unbounded, - max: Bound::Excluded(Arc::new(DataValue::Int32(Some(-2)))) - })); + assert_eq!( + cb_1_c1, + Some(ConstantBinary::Scope { + min: Bound::Unbounded, + max: Bound::Excluded(Arc::new(DataValue::Int32(Some(-2)))) + }) + ); let cb_3_c1 = op_3.predicate.convert_binary(&0).unwrap(); println!("op_3 => c1: {:#?}", cb_3_c1); - assert_eq!(cb_3_c1, Some(ConstantBinary::Scope { - min: Bound::Unbounded, - max: Bound::Excluded(Arc::new(DataValue::Int32(Some(-1)))) - })); + assert_eq!( + cb_3_c1, + Some(ConstantBinary::Scope { + min: Bound::Unbounded, + max: Bound::Excluded(Arc::new(DataValue::Int32(Some(-1)))) + }) + ); let cb_3_c2 = op_3.predicate.convert_binary(&1).unwrap(); println!("op_3 => c2: {:#?}", cb_3_c2); - assert_eq!(cb_3_c2, Some(ConstantBinary::Scope { - min: Bound::Excluded(Arc::new(DataValue::Int32(Some(0)))), - max: Bound::Unbounded - })); + assert_eq!( + cb_3_c2, + Some(ConstantBinary::Scope { + min: Bound::Excluded(Arc::new(DataValue::Int32(Some(0)))), + max: Bound::Unbounded + }) + ); let cb_4_c1 = op_4.predicate.convert_binary(&0).unwrap(); println!("op_4 => c1: {:#?}", cb_4_c1); - assert_eq!(cb_4_c1, Some(ConstantBinary::Scope { - min: Bound::Excluded(Arc::new(DataValue::Int32(Some(0)))), - max: Bound::Unbounded - })); + assert_eq!( + cb_4_c1, + Some(ConstantBinary::Scope { + min: Bound::Excluded(Arc::new(DataValue::Int32(Some(0)))), + max: Bound::Unbounded + }) + ); let cb_4_c2 = op_4.predicate.convert_binary(&1).unwrap(); println!("op_4 => c2: {:#?}", cb_4_c2); - assert_eq!(cb_4_c2, Some(ConstantBinary::Scope { - min: Bound::Unbounded, - max: Bound::Excluded(Arc::new(DataValue::Int32(Some(-1)))) - })); + assert_eq!( + cb_4_c2, + Some(ConstantBinary::Scope { + min: Bound::Unbounded, + max: Bound::Excluded(Arc::new(DataValue::Int32(Some(-1)))) + }) + ); Ok(()) } @@ -294,7 +317,7 @@ mod test { .batch( "test_simplify_filter".to_string(), HepBatchStrategy::once_topdown(), - vec![RuleImpl::SimplifyFilter] + vec![RuleImpl::SimplifyFilter], ) .find_best()?; if let Operator::Filter(filter_op) = best_plan.childrens[0].clone().operator { @@ -314,4 +337,4 @@ mod test { Ok(()) } -} \ No newline at end of file +} diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 1b0af142..deaa0fdf 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -1,5 +1,5 @@ -use sqlparser::{ast::Statement, dialect::PostgreSqlDialect, parser::Parser}; use sqlparser::parser::ParserError; +use sqlparser::{ast::Statement, dialect::PostgreSqlDialect, parser::Parser}; /// Parse a string to a collection of statements. /// diff --git a/src/planner/mod.rs b/src/planner/mod.rs index 2a492aa8..4b572cd2 100644 --- a/src/planner/mod.rs +++ b/src/planner/mod.rs @@ -10,7 +10,6 @@ pub struct LogicalPlan { impl LogicalPlan { pub fn child(&self, index: usize) -> Option<&LogicalPlan> { - self.childrens - .get(index) + self.childrens.get(index) } } diff --git a/src/planner/operator/aggregate.rs b/src/planner/operator/aggregate.rs index 7a6e3b8a..acdeeb19 100644 --- a/src/planner/operator/aggregate.rs +++ b/src/planner/operator/aggregate.rs @@ -1,8 +1,5 @@ -use crate::{ - expression::ScalarExpression, - planner::operator::Operator, -}; use crate::planner::LogicalPlan; +use crate::{expression::ScalarExpression, planner::operator::Operator}; #[derive(Debug, PartialEq, Clone)] pub struct AggregateOperator { diff --git a/src/planner/operator/delete.rs b/src/planner/operator/delete.rs index 00d48967..09bb023e 100644 --- a/src/planner/operator/delete.rs +++ b/src/planner/operator/delete.rs @@ -3,4 +3,4 @@ use crate::catalog::TableName; #[derive(Debug, PartialEq, Clone)] pub struct DeleteOperator { pub table_name: TableName, -} \ No newline at end of file +} diff --git a/src/planner/operator/drop_table.rs b/src/planner/operator/drop_table.rs index 9e5aa833..b343a457 100644 --- a/src/planner/operator/drop_table.rs +++ b/src/planner/operator/drop_table.rs @@ -4,4 +4,4 @@ use crate::catalog::TableName; pub struct DropTableOperator { /// Table name to insert to pub table_name: TableName, -} \ No newline at end of file +} diff --git a/src/planner/operator/filter.rs b/src/planner/operator/filter.rs index 3b0b3d23..6a1f5a12 100644 --- a/src/planner/operator/filter.rs +++ b/src/planner/operator/filter.rs @@ -12,11 +12,7 @@ pub struct FilterOperator { } impl FilterOperator { - pub fn new( - predicate: ScalarExpression, - children: LogicalPlan, - having: bool, - ) -> LogicalPlan { + pub fn new(predicate: ScalarExpression, children: LogicalPlan, having: bool) -> LogicalPlan { LogicalPlan { operator: Operator::Filter(FilterOperator { predicate, having }), childrens: vec![children], diff --git a/src/planner/operator/insert.rs b/src/planner/operator/insert.rs index 8d324242..0c2fa242 100644 --- a/src/planner/operator/insert.rs +++ b/src/planner/operator/insert.rs @@ -4,4 +4,4 @@ use crate::catalog::TableName; pub struct InsertOperator { pub table_name: TableName, pub is_overwrite: bool, -} \ No newline at end of file +} diff --git a/src/planner/operator/limit.rs b/src/planner/operator/limit.rs index 0eae30f1..29d06b71 100644 --- a/src/planner/operator/limit.rs +++ b/src/planner/operator/limit.rs @@ -9,7 +9,7 @@ pub struct LimitOperator { } impl LimitOperator { -pub fn new(offset: usize, limit: usize, children: LogicalPlan) -> LogicalPlan { + pub fn new(offset: usize, limit: usize, children: LogicalPlan) -> LogicalPlan { LogicalPlan { operator: Operator::Limit(LimitOperator { offset, limit }), childrens: vec![children], diff --git a/src/planner/operator/mod.rs b/src/planner/operator/mod.rs index 37f158f8..6525393f 100644 --- a/src/planner/operator/mod.rs +++ b/src/planner/operator/mod.rs @@ -1,22 +1,22 @@ pub mod aggregate; pub mod create_table; +pub mod delete; +pub mod drop_table; pub mod filter; +pub mod insert; pub mod join; pub mod limit; pub mod project; pub mod scan; +pub mod show; pub mod sort; -pub mod insert; -pub mod values; -pub mod update; -pub mod delete; -pub mod drop_table; pub mod truncate; pub mod show; pub mod copy_from_file; pub mod copy_to_file; +pub mod update; +pub mod values; -use itertools::Itertools; use crate::catalog::ColumnRef; use crate::expression::ScalarExpression; use crate::planner::operator::copy_from_file::CopyFromFileOperator; @@ -30,6 +30,7 @@ use crate::planner::operator::show::ShowTablesOperator; use crate::planner::operator::truncate::TruncateOperator; use crate::planner::operator::update::UpdateOperator; use crate::planner::operator::values::ValuesOperator; +use itertools::Itertools; use self::{ aggregate::AggregateOperator, filter::FilterOperator, join::JoinOperator, limit::LimitOperator, @@ -66,53 +67,48 @@ pub enum Operator { impl Operator { pub fn project_input_refs(&self) -> Vec { match self { - Operator::Project(op) => { - op.columns - .iter() - .map(ScalarExpression::unpack_alias) - .filter(|expr| matches!(expr, ScalarExpression::InputRef { .. })) - .sorted_by_key(|expr| match expr { - ScalarExpression::InputRef { index, .. } => index, - _ => unreachable!() - }) - .cloned() - .collect_vec() - } + Operator::Project(op) => op + .columns + .iter() + .map(ScalarExpression::unpack_alias) + .filter(|expr| matches!(expr, ScalarExpression::InputRef { .. })) + .sorted_by_key(|expr| match expr { + ScalarExpression::InputRef { index, .. } => index, + _ => unreachable!(), + }) + .cloned() + .collect_vec(), _ => vec![], } } pub fn agg_mapping_col_refs(&self, input_refs: &[ScalarExpression]) -> Vec { match self { - Operator::Aggregate(AggregateOperator { agg_calls, .. }) => { - input_refs.iter() - .filter_map(|expr| { - if let ScalarExpression::InputRef { index, .. } = expr { - Some(agg_calls[*index].clone()) - } else { - None - } - }) - .map(|expr| expr.referenced_columns()) - .flatten() - .collect_vec() - } + Operator::Aggregate(AggregateOperator { agg_calls, .. }) => input_refs + .iter() + .filter_map(|expr| { + if let ScalarExpression::InputRef { index, .. } = expr { + Some(agg_calls[*index].clone()) + } else { + None + } + }) + .map(|expr| expr.referenced_columns()) + .flatten() + .collect_vec(), _ => vec![], } } pub fn referenced_columns(&self) -> Vec { match self { - Operator::Aggregate(op) => { - op.agg_calls - .iter() - .chain(op.groupby_exprs.iter()) - .flat_map(|expr| expr.referenced_columns()) - .collect_vec() - } - Operator::Filter(op) => { - op.predicate.referenced_columns() - } + Operator::Aggregate(op) => op + .agg_calls + .iter() + .chain(op.groupby_exprs.iter()) + .flat_map(|expr| expr.referenced_columns()) + .collect_vec(), + Operator::Filter(op) => op.predicate.referenced_columns(), Operator::Join(op) => { let mut exprs = Vec::new(); @@ -129,27 +125,23 @@ impl Operator { exprs } - Operator::Project(op) => { - op.columns - .iter() - .flat_map(|expr| expr.referenced_columns()) - .collect_vec() - } - Operator::Scan(op) => { - op.columns.iter() - .flat_map(|expr| expr.referenced_columns()) - .collect_vec() - } - Operator::Sort(op) => { - op.sort_fields - .iter() - .map(|field| &field.expr) - .flat_map(|expr| expr.referenced_columns()) - .collect_vec() - } - Operator::Values(op) => { - op.columns.clone() - } + Operator::Project(op) => op + .columns + .iter() + .flat_map(|expr| expr.referenced_columns()) + .collect_vec(), + Operator::Scan(op) => op + .columns + .iter() + .flat_map(|expr| expr.referenced_columns()) + .collect_vec(), + Operator::Sort(op) => op + .sort_fields + .iter() + .map(|field| &field.expr) + .flat_map(|expr| expr.referenced_columns()) + .collect_vec(), + Operator::Values(op) => op.columns.clone(), _ => vec![], } } diff --git a/src/planner/operator/scan.rs b/src/planner/operator/scan.rs index d40d2d1f..a16ddbc7 100644 --- a/src/planner/operator/scan.rs +++ b/src/planner/operator/scan.rs @@ -1,10 +1,10 @@ -use itertools::Itertools; use crate::catalog::{TableCatalog, TableName}; -use crate::expression::ScalarExpression; use crate::expression::simplify::ConstantBinary; +use crate::expression::ScalarExpression; use crate::planner::LogicalPlan; use crate::storage::Bounds; use crate::types::index::IndexMetaRef; +use itertools::Itertools; use super::Operator; diff --git a/src/planner/operator/show.rs b/src/planner/operator/show.rs index 5b65bb9a..9d726726 100644 --- a/src/planner/operator/show.rs +++ b/src/planner/operator/show.rs @@ -1,2 +1,2 @@ #[derive(Debug, PartialEq, Clone)] -pub struct ShowTablesOperator {} \ No newline at end of file +pub struct ShowTablesOperator {} diff --git a/src/planner/operator/truncate.rs b/src/planner/operator/truncate.rs index 01a183ce..1b63a5d0 100644 --- a/src/planner/operator/truncate.rs +++ b/src/planner/operator/truncate.rs @@ -4,4 +4,4 @@ use crate::catalog::TableName; pub struct TruncateOperator { /// Table name to insert to pub table_name: TableName, -} \ No newline at end of file +} diff --git a/src/planner/operator/update.rs b/src/planner/operator/update.rs index 829a1fdd..ed37c72d 100644 --- a/src/planner/operator/update.rs +++ b/src/planner/operator/update.rs @@ -3,4 +3,4 @@ use crate::catalog::TableName; #[derive(Debug, PartialEq, Clone)] pub struct UpdateOperator { pub table_name: TableName, -} \ No newline at end of file +} diff --git a/src/planner/operator/values.rs b/src/planner/operator/values.rs index fbbc8ce9..1a88c753 100644 --- a/src/planner/operator/values.rs +++ b/src/planner/operator/values.rs @@ -4,5 +4,5 @@ use crate::types::value::ValueRef; #[derive(Debug, PartialEq, Clone)] pub struct ValuesOperator { pub rows: Vec>, - pub columns: Vec -} \ No newline at end of file + pub columns: Vec, +} diff --git a/src/storage/kip.rs b/src/storage/kip.rs index 50d6008d..e9babe86 100644 --- a/src/storage/kip.rs +++ b/src/storage/kip.rs @@ -1,29 +1,31 @@ -use std::collections::{Bound, VecDeque}; -use std::collections::hash_map::RandomState; -use std::mem; -use std::ops::SubAssign; -use std::path::PathBuf; -use std::sync::Arc; -use async_trait::async_trait; -use kip_db::kernel::lsm::mvcc::TransactionIter; -use kip_db::kernel::lsm::{mvcc, storage}; -use kip_db::kernel::lsm::iterator::Iter as KipDBIter; -use kip_db::kernel::lsm::storage::Config; -use kip_db::kernel::Storage as KipDBStorage; -use kip_db::kernel::utils::lru_cache::ShardingLruCache; use crate::catalog::{ColumnCatalog, TableCatalog, TableName}; use crate::expression::simplify::ConstantBinary; -use crate::storage::{Bounds, Projections, Storage, StorageError, Transaction, Iter, tuple_projection, IndexIter}; use crate::storage::table_codec::TableCodec; +use crate::storage::{ + tuple_projection, Bounds, IndexIter, Iter, Projections, Storage, StorageError, Transaction, +}; use crate::types::errors::TypeError; use crate::types::index::{Index, IndexMeta, IndexMetaRef}; use crate::types::tuple::{Tuple, TupleId}; use crate::types::value::ValueRef; +use async_trait::async_trait; +use kip_db::kernel::lsm::iterator::Iter as KipDBIter; +use kip_db::kernel::lsm::mvcc::TransactionIter; +use kip_db::kernel::lsm::storage::Config; +use kip_db::kernel::lsm::{mvcc, storage}; +use kip_db::kernel::utils::lru_cache::ShardingLruCache; +use kip_db::kernel::Storage as KipDBStorage; +use std::collections::hash_map::RandomState; +use std::collections::{Bound, VecDeque}; +use std::mem; +use std::ops::SubAssign; +use std::path::PathBuf; +use std::sync::Arc; #[derive(Clone)] pub struct KipStorage { cache: Arc>, - pub inner: Arc + pub inner: Arc, } impl KipStorage { @@ -32,18 +34,18 @@ impl KipStorage { let storage = storage::KipStorage::open_with_config(config).await?; Ok(KipStorage { - cache: Arc::new(ShardingLruCache::new( - 32, - 16, - RandomState::default(), - )?), + cache: Arc::new(ShardingLruCache::new(32, 16, RandomState::default())?), inner: Arc::new(storage), }) } - fn column_collect(name: &String, tx: &mvcc::Transaction) -> Result<(Vec, Option), StorageError> { + fn column_collect( + name: &String, + tx: &mvcc::Transaction, + ) -> Result<(Vec, Option), StorageError> { let (column_min, column_max) = TableCodec::columns_bound(name); - let mut column_iter = tx.iter(Bound::Included(&column_min), Bound::Included(&column_max))?; + let mut column_iter = + tx.iter(Bound::Included(&column_min), Bound::Included(&column_max))?; let mut columns = vec![]; let mut name_option = None; @@ -67,7 +69,9 @@ impl KipStorage { fn index_meta_collect(name: &String, tx: &mvcc::Transaction) -> Option> { let (index_min, index_max) = TableCodec::index_meta_bound(name); let mut index_metas = vec![]; - let mut index_iter = tx.iter(Bound::Included(&index_min), Bound::Included(&index_max)).ok()?; + let mut index_iter = tx + .iter(Bound::Included(&index_min), Bound::Included(&index_max)) + .ok()?; while let Some((_, value_option)) = index_iter.try_next().ok().flatten() { if let Some(value) = value_option { @@ -81,7 +85,9 @@ impl KipStorage { } fn _drop_data(table: &mut KipTransaction, min: &[u8], max: &[u8]) -> Result<(), StorageError> { - let mut iter = table.tx.iter(Bound::Included(&min), Bound::Included(&max))?; + let mut iter = table + .tx + .iter(Bound::Included(&min), Bound::Included(&max))?; let mut data_keys = vec![]; while let Some((key, value_option)) = iter.try_next()? { @@ -100,11 +106,12 @@ impl KipStorage { fn create_index_meta_for_table( tx: &mut mvcc::Transaction, - table: &mut TableCatalog + table: &mut TableCatalog, ) -> Result<(), StorageError> { let table_name = table.name.clone(); - for col in table.all_columns() + for col in table + .all_columns() .into_iter() .filter(|col| col.desc.is_unique) { @@ -129,7 +136,11 @@ impl KipStorage { impl Storage for KipStorage { type TransactionType = KipTransaction; - async fn create_table(&self, table_name: TableName, columns: Vec) -> Result { + async fn create_table( + &self, + table_name: TableName, + columns: Vec, + ) -> Result { let mut tx = self.inner.new_transaction().await; let mut table_catalog = TableCatalog::new(table_name.clone(), columns)?; @@ -140,7 +151,7 @@ impl Storage for KipStorage { tx.set(key, value); } - let (k, v)= TableCodec::encode_root_table(&table_name)?; + let (k, v) = TableCodec::encode_root_table(&table_name)?; self.inner.set(k, v).await?; tx.commit().await?; @@ -157,7 +168,7 @@ impl Storage for KipStorage { let mut iter = tx.iter(Bound::Included(&min), Bound::Included(&max))?; let mut col_keys = vec![]; - while let Some((key, value_option)) = iter.try_next()? { + while let Some((key, value_option)) = iter.try_next()? { if value_option.is_some() { col_keys.push(key); } @@ -177,7 +188,6 @@ impl Storage for KipStorage { async fn drop_data(&self, name: &String) -> Result<(), StorageError> { if let Some(mut transaction) = self.transaction(name).await { - let (tuple_min, tuple_max) = transaction.table_codec.tuple_bound(); Self::_drop_data(&mut transaction, &tuple_min, &tuple_max)?; @@ -191,12 +201,12 @@ impl Storage for KipStorage { } async fn transaction(&self, name: &String) -> Option { - let table_codec = self.table(name) - .await - .map(|catalog| TableCodec { table: catalog.clone() })?; + let table_codec = self.table(name).await.map(|catalog| TableCodec { + table: catalog.clone(), + })?; let tx = self.inner.new_transaction().await; - Some(KipTransaction { table_codec, tx, }) + Some(KipTransaction { table_codec, tx }) } async fn table(&self, name: &String) -> Option<&TableCatalog> { @@ -208,10 +218,13 @@ impl Storage for KipStorage { let (columns, name_option) = Self::column_collect(name, &tx).ok()?; let indexes = Self::index_meta_collect(name, &tx)?; - if let Some(catalog) = name_option - .and_then(|table_name| TableCatalog::new_with_indexes(table_name, columns, indexes).ok()) - { - option = self.cache.get_or_insert(name.to_string(), |_| Ok(catalog)).ok(); + if let Some(catalog) = name_option.and_then(|table_name| { + TableCatalog::new_with_indexes(table_name, columns, indexes).ok() + }) { + option = self + .cache + .get_or_insert(name.to_string(), |_| Ok(catalog)) + .ok(); } } @@ -225,7 +238,7 @@ impl Storage for KipStorage { let tx = self.inner.new_transaction().await; let mut iter = tx.iter(Bound::Included(&min), Bound::Included(&max))?; - while let Some((_, value_option)) = iter.try_next().ok().flatten() { + while let Some((_, value_option)) = iter.try_next().ok().flatten() { if let Some(value) = value_option { let table_name = TableCodec::decode_root_table(&value)?; @@ -239,14 +252,18 @@ impl Storage for KipStorage { pub struct KipTransaction { table_codec: TableCodec, - tx: mvcc::Transaction + tx: mvcc::Transaction, } #[async_trait] impl Transaction for KipTransaction { type IterType<'a> = KipIter<'a>; - fn read(&self, bounds: Bounds, projections: Projections) -> Result, StorageError> { + fn read( + &self, + bounds: Bounds, + projections: Projections, + ) -> Result, StorageError> { let (min, max) = self.table_codec.tuple_bound(); let iter = self.tx.iter(Bound::Included(&min), Bound::Included(&max))?; @@ -264,7 +281,7 @@ impl Transaction for KipTransaction { (offset_option, mut limit_option): Bounds, projections: Projections, index_meta: IndexMetaRef, - binaries: Vec + binaries: Vec, ) -> Result, StorageError> { let mut tuple_ids = Vec::new(); let mut offset = offset_option.unwrap_or(0); @@ -281,11 +298,15 @@ impl Transaction for KipTransaction { while let Some((_, value_option)) = iter.try_next()? { if let Some(value) = value_option { for id in TableCodec::decode_index(&value)? { - if Self::offset_move(&mut offset) { continue; } + if Self::offset_move(&mut offset) { + continue; + } tuple_ids.push(id); - if Self::limit_move(&mut limit_option) { break; } + if Self::limit_move(&mut limit_option) { + break; + } } } @@ -295,7 +316,9 @@ impl Transaction for KipTransaction { } } ConstantBinary::Eq(val) => { - if Self::offset_move(&mut offset) { continue; } + if Self::offset_move(&mut offset) { + continue; + } let key = self.val_to_key(&index_meta, val)?; @@ -305,7 +328,7 @@ impl Transaction for KipTransaction { let _ = Self::limit_move(&mut limit_option); } - _ => () + _ => (), } } @@ -317,7 +340,12 @@ impl Transaction for KipTransaction { }) } - fn add_index(&mut self, index: Index, tuple_ids: Vec, is_unique: bool) -> Result<(), StorageError> { + fn add_index( + &mut self, + index: Index, + tuple_ids: Vec, + is_unique: bool, + ) -> Result<(), StorageError> { let (key, value) = self.table_codec.encode_index(&index, &tuple_ids)?; if let Some(bytes) = self.tx.get(&key)? { @@ -383,17 +411,13 @@ impl KipTransaction { &self, index_meta: &IndexMetaRef, min: Bound, - max: Bound + max: Bound, ) -> Result { let bound_encode = |bound: Bound| -> Result<_, StorageError> { match bound { - Bound::Included(val) => { - Ok(Bound::Included(self.val_to_key(&index_meta, val)?)) - }, - Bound::Excluded(val) => { - Ok(Bound::Excluded(self.val_to_key(&index_meta, val)?)) - } - Bound::Unbounded => Ok(Bound::Unbounded) + Bound::Included(val) => Ok(Bound::Included(self.val_to_key(&index_meta, val)?)), + Bound::Excluded(val) => Ok(Bound::Excluded(self.val_to_key(&index_meta, val)?)), + Bound::Unbounded => Ok(Bound::Unbounded), } }; let check_bound = |value: &mut Bound>, bound: Vec| { @@ -416,7 +440,7 @@ impl KipTransaction { } fn offset_move(offset: &mut usize) -> bool { - if *offset > 0 { + if *offset > 0 { offset.sub_assign(1); true @@ -441,7 +465,7 @@ pub struct KipIter<'a> { limit: Option, projections: Projections, table_codec: &'a TableCodec, - iter: TransactionIter<'a> + iter: TransactionIter<'a>, } impl Iter for KipIter<'_> { @@ -462,10 +486,10 @@ impl Iter for KipIter<'_> { let tuple = tuple_projection( &mut self.limit, &self.projections, - self.table_codec.decode_tuple(&value) + self.table_codec.decode_tuple(&value), )?; - return Ok(Some(tuple)) + return Ok(Some(tuple)); } } @@ -475,20 +499,20 @@ impl Iter for KipIter<'_> { #[cfg(test)] mod test { - use std::collections::{Bound, VecDeque}; - use std::sync::Arc; - use itertools::Itertools; - use tempfile::TempDir; use crate::catalog::{ColumnCatalog, ColumnDesc}; use crate::db::{Database, DatabaseError}; - use crate::expression::ScalarExpression; use crate::expression::simplify::ConstantBinary; + use crate::expression::ScalarExpression; use crate::storage::kip::KipStorage; - use crate::storage::{Storage, StorageError, Iter, Transaction, IndexIter}; use crate::storage::memory::test::data_filling; use crate::storage::table_codec::TableCodec; - use crate::types::LogicalType; + use crate::storage::{IndexIter, Iter, Storage, StorageError, Transaction}; use crate::types::value::DataValue; + use crate::types::LogicalType; + use itertools::Itertools; + use std::collections::{Bound, VecDeque}; + use std::sync::Arc; + use tempfile::TempDir; #[tokio::test] async fn test_in_kipdb_storage_works_with_data() -> Result<(), StorageError> { @@ -499,35 +523,47 @@ mod test { "c1".to_string(), false, ColumnDesc::new(LogicalType::Integer, true, false), - None + None, )), Arc::new(ColumnCatalog::new( "c2".to_string(), false, ColumnDesc::new(LogicalType::Boolean, false, false), - None + None, )), ]; - let source_columns = columns.iter() + let source_columns = columns + .iter() .map(|col_ref| ColumnCatalog::clone(&col_ref)) .collect_vec(); - let table_id = storage.create_table(Arc::new("test".to_string()), source_columns).await?; + let table_id = storage + .create_table(Arc::new("test".to_string()), source_columns) + .await?; let table_catalog = storage.table(&"test".to_string()).await; assert!(table_catalog.is_some()); - assert!(table_catalog.unwrap().get_column_id_by_name(&"c1".to_string()).is_some()); + assert!(table_catalog + .unwrap() + .get_column_id_by_name(&"c1".to_string()) + .is_some()); let mut transaction = storage.transaction(&table_id).await.unwrap(); data_filling(columns, &mut transaction)?; let mut iter = transaction.read( (Some(1), Some(1)), - vec![ScalarExpression::InputRef { index: 0, ty: LogicalType::Integer }] + vec![ScalarExpression::InputRef { + index: 0, + ty: LogicalType::Integer, + }], )?; let option_1 = iter.next_tuple()?; - assert_eq!(option_1.unwrap().id, Some(Arc::new(DataValue::Int32(Some(2))))); + assert_eq!( + option_1.unwrap().id, + Some(Arc::new(DataValue::Int32(Some(2)))) + ); let option_2 = iter.next_tuple()?; assert_eq!(option_2, None); @@ -541,16 +577,22 @@ mod test { let kipsql = Database::with_kipdb(temp_dir.path()).await?; let _ = kipsql.run("create table t1 (a int primary key)").await?; - let _ = kipsql.run("insert into t1 (a) values (0), (1), (2)").await?; + let _ = kipsql + .run("insert into t1 (a) values (0), (1), (2)") + .await?; - let table = kipsql.storage.table(&"t1".to_string()).await.unwrap().clone(); - let projections = table.all_columns() + let table = kipsql + .storage + .table(&"t1".to_string()) + .await + .unwrap() + .clone(); + let projections = table + .all_columns() .into_iter() .map(|col| ScalarExpression::ColumnRef(col)) .collect_vec(); - let codec = TableCodec { - table, - }; + let codec = TableCodec { table }; let tx = kipsql.storage.transaction(&"t1".to_string()).await.unwrap(); let tuple_ids = vec![ Arc::new(DataValue::Int32(Some(0))), @@ -579,35 +621,48 @@ mod test { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); let kipsql = Database::with_kipdb(temp_dir.path()).await?; - let _ = kipsql.run("create table t1 (a int primary key, b int unique)").await?; - let _ = kipsql.run("insert into t1 (a, b) values (0, 0), (1, 1), (2, 2)").await?; + let _ = kipsql + .run("create table t1 (a int primary key, b int unique)") + .await?; + let _ = kipsql + .run("insert into t1 (a, b) values (0, 0), (1, 1), (2, 2)") + .await?; - let table = kipsql.storage.table(&"t1".to_string()).await.unwrap().clone(); - let projections = table.all_columns() + let table = kipsql + .storage + .table(&"t1".to_string()) + .await + .unwrap() + .clone(); + let projections = table + .all_columns() .into_iter() .map(|col| ScalarExpression::ColumnRef(col)) .collect_vec(); let transaction = kipsql.storage.transaction(&"t1".to_string()).await.unwrap(); - let mut iter = transaction.read_by_index( - (Some(0), Some(1)), - projections, - table.indexes[0].clone(), - vec![ - ConstantBinary::Scope { + let mut iter = transaction + .read_by_index( + (Some(0), Some(1)), + projections, + table.indexes[0].clone(), + vec![ConstantBinary::Scope { min: Bound::Excluded(Arc::new(DataValue::Int32(Some(0)))), - max: Bound::Unbounded - } - ] - ).unwrap(); + max: Bound::Unbounded, + }], + ) + .unwrap(); while let Some(tuple) = iter.next_tuple()? { assert_eq!(tuple.id, Some(Arc::new(DataValue::Int32(Some(1))))); - assert_eq!(tuple.values, vec![ - Arc::new(DataValue::Int32(Some(1))), - Arc::new(DataValue::Int32(Some(1))) - ]) + assert_eq!( + tuple.values, + vec![ + Arc::new(DataValue::Int32(Some(1))), + Arc::new(DataValue::Int32(Some(1))) + ] + ) } Ok(()) } -} \ No newline at end of file +} diff --git a/src/storage/memory.rs b/src/storage/memory.rs index 76af970a..b1053523 100644 --- a/src/storage/memory.rs +++ b/src/storage/memory.rs @@ -1,39 +1,33 @@ -use std::cell::Cell; -use std::fmt::{Debug, Formatter}; -use std::slice; -use std::sync::Arc; -use async_trait::async_trait; use crate::catalog::{ColumnCatalog, RootCatalog, TableCatalog, TableName}; use crate::expression::simplify::ConstantBinary; -use crate::storage::{Bounds, Projections, Storage, StorageError, Transaction, Iter, tuple_projection, IndexIter}; +use crate::storage::{ + tuple_projection, Bounds, IndexIter, Iter, Projections, Storage, StorageError, Transaction, +}; use crate::types::index::{Index, IndexMetaRef}; use crate::types::tuple::{Tuple, TupleId}; +use async_trait::async_trait; +use std::cell::Cell; +use std::fmt::{Debug, Formatter}; +use std::slice; +use std::sync::Arc; // WARRING: Only single-threaded and tested using #[derive(Clone)] pub struct MemStorage { - inner: Arc> -} - -unsafe impl Send for MemStorage { - + inner: Arc>, } -unsafe impl Sync for MemStorage { +unsafe impl Send for MemStorage {} -} +unsafe impl Sync for MemStorage {} impl MemStorage { pub fn new() -> MemStorage { Self { - inner: Arc::new( - Cell::new( - StorageInner { - root: Default::default(), - tables: Default::default(), - } - ) - ), + inner: Arc::new(Cell::new(StorageInner { + root: Default::default(), + tables: Default::default(), + })), } } @@ -48,14 +42,18 @@ impl MemStorage { #[derive(Debug)] struct StorageInner { root: RootCatalog, - tables: Vec<(TableName, MemTable)> + tables: Vec<(TableName, MemTable)>, } #[async_trait] impl Storage for MemStorage { type TransactionType = MemTable; - async fn create_table(&self, table_name: TableName, columns: Vec) -> Result { + async fn create_table( + &self, + table_name: TableName, + columns: Vec, + ) -> Result { let new_table = MemTable { tuples: Arc::new(Cell::new(vec![])), }; @@ -68,12 +66,7 @@ impl Storage for MemStorage { } async fn drop_table(&self, name: &String) -> Result<(), StorageError> { - let inner = unsafe { - self.inner - .as_ptr() - .as_mut() - .unwrap() - }; + let inner = unsafe { self.inner.as_ptr().as_mut().unwrap() }; inner.root.drop_table(&name)?; @@ -81,12 +74,7 @@ impl Storage for MemStorage { } async fn drop_data(&self, name: &String) -> Result<(), StorageError> { - let inner = unsafe { - self.inner - .as_ptr() - .as_mut() - .unwrap() - }; + let inner = unsafe { self.inner.as_ptr().as_mut().unwrap() }; inner.tables.retain(|(t_name, _)| t_name.as_str() != name); @@ -107,14 +95,7 @@ impl Storage for MemStorage { } async fn table(&self, name: &String) -> Option<&TableCatalog> { - unsafe { - self.inner - .as_ptr() - .as_ref() - .unwrap() - .root - .get_table(name) - } + unsafe { self.inner.as_ptr().as_ref().unwrap().root.get_table(name) } } async fn show_tables(&self) -> Result, StorageError> { @@ -122,17 +103,13 @@ impl Storage for MemStorage { } } -unsafe impl Send for MemTable { +unsafe impl Send for MemTable {} -} - -unsafe impl Sync for MemTable { - -} +unsafe impl Sync for MemTable {} #[derive(Clone)] pub struct MemTable { - tuples: Arc>> + tuples: Arc>>, } impl Debug for MemTable { @@ -149,26 +126,39 @@ impl Debug for MemTable { impl Transaction for MemTable { type IterType<'a> = MemTraction<'a>; - fn read(&self, bounds: Bounds, projection: Projections) -> Result, StorageError> { + fn read( + &self, + bounds: Bounds, + projection: Projections, + ) -> Result, StorageError> { unsafe { - Ok( - MemTraction { - offset: bounds.0.unwrap_or(0), - limit: bounds.1, - projections: projection, - iter: self.tuples.as_ptr().as_ref().unwrap().iter(), - } - ) + Ok(MemTraction { + offset: bounds.0.unwrap_or(0), + limit: bounds.1, + projections: projection, + iter: self.tuples.as_ptr().as_ref().unwrap().iter(), + }) } } #[allow(unused_variables)] - fn read_by_index(&self, bounds: Bounds, projection: Projections, index_meta: IndexMetaRef, binaries: Vec) -> Result, StorageError> { + fn read_by_index( + &self, + bounds: Bounds, + projection: Projections, + index_meta: IndexMetaRef, + binaries: Vec, + ) -> Result, StorageError> { todo!() } #[allow(unused_variables)] - fn add_index(&mut self, index: Index, tuple_ids: Vec, is_unique: bool) -> Result<(), StorageError> { + fn add_index( + &mut self, + index: Index, + tuple_ids: Vec, + is_unique: bool, + ) -> Result<(), StorageError> { todo!() } @@ -177,11 +167,7 @@ impl Transaction for MemTable { } fn append(&mut self, tuple: Tuple, is_overwrite: bool) -> Result<(), StorageError> { - let tuples = unsafe { - self.tuples - .as_ptr() - .as_mut() - }.unwrap(); + let tuples = unsafe { self.tuples.as_ptr().as_mut() }.unwrap(); if let Some(original_tuple) = tuples.iter_mut().find(|t| t.id == tuple.id) { if !is_overwrite { @@ -196,11 +182,7 @@ impl Transaction for MemTable { } fn delete(&mut self, tuple_id: TupleId) -> Result<(), StorageError> { - let tuples = unsafe { - self.tuples - .as_ptr() - .as_mut() - }.unwrap(); + let tuples = unsafe { self.tuples.as_ptr().as_mut() }.unwrap(); tuples.retain(|tuple| tuple.id.clone().unwrap() != tuple_id); @@ -216,7 +198,7 @@ pub struct MemTraction<'a> { offset: usize, limit: Option, projections: Projections, - iter: slice::Iter<'a, Tuple> + iter: slice::Iter<'a, Tuple>, } impl Iter for MemTraction<'_> { @@ -242,33 +224,42 @@ impl Iter for MemTraction<'_> { #[cfg(test)] pub(crate) mod test { - use std::sync::Arc; - use itertools::Itertools; use crate::catalog::{ColumnCatalog, ColumnDesc, ColumnRef}; use crate::expression::ScalarExpression; use crate::storage::memory::MemStorage; - use crate::storage::{Storage, StorageError, Transaction, Iter}; - use crate::types::LogicalType; + use crate::storage::{Iter, Storage, StorageError, Transaction}; use crate::types::tuple::Tuple; use crate::types::value::DataValue; + use crate::types::LogicalType; + use itertools::Itertools; + use std::sync::Arc; - pub fn data_filling(columns: Vec, table: &mut impl Transaction) -> Result<(), StorageError> { - table.append(Tuple { - id: Some(Arc::new(DataValue::Int32(Some(1)))), - columns: columns.clone(), - values: vec![ - Arc::new(DataValue::Int32(Some(1))), - Arc::new(DataValue::Boolean(Some(true))) - ], - }, false)?; - table.append(Tuple { - id: Some(Arc::new(DataValue::Int32(Some(2)))), - columns: columns.clone(), - values: vec![ - Arc::new(DataValue::Int32(Some(2))), - Arc::new(DataValue::Boolean(Some(false))) - ], - }, false)?; + pub fn data_filling( + columns: Vec, + table: &mut impl Transaction, + ) -> Result<(), StorageError> { + table.append( + Tuple { + id: Some(Arc::new(DataValue::Int32(Some(1)))), + columns: columns.clone(), + values: vec![ + Arc::new(DataValue::Int32(Some(1))), + Arc::new(DataValue::Boolean(Some(true))), + ], + }, + false, + )?; + table.append( + Tuple { + id: Some(Arc::new(DataValue::Int32(Some(2)))), + columns: columns.clone(), + values: vec![ + Arc::new(DataValue::Int32(Some(2))), + Arc::new(DataValue::Boolean(Some(false))), + ], + }, + false, + )?; Ok(()) } @@ -281,40 +272,52 @@ pub(crate) mod test { "c1".to_string(), false, ColumnDesc::new(LogicalType::Integer, true, false), - None + None, )), Arc::new(ColumnCatalog::new( "c2".to_string(), false, ColumnDesc::new(LogicalType::Boolean, false, false), - None + None, )), ]; - let source_columns = columns.iter() + let source_columns = columns + .iter() .map(|col_ref| ColumnCatalog::clone(&col_ref)) .collect_vec(); - let table_id = storage.create_table(Arc::new("test".to_string()), source_columns).await?; + let table_id = storage + .create_table(Arc::new("test".to_string()), source_columns) + .await?; let table_catalog = storage.table(&"test".to_string()).await; assert!(table_catalog.is_some()); - assert!(table_catalog.unwrap().get_column_id_by_name(&"c1".to_string()).is_some()); + assert!(table_catalog + .unwrap() + .get_column_id_by_name(&"c1".to_string()) + .is_some()); let mut transaction = storage.transaction(&table_id).await.unwrap(); data_filling(columns, &mut transaction)?; let mut iter = transaction.read( (Some(1), Some(1)), - vec![ScalarExpression::InputRef { index: 0, ty: LogicalType::Integer }] + vec![ScalarExpression::InputRef { + index: 0, + ty: LogicalType::Integer, + }], )?; let option_1 = iter.next_tuple()?; - assert_eq!(option_1.unwrap().id, Some(Arc::new(DataValue::Int32(Some(2))))); + assert_eq!( + option_1.unwrap().id, + Some(Arc::new(DataValue::Int32(Some(2)))) + ); let option_2 = iter.next_tuple()?; assert_eq!(option_2, None); Ok(()) } -} \ No newline at end of file +} diff --git a/src/storage/mod.rs b/src/storage/mod.rs index 553e9662..a87eef1e 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -1,20 +1,20 @@ +pub mod kip; pub mod memory; mod table_codec; -pub mod kip; -use std::collections::VecDeque; -use std::ops::SubAssign; -use async_trait::async_trait; -use kip_db::error::CacheError; -use kip_db::kernel::lsm::mvcc; -use kip_db::KernelError; use crate::catalog::{CatalogError, ColumnCatalog, TableCatalog, TableName}; -use crate::expression::ScalarExpression; use crate::expression::simplify::ConstantBinary; +use crate::expression::ScalarExpression; use crate::storage::table_codec::TableCodec; use crate::types::errors::TypeError; use crate::types::index::{Index, IndexMetaRef}; use crate::types::tuple::{Tuple, TupleId}; +use async_trait::async_trait; +use kip_db::error::CacheError; +use kip_db::kernel::lsm::mvcc; +use kip_db::KernelError; +use std::collections::VecDeque; +use std::ops::SubAssign; #[async_trait] pub trait Storage: Sync + Send + Clone + 'static { @@ -23,7 +23,7 @@ pub trait Storage: Sync + Send + Clone + 'static { async fn create_table( &self, table_name: TableName, - columns: Vec + columns: Vec, ) -> Result; async fn drop_table(&self, name: &String) -> Result<(), StorageError>; @@ -57,10 +57,15 @@ pub trait Transaction: Sync + Send + 'static { bounds: Bounds, projection: Projections, index_meta: IndexMetaRef, - binaries: Vec + binaries: Vec, ) -> Result, StorageError>; - fn add_index(&mut self, index: Index, tuple_ids: Vec, is_unique: bool) -> Result<(), StorageError>; + fn add_index( + &mut self, + index: Index, + tuple_ids: Vec, + is_unique: bool, + ) -> Result<(), StorageError>; fn del_index(&mut self, index: &Index) -> Result<(), StorageError>; @@ -76,7 +81,7 @@ pub struct IndexIter<'a> { projections: Projections, table_codec: &'a TableCodec, tuple_ids: VecDeque, - tx: &'a mvcc::Transaction + tx: &'a mvcc::Transaction, } impl Iter for IndexIter<'_> { @@ -84,12 +89,16 @@ impl Iter for IndexIter<'_> { if let Some(tuple_id) = self.tuple_ids.pop_front() { let key = self.table_codec.encode_tuple_key(&tuple_id)?; - Ok(self.tx.get(&key)? - .map(|bytes| tuple_projection( - &mut None, - &self.projections, - self.table_codec.decode_tuple(&bytes) - )) + Ok(self + .tx + .get(&key)? + .map(|bytes| { + tuple_projection( + &mut None, + &self.projections, + self.table_codec.decode_tuple(&bytes), + ) + }) .transpose()?) } else { Ok(None) @@ -104,7 +113,7 @@ pub trait Iter: Sync + Send { pub(crate) fn tuple_projection( limit: &mut Option, projections: &Projections, - tuple: Tuple + tuple: Tuple, ) -> Result { let projection_len = projections.len(); let mut columns = Vec::with_capacity(projection_len); @@ -157,4 +166,4 @@ impl From for StorageError { fn from(value: CacheError) -> Self { StorageError::CacheError(value) } -} \ No newline at end of file +} diff --git a/src/storage/table_codec.rs b/src/storage/table_codec.rs index a038cb8f..a4b5098b 100644 --- a/src/storage/table_codec.rs +++ b/src/storage/table_codec.rs @@ -1,21 +1,19 @@ -use bytes::Bytes; -use lazy_static::lazy_static; use crate::catalog::{ColumnCatalog, TableCatalog, TableName}; use crate::types::errors::TypeError; use crate::types::index::{Index, IndexId, IndexMeta}; use crate::types::tuple::{Tuple, TupleId}; +use bytes::Bytes; +use lazy_static::lazy_static; const BOUND_MIN_TAG: u8 = 0; const BOUND_MAX_TAG: u8 = 1; lazy_static! { - static ref ROOT_BYTES: Vec = { - b"Root".to_vec() - }; + static ref ROOT_BYTES: Vec = { b"Root".to_vec() }; } #[derive(Clone)] pub struct TableCodec { - pub table: TableCatalog + pub table: TableCatalog, } #[derive(Copy, Clone)] @@ -32,9 +30,7 @@ impl TableCodec { /// /// Tips: Root full key = key_prefix fn key_prefix(ty: CodecType, table_name: &String) -> Vec { - let mut table_bytes = table_name - .clone() - .into_bytes(); + let mut table_bytes = table_name.clone().into_bytes(); match ty { CodecType::Column => { @@ -132,10 +128,7 @@ impl TableCodec { /// Key: TableName_Tuple_0_RowID(Sorted) /// Value: Tuple pub fn encode_tuple(&self, tuple: &Tuple) -> Result<(Bytes, Bytes), TypeError> { - let tuple_id = tuple - .id - .clone() - .ok_or(TypeError::NotNull)?; + let tuple_id = tuple.id.clone().ok_or(TypeError::NotNull)?; let key = self.encode_tuple_key(&tuple_id)?; Ok((Bytes::from(key), Bytes::from(tuple.serialize_to()))) @@ -156,12 +149,18 @@ impl TableCodec { /// Key: TableName_IndexMeta_0_IndexID /// Value: IndexMeta - pub fn encode_index_meta(name: &String, index_meta: &IndexMeta) -> Result<(Bytes, Bytes), TypeError> { + pub fn encode_index_meta( + name: &String, + index_meta: &IndexMeta, + ) -> Result<(Bytes, Bytes), TypeError> { let mut key_prefix = Self::key_prefix(CodecType::IndexMeta, &name); key_prefix.push(BOUND_MIN_TAG); key_prefix.append(&mut index_meta.id.to_be_bytes().to_vec()); - Ok((Bytes::from(key_prefix), Bytes::from(bincode::serialize(&index_meta)?))) + Ok(( + Bytes::from(key_prefix), + Bytes::from(bincode::serialize(&index_meta)?), + )) } pub fn decode_index_meta(bytes: &[u8]) -> Result { @@ -178,10 +177,17 @@ impl TableCodec { /// /// Tips: The unique index has only one ColumnID and one corresponding DataValue, /// so it can be positioned directly. - pub fn encode_index(&self, index: &Index, tuple_ids: &[TupleId]) -> Result<(Bytes, Bytes), TypeError> { + pub fn encode_index( + &self, + index: &Index, + tuple_ids: &[TupleId], + ) -> Result<(Bytes, Bytes), TypeError> { let key = self.encode_index_key(index)?; - Ok((Bytes::from(key), Bytes::from(bincode::serialize(tuple_ids)?))) + Ok(( + Bytes::from(key), + Bytes::from(bincode::serialize(tuple_ids)?), + )) } pub fn encode_index_key(&self, index: &Index) -> Result, TypeError> { @@ -221,13 +227,15 @@ impl TableCodec { Ok((column.table_name.clone().unwrap(), column)) } - /// Key: RootCatalog_0_TableName /// Value: TableName pub fn encode_root_table(table_name: &String) -> Result<(Bytes, Bytes), TypeError> { let key = Self::encode_root_table_key(table_name); - Ok((Bytes::from(key), Bytes::from(table_name.clone().into_bytes()))) + Ok(( + Bytes::from(key), + Bytes::from(table_name.clone().into_bytes()), + )) } pub fn encode_root_table_key(table_name: &String) -> Vec { @@ -241,19 +249,19 @@ impl TableCodec { #[cfg(test)] mod tests { - use std::collections::BTreeSet; - use std::ops::Bound; - use std::sync::Arc; - use bytes::Bytes; - use itertools::Itertools; - use rust_decimal::Decimal; use crate::catalog::{ColumnCatalog, ColumnDesc, TableCatalog}; use crate::storage::table_codec::TableCodec; use crate::types::errors::TypeError; use crate::types::index::{Index, IndexMeta}; - use crate::types::LogicalType; use crate::types::tuple::Tuple; use crate::types::value::DataValue; + use crate::types::LogicalType; + use bytes::Bytes; + use itertools::Itertools; + use rust_decimal::Decimal; + use std::collections::BTreeSet; + use std::ops::Bound; + use std::sync::Arc; fn build_table_codec() -> (TableCatalog, TableCodec) { let columns = vec![ @@ -261,17 +269,19 @@ mod tests { "c1".into(), false, ColumnDesc::new(LogicalType::Integer, true, false), - None + None, ), ColumnCatalog::new( "c2".into(), false, - ColumnDesc::new(LogicalType::Decimal(None,None), false, false), - None + ColumnDesc::new(LogicalType::Decimal(None, None), false, false), + None, ), ]; let table_catalog = TableCatalog::new(Arc::new("t1".to_string()), columns).unwrap(); - let codec = TableCodec { table: table_catalog.clone() }; + let codec = TableCodec { + table: table_catalog.clone(), + }; (table_catalog, codec) } @@ -285,7 +295,7 @@ mod tests { values: vec![ Arc::new(DataValue::Int32(Some(0))), Arc::new(DataValue::Decimal(Some(Decimal::new(1, 0)))), - ] + ], }; let (_, bytes) = codec.encode_tuple(&tuple)?; @@ -359,7 +369,7 @@ mod tests { is_primary: false, is_unique: false, }, - None + None, ); col.table_name = Some(Arc::new(table_name.to_string())); @@ -381,14 +391,12 @@ mod tests { set.insert(op(0, "T2")); set.insert(op(0, "T2")); - let (min, max) = TableCodec::columns_bound( - &Arc::new("T1".to_string()) - ); + let (min, max) = TableCodec::columns_bound(&Arc::new("T1".to_string())); let vec = set .range::, Bound<&Bytes>)>(( Bound::Included(&Bytes::from(min)), - Bound::Included(&Bytes::from(max)) + Bound::Included(&Bytes::from(max)), )) .collect_vec(); @@ -410,7 +418,8 @@ mod tests { is_unique: false, }; - let (key, _) = TableCodec::encode_index_meta(&table_name.to_string(), &index_meta).unwrap(); + let (key, _) = + TableCodec::encode_index_meta(&table_name.to_string(), &index_meta).unwrap(); key }; @@ -431,7 +440,7 @@ mod tests { let vec = set .range::, Bound<&Bytes>)>(( Bound::Included(&Bytes::from(min)), - Bound::Included(&Bytes::from(max)) + Bound::Included(&Bytes::from(max)), )) .collect_vec(); @@ -478,7 +487,10 @@ mod tests { println!("{:?}", max); let vec = set - .range::, (Bound<&Vec>, Bound<&Vec>)>((Bound::Included(&min), Bound::Included(&max))) + .range::, (Bound<&Vec>, Bound<&Vec>)>(( + Bound::Included(&min), + Bound::Included(&max), + )) .collect_vec(); assert_eq!(vec.len(), 3); @@ -498,8 +510,10 @@ mod tests { }; TableCodec { - table: TableCatalog::new(Arc::new(table_name.to_string()), vec![]).unwrap() - }.encode_index_key(&index).unwrap() + table: TableCatalog::new(Arc::new(table_name.to_string()), vec![]).unwrap(), + } + .encode_index_key(&index) + .unwrap() }; set.insert(op(DataValue::Int32(Some(0)), 0, "T0")); @@ -520,7 +534,10 @@ mod tests { let (min, max) = table_codec.all_index_bound(); let vec = set - .range::, (Bound<&Vec>, Bound<&Vec>)>((Bound::Included(&min), Bound::Included(&max))) + .range::, (Bound<&Vec>, Bound<&Vec>)>(( + Bound::Included(&min), + Bound::Included(&max), + )) .collect_vec(); assert_eq!(vec.len(), 3); @@ -535,8 +552,10 @@ mod tests { let mut set = BTreeSet::new(); let op = |tuple_id: DataValue, table_name: &str| { TableCodec { - table: TableCatalog::new(Arc::new(table_name.to_string()), vec![]).unwrap() - }.encode_tuple_key(&Arc::new(tuple_id)).unwrap() + table: TableCatalog::new(Arc::new(table_name.to_string()), vec![]).unwrap(), + } + .encode_tuple_key(&Arc::new(tuple_id)) + .unwrap() }; set.insert(op(DataValue::Int32(Some(0)), "T0")); @@ -557,7 +576,10 @@ mod tests { let (min, max) = table_codec.tuple_bound(); let vec = set - .range::, (Bound<&Vec>, Bound<&Vec>)>((Bound::Included(&min), Bound::Included(&max))) + .range::, (Bound<&Vec>, Bound<&Vec>)>(( + Bound::Included(&min), + Bound::Included(&max), + )) .collect_vec(); assert_eq!(vec.len(), 3); @@ -568,11 +590,9 @@ mod tests { } #[test] - fn test_root_codec_name_bound(){ + fn test_root_codec_name_bound() { let mut set = BTreeSet::new(); - let op = |table_name: &str| { - TableCodec::encode_root_table_key(&table_name.to_string()) - }; + let op = |table_name: &str| TableCodec::encode_root_table_key(&table_name.to_string()); set.insert(b"A".to_vec()); @@ -585,12 +605,14 @@ mod tests { let (min, max) = TableCodec::root_table_bound(); let vec = set - .range::, (Bound<&Vec>, Bound<&Vec>)>((Bound::Included(&min), Bound::Included(&max))) + .range::, (Bound<&Vec>, Bound<&Vec>)>(( + Bound::Included(&min), + Bound::Included(&max), + )) .collect_vec(); assert_eq!(vec[0], &op("T0")); assert_eq!(vec[1], &op("T1")); assert_eq!(vec[2], &op("T2")); - } -} \ No newline at end of file +} diff --git a/src/types/errors.rs b/src/types/errors.rs index 9ec307a1..b897138d 100644 --- a/src/types/errors.rs +++ b/src/types/errors.rs @@ -1,7 +1,7 @@ +use chrono::ParseError; use std::num::{ParseFloatError, ParseIntError, TryFromIntError}; use std::str::ParseBoolError; use std::string::FromUtf8Error; -use chrono::ParseError; #[derive(thiserror::Error, Debug)] pub enum TypeError { @@ -51,7 +51,7 @@ pub enum TypeError { Bincode( #[source] #[from] - Box + Box, ), #[error("try from decimal")] TryFromDecimal( @@ -64,5 +64,5 @@ pub enum TypeError { #[source] #[from] FromUtf8Error, - ) + ), } diff --git a/src/types/index.rs b/src/types/index.rs index 59c89a05..bd0f7e3d 100644 --- a/src/types/index.rs +++ b/src/types/index.rs @@ -1,7 +1,7 @@ -use std::sync::Arc; -use serde::{Deserialize, Serialize}; -use crate::types::ColumnId; use crate::types::value::ValueRef; +use crate::types::ColumnId; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; pub type IndexId = u32; pub type IndexMetaRef = Arc; @@ -11,7 +11,7 @@ pub struct IndexMeta { pub id: IndexId, pub column_ids: Vec, pub name: String, - pub is_unique:bool + pub is_unique: bool, } pub struct Index { @@ -21,9 +21,6 @@ pub struct Index { impl Index { pub fn new(id: IndexId, column_values: Vec) -> Self { - Index { - id, - column_values, - } + Index { id, column_values } } -} \ No newline at end of file +} diff --git a/src/types/mod.rs b/src/types/mod.rs index b3f82ad4..f4f3218f 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -1,10 +1,13 @@ pub mod errors; -pub mod value; -pub mod tuple; pub mod index; pub mod tuple_builder; +pub mod tuple; +pub mod value; +use chrono::{NaiveDate, NaiveDateTime}; +use rust_decimal::Decimal; use serde::{Deserialize, Serialize}; +use std::any::TypeId; use sqlparser::ast::ExactNumberInfo; use strum_macros::AsRefStr; @@ -15,7 +18,9 @@ pub type ColumnId = u32; /// Sqlrs type conversion: /// sqlparser::ast::DataType -> LogicalType -> arrow::datatypes::DataType -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, AsRefStr, Serialize, Deserialize)] +#[derive( + Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, AsRefStr, Serialize, Deserialize, +)] pub enum LogicalType { Invalid, SqlNull, @@ -38,6 +43,42 @@ pub enum LogicalType { } impl LogicalType { + pub fn type_trans() -> Option { + let type_id = TypeId::of::(); + + if type_id == TypeId::of::() { + Some(LogicalType::Tinyint) + } else if type_id == TypeId::of::() { + Some(LogicalType::Smallint) + } else if type_id == TypeId::of::() { + Some(LogicalType::Integer) + } else if type_id == TypeId::of::() { + Some(LogicalType::Bigint) + } else if type_id == TypeId::of::() { + Some(LogicalType::UTinyint) + } else if type_id == TypeId::of::() { + Some(LogicalType::USmallint) + } else if type_id == TypeId::of::() { + Some(LogicalType::UInteger) + } else if type_id == TypeId::of::() { + Some(LogicalType::UBigint) + } else if type_id == TypeId::of::() { + Some(LogicalType::Float) + } else if type_id == TypeId::of::() { + Some(LogicalType::Double) + } else if type_id == TypeId::of::() { + Some(LogicalType::Date) + } else if type_id == TypeId::of::() { + Some(LogicalType::DateTime) + } else if type_id == TypeId::of::() { + Some(LogicalType::Decimal(None, None)) + } else if type_id == TypeId::of::() { + Some(LogicalType::Varchar(None)) + } else { + None + } + } + pub fn raw_len(&self) -> Option { match self { LogicalType::Invalid => Some(0), @@ -113,11 +154,7 @@ impl LogicalType { } pub fn is_floating_point_numeric(&self) -> bool { - matches!( - self, - LogicalType::Float - | LogicalType::Double - ) + matches!(self, LogicalType::Float | LogicalType::Double) } pub fn max_logical_type( @@ -136,13 +173,24 @@ impl LogicalType { if left.is_numeric() && right.is_numeric() { return LogicalType::combine_numeric_types(left, right); } - if matches!((left, right), (LogicalType::Date, LogicalType::Varchar(_)) | (LogicalType::Varchar(_), LogicalType::Date)) { + if matches!( + (left, right), + (LogicalType::Date, LogicalType::Varchar(_)) + | (LogicalType::Varchar(_), LogicalType::Date) + ) { return Ok(LogicalType::Date); } - if matches!((left, right), (LogicalType::Date, LogicalType::DateTime) | (LogicalType::DateTime, LogicalType::Date)) { + if matches!( + (left, right), + (LogicalType::Date, LogicalType::DateTime) | (LogicalType::DateTime, LogicalType::Date) + ) { return Ok(LogicalType::DateTime); } - if matches!((left, right), (LogicalType::DateTime, LogicalType::Varchar(_)) | (LogicalType::Varchar(_), LogicalType::DateTime)) { + if matches!( + (left, right), + (LogicalType::DateTime, LogicalType::Varchar(_)) + | (LogicalType::Varchar(_), LogicalType::DateTime) + ) { return Ok(LogicalType::DateTime); } Err(TypeError::InternalError(format!( @@ -259,8 +307,9 @@ impl TryFrom for LogicalType { fn try_from(value: sqlparser::ast::DataType) -> Result { match value { - sqlparser::ast::DataType::Char(len) - | sqlparser::ast::DataType::Varchar(len)=> Ok(LogicalType::Varchar(len.map(|len| len.length as u32))), + sqlparser::ast::DataType::Char(len) | sqlparser::ast::DataType::Varchar(len) => { + Ok(LogicalType::Varchar(len.map(|len| len.length as u32))) + } sqlparser::ast::DataType::Float(_) => Ok(LogicalType::Float), sqlparser::ast::DataType::Double => Ok(LogicalType::Double), sqlparser::ast::DataType::TinyInt(_) => Ok(LogicalType::Tinyint), @@ -276,12 +325,12 @@ impl TryFrom for LogicalType { sqlparser::ast::DataType::UnsignedBigInt(_) => Ok(LogicalType::UBigint), sqlparser::ast::DataType::Boolean => Ok(LogicalType::Boolean), sqlparser::ast::DataType::Datetime(_) => Ok(LogicalType::DateTime), - sqlparser::ast::DataType::Decimal(info) => match info { - ExactNumberInfo::None => Ok(Self::Decimal(None, None)), - ExactNumberInfo::Precision(p) => Ok(Self::Decimal(Some(p as u8), None)), - ExactNumberInfo::PrecisionAndScale(p, s) => { - Ok(Self::Decimal(Some(p as u8), Some(s as u8))) - } + sqlparser::ast::DataType::Decimal(info) => match info { + ExactNumberInfo::None => Ok(Self::Decimal(None, None)), + ExactNumberInfo::Precision(p) => Ok(Self::Decimal(Some(p as u8), None)), + ExactNumberInfo::PrecisionAndScale(p, s) => { + Ok(Self::Decimal(Some(p as u8), Some(s as u8))) + } }, other => Err(TypeError::NotImplementedSqlparserDataType( other.to_string(), diff --git a/src/types/tuple.rs b/src/types/tuple.rs index 267fd5cd..0359d569 100644 --- a/src/types/tuple.rs +++ b/src/types/tuple.rs @@ -1,9 +1,9 @@ -use std::sync::Arc; +use crate::catalog::ColumnRef; +use crate::types::value::{DataValue, ValueRef}; use comfy_table::{Cell, Table}; use integer_encoding::FixedInt; use itertools::Itertools; -use crate::catalog::ColumnRef; -use crate::types::value::{DataValue, ValueRef}; +use std::sync::Arc; const BITS_MAX_INDEX: usize = 8; @@ -36,13 +36,19 @@ impl Tuple { values.push(Arc::new(DataValue::none(logic_type))); } else if let Some(len) = logic_type.raw_len() { /// fixed length (e.g.: int) - values.push(Arc::new(DataValue::from_raw(&bytes[pos..pos + len], logic_type))); + values.push(Arc::new(DataValue::from_raw( + &bytes[pos..pos + len], + logic_type, + ))); pos += len; } else { /// variable length (e.g.: varchar) let len = u32::decode_fixed(&bytes[pos..pos + 4]) as usize; pos += 4; - values.push(Arc::new(DataValue::from_raw(&bytes[pos..pos + len], logic_type))); + values.push(Arc::new(DataValue::from_raw( + &bytes[pos..pos + len], + logic_type, + ))); pos += len; } @@ -100,7 +106,8 @@ pub fn create_table(tuples: &[Tuple]) -> Table { table.set_header(header); for tuple in tuples { - let cells = tuple.values + let cells = tuple + .values .iter() .map(|value| Cell::new(format!("{value}"))) .collect_vec(); @@ -113,11 +120,11 @@ pub fn create_table(tuples: &[Tuple]) -> Table { #[cfg(test)] mod tests { - use std::sync::Arc; use crate::catalog::{ColumnCatalog, ColumnDesc}; - use crate::types::LogicalType; use crate::types::tuple::Tuple; use crate::types::value::DataValue; + use crate::types::LogicalType; + use std::sync::Arc; #[test] fn test_tuple_serialize_to_and_deserialize_from() { @@ -126,73 +133,73 @@ mod tests { "c1".to_string(), false, ColumnDesc::new(LogicalType::Integer, true, false), - None + None, )), Arc::new(ColumnCatalog::new( "c2".to_string(), false, ColumnDesc::new(LogicalType::UInteger, false, false), - None + None, )), Arc::new(ColumnCatalog::new( "c3".to_string(), false, ColumnDesc::new(LogicalType::Varchar(Some(2)), false, false), - None + None, )), Arc::new(ColumnCatalog::new( "c4".to_string(), false, ColumnDesc::new(LogicalType::Smallint, false, false), - None + None, )), Arc::new(ColumnCatalog::new( "c5".to_string(), false, ColumnDesc::new(LogicalType::USmallint, false, false), - None + None, )), Arc::new(ColumnCatalog::new( "c6".to_string(), false, ColumnDesc::new(LogicalType::Float, false, false), - None + None, )), Arc::new(ColumnCatalog::new( "c7".to_string(), false, ColumnDesc::new(LogicalType::Double, false, false), - None + None, )), Arc::new(ColumnCatalog::new( "c8".to_string(), false, ColumnDesc::new(LogicalType::Tinyint, false, false), - None + None, )), Arc::new(ColumnCatalog::new( "c9".to_string(), false, ColumnDesc::new(LogicalType::UTinyint, false, false), - None + None, )), Arc::new(ColumnCatalog::new( "c10".to_string(), false, ColumnDesc::new(LogicalType::Boolean, false, false), - None + None, )), Arc::new(ColumnCatalog::new( "c11".to_string(), false, ColumnDesc::new(LogicalType::DateTime, false, false), - None + None, )), Arc::new(ColumnCatalog::new( "c12".to_string(), false, ColumnDesc::new(LogicalType::Date, false, false), - None + None, )), ]; @@ -213,7 +220,7 @@ mod tests { Arc::new(DataValue::Boolean(Some(true))), Arc::new(DataValue::Date64(Some(0))), Arc::new(DataValue::Date32(Some(0))), - ] + ], }, Tuple { id: Some(Arc::new(DataValue::Int32(Some(1)))), @@ -230,21 +237,15 @@ mod tests { Arc::new(DataValue::UInt8(None)), Arc::new(DataValue::Boolean(None)), Arc::new(DataValue::Date64(None)), - Arc::new(DataValue::Date32(None)) + Arc::new(DataValue::Date32(None)), ], - } + }, ]; - let tuple_0 = Tuple::deserialize_from( - columns.clone(), - &tuples[0].serialize_to() - ); - let tuple_1 = Tuple::deserialize_from( - columns.clone(), - &tuples[1].serialize_to() - ); + let tuple_0 = Tuple::deserialize_from(columns.clone(), &tuples[0].serialize_to()); + let tuple_1 = Tuple::deserialize_from(columns.clone(), &tuples[1].serialize_to()); assert_eq!(tuples[0], tuple_0); assert_eq!(tuples[1], tuple_1); } -} \ No newline at end of file +} diff --git a/src/types/value.rs b/src/types/value.rs index 5f097d68..a5f0d962 100644 --- a/src/types/value.rs +++ b/src/types/value.rs @@ -1,26 +1,24 @@ +use chrono::format::{DelayedFormat, StrftimeItems}; +use chrono::{Datelike, NaiveDate, NaiveDateTime}; +use integer_encoding::FixedInt; +use lazy_static::lazy_static; +use rust_decimal::Decimal; use std::cmp::Ordering; -use std::{fmt, mem}; use std::fmt::Formatter; use std::hash::Hash; use std::str::FromStr; use std::sync::Arc; -use chrono::{NaiveDateTime, Datelike, NaiveDate}; -use chrono::format::{DelayedFormat, StrftimeItems}; -use integer_encoding::FixedInt; -use lazy_static::lazy_static; -use rust_decimal::Decimal; +use std::{fmt, mem}; +use crate::types::errors::TypeError; use ordered_float::OrderedFloat; -use serde::{Deserialize, Serialize}; use rust_decimal::prelude::FromPrimitive; -use crate::types::errors::TypeError; +use serde::{Deserialize, Serialize}; use super::LogicalType; lazy_static! { - static ref UNIX_DATETIME: NaiveDateTime = { - NaiveDateTime::from_timestamp_opt(0, 0).unwrap() - }; + static ref UNIX_DATETIME: NaiveDateTime = { NaiveDateTime::from_timestamp_opt(0, 0).unwrap() }; } pub const DATE_FMT: &str = "%Y-%m-%d"; @@ -53,6 +51,38 @@ pub enum DataValue { Decimal(Option), } +macro_rules! generate_get_option { + ($data_value:ident, $($prefix:ident : $variant:ident($field:ty)),*) => { + impl $data_value { + $( + pub fn $prefix(&self) -> $field { + if let $data_value::$variant(Some(val)) = self { + Some(val.clone()) + } else { + None + } + } + )* + } + }; +} + +generate_get_option!(DataValue, + bool : Boolean(Option), + float : Float32(Option), + double : Float64(Option), + i8 : Int8(Option), + i16 : Int16(Option), + i32 : Int32(Option), + i64 : Int64(Option), + u8 : UInt8(Option), + u16 : UInt16(Option), + u32 : UInt32(Option), + u64 : UInt64(Option), + utf8 : Utf8(Option), + decimal : Decimal(Option) +); + impl PartialEq for DataValue { fn eq(&self, other: &Self) -> bool { use DataValue::*; @@ -188,20 +218,37 @@ impl Hash for DataValue { } macro_rules! varchar_cast { ($value:expr, $len:expr) => { - $value.map(|v| { - let string_value = format!("{}", v); - if let Some(len) = $len { - if string_value.len() > *len as usize { - return Err(TypeError::TooLong); + $value + .map(|v| { + let string_value = format!("{}", v); + if let Some(len) = $len { + if string_value.len() > *len as usize { + return Err(TypeError::TooLong); + } } - } - Ok(DataValue::Utf8(Some(string_value))) - - }).unwrap_or(Ok(DataValue::Utf8(None))) + Ok(DataValue::Utf8(Some(string_value))) + }) + .unwrap_or(Ok(DataValue::Utf8(None))) }; } impl DataValue { + pub fn date(&self) -> Option { + if let DataValue::Date32(Some(val)) = self { + NaiveDate::from_num_days_from_ce_opt(*val) + } else { + None + } + } + + pub fn datetime(&self) -> Option { + if let DataValue::Date64(Some(val)) = self { + NaiveDateTime::from_timestamp_opt(*val, 0) + } else { + None + } + } + pub(crate) fn check_len(&self, logic_type: &LogicalType) -> Result<(), TypeError> { let is_over_len = match (logic_type, self) { (LogicalType::Varchar(Some(len)), DataValue::Utf8(Some(val))) => { @@ -210,42 +257,38 @@ impl DataValue { (LogicalType::Decimal(full_len, scale_len), DataValue::Decimal(Some(val))) => { if let Some(len) = full_len { if val.mantissa().ilog10() + 1 > *len as u32 { - return Err(TypeError::TooLong) + return Err(TypeError::TooLong); } } if let Some(len) = scale_len { if val.scale() > *len as u32 { - return Err(TypeError::TooLong) + return Err(TypeError::TooLong); } } false } - _ => false + _ => false, }; if is_over_len { - return Err(TypeError::TooLong) + return Err(TypeError::TooLong); } Ok(()) } fn format_date(value: Option) -> Option { - value.and_then(|v| { - Self::date_format(v).map(|fmt| format!("{}", fmt)) - }) + value.and_then(|v| Self::date_format(v).map(|fmt| format!("{}", fmt))) } fn format_datetime(value: Option) -> Option { - value.and_then(|v| { - Self::date_time_format(v).map(|fmt| format!("{}", fmt)) - }) + value.and_then(|v| Self::date_time_format(v).map(|fmt| format!("{}", fmt))) } pub fn is_variable(&self) -> bool { match self { DataValue::Utf8(_) => true, - _ => false + _ => false, } } @@ -332,7 +375,8 @@ impl DataValue { DataValue::Date32(v) => v.map(|v| v.encode_fixed_vec()), DataValue::Date64(v) => v.map(|v| v.encode_fixed_vec()), DataValue::Decimal(v) => v.clone().map(|v| v.serialize().to_vec()), - }.unwrap_or(vec![]) + } + .unwrap_or(vec![]) } pub fn from_raw(bytes: &[u8], ty: &LogicalType) -> Self { @@ -340,14 +384,30 @@ impl DataValue { LogicalType::Invalid => panic!("invalid logical type"), LogicalType::SqlNull => DataValue::Null, LogicalType::Boolean => DataValue::Boolean(bytes.get(0).map(|v| *v != 0)), - LogicalType::Tinyint => DataValue::Int8((!bytes.is_empty()).then(|| i8::decode_fixed(bytes))), - LogicalType::UTinyint => DataValue::UInt8((!bytes.is_empty()).then(|| u8::decode_fixed(bytes))), - LogicalType::Smallint => DataValue::Int16((!bytes.is_empty()).then(|| i16::decode_fixed(bytes))), - LogicalType::USmallint => DataValue::UInt16((!bytes.is_empty()).then(|| u16::decode_fixed(bytes))), - LogicalType::Integer => DataValue::Int32((!bytes.is_empty()).then(|| i32::decode_fixed(bytes))), - LogicalType::UInteger => DataValue::UInt32((!bytes.is_empty()).then(|| u32::decode_fixed(bytes))), - LogicalType::Bigint => DataValue::Int64((!bytes.is_empty()).then(|| i64::decode_fixed(bytes))), - LogicalType::UBigint => DataValue::UInt64((!bytes.is_empty()).then(|| u64::decode_fixed(bytes))), + LogicalType::Tinyint => { + DataValue::Int8((!bytes.is_empty()).then(|| i8::decode_fixed(bytes))) + } + LogicalType::UTinyint => { + DataValue::UInt8((!bytes.is_empty()).then(|| u8::decode_fixed(bytes))) + } + LogicalType::Smallint => { + DataValue::Int16((!bytes.is_empty()).then(|| i16::decode_fixed(bytes))) + } + LogicalType::USmallint => { + DataValue::UInt16((!bytes.is_empty()).then(|| u16::decode_fixed(bytes))) + } + LogicalType::Integer => { + DataValue::Int32((!bytes.is_empty()).then(|| i32::decode_fixed(bytes))) + } + LogicalType::UInteger => { + DataValue::UInt32((!bytes.is_empty()).then(|| u32::decode_fixed(bytes))) + } + LogicalType::Bigint => { + DataValue::Int64((!bytes.is_empty()).then(|| i64::decode_fixed(bytes))) + } + LogicalType::UBigint => { + DataValue::UInt64((!bytes.is_empty()).then(|| u64::decode_fixed(bytes))) + } LogicalType::Float => DataValue::Float32((!bytes.is_empty()).then(|| { let mut buf = [0; 4]; buf.copy_from_slice(bytes); @@ -358,10 +418,19 @@ impl DataValue { buf.copy_from_slice(bytes); f64::from_ne_bytes(buf) })), - LogicalType::Varchar(_) => DataValue::Utf8((!bytes.is_empty()).then(|| String::from_utf8(bytes.to_owned()).unwrap())), - LogicalType::Date => DataValue::Date32((!bytes.is_empty()).then(|| i32::decode_fixed(bytes))), - LogicalType::DateTime => DataValue::Date64((!bytes.is_empty()).then(|| i64::decode_fixed(bytes))), - LogicalType::Decimal(_, _) => DataValue::Decimal((!bytes.is_empty()).then(|| Decimal::deserialize(<[u8; 16]>::try_from(bytes).unwrap()))), + LogicalType::Varchar(_) => DataValue::Utf8( + (!bytes.is_empty()).then(|| String::from_utf8(bytes.to_owned()).unwrap()), + ), + LogicalType::Date => { + DataValue::Date32((!bytes.is_empty()).then(|| i32::decode_fixed(bytes))) + } + LogicalType::DateTime => { + DataValue::Date64((!bytes.is_empty()).then(|| i64::decode_fixed(bytes))) + } + LogicalType::Decimal(_, _) => DataValue::Decimal( + (!bytes.is_empty()) + .then(|| Decimal::deserialize(<[u8; 16]>::try_from(bytes).unwrap())), + ), } } @@ -463,10 +532,10 @@ impl DataValue { DataValue::Int16(Some(v)) => encode_u!(b, *v as u16 ^ 0x8000_u16), DataValue::Int32(Some(v)) | DataValue::Date32(Some(v)) => { encode_u!(b, *v as u32 ^ 0x80000000_u32) - }, + } DataValue::Int64(Some(v)) | DataValue::Date64(Some(v)) => { encode_u!(b, *v as u64 ^ 0x8000000000000000_u64) - }, + } DataValue::UInt8(Some(v)) => encode_u!(b, v), DataValue::UInt16(Some(v)) => encode_u!(b, v), DataValue::UInt32(Some(v)) => encode_u!(b, v), @@ -483,7 +552,7 @@ impl DataValue { } encode_u!(b, u); - }, + } DataValue::Float64(Some(f)) => { let mut u = f.to_bits(); @@ -494,7 +563,7 @@ impl DataValue { } encode_u!(b, u); - }, + } DataValue::Decimal(Some(_v)) => todo!(), value => { return if value.is_null() { @@ -502,7 +571,7 @@ impl DataValue { } else { Err(TypeError::InvalidType) } - }, + } } Ok(()) @@ -510,316 +579,346 @@ impl DataValue { pub fn cast(self, to: &LogicalType) -> Result { match self { - DataValue::Null => { - match to { - LogicalType::Invalid => Err(TypeError::CastFail), - LogicalType::SqlNull => Ok(DataValue::Null), - LogicalType::Boolean => Ok(DataValue::Boolean(None)), - LogicalType::Tinyint => Ok(DataValue::Int8(None)), - LogicalType::UTinyint => Ok(DataValue::UInt8(None)), - LogicalType::Smallint => Ok(DataValue::Int16(None)), - LogicalType::USmallint => Ok(DataValue::UInt16(None)), - LogicalType::Integer => Ok(DataValue::Int32(None)), - LogicalType::UInteger => Ok(DataValue::UInt32(None)), - LogicalType::Bigint => Ok(DataValue::Int64(None)), - LogicalType::UBigint => Ok(DataValue::UInt64(None)), - LogicalType::Float => Ok(DataValue::Float32(None)), - LogicalType::Double => Ok(DataValue::Float64(None)), - LogicalType::Varchar(_) => Ok(DataValue::Utf8(None)), - LogicalType::Date => Ok(DataValue::Date32(None)), - LogicalType::DateTime => Ok(DataValue::Date64(None)), - LogicalType::Decimal(_, _) => Ok(DataValue::Decimal(None)), - } - } - DataValue::Boolean(value) => { - match to { - LogicalType::SqlNull => Ok(DataValue::Null), - LogicalType::Boolean => Ok(DataValue::Boolean(value)), - LogicalType::Tinyint => Ok(DataValue::Int8(value.map(|v| v.into()))), - LogicalType::UTinyint => Ok(DataValue::UInt8(value.map(|v| v.into()))), - LogicalType::Smallint => Ok(DataValue::Int16(value.map(|v| v.into()))), - LogicalType::USmallint => Ok(DataValue::UInt16(value.map(|v| v.into()))), - LogicalType::Integer => Ok(DataValue::Int32(value.map(|v| v.into()))), - LogicalType::UInteger => Ok(DataValue::UInt32(value.map(|v| v.into()))), - LogicalType::Bigint => Ok(DataValue::Int64(value.map(|v| v.into()))), - LogicalType::UBigint => Ok(DataValue::UInt64(value.map(|v| v.into()))), - LogicalType::Float => Ok(DataValue::Float32(value.map(|v| v.into()))), - LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))), - LogicalType::Varchar(len) => varchar_cast!(value, len), - _ => Err(TypeError::CastFail), - } - } - DataValue::Float32(value) => { - match to { - LogicalType::SqlNull => Ok(DataValue::Null), - LogicalType::Float => Ok(DataValue::Float32(value)), - LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))), - LogicalType::Varchar(len) => varchar_cast!(value, len), - LogicalType::Decimal(_, option) =>{ - Ok(DataValue::Decimal(value.map(|v| { + DataValue::Null => match to { + LogicalType::Invalid => Err(TypeError::CastFail), + LogicalType::SqlNull => Ok(DataValue::Null), + LogicalType::Boolean => Ok(DataValue::Boolean(None)), + LogicalType::Tinyint => Ok(DataValue::Int8(None)), + LogicalType::UTinyint => Ok(DataValue::UInt8(None)), + LogicalType::Smallint => Ok(DataValue::Int16(None)), + LogicalType::USmallint => Ok(DataValue::UInt16(None)), + LogicalType::Integer => Ok(DataValue::Int32(None)), + LogicalType::UInteger => Ok(DataValue::UInt32(None)), + LogicalType::Bigint => Ok(DataValue::Int64(None)), + LogicalType::UBigint => Ok(DataValue::UInt64(None)), + LogicalType::Float => Ok(DataValue::Float32(None)), + LogicalType::Double => Ok(DataValue::Float64(None)), + LogicalType::Varchar(_) => Ok(DataValue::Utf8(None)), + LogicalType::Date => Ok(DataValue::Date32(None)), + LogicalType::DateTime => Ok(DataValue::Date64(None)), + LogicalType::Decimal(_, _) => Ok(DataValue::Decimal(None)), + }, + DataValue::Boolean(value) => match to { + LogicalType::SqlNull => Ok(DataValue::Null), + LogicalType::Boolean => Ok(DataValue::Boolean(value)), + LogicalType::Tinyint => Ok(DataValue::Int8(value.map(|v| v.into()))), + LogicalType::UTinyint => Ok(DataValue::UInt8(value.map(|v| v.into()))), + LogicalType::Smallint => Ok(DataValue::Int16(value.map(|v| v.into()))), + LogicalType::USmallint => Ok(DataValue::UInt16(value.map(|v| v.into()))), + LogicalType::Integer => Ok(DataValue::Int32(value.map(|v| v.into()))), + LogicalType::UInteger => Ok(DataValue::UInt32(value.map(|v| v.into()))), + LogicalType::Bigint => Ok(DataValue::Int64(value.map(|v| v.into()))), + LogicalType::UBigint => Ok(DataValue::UInt64(value.map(|v| v.into()))), + LogicalType::Float => Ok(DataValue::Float32(value.map(|v| v.into()))), + LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))), + LogicalType::Varchar(len) => varchar_cast!(value, len), + _ => Err(TypeError::CastFail), + }, + DataValue::Float32(value) => match to { + LogicalType::SqlNull => Ok(DataValue::Null), + LogicalType::Float => Ok(DataValue::Float32(value)), + LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))), + LogicalType::Varchar(len) => varchar_cast!(value, len), + LogicalType::Decimal(_, option) => Ok(DataValue::Decimal( + value + .map(|v| { let mut decimal = Decimal::from_f32(v).ok_or(TypeError::CastFail)?; Self::decimal_round_f(option, &mut decimal); Ok::(decimal) - }).transpose()?)) - } - _ => Err(TypeError::CastFail), - } - } - DataValue::Float64(value) => { - match to { - LogicalType::SqlNull => Ok(DataValue::Null), - LogicalType::Double => Ok(DataValue::Float64(value)), - LogicalType::Varchar(len) => varchar_cast!(value, len), - LogicalType::Decimal(_, option) => { - Ok(DataValue::Decimal(value.map(|v| { + }) + .transpose()?, + )), + _ => Err(TypeError::CastFail), + }, + DataValue::Float64(value) => match to { + LogicalType::SqlNull => Ok(DataValue::Null), + LogicalType::Double => Ok(DataValue::Float64(value)), + LogicalType::Varchar(len) => varchar_cast!(value, len), + LogicalType::Decimal(_, option) => Ok(DataValue::Decimal( + value + .map(|v| { let mut decimal = Decimal::from_f64(v).ok_or(TypeError::CastFail)?; Self::decimal_round_f(option, &mut decimal); Ok::(decimal) - }).transpose()?)) - } - _ => Err(TypeError::CastFail), - } - } - DataValue::Int8(value) => { - match to { - LogicalType::SqlNull => Ok(DataValue::Null), - LogicalType::Tinyint => Ok(DataValue::Int8(value)), - LogicalType::UTinyint => Ok(DataValue::UInt8(value.map(|v| u8::try_from(v)).transpose()?)), - LogicalType::USmallint => Ok(DataValue::UInt16(value.map(|v| u16::try_from(v)).transpose()?)), - LogicalType::UInteger => Ok(DataValue::UInt32(value.map(|v| u32::try_from(v)).transpose()?)), - LogicalType::UBigint => Ok(DataValue::UInt64(value.map(|v| u64::try_from(v)).transpose()?)), - LogicalType::Smallint => Ok(DataValue::Int16(value.map(|v| v.into()))), - LogicalType::Integer => Ok(DataValue::Int32(value.map(|v| v.into()))), - LogicalType::Bigint => Ok(DataValue::Int64(value.map(|v| v.into()))), - LogicalType::Float => Ok(DataValue::Float32(value.map(|v| v.into()))), - LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))), - LogicalType::Varchar(len) => varchar_cast!(value, len), - LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| { - let mut decimal = Decimal::from(v); - Self::decimal_round_i(option, &mut decimal); - - decimal - }))), - _ => Err(TypeError::CastFail), - } - } - DataValue::Int16(value) => { - match to { - LogicalType::SqlNull => Ok(DataValue::Null), - LogicalType::UTinyint => Ok(DataValue::UInt8(value.map(|v| u8::try_from(v)).transpose()?)), - LogicalType::USmallint => Ok(DataValue::UInt16(value.map(|v| u16::try_from(v)).transpose()?)), - LogicalType::UInteger => Ok(DataValue::UInt32(value.map(|v| u32::try_from(v)).transpose()?)), - LogicalType::UBigint => Ok(DataValue::UInt64(value.map(|v| u64::try_from(v)).transpose()?)), - LogicalType::Smallint => Ok(DataValue::Int16(value.map(|v| v.into()))), - LogicalType::Integer => Ok(DataValue::Int32(value.map(|v| v.into()))), - LogicalType::Bigint => Ok(DataValue::Int64(value.map(|v| v.into()))), - LogicalType::Float => Ok(DataValue::Float32(value.map(|v| v.into()))), - LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))), - LogicalType::Varchar(len) => varchar_cast!(value, len), - LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| { - let mut decimal = Decimal::from(v); - Self::decimal_round_i(option, &mut decimal); - - decimal - }))), - _ => Err(TypeError::CastFail), - } - } - DataValue::Int32(value) => { - match to { - LogicalType::SqlNull => Ok(DataValue::Null), - LogicalType::UTinyint => Ok(DataValue::UInt8(value.map(|v| u8::try_from(v)).transpose()?)), - LogicalType::USmallint => Ok(DataValue::UInt16(value.map(|v| u16::try_from(v)).transpose()?)), - LogicalType::UInteger => Ok(DataValue::UInt32(value.map(|v| u32::try_from(v)).transpose()?)), - LogicalType::UBigint => Ok(DataValue::UInt64(value.map(|v| u64::try_from(v)).transpose()?)), - LogicalType::Integer => Ok(DataValue::Int32(value.map(|v| v.into()))), - LogicalType::Bigint => Ok(DataValue::Int64(value.map(|v| v.into()))), - LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))), - LogicalType::Varchar(len) => varchar_cast!(value, len), - LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| { - let mut decimal = Decimal::from(v); - Self::decimal_round_i(option, &mut decimal); - - decimal - }))), - _ => Err(TypeError::CastFail), - } - } - DataValue::Int64(value) => { - match to { - LogicalType::SqlNull => Ok(DataValue::Null), - LogicalType::UTinyint => Ok(DataValue::UInt8(value.map(|v| u8::try_from(v)).transpose()?)), - LogicalType::USmallint => Ok(DataValue::UInt16(value.map(|v| u16::try_from(v)).transpose()?)), - LogicalType::UInteger => Ok(DataValue::UInt32(value.map(|v| u32::try_from(v)).transpose()?)), - LogicalType::UBigint => Ok(DataValue::UInt64(value.map(|v| u64::try_from(v)).transpose()?)), - LogicalType::Bigint => Ok(DataValue::Int64(value.map(|v| v.into()))), - LogicalType::Varchar(len) => varchar_cast!(value, len), - LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| { - let mut decimal = Decimal::from(v); - Self::decimal_round_i(option, &mut decimal); - - decimal - }))), - _ => Err(TypeError::CastFail), - } - } - DataValue::UInt8(value) => { - match to { - LogicalType::SqlNull => Ok(DataValue::Null), - LogicalType::UTinyint => Ok(DataValue::UInt8(value)), - LogicalType::Smallint => Ok(DataValue::Int16(value.map(|v| v.into()))), - LogicalType::USmallint => Ok(DataValue::UInt16(value.map(|v| v.into()))), - LogicalType::Integer => Ok(DataValue::Int32(value.map(|v| v.into()))), - LogicalType::UInteger => Ok(DataValue::UInt32(value.map(|v| v.into()))), - LogicalType::Bigint => Ok(DataValue::Int64(value.map(|v| v.into()))), - LogicalType::UBigint => Ok(DataValue::UInt64(value.map(|v| v.into()))), - LogicalType::Float => Ok(DataValue::Float32(value.map(|v| v.into()))), - LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))), - LogicalType::Varchar(len) => varchar_cast!(value, len), - LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| { - let mut decimal = Decimal::from(v); - Self::decimal_round_i(option, &mut decimal); - - decimal - }))), - _ => Err(TypeError::CastFail), - } - } - DataValue::UInt16(value) => { - match to { - LogicalType::SqlNull => Ok(DataValue::Null), - LogicalType::USmallint => Ok(DataValue::UInt16(value.map(|v| v.into()))), - LogicalType::Integer => Ok(DataValue::Int32(value.map(|v| v.into()))), - LogicalType::UInteger => Ok(DataValue::UInt32(value.map(|v| v.into()))), - LogicalType::Bigint => Ok(DataValue::Int64(value.map(|v| v.into()))), - LogicalType::UBigint => Ok(DataValue::UInt64(value.map(|v| v.into()))), - LogicalType::Float => Ok(DataValue::Float32(value.map(|v| v.into()))), - LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))), - LogicalType::Varchar(len) => varchar_cast!(value, len), - LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| { - let mut decimal = Decimal::from(v); - Self::decimal_round_i(option, &mut decimal); - - decimal - }))), - _ => Err(TypeError::CastFail), - } - } - DataValue::UInt32(value) => { - match to { - LogicalType::SqlNull => Ok(DataValue::Null), - LogicalType::UInteger => Ok(DataValue::UInt32(value.map(|v| v.into()))), - LogicalType::Bigint => Ok(DataValue::Int64(value.map(|v| v.into()))), - LogicalType::UBigint => Ok(DataValue::UInt64(value.map(|v| v.into()))), - LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))), - LogicalType::Varchar(len) => varchar_cast!(value, len), - LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| { - let mut decimal = Decimal::from(v); - Self::decimal_round_i(option, &mut decimal); - - decimal - }))), - _ => Err(TypeError::CastFail), - } - } - DataValue::UInt64(value) => { - match to { - LogicalType::SqlNull => Ok(DataValue::Null), - LogicalType::UBigint => Ok(DataValue::UInt64(value.map(|v| v.into()))), - LogicalType::Varchar(len) => varchar_cast!(value, len), - LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| { - let mut decimal = Decimal::from(v); - Self::decimal_round_i(option, &mut decimal); - - decimal - }))), - _ => Err(TypeError::CastFail), - } - } - DataValue::Utf8(value) => { - match to { - LogicalType::Invalid => Err(TypeError::CastFail), - LogicalType::SqlNull => Ok(DataValue::Null), - LogicalType::Boolean => Ok(DataValue::Boolean(value.map(|v| bool::from_str(&v)).transpose()?)), - LogicalType::Tinyint => Ok(DataValue::Int8(value.map(|v| i8::from_str(&v)).transpose()?)), - LogicalType::UTinyint => Ok(DataValue::UInt8(value.map(|v| u8::from_str(&v)).transpose()?)), - LogicalType::Smallint => Ok(DataValue::Int16(value.map(|v| i16::from_str(&v)).transpose()?)), - LogicalType::USmallint => Ok(DataValue::UInt16(value.map(|v| u16::from_str(&v)).transpose()?)), - LogicalType::Integer => Ok(DataValue::Int32(value.map(|v| i32::from_str(&v)).transpose()?)), - LogicalType::UInteger => Ok(DataValue::UInt32(value.map(|v| u32::from_str(&v)).transpose()?)), - LogicalType::Bigint => Ok(DataValue::Int64(value.map(|v| i64::from_str(&v)).transpose()?)), - LogicalType::UBigint => Ok(DataValue::UInt64(value.map(|v| u64::from_str(&v)).transpose()?)), - LogicalType::Float => Ok(DataValue::Float32(value.map(|v| f32::from_str(&v)).transpose()?)), - LogicalType::Double => Ok(DataValue::Float64(value.map(|v| f64::from_str(&v)).transpose()?)), - LogicalType::Varchar(len) => varchar_cast!(value, len), - LogicalType::Date => { - let option = value.map(|v| { + }) + .transpose()?, + )), + _ => Err(TypeError::CastFail), + }, + DataValue::Int8(value) => match to { + LogicalType::SqlNull => Ok(DataValue::Null), + LogicalType::Tinyint => Ok(DataValue::Int8(value)), + LogicalType::UTinyint => Ok(DataValue::UInt8( + value.map(|v| u8::try_from(v)).transpose()?, + )), + LogicalType::USmallint => Ok(DataValue::UInt16( + value.map(|v| u16::try_from(v)).transpose()?, + )), + LogicalType::UInteger => Ok(DataValue::UInt32( + value.map(|v| u32::try_from(v)).transpose()?, + )), + LogicalType::UBigint => Ok(DataValue::UInt64( + value.map(|v| u64::try_from(v)).transpose()?, + )), + LogicalType::Smallint => Ok(DataValue::Int16(value.map(|v| v.into()))), + LogicalType::Integer => Ok(DataValue::Int32(value.map(|v| v.into()))), + LogicalType::Bigint => Ok(DataValue::Int64(value.map(|v| v.into()))), + LogicalType::Float => Ok(DataValue::Float32(value.map(|v| v.into()))), + LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))), + LogicalType::Varchar(len) => varchar_cast!(value, len), + LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| { + let mut decimal = Decimal::from(v); + Self::decimal_round_i(option, &mut decimal); + + decimal + }))), + _ => Err(TypeError::CastFail), + }, + DataValue::Int16(value) => match to { + LogicalType::SqlNull => Ok(DataValue::Null), + LogicalType::UTinyint => Ok(DataValue::UInt8( + value.map(|v| u8::try_from(v)).transpose()?, + )), + LogicalType::USmallint => Ok(DataValue::UInt16( + value.map(|v| u16::try_from(v)).transpose()?, + )), + LogicalType::UInteger => Ok(DataValue::UInt32( + value.map(|v| u32::try_from(v)).transpose()?, + )), + LogicalType::UBigint => Ok(DataValue::UInt64( + value.map(|v| u64::try_from(v)).transpose()?, + )), + LogicalType::Smallint => Ok(DataValue::Int16(value.map(|v| v.into()))), + LogicalType::Integer => Ok(DataValue::Int32(value.map(|v| v.into()))), + LogicalType::Bigint => Ok(DataValue::Int64(value.map(|v| v.into()))), + LogicalType::Float => Ok(DataValue::Float32(value.map(|v| v.into()))), + LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))), + LogicalType::Varchar(len) => varchar_cast!(value, len), + LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| { + let mut decimal = Decimal::from(v); + Self::decimal_round_i(option, &mut decimal); + + decimal + }))), + _ => Err(TypeError::CastFail), + }, + DataValue::Int32(value) => match to { + LogicalType::SqlNull => Ok(DataValue::Null), + LogicalType::UTinyint => Ok(DataValue::UInt8( + value.map(|v| u8::try_from(v)).transpose()?, + )), + LogicalType::USmallint => Ok(DataValue::UInt16( + value.map(|v| u16::try_from(v)).transpose()?, + )), + LogicalType::UInteger => Ok(DataValue::UInt32( + value.map(|v| u32::try_from(v)).transpose()?, + )), + LogicalType::UBigint => Ok(DataValue::UInt64( + value.map(|v| u64::try_from(v)).transpose()?, + )), + LogicalType::Integer => Ok(DataValue::Int32(value.map(|v| v.into()))), + LogicalType::Bigint => Ok(DataValue::Int64(value.map(|v| v.into()))), + LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))), + LogicalType::Varchar(len) => varchar_cast!(value, len), + LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| { + let mut decimal = Decimal::from(v); + Self::decimal_round_i(option, &mut decimal); + + decimal + }))), + _ => Err(TypeError::CastFail), + }, + DataValue::Int64(value) => match to { + LogicalType::SqlNull => Ok(DataValue::Null), + LogicalType::UTinyint => Ok(DataValue::UInt8( + value.map(|v| u8::try_from(v)).transpose()?, + )), + LogicalType::USmallint => Ok(DataValue::UInt16( + value.map(|v| u16::try_from(v)).transpose()?, + )), + LogicalType::UInteger => Ok(DataValue::UInt32( + value.map(|v| u32::try_from(v)).transpose()?, + )), + LogicalType::UBigint => Ok(DataValue::UInt64( + value.map(|v| u64::try_from(v)).transpose()?, + )), + LogicalType::Bigint => Ok(DataValue::Int64(value.map(|v| v.into()))), + LogicalType::Varchar(len) => varchar_cast!(value, len), + LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| { + let mut decimal = Decimal::from(v); + Self::decimal_round_i(option, &mut decimal); + + decimal + }))), + _ => Err(TypeError::CastFail), + }, + DataValue::UInt8(value) => match to { + LogicalType::SqlNull => Ok(DataValue::Null), + LogicalType::UTinyint => Ok(DataValue::UInt8(value)), + LogicalType::Smallint => Ok(DataValue::Int16(value.map(|v| v.into()))), + LogicalType::USmallint => Ok(DataValue::UInt16(value.map(|v| v.into()))), + LogicalType::Integer => Ok(DataValue::Int32(value.map(|v| v.into()))), + LogicalType::UInteger => Ok(DataValue::UInt32(value.map(|v| v.into()))), + LogicalType::Bigint => Ok(DataValue::Int64(value.map(|v| v.into()))), + LogicalType::UBigint => Ok(DataValue::UInt64(value.map(|v| v.into()))), + LogicalType::Float => Ok(DataValue::Float32(value.map(|v| v.into()))), + LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))), + LogicalType::Varchar(len) => varchar_cast!(value, len), + LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| { + let mut decimal = Decimal::from(v); + Self::decimal_round_i(option, &mut decimal); + + decimal + }))), + _ => Err(TypeError::CastFail), + }, + DataValue::UInt16(value) => match to { + LogicalType::SqlNull => Ok(DataValue::Null), + LogicalType::USmallint => Ok(DataValue::UInt16(value.map(|v| v.into()))), + LogicalType::Integer => Ok(DataValue::Int32(value.map(|v| v.into()))), + LogicalType::UInteger => Ok(DataValue::UInt32(value.map(|v| v.into()))), + LogicalType::Bigint => Ok(DataValue::Int64(value.map(|v| v.into()))), + LogicalType::UBigint => Ok(DataValue::UInt64(value.map(|v| v.into()))), + LogicalType::Float => Ok(DataValue::Float32(value.map(|v| v.into()))), + LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))), + LogicalType::Varchar(len) => varchar_cast!(value, len), + LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| { + let mut decimal = Decimal::from(v); + Self::decimal_round_i(option, &mut decimal); + + decimal + }))), + _ => Err(TypeError::CastFail), + }, + DataValue::UInt32(value) => match to { + LogicalType::SqlNull => Ok(DataValue::Null), + LogicalType::UInteger => Ok(DataValue::UInt32(value.map(|v| v.into()))), + LogicalType::Bigint => Ok(DataValue::Int64(value.map(|v| v.into()))), + LogicalType::UBigint => Ok(DataValue::UInt64(value.map(|v| v.into()))), + LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))), + LogicalType::Varchar(len) => varchar_cast!(value, len), + LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| { + let mut decimal = Decimal::from(v); + Self::decimal_round_i(option, &mut decimal); + + decimal + }))), + _ => Err(TypeError::CastFail), + }, + DataValue::UInt64(value) => match to { + LogicalType::SqlNull => Ok(DataValue::Null), + LogicalType::UBigint => Ok(DataValue::UInt64(value.map(|v| v.into()))), + LogicalType::Varchar(len) => varchar_cast!(value, len), + LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| { + let mut decimal = Decimal::from(v); + Self::decimal_round_i(option, &mut decimal); + + decimal + }))), + _ => Err(TypeError::CastFail), + }, + DataValue::Utf8(value) => match to { + LogicalType::Invalid => Err(TypeError::CastFail), + LogicalType::SqlNull => Ok(DataValue::Null), + LogicalType::Boolean => Ok(DataValue::Boolean( + value.map(|v| bool::from_str(&v)).transpose()?, + )), + LogicalType::Tinyint => Ok(DataValue::Int8( + value.map(|v| i8::from_str(&v)).transpose()?, + )), + LogicalType::UTinyint => Ok(DataValue::UInt8( + value.map(|v| u8::from_str(&v)).transpose()?, + )), + LogicalType::Smallint => Ok(DataValue::Int16( + value.map(|v| i16::from_str(&v)).transpose()?, + )), + LogicalType::USmallint => Ok(DataValue::UInt16( + value.map(|v| u16::from_str(&v)).transpose()?, + )), + LogicalType::Integer => Ok(DataValue::Int32( + value.map(|v| i32::from_str(&v)).transpose()?, + )), + LogicalType::UInteger => Ok(DataValue::UInt32( + value.map(|v| u32::from_str(&v)).transpose()?, + )), + LogicalType::Bigint => Ok(DataValue::Int64( + value.map(|v| i64::from_str(&v)).transpose()?, + )), + LogicalType::UBigint => Ok(DataValue::UInt64( + value.map(|v| u64::from_str(&v)).transpose()?, + )), + LogicalType::Float => Ok(DataValue::Float32( + value.map(|v| f32::from_str(&v)).transpose()?, + )), + LogicalType::Double => Ok(DataValue::Float64( + value.map(|v| f64::from_str(&v)).transpose()?, + )), + LogicalType::Varchar(len) => varchar_cast!(value, len), + LogicalType::Date => { + let option = value + .map(|v| { NaiveDate::parse_from_str(&v, DATE_FMT) .map(|date| date.num_days_from_ce()) - }).transpose()?; + }) + .transpose()?; - Ok(DataValue::Date32(option)) - } - LogicalType::DateTime => { - let option = value.map(|v| { + Ok(DataValue::Date32(option)) + } + LogicalType::DateTime => { + let option = value + .map(|v| { NaiveDateTime::parse_from_str(&v, DATE_TIME_FMT) .or_else(|_| { NaiveDate::parse_from_str(&v, DATE_FMT) .map(|date| date.and_hms_opt(0, 0, 0).unwrap()) }) .map(|date_time| date_time.timestamp()) - }).transpose()?; + }) + .transpose()?; - Ok(DataValue::Date64(option)) - }, - LogicalType::Decimal(_, _) => { - Ok(DataValue::Decimal(value.map(|v| Decimal::from_str(&v)).transpose()?)) - } + Ok(DataValue::Date64(option)) } - } - DataValue::Date32(value) => { - match to { - LogicalType::SqlNull => Ok(DataValue::Null), - LogicalType::Varchar(len) => varchar_cast!(Self::format_date(value), len), - LogicalType::Date => Ok(DataValue::Date32(value)), - LogicalType::DateTime => { - let option = value.and_then(|v| { - NaiveDate::from_num_days_from_ce_opt(v) - .and_then(|date| date.and_hms_opt(0, 0, 0)) - .map(|date_time| date_time.timestamp()) - }); - - Ok(DataValue::Date64(option)) - } - _ => Err(TypeError::CastFail) - } - } - DataValue::Date64(value) => { - match to { - LogicalType::SqlNull => Ok(DataValue::Null), - LogicalType::Varchar(len) => varchar_cast!(Self::format_datetime(value), len), - LogicalType::Date => { - let option = value.and_then(|v| { - NaiveDateTime::from_timestamp_opt(v, 0) - .map(|date_time| date_time.date().num_days_from_ce()) - }); - - Ok(DataValue::Date32(option)) - } - LogicalType::DateTime => Ok(DataValue::Date64(value)), - _ => Err(TypeError::CastFail), + LogicalType::Decimal(_, _) => Ok(DataValue::Decimal( + value.map(|v| Decimal::from_str(&v)).transpose()?, + )), + }, + DataValue::Date32(value) => match to { + LogicalType::SqlNull => Ok(DataValue::Null), + LogicalType::Varchar(len) => varchar_cast!(Self::format_date(value), len), + LogicalType::Date => Ok(DataValue::Date32(value)), + LogicalType::DateTime => { + let option = value.and_then(|v| { + NaiveDate::from_num_days_from_ce_opt(v) + .and_then(|date| date.and_hms_opt(0, 0, 0)) + .map(|date_time| date_time.timestamp()) + }); + + Ok(DataValue::Date64(option)) } - } - DataValue::Decimal(value) => { - match to { - LogicalType::SqlNull => Ok(DataValue::Null), - LogicalType::Decimal(_, _) => Ok(DataValue::Decimal(value)), - LogicalType::Varchar(len) => varchar_cast!(value, len), - _ => Err(TypeError::CastFail), + _ => Err(TypeError::CastFail), + }, + DataValue::Date64(value) => match to { + LogicalType::SqlNull => Ok(DataValue::Null), + LogicalType::Varchar(len) => varchar_cast!(Self::format_datetime(value), len), + LogicalType::Date => { + let option = value.and_then(|v| { + NaiveDateTime::from_timestamp_opt(v, 0) + .map(|date_time| date_time.date().num_days_from_ce()) + }); + + Ok(DataValue::Date32(option)) } - } + LogicalType::DateTime => Ok(DataValue::Date64(value)), + _ => Err(TypeError::CastFail), + }, + DataValue::Decimal(value) => match to { + LogicalType::SqlNull => Ok(DataValue::Null), + LogicalType::Decimal(_, _) => Ok(DataValue::Decimal(value)), + LogicalType::Varchar(len) => varchar_cast!(value, len), + _ => Err(TypeError::CastFail), + }, } } @@ -834,25 +933,22 @@ impl DataValue { if let Some(scale) = option { let new_decimal = decimal.round_dp_with_strategy( *scale as u32, - rust_decimal::RoundingStrategy::MidpointAwayFromZero + rust_decimal::RoundingStrategy::MidpointAwayFromZero, ); let _ = mem::replace(decimal, new_decimal); } } fn date_format<'a>(v: i32) -> Option>> { - NaiveDate::from_num_days_from_ce_opt(v) - .map(|date| date.format(DATE_FMT)) + NaiveDate::from_num_days_from_ce_opt(v).map(|date| date.format(DATE_FMT)) } fn date_time_format<'a>(v: i64) -> Option>> { - NaiveDateTime::from_timestamp_opt(v, 0) - .map(|date_time| date_time.format(DATE_TIME_FMT)) + NaiveDateTime::from_timestamp_opt(v, 0).map(|date_time| date_time.format(DATE_TIME_FMT)) } fn decimal_format(v: &Decimal) -> String { v.to_string() - } } @@ -936,9 +1032,7 @@ impl fmt::Display for DataValue { DataValue::UInt64(e) => format_option!(f, e)?, DataValue::Utf8(e) => format_option!(f, e)?, DataValue::Null => write!(f, "null")?, - DataValue::Date32(e) => { - format_option!(f, e.and_then(|s| DataValue::date_format(s)))? - } + DataValue::Date32(e) => format_option!(f, e.and_then(|s| DataValue::date_format(s)))?, DataValue::Date64(e) => { format_option!(f, e.and_then(|s| DataValue::date_time_format(s)))? } diff --git a/tests/sqllogictest/src/lib.rs b/tests/sqllogictest/src/lib.rs index 8e1eaff6..aca86e83 100644 --- a/tests/sqllogictest/src/lib.rs +++ b/tests/sqllogictest/src/lib.rs @@ -1,7 +1,7 @@ -use std::time::Instant; -use sqllogictest::{AsyncDB, DBOutput, DefaultColumnType}; use kip_sql::db::{Database, DatabaseError}; use kip_sql::storage::kip::KipStorage; +use sqllogictest::{AsyncDB, DBOutput, DefaultColumnType}; +use std::time::Instant; pub struct KipSQL { pub db: Database, diff --git a/tests/sqllogictest/src/main.rs b/tests/sqllogictest/src/main.rs index 7438db19..58776d61 100644 --- a/tests/sqllogictest/src/main.rs +++ b/tests/sqllogictest/src/main.rs @@ -1,8 +1,8 @@ -use std::path::Path; -use sqllogictest::Runner; -use tempfile::TempDir; use kip_sql::db::Database; +use sqllogictest::Runner; use sqllogictest_test::KipSQL; +use std::path::Path; +use tempfile::TempDir; #[tokio::main] async fn main() { @@ -23,7 +23,8 @@ async fn main() { .to_string(); println!("-> Now the test file is: {}", filepath); - let db = Database::with_kipdb(temp_dir.path()).await + let db = Database::with_kipdb(temp_dir.path()) + .await .expect("init db error"); let mut tester = Runner::new(KipSQL { db }); @@ -32,4 +33,4 @@ async fn main() { } println!("-> Pass!\n\n") } -} \ No newline at end of file +}