From 66910377e333864cedde2ccea5d694bd9f9780f7 Mon Sep 17 00:00:00 2001 From: Kould <2435992353@qq.com> Date: Sat, 30 Sep 2023 15:12:01 +0800 Subject: [PATCH] Implement indexes (currently only unique indexes are supported) (#65) * feat: implement `DataValue::to_index_key` * feat: implement `Index` and encoding in `TableCodec` * feat: implementing when create table, detect the unique field and create a unique index for it * feat: when inserting data, it will check whether the field has a corresponding unique index, insert the index and check whether it already exists * feat: Processing of unique indexes when adding update and delete * feat: Processing of unique indexes when truncate * style: rename Table -> Transaction, Transaction -> Iter * feat: added `ScalarExpression::convert_binary` used to extract the constant binary expression information corresponding to Column in the condition of the where clause * style: code optimization * feat: Implement RBO rule -> `SimplifyFilter` * test: add test case for simplification.rs * feat: add RBO Rule`SimplifyFilter` * perf: `ConstantBinary::scope_aggregation` Fusion of marginal values with values within a range * feat: implement `ConstantBinary::rearrange` constant folding -> expression extraction -> aggregation (and) -> rearrangement (or) * fix: `ConstantBinary::scope_aggregation` selection of Eq/NotEq and Scope * fix: `ConstantBinary::scope_aggregation` the eq condition only aggregates one * feat: add RBO Rule`PushPredicateIntoScan` * feat: implement `IndexScan` * feat: implement offset and limit for `IndexScan` * fix: many bugs - RBO Rule: `PushProjectThroughChild`: fixed the problem of missing fields when pushing down - RBO Rule: `PushLimitThroughJoin`: fixed the problem that when the on condition in Join generates multiple same number of connection rows, the limit limit is exceeded. - * fix: resolve merge conflicts * style: code format * fix: check or in `c1 > c2 or c1 > 1` * docs: supplementary index related * perf: `ScalarExpression::check_or` optimize implementation * feat: implemented for Decimal type `DataValue::to_index_key` * perf: Optimized `DataValue` conversion to bitwise sequence * perf: optimized `DataValue::Utf8` convert to encoding of primary/unique key * refactor: reconstruct the Key encoding of each structure of TableCodec https://github.com/KipData/KipSQL/issues/68 --- README.md | 8 + src/binder/create_table.rs | 15 +- src/binder/expr.rs | 2 +- src/binder/insert.rs | 2 +- src/binder/mod.rs | 9 +- src/binder/select.rs | 8 +- src/catalog/column.rs | 28 +- src/catalog/root.rs | 4 +- src/catalog/table.rs | 38 +- src/db.rs | 66 +- src/execution/executor/dml/delete.rs | 38 +- src/execution/executor/dml/insert.rs | 64 +- src/execution/executor/dml/update.rs | 34 +- .../executor/dql/aggregate/hash_agg.rs | 2 +- src/execution/executor/dql/index_scan.rs | 46 + src/execution/executor/dql/join/hash_join.rs | 2 +- src/execution/executor/dql/mod.rs | 1 + src/execution/executor/dql/seq_scan.rs | 8 +- src/execution/executor/mod.rs | 7 +- src/execution/executor/show/show_table.rs | 34 +- src/expression/evaluator.rs | 6 +- src/expression/mod.rs | 13 +- src/expression/simplify.rs | 935 ++++++++++++++++++ src/expression/value_compute.rs | 2 +- src/lib.rs | 1 + src/optimizer/core/rule.rs | 3 +- src/optimizer/heuristic/optimizer.rs | 19 +- src/optimizer/mod.rs | 14 +- src/optimizer/rule/column_pruning.rs | 148 +-- src/optimizer/rule/combine_operators.rs | 19 +- src/optimizer/rule/mod.rs | 16 +- src/optimizer/rule/pushdown_limit.rs | 83 +- src/optimizer/rule/pushdown_predicates.rs | 116 ++- src/optimizer/rule/simplification.rs | 289 ++++++ src/planner/operator/mod.rs | 6 +- src/planner/operator/scan.rs | 15 +- src/storage/kip.rs | 480 +++++++-- src/storage/memory.rs | 72 +- src/storage/mod.rs | 90 +- src/storage/table_codec.rs | 563 ++++++++--- src/types/errors.rs | 13 + src/types/index.rs | 29 + src/types/mod.rs | 59 +- src/types/tuple.rs | 28 +- src/types/value.rs | 250 ++++- 45 files changed, 3020 insertions(+), 665 deletions(-) create mode 100644 src/execution/executor/dql/index_scan.rs create mode 100644 src/expression/simplify.rs create mode 100644 src/optimizer/rule/simplification.rs create mode 100644 src/types/index.rs diff --git a/README.md b/README.md index 4c0c9800..c0d543a1 100755 --- a/README.md +++ b/README.md @@ -43,6 +43,12 @@ Storage Support: ![demo](./static/images/demo.png) ### Features +- SQL field options + - not null + - null + - unique +- Supports index type + - Unique Index - Supports multiple primary key types - Tinyint - UTinyint @@ -63,6 +69,8 @@ Storage Support: - [x] Truncate - DQL - [x] Select + - SeqScan + - IndexScan - [x] Where - [x] Distinct - [x] Alias diff --git a/src/binder/create_table.rs b/src/binder/create_table.rs index 9bfae8e3..315b32f0 100644 --- a/src/binder/create_table.rs +++ b/src/binder/create_table.rs @@ -34,8 +34,15 @@ impl Binder { .map(|col| ColumnCatalog::from(col.clone())) .collect_vec(); - if columns.iter().find(|col| col.desc.is_primary).is_none() { - return Err(BindError::InvalidTable("At least one primary key field exists".to_string())); + let primary_key_count = columns + .iter() + .filter(|col| col.desc.is_primary) + .count(); + + if primary_key_count != 1 { + return Err(BindError::InvalidTable( + "The primary key field must exist and have at least one".to_string() + )); } let plan = LogicalPlan { @@ -75,10 +82,10 @@ mod tests { assert_eq!(op.table_name, Arc::new("t1".to_string())); assert_eq!(op.columns[0].name, "id".to_string()); assert_eq!(op.columns[0].nullable, false); - assert_eq!(op.columns[0].desc, ColumnDesc::new(LogicalType::Integer, true)); + assert_eq!(op.columns[0].desc, ColumnDesc::new(LogicalType::Integer, true, false)); assert_eq!(op.columns[1].name, "name".to_string()); assert_eq!(op.columns[1].nullable, true); - assert_eq!(op.columns[1].desc, ColumnDesc::new(LogicalType::Varchar(Some(10)), false)); + assert_eq!(op.columns[1].desc, ColumnDesc::new(LogicalType::Varchar(Some(10)), false, false)); } _ => unreachable!() } diff --git a/src/binder/expr.rs b/src/binder/expr.rs index 65ed83f3..d84c7c57 100644 --- a/src/binder/expr.rs +++ b/src/binder/expr.rs @@ -64,7 +64,7 @@ impl Binder { let table_catalog = self .context .storage - .table_catalog(table) + .table(table) .await .ok_or_else(|| BindError::InvalidTable(table.to_string()))?; diff --git a/src/binder/insert.rs b/src/binder/insert.rs index e5e21c93..4b82755c 100644 --- a/src/binder/insert.rs +++ b/src/binder/insert.rs @@ -24,7 +24,7 @@ impl Binder { let (_, name) = split_name(&name)?; let table_name = Arc::new(name.to_string()); - if let Some(table) = self.context.storage.table_catalog(&table_name).await { + if let Some(table) = self.context.storage.table(&table_name).await { let mut columns = Vec::new(); if idents.is_empty() { diff --git a/src/binder/mod.rs b/src/binder/mod.rs index 007b626f..6ac8e187 100644 --- a/src/binder/mod.rs +++ b/src/binder/mod.rs @@ -193,16 +193,16 @@ pub mod test { let _ = storage.create_table( Arc::new("t1".to_string()), vec![ - ColumnCatalog::new("c1".to_string(), false, ColumnDesc::new(Integer, true)), - ColumnCatalog::new("c2".to_string(), false, ColumnDesc::new(Integer, false)), + ColumnCatalog::new("c1".to_string(), false, ColumnDesc::new(Integer, true, false)), + ColumnCatalog::new("c2".to_string(), false, ColumnDesc::new(Integer, false, true)), ] ).await?; let _ = storage.create_table( Arc::new("t2".to_string()), vec![ - ColumnCatalog::new("c3".to_string(), false, ColumnDesc::new(Integer, true)), - ColumnCatalog::new("c4".to_string(), false, ColumnDesc::new(Integer, false)), + ColumnCatalog::new("c3".to_string(), false, ColumnDesc::new(Integer, true, false)), + ColumnCatalog::new("c4".to_string(), false, ColumnDesc::new(Integer, false, false)), ] ).await?; @@ -211,7 +211,6 @@ pub mod test { pub async fn select_sql_run(sql: &str) -> Result { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); - let storage = build_test_catalog(temp_dir.path()).await?; let binder = Binder::new(BinderContext::new(storage)); let stmt = crate::parser::parse_sql(sql)?; diff --git a/src/binder/select.rs b/src/binder/select.rs index d9a4e09e..8e37a775 100644 --- a/src/binder/select.rs +++ b/src/binder/select.rs @@ -167,7 +167,7 @@ impl Binder { let table_catalog = self .context .storage - .table_catalog(&table_name) + .table(&table_name) .await .ok_or_else(|| BindError::InvalidTable(format!("bind table {}", table)))?; @@ -215,7 +215,7 @@ impl Binder { for table_name in self.context.bind_table.keys().cloned() { let table = self.context .storage - .table_catalog(&table_name) + .table(&table_name) .await .ok_or_else(|| BindError::InvalidTable(table_name.to_string()))?; for col in table.all_columns() { @@ -244,12 +244,12 @@ impl Binder { let (right_table, right) = self.bind_single_table_ref(relation, Some(join_type)).await?; let left_table = self.context.storage - .table_catalog(&left_table) + .table(&left_table) .await .cloned() .ok_or_else(|| BindError::InvalidTable(format!("Left: {} not found", left_table)))?; let right_table = self.context.storage - .table_catalog(&right_table) + .table(&right_table) .await .cloned() .ok_or_else(|| BindError::InvalidTable(format!("Right: {} not found", right_table)))?; diff --git a/src/catalog/column.rs b/src/catalog/column.rs index cf4f076b..4dcea4be 100644 --- a/src/catalog/column.rs +++ b/src/catalog/column.rs @@ -3,13 +3,13 @@ use serde::{Deserialize, Serialize}; use sqlparser::ast::{ColumnDef, ColumnOption}; use crate::catalog::TableName; -use crate::types::{ColumnId, IdGenerator, LogicalType}; +use crate::types::{ColumnId, LogicalType}; pub type ColumnRef = Arc; #[derive(Debug, Clone, Serialize, Deserialize, Hash, Eq, PartialEq)] pub struct ColumnCatalog { - pub id: ColumnId, + pub id: Option, pub name: String, pub table_name: Option, pub nullable: bool, @@ -19,7 +19,7 @@ pub struct ColumnCatalog { impl ColumnCatalog { pub(crate) fn new(column_name: String, nullable: bool, column_desc: ColumnDesc) -> ColumnCatalog { ColumnCatalog { - id: IdGenerator::build(), + id: None, name: column_name, table_name: None, nullable, @@ -29,11 +29,11 @@ impl ColumnCatalog { pub(crate) fn new_dummy(column_name: String)-> ColumnCatalog { ColumnCatalog { - id: 0, + id: Some(0), name: column_name, table_name: None, nullable: false, - desc: ColumnDesc::new(LogicalType::Varchar(None), false), + desc: ColumnDesc::new(LogicalType::Varchar(None), false, false), } } @@ -51,6 +51,7 @@ impl From for ColumnCatalog { let column_name = column_def.name.to_string(); let mut column_desc = ColumnDesc::new( LogicalType::try_from(column_def.data_type).unwrap(), + false, false ); let mut nullable = false; @@ -60,10 +61,15 @@ impl From for ColumnCatalog { match option_def.option { ColumnOption::Null => nullable = true, ColumnOption::NotNull => (), - ColumnOption::Unique { is_primary: true } => { - column_desc.is_primary = true; - // Skip other options when using primary key - break; + ColumnOption::Unique { is_primary } => { + if is_primary { + column_desc.is_primary = true; + nullable = false; + // Skip other options when using primary key + break; + } else { + column_desc.is_unique = true; + } }, _ => todo!() } @@ -78,13 +84,15 @@ impl From for ColumnCatalog { pub struct ColumnDesc { pub(crate) column_datatype: LogicalType, pub(crate) is_primary: bool, + pub(crate) is_unique: bool, } impl ColumnDesc { - pub(crate) const fn new(column_datatype: LogicalType, is_primary: bool) -> ColumnDesc { + pub(crate) const fn new(column_datatype: LogicalType, is_primary: bool, is_unique: bool) -> ColumnDesc { ColumnDesc { column_datatype, is_primary, + is_unique, } } } diff --git a/src/catalog/root.rs b/src/catalog/root.rs index 16cdf6f3..b047b0dd 100644 --- a/src/catalog/root.rs +++ b/src/catalog/root.rs @@ -67,12 +67,12 @@ mod tests { let col0 = ColumnCatalog::new( "a".to_string(), false, - ColumnDesc::new(LogicalType::Integer, false), + ColumnDesc::new(LogicalType::Integer, false, false), ); let col1 = ColumnCatalog::new( "b".to_string(), false, - ColumnDesc::new(LogicalType::Boolean, false), + ColumnDesc::new(LogicalType::Boolean, false, false), ); let col_catalogs = vec![col0, col1]; diff --git a/src/catalog/table.rs b/src/catalog/table.rs index e0ea628d..76a19a24 100644 --- a/src/catalog/table.rs +++ b/src/catalog/table.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use crate::catalog::{CatalogError, ColumnCatalog, ColumnRef}; use crate::types::ColumnId; +use crate::types::index::{IndexMeta, IndexMetaRef}; pub type TableName = Arc; @@ -12,9 +13,16 @@ pub struct TableCatalog { /// Mapping from column names to column ids column_idxs: BTreeMap, pub(crate) columns: BTreeMap, + pub indexes: Vec } impl TableCatalog { + pub(crate) fn get_unique_index(&self, col_id: &ColumnId) -> Option<&IndexMetaRef> { + self.indexes + .iter() + .find(|meta| meta.is_unique && &meta.column_ids[0] == col_id) + } + pub(crate) fn get_column_by_id(&self, id: &ColumnId) -> Option<&ColumnRef> { self.columns.get(id) } @@ -54,8 +62,9 @@ impl TableCatalog { return Err(CatalogError::Duplicated("column", col.name.clone())); } - let col_id = col.id; + let col_id = self.columns.len() as u32; + col.id = Some(col_id); col.table_name = Some(self.name.clone()); self.column_idxs.insert(col.name.clone(), col_id); self.columns.insert(col_id, Arc::new(col)); @@ -63,14 +72,24 @@ impl TableCatalog { Ok(col_id) } + pub(crate) fn add_index_meta(&mut self, mut index: IndexMeta) -> &IndexMeta { + let index_id = self.indexes.len(); + + index.id = index_id as u32; + self.indexes.push(Arc::new(index)); + + &self.indexes[index_id] + } + pub(crate) fn new( name: TableName, - columns: Vec, + columns: Vec ) -> Result { let mut table_catalog = TableCatalog { name, column_idxs: BTreeMap::new(), columns: BTreeMap::new(), + indexes: vec![], }; for col_catalog in columns.into_iter() { @@ -79,6 +98,17 @@ impl TableCatalog { Ok(table_catalog) } + + pub(crate) fn new_with_indexes( + name: TableName, + columns: Vec, + indexes: Vec + ) -> Result { + let mut catalog = TableCatalog::new(name, columns)?; + catalog.indexes = indexes; + + Ok(catalog) + } } #[cfg(test)] @@ -93,8 +123,8 @@ mod tests { // | 1 | true | // | 2 | false | fn test_table_catalog() { - let col0 = ColumnCatalog::new("a".into(), false, ColumnDesc::new(LogicalType::Integer, false)); - let col1 = ColumnCatalog::new("b".into(), false, ColumnDesc::new(LogicalType::Boolean, false)); + let col0 = ColumnCatalog::new("a".into(), false, ColumnDesc::new(LogicalType::Integer, false, false)); + let col1 = ColumnCatalog::new("b".into(), false, ColumnDesc::new(LogicalType::Boolean, false, false)); let col_catalogs = vec![col0, col1]; let table_catalog = TableCatalog::new(Arc::new("test".to_string()), col_catalogs).unwrap(); diff --git a/src/db.rs b/src/db.rs index a8accc6e..ae68e98d 100644 --- a/src/db.rs +++ b/src/db.rs @@ -6,6 +6,7 @@ use crate::execution::ExecutorError; use crate::execution::executor::{build, try_collect}; use crate::optimizer::heuristic::batch::HepBatchStrategy; use crate::optimizer::heuristic::optimizer::HepOptimizer; +use crate::optimizer::OptimizerError; use crate::optimizer::rule::RuleImpl; use crate::parser::parse_sql; use crate::planner::LogicalPlan; @@ -64,7 +65,7 @@ impl Database { // println!("source_plan plan: {:#?}", source_plan); let best_plan = Self::default_optimizer(source_plan) - .find_best(); + .find_best()?; // println!("best_plan plan: {:#?}", best_plan); let mut stream = build(best_plan, &self.storage); @@ -75,24 +76,20 @@ impl Database { fn default_optimizer(source_plan: LogicalPlan) -> HepOptimizer { HepOptimizer::new(source_plan) .batch( - "Predicate pushdown".to_string(), + "Simplify Filter".to_string(), HepBatchStrategy::fix_point_topdown(10), - vec![ - RuleImpl::PushPredicateThroughJoin - ] + vec![RuleImpl::SimplifyFilter] ) .batch( - "Limit pushdown".to_string(), + "Predicate Pushdown".to_string(), HepBatchStrategy::fix_point_topdown(10), vec![ - RuleImpl::LimitProjectTranspose, - RuleImpl::PushLimitThroughJoin, - RuleImpl::PushLimitIntoTableScan, - RuleImpl::EliminateLimits, - ], + RuleImpl::PushPredicateThroughJoin, + RuleImpl::PushPredicateIntoScan + ] ) .batch( - "Column pruning".to_string(), + "Column Pruning".to_string(), HepBatchStrategy::fix_point_topdown(10), vec![ RuleImpl::PushProjectThroughChild, @@ -100,7 +97,17 @@ impl Database { ] ) .batch( - "Combine operators".to_string(), + "Limit Pushdown".to_string(), + HepBatchStrategy::fix_point_topdown(10), + vec![ + RuleImpl::LimitProjectTranspose, + RuleImpl::PushLimitThroughJoin, + RuleImpl::PushLimitIntoTableScan, + RuleImpl::EliminateLimits, + ], + ) + .batch( + "Combine Operators".to_string(), HepBatchStrategy::fix_point_topdown(10), vec![ RuleImpl::CollapseProject, @@ -138,6 +145,12 @@ pub enum DatabaseError { ), #[error("Internal error: {0}")] InternalError(String), + #[error("optimizer error: {0}")] + OptimizerError( + #[source] + #[from] + OptimizerError + ) } #[cfg(test)] @@ -155,12 +168,12 @@ mod test { ColumnCatalog::new( "c1".to_string(), false, - ColumnDesc::new(LogicalType::Integer, true) + ColumnDesc::new(LogicalType::Integer, true, false) ), ColumnCatalog::new( "c2".to_string(), false, - ColumnDesc::new(LogicalType::Boolean, false) + ColumnDesc::new(LogicalType::Boolean, false, false) ), ]; @@ -182,14 +195,19 @@ mod test { async fn test_crud_sql() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); let kipsql = Database::with_kipdb(temp_dir.path()).await?; - let _ = kipsql.run("create table t1 (a int primary key, b int, k int)").await?; + + let _ = kipsql.run("create table t1 (a int primary key, b int unique null, k int, z varchar unique null)").await?; let _ = kipsql.run("create table t2 (c int primary key, d int unsigned null, e datetime)").await?; - let _ = kipsql.run("insert into t1 (a, b, k) values (-99, 1, 1), (-1, 2, 2), (5, 2, 2)").await?; + let _ = kipsql.run("insert into t1 (a, b, k, z) values (-99, 1, 1, 'k'), (-1, 2, 2, 'i'), (5, 3, 2, 'p')").await?; let _ = kipsql.run("insert into t2 (d, c, e) values (2, 1, '2021-05-20 21:00:00'), (3, 4, '2023-09-10 00:00:00')").await?; let _ = kipsql.run("create table t3 (a int primary key, b decimal(4,2))").await?; let _ = kipsql.run("insert into t3 (a, b) values (1, 1111), (2, 2.01), (3, 3.00)").await?; let _ = kipsql.run("insert into t3 (a, b) values (4, 4444), (5, 5222), (6, 1.00)").await?; + println!("show tables:"); + let tuples_show_tables = kipsql.run("show tables").await?; + println!("{}", create_table(&tuples_show_tables)); + println!("full t1:"); let tuples_full_fields_t1 = kipsql.run("select * from t1").await?; println!("{}", create_table(&tuples_full_fields_t1)); @@ -199,7 +217,7 @@ mod test { println!("{}", create_table(&tuples_full_fields_t2)); println!("projection_and_filter:"); - let tuples_projection_and_filter = kipsql.run("select a from t1 where a <= 1").await?; + let tuples_projection_and_filter = kipsql.run("select a from t1 where b > 1").await?; println!("{}", create_table(&tuples_projection_and_filter)); println!("projection_and_sort:"); @@ -281,19 +299,21 @@ mod test { println!("{}", create_table(&tuples_distinct_t1)); println!("update t1 with filter:"); - let _ = kipsql.run("update t1 set b = 0 where b > 1").await?; + let _ = kipsql.run("update t1 set b = 0 where b = 1").await?; println!("after t1:"); let update_after_full_t1 = kipsql.run("select * from t1").await?; println!("{}", create_table(&update_after_full_t1)); println!("insert overwrite t1:"); - let _ = kipsql.run("insert overwrite t1 (a, b, k) values (-1, 1, 1)").await?; + let _ = kipsql.run("insert overwrite t1 (a, b, k) values (-99, 1, 0)").await?; println!("after t1:"); let insert_overwrite_after_full_t1 = kipsql.run("select * from t1").await?; println!("{}", create_table(&insert_overwrite_after_full_t1)); + assert!(kipsql.run("insert overwrite t1 (a, b, k) values (-1, 1, 0)").await.is_err()); + println!("delete t1 with filter:"); - let _ = kipsql.run("delete from t1 where b > 1").await?; + let _ = kipsql.run("delete from t1 where b = 0").await?; println!("after t1:"); let delete_after_full_t1 = kipsql.run("select * from t1").await?; println!("{}", create_table(&delete_after_full_t1)); @@ -304,10 +324,6 @@ mod test { println!("drop t1:"); let _ = kipsql.run("drop table t1").await?; - println!("show tables:"); - let tuples_show_tables = kipsql.run("show tables").await?; - println!("{}", create_table(&tuples_show_tables)); - println!("decimal:"); let tuples_decimal = kipsql.run("select * from t3").await?; println!("{}", create_table(&tuples_decimal)); diff --git a/src/execution/executor/dml/delete.rs b/src/execution/executor/dml/delete.rs index 731f1c18..e6ef8cf8 100644 --- a/src/execution/executor/dml/delete.rs +++ b/src/execution/executor/dml/delete.rs @@ -1,9 +1,11 @@ use futures_async_stream::try_stream; +use itertools::Itertools; use crate::catalog::TableName; use crate::execution::executor::{BoxedExecutor, Executor}; use crate::execution::ExecutorError; use crate::planner::operator::delete::DeleteOperator; -use crate::storage::{Storage, Table}; +use crate::storage::{Storage, Transaction}; +use crate::types::index::Index; use crate::types::tuple::Tuple; pub struct Delete { @@ -31,16 +33,44 @@ impl Delete { pub async fn _execute(self, storage: S) { let Delete { table_name, input } = self; - if let Some(mut table) = storage.table(&table_name).await { + if let Some(mut transaction) = storage.transaction(&table_name).await { + let table_catalog = storage.table(&table_name).await.unwrap(); + + let vec = table_catalog + .all_columns() + .into_iter() + .enumerate() + .filter_map(|(i, col)| col.desc.is_unique + .then(|| col.id.and_then(|col_id| { + table_catalog.get_unique_index(&col_id) + .map(|index_meta| (i, index_meta)) + })) + .flatten()) + .collect_vec(); + + #[for_await] for tuple in input { let tuple: Tuple = tuple?; + for (i, index_meta) in vec.iter() { + let value = &tuple.values[*i]; + + if !value.is_null() { + let index = Index { + id: index_meta.id, + column_values: vec![value.clone()], + }; + + transaction.del_index(&index)?; + } + } + if let Some(tuple_id) = tuple.id { - table.delete(tuple_id)?; + transaction.delete(tuple_id)?; } } - table.commit().await?; + transaction.commit().await?; } } } \ No newline at end of file diff --git a/src/execution/executor/dml/insert.rs b/src/execution/executor/dml/insert.rs index b31c7fe4..ae6a876b 100644 --- a/src/execution/executor/dml/insert.rs +++ b/src/execution/executor/dml/insert.rs @@ -1,15 +1,14 @@ use std::collections::HashMap; use std::sync::Arc; use futures_async_stream::try_stream; -use itertools::Itertools; use crate::catalog::TableName; use crate::execution::executor::{BoxedExecutor, Executor}; use crate::execution::ExecutorError; use crate::planner::operator::insert::InsertOperator; -use crate::storage::{Storage, Table}; -use crate::types::ColumnId; +use crate::storage::{Storage, Transaction}; +use crate::types::index::Index; use crate::types::tuple::Tuple; -use crate::types::value::{DataValue, ValueRef}; +use crate::types::value::DataValue; pub struct Insert { table_name: TableName, @@ -38,37 +37,48 @@ impl Insert { pub async fn _execute(self, storage: S) { let Insert { table_name, input, is_overwrite } = self; let mut primary_key_index = None; + let mut unique_values = HashMap::new(); - if let (Some(table_catalog), Some(mut table)) = - (storage.table_catalog(&table_name).await, storage.table(&table_name).await) + if let (Some(table_catalog), Some(mut transaction)) = + (storage.table(&table_name).await, storage.transaction(&table_name).await) { #[for_await] for tuple in input { let Tuple { columns, values, .. } = tuple?; - let primary_idx = primary_key_index.get_or_insert_with(|| { + let mut tuple_map = HashMap::new(); + for (i, value) in values.into_iter().enumerate() { + let col = &columns[i]; + let cast_val = DataValue::clone(&value).cast(&col.datatype())?; + + if let Some(col_id) = col.id { + tuple_map.insert(col_id, Arc::new(cast_val)); + } + } + let primary_col_id = primary_key_index.get_or_insert_with(|| { columns.iter() - .find_position(|col| col.desc.is_primary) - .map(|(i, _)| i) + .find(|col| col.desc.is_primary) + .map(|col| col.id.unwrap()) .unwrap() }); - let id = Some(values[*primary_idx].clone()); - let mut tuple_map: HashMap = values - .into_iter() - .enumerate() - .map(|(i, value)| (columns[i].id, value)) - .collect(); let all_columns = table_catalog.all_columns_with_id(); - + let tuple_id = tuple_map.get(primary_col_id) + .cloned() + .unwrap(); let mut tuple = Tuple { - id, + id: Some(tuple_id.clone()), columns: Vec::with_capacity(all_columns.len()), values: Vec::with_capacity(all_columns.len()), }; - for (col_id, col) in all_columns { let value = tuple_map.remove(col_id) .unwrap_or_else(|| Arc::new(DataValue::none(col.datatype()))); + if col.desc.is_unique && !value.is_null() { + unique_values + .entry(col.id) + .or_insert_with(|| vec![]) + .push((tuple_id.clone(), value.clone())) + } if value.is_null() && !col.nullable { return Err(ExecutorError::InternalError(format!("Non-null fields do not allow null values to be passed in: {:?}", col))); } @@ -77,9 +87,23 @@ impl Insert { tuple.values.push(value) } - table.append(tuple, is_overwrite)?; + transaction.append(tuple, is_overwrite)?; } - table.commit().await?; + // Unique Index + for (col_id, values) in unique_values { + if let Some(index_meta) = table_catalog.get_unique_index(&col_id.unwrap()) { + for (tuple_id, value) in values { + let index = Index { + id: index_meta.id, + column_values: vec![value], + }; + + transaction.add_index(index, vec![tuple_id], true)?; + } + } + } + + transaction.commit().await?; } } } \ No newline at end of file diff --git a/src/execution/executor/dml/update.rs b/src/execution/executor/dml/update.rs index d6618693..0c1320a2 100644 --- a/src/execution/executor/dml/update.rs +++ b/src/execution/executor/dml/update.rs @@ -4,7 +4,8 @@ use crate::catalog::TableName; use crate::execution::executor::{BoxedExecutor, Executor}; use crate::execution::ExecutorError; use crate::planner::operator::update::UpdateOperator; -use crate::storage::{Storage, Table}; +use crate::storage::{Storage, Transaction}; +use crate::types::index::Index; use crate::types::tuple::Tuple; pub struct Update { @@ -34,7 +35,8 @@ impl Update { pub async fn _execute(self, storage: S) { let Update { table_name, input, values } = self; - if let Some(mut table) = storage.table(&table_name).await { + if let Some(mut transaction) = storage.transaction(&table_name).await { + let table_catalog = storage.table(&table_name).await.unwrap(); let mut value_map = HashMap::new(); // only once @@ -47,24 +49,40 @@ impl Update { } #[for_await] for tuple in input { - let mut tuple = tuple?; + let mut tuple: Tuple = tuple?; let mut is_overwrite = true; for (i, column) in tuple.columns.iter().enumerate() { if let Some(value) = value_map.get(&column.id) { if column.desc.is_primary { - if let Some(old_key) = tuple.id.replace(value.clone()) { - table.delete(old_key)?; - is_overwrite = false; + let old_key = tuple.id.replace(value.clone()).unwrap(); + + transaction.delete(old_key)?; + is_overwrite = false; + } + if column.desc.is_unique && value != &tuple.values[i] { + if let Some(index_meta) = table_catalog.get_unique_index(&column.id.unwrap()) { + let mut index = Index { + id: index_meta.id, + column_values: vec![tuple.values[i].clone()], + }; + transaction.del_index(&index)?; + + if !value.is_null() { + index.column_values[0] = value.clone(); + transaction.add_index(index, vec![tuple.id.clone().unwrap()], true)?; + } } } + tuple.values[i] = value.clone(); } } - table.append(tuple, is_overwrite)?; + transaction.append(tuple, is_overwrite)?; } - table.commit().await?; + + transaction.commit().await?; } } } \ No newline at end of file diff --git a/src/execution/executor/dql/aggregate/hash_agg.rs b/src/execution/executor/dql/aggregate/hash_agg.rs index 2db90a27..bc42fa3c 100644 --- a/src/execution/executor/dql/aggregate/hash_agg.rs +++ b/src/execution/executor/dql/aggregate/hash_agg.rs @@ -122,7 +122,7 @@ mod test { #[tokio::test] async fn test_hash_agg() -> Result<(), ExecutorError> { let mem_storage = MemStorage::new(); - let desc = ColumnDesc::new(LogicalType::Integer, false); + let desc = ColumnDesc::new(LogicalType::Integer, false, false); let t1_columns = vec![ Arc::new(ColumnCatalog::new("c1".to_string(), true, desc.clone())), diff --git a/src/execution/executor/dql/index_scan.rs b/src/execution/executor/dql/index_scan.rs new file mode 100644 index 00000000..cf826cd4 --- /dev/null +++ b/src/execution/executor/dql/index_scan.rs @@ -0,0 +1,46 @@ +use futures_async_stream::try_stream; +use crate::execution::executor::{BoxedExecutor, Executor}; +use crate::execution::ExecutorError; +use crate::planner::operator::scan::ScanOperator; +use crate::storage::{Iter, Storage, Transaction}; +use crate::types::errors::TypeError; +use crate::types::tuple::Tuple; + +pub(crate) struct IndexScan { + op: ScanOperator +} + +impl From for IndexScan { + fn from(op: ScanOperator) -> Self { + IndexScan { + op + } + } +} + +impl Executor for IndexScan { + fn execute(self, storage: &S) -> BoxedExecutor { + self._execute(storage.clone()) + } +} + +impl IndexScan { + #[try_stream(boxed, ok = Tuple, error = ExecutorError)] + pub async fn _execute(self, storage: S) { + let ScanOperator { table_name, columns, limit, index_by, .. } = self.op; + let (index_meta, binaries) = index_by.ok_or(TypeError::InvalidType)?; + + if let Some(transaction) = storage.transaction(&table_name).await { + let mut iter = transaction.read_by_index( + limit, + columns, + index_meta, + binaries + )?; + + while let Some(tuple) = iter.next_tuple()? { + yield tuple; + } + } + } +} \ No newline at end of file diff --git a/src/execution/executor/dql/join/hash_join.rs b/src/execution/executor/dql/join/hash_join.rs index 917b7623..e5b401c6 100644 --- a/src/execution/executor/dql/join/hash_join.rs +++ b/src/execution/executor/dql/join/hash_join.rs @@ -231,7 +231,7 @@ mod test { use crate::types::value::DataValue; fn build_join_values(_s: &S) -> (Vec<(ScalarExpression, ScalarExpression)>, BoxedExecutor, BoxedExecutor) { - let desc = ColumnDesc::new(LogicalType::Integer, false); + let desc = ColumnDesc::new(LogicalType::Integer, false, false); let t1_columns = vec![ Arc::new(ColumnCatalog::new("c1".to_string(), true, desc.clone())), diff --git a/src/execution/executor/dql/mod.rs b/src/execution/executor/dql/mod.rs index 1c7378e6..b42e5b74 100644 --- a/src/execution/executor/dql/mod.rs +++ b/src/execution/executor/dql/mod.rs @@ -7,6 +7,7 @@ pub(crate) mod limit; pub(crate) mod join; pub(crate) mod dummy; pub(crate) mod aggregate; +pub(crate) mod index_scan; #[cfg(test)] pub(crate) mod test { diff --git a/src/execution/executor/dql/seq_scan.rs b/src/execution/executor/dql/seq_scan.rs index fc3e5bd8..5df79669 100644 --- a/src/execution/executor/dql/seq_scan.rs +++ b/src/execution/executor/dql/seq_scan.rs @@ -2,7 +2,7 @@ use futures_async_stream::try_stream; use crate::execution::executor::{BoxedExecutor, Executor}; use crate::execution::ExecutorError; use crate::planner::operator::scan::ScanOperator; -use crate::storage::{Table, Transaction, Storage}; +use crate::storage::{Transaction, Iter, Storage}; use crate::types::tuple::Tuple; pub(crate) struct SeqScan { @@ -28,13 +28,13 @@ impl SeqScan { pub async fn _execute(self, storage: S) { let ScanOperator { table_name, columns, limit, .. } = self.op; - if let Some(table) = storage.table(&table_name).await { - let mut transaction = table.read( + if let Some(transaction) = storage.transaction(&table_name).await { + let mut iter = transaction.read( limit, columns )?; - while let Some(tuple) = transaction.next_tuple()? { + while let Some(tuple) = iter.next_tuple()? { yield tuple; } } diff --git a/src/execution/executor/mod.rs b/src/execution/executor/mod.rs index bdc816a9..ab5eab18 100644 --- a/src/execution/executor/mod.rs +++ b/src/execution/executor/mod.rs @@ -15,6 +15,7 @@ use crate::execution::executor::dql::aggregate::hash_agg::HashAggExecutor; use crate::execution::executor::dql::aggregate::simple_agg::SimpleAggExecutor; use crate::execution::executor::dql::dummy::Dummy; use crate::execution::executor::dql::filter::Filter; +use crate::execution::executor::dql::index_scan::IndexScan; use crate::execution::executor::dql::join::hash_join::HashJoin; use crate::execution::executor::dql::limit::Limit; use crate::execution::executor::dql::projection::Projection; @@ -65,7 +66,11 @@ pub fn build(plan: LogicalPlan, storage: &S) -> BoxedExecutor { Projection::from((op, input)).execute(storage) } Operator::Scan(op) => { - SeqScan::from(op).execute(storage) + if op.index_by.is_some() { + IndexScan::from(op).execute(storage) + } else { + SeqScan::from(op).execute(storage) + } } Operator::Sort(op) => { let input = build(childrens.remove(0), storage); diff --git a/src/execution/executor/show/show_table.rs b/src/execution/executor/show/show_table.rs index cf424a01..8641cc0f 100644 --- a/src/execution/executor/show/show_table.rs +++ b/src/execution/executor/show/show_table.rs @@ -10,13 +10,13 @@ use std::sync::Arc; use crate::types::value::{DataValue, ValueRef}; pub struct ShowTables { - op: ShowTablesOperator, + _op: ShowTablesOperator, } impl From for ShowTables { fn from(op: ShowTablesOperator) -> Self { ShowTables { - op + _op: op } } } @@ -30,23 +30,21 @@ impl Executor for ShowTables { impl ShowTables { #[try_stream(boxed, ok = Tuple, error = ExecutorError)] pub async fn _execute(self, storage: S) { - if let Some(tables) = storage.show_tables().await { - for (table,column_count) in tables { - let columns: Vec = vec![ - Arc::new(ColumnCatalog::new_dummy("TABLES".to_string())), - Arc::new(ColumnCatalog::new_dummy("COLUMN_COUNT".to_string())), - ]; - let values: Vec = vec![ - Arc::new(DataValue::Utf8(Some(table))), - Arc::new(DataValue::UInt32(Some(column_count as u32))), - ]; + let tables = storage.show_tables().await?; - yield Tuple { - id: None, - columns, - values, - }; - } + for table in tables { + let columns: Vec = vec![ + Arc::new(ColumnCatalog::new_dummy("TABLES".to_string())), + ]; + let values: Vec = vec![ + Arc::new(DataValue::Utf8(Some(table))), + ]; + + yield Tuple { + id: None, + columns, + values, + }; } } } \ No newline at end of file diff --git a/src/expression/evaluator.rs b/src/expression/evaluator.rs index 0c004dd5..51aa2346 100644 --- a/src/expression/evaluator.rs +++ b/src/expression/evaluator.rs @@ -33,8 +33,10 @@ impl ScalarExpression { Ok(Arc::new(binary_op(&left, &right, op)?)) } ScalarExpression::IsNull{ expr } => { - Ok(Arc::new(DataValue::Boolean(Some(expr.nullable())))) - } + let value = expr.eval_column(tuple)?; + + Ok(Arc::new(DataValue::Boolean(Some(value.is_null())))) + }, ScalarExpression::Unary{ expr, op, .. } => { let value = expr.eval_column(tuple)?; diff --git a/src/expression/mod.rs b/src/expression/mod.rs index 30678993..35f55e5f 100644 --- a/src/expression/mod.rs +++ b/src/expression/mod.rs @@ -16,6 +16,7 @@ use crate::types::tuple::Tuple; pub mod agg; mod evaluator; pub mod value_compute; +pub mod simplify; /// ScalarExpression represnet all scalar expression in SQL. /// SELECT a+1, b FROM t1. @@ -36,7 +37,6 @@ pub enum ScalarExpression { TypeCast { expr: Box, ty: LogicalType, - is_try: bool, }, IsNull { expr: Box, @@ -173,14 +173,14 @@ impl ScalarExpression { Arc::new(ColumnCatalog::new( format!("{}", value), true, - ColumnDesc::new(value.logical_type(), false) + ColumnDesc::new(value.logical_type(), false, false) )) } ScalarExpression::Alias { expr, alias } => { Arc::new(ColumnCatalog::new( alias.to_string(), true, - ColumnDesc::new(expr.return_type(), false) + ColumnDesc::new(expr.return_type(), false, false) )) } ScalarExpression::AggCall { kind, args, ty, distinct } => { @@ -204,7 +204,7 @@ impl ScalarExpression { Arc::new(ColumnCatalog::new( column_name, true, - ColumnDesc::new(ty.clone(), false) + ColumnDesc::new(ty.clone(), false, false) )) } ScalarExpression::InputRef { index, .. } => { @@ -226,7 +226,7 @@ impl ScalarExpression { Arc::new(ColumnCatalog::new( column_name, true, - ColumnDesc::new(ty.clone(), false) + ColumnDesc::new(ty.clone(), false, false) )) } _ => unreachable!() @@ -258,8 +258,10 @@ pub enum BinaryOperator { Minus, Multiply, Divide, + Modulo, StringConcat, + Gt, Lt, GtEq, @@ -267,6 +269,7 @@ pub enum BinaryOperator { Spaceship, Eq, NotEq, + And, Or, Xor, diff --git a/src/expression/simplify.rs b/src/expression/simplify.rs new file mode 100644 index 00000000..7bf2beab --- /dev/null +++ b/src/expression/simplify.rs @@ -0,0 +1,935 @@ +use std::cmp::Ordering; +use std::collections::{Bound, HashSet}; +use std::mem; +use std::sync::Arc; +use ahash::RandomState; +use itertools::Itertools; +use crate::catalog::ColumnRef; +use crate::expression::{BinaryOperator, ScalarExpression, UnaryOperator}; +use crate::expression::value_compute::{binary_op, unary_op}; +use crate::types::{ColumnId, LogicalType}; +use crate::types::errors::TypeError; +use crate::types::value::{DataValue, ValueRef}; + +#[derive(Debug, PartialEq, Clone)] +pub enum ConstantBinary { + Scope { + min: Bound, + max: Bound + }, + Eq(ValueRef), + NotEq(ValueRef), + + // ConstantBinary in And can only be Scope\Eq\NotEq + And(Vec), + // ConstantBinary in Or can only be Scope\Eq\NotEq\And + Or(Vec) +} + +impl ConstantBinary { + #[allow(dead_code)] + fn is_null(&self) -> Result { + match self { + ConstantBinary::Scope { min, max } => { + let op = |bound: &Bound| { + if let Bound::Included(val) | Bound::Excluded(val) = bound { + val.is_null() + } else { + false + } + }; + if op(min) || op(max) { + return Ok(true); + } + + Ok(matches!((min, max), (Bound::Unbounded, Bound::Unbounded))) + }, + ConstantBinary::Eq(val) | ConstantBinary::NotEq(val) => Ok(val.is_null()), + _ => Err(TypeError::InvalidType), + } + } + + pub fn rearrange(self) -> Result, TypeError> { + match self { + ConstantBinary::Or(binaries) => { + let mut condition_binaries = Vec::new(); + + for binary in binaries { + match binary { + ConstantBinary::Or(_) => return Err(TypeError::InvalidType), + ConstantBinary::And(mut and_binaries) => { + condition_binaries.append(&mut and_binaries); + } + ConstantBinary::Scope { min: Bound::Unbounded, max: Bound::Unbounded } => (), + source => condition_binaries.push(source), + } + } + // Sort + condition_binaries.sort_by(|a, b| { + let op = |binary: &ConstantBinary| { + match binary { + ConstantBinary::Scope { min, .. } => min.clone(), + ConstantBinary::Eq(val) => Bound::Included(val.clone()), + ConstantBinary::NotEq(val) => Bound::Excluded(val.clone()), + _ => unreachable!() + } + }; + + Self::bound_compared(&op(a), &op(b), true) + .unwrap_or(Ordering::Equal) + }); + + let mut merged_binaries: Vec = Vec::new(); + + for condition in condition_binaries { + let op = |binary: &ConstantBinary| { + match binary { + ConstantBinary::Scope { min, max } => (min.clone(), max.clone()), + ConstantBinary::Eq(val) => (Bound::Unbounded, Bound::Included(val.clone())), + ConstantBinary::NotEq(val) => (Bound::Unbounded, Bound::Excluded(val.clone())), + _ => unreachable!() + } + }; + let mut is_push = merged_binaries.is_empty(); + + for binary in merged_binaries.iter_mut().rev() { + if let ConstantBinary::Scope { max, .. } = binary { + let (condition_min, condition_max) = op(&condition); + let is_lt_min = Self::bound_compared(max, &condition_min, false) + .unwrap_or(Ordering::Equal) + .is_lt(); + let is_lt_max = Self::bound_compared(max, &condition_max, false) + .unwrap_or(Ordering::Equal) + .is_lt(); + + if !is_lt_min && is_lt_max { + let _ = mem::replace(max, condition_max); + } else if !matches!(condition, ConstantBinary::Scope {..}) { + is_push = is_lt_max; + } else if is_lt_min && is_lt_max { + is_push = true + } + + break + } + } + + if is_push { + merged_binaries.push(condition); + } + } + + Ok(merged_binaries) + }, + ConstantBinary::And(binaries) => Ok(binaries), + source => Ok(vec![source]), + } + } + + pub fn scope_aggregation(&mut self) -> Result<(), TypeError> { + match self { + ConstantBinary::Or(binaries) => { + for binary in binaries { + binary.scope_aggregation()? + } + } + binary => binary._scope_aggregation()? + } + + Ok(()) + } + + fn bound_compared(left_bound: &Bound, right_bound: &Bound, is_min: bool) -> Option { + let op = |is_min, order: Ordering| { + if is_min { + order + } else { + order.reverse() + } + }; + + match (left_bound, right_bound) { + (Bound::Unbounded, Bound::Unbounded) => Some(Ordering::Equal), + (Bound::Unbounded, _) => Some(op(is_min, Ordering::Less)), + (_, Bound::Unbounded) => Some(op(is_min, Ordering::Greater)), + (Bound::Included(left), Bound::Included(right)) => left.partial_cmp(right), + (Bound::Included(left), Bound::Excluded(right)) => { + left.partial_cmp(right) + .map(|order| order.then(op(is_min, Ordering::Less))) + }, + (Bound::Excluded(left), Bound::Excluded(right)) => left.partial_cmp(right), + (Bound::Excluded(left), Bound::Included(right)) => { + left.partial_cmp(right) + .map(|order| order.then(op(is_min, Ordering::Greater))) + }, + } + } + + // Tips: It only makes sense if the condition is and aggregation + fn _scope_aggregation(&mut self) -> Result<(), TypeError> { + if let ConstantBinary::And(binaries) = self { + let mut scope_min = Bound::Unbounded; + let mut scope_max = Bound::Unbounded; + let mut eq_set = HashSet::with_hasher(RandomState::new()); + + let sort_op = |binary: &&ConstantBinary| { + match binary { + ConstantBinary::Scope { .. } => 3, + ConstantBinary::NotEq(_) => 2, + ConstantBinary::Eq(_) => 1, + ConstantBinary::And(_) | ConstantBinary::Or(_) => 0 + } + }; + + // Aggregate various ranges to get the minimum range + for binary in binaries.iter().sorted_by_key(sort_op) { + match binary { + ConstantBinary::Scope { min, max } => { + // Skip if eq or noteq exists + if !eq_set.is_empty() { continue } + + if let Some(order) = Self::bound_compared(&scope_min, &min, true) { + if order.is_lt() { + scope_min = min.clone(); + } + } + + if let Some(order) = Self::bound_compared(&scope_max, &max, false) { + if order.is_gt() { + scope_max = max.clone(); + } + } + } + ConstantBinary::Eq(val) => { + let _ = eq_set.insert(val.clone()); + }, + ConstantBinary::NotEq(val) => { + let _ = eq_set.remove(val); + }, + ConstantBinary::Or(_) | ConstantBinary::And(_) => return Err(TypeError::InvalidType) + } + } + + let eq_option = eq_set.into_iter() + .sorted_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal)) + .next() + .map(|val| ConstantBinary::Eq(val)); + + if let Some(eq) = eq_option { + let _ = mem::replace(self, eq); + } else if !matches!((&scope_min, &scope_max), (Bound::Unbounded, Bound::Unbounded)) { + let scope_binary = ConstantBinary::Scope { + min: scope_min, + max: scope_max, + }; + + let _ = mem::replace(self, scope_binary); + } else { + let _ = mem::replace(self, ConstantBinary::And(vec![])); + } + } + + Ok(()) + } +} + +enum Replace { + Binary(ReplaceBinary), + Unary(ReplaceUnary), +} + +struct ReplaceBinary { + column_expr: ScalarExpression, + val_expr: ScalarExpression, + op: BinaryOperator, + ty: LogicalType, + is_column_left: bool +} + +struct ReplaceUnary { + child_expr: ScalarExpression, + op: UnaryOperator, + ty: LogicalType, +} + +impl ScalarExpression { + pub fn exist_column(&self, col_id: &ColumnId) -> bool { + match self { + ScalarExpression::ColumnRef(col) => col.id == Some(*col_id), + ScalarExpression::Alias { expr, .. } => expr.exist_column(col_id), + ScalarExpression::TypeCast { expr, .. } => expr.exist_column(col_id), + ScalarExpression::IsNull { expr } => expr.exist_column(col_id), + ScalarExpression::Unary { expr, .. } => expr.exist_column(col_id), + ScalarExpression::Binary { left_expr, right_expr, .. } => { + left_expr.exist_column(col_id) || right_expr.exist_column(col_id) + } + _ => false + } + } + + fn unpack_val(&self) -> Option { + match self { + ScalarExpression::Constant(val) => Some(val.clone()), + ScalarExpression::Alias { expr, .. } => expr.unpack_val(), + ScalarExpression::TypeCast { expr, ty, .. } => { + expr.unpack_val() + .and_then(|val| DataValue::clone(&val) + .cast(ty).ok() + .map(Arc::new)) + } + ScalarExpression::IsNull { expr } => { + let is_null = expr.unpack_val().map(|val| val.is_null()); + + Some(Arc::new(DataValue::Boolean(is_null))) + }, + ScalarExpression::Unary { expr, op, .. } => { + let val = expr.unpack_val()?; + + unary_op(&val, op).ok() + .map(Arc::new) + + } + ScalarExpression::Binary { left_expr, right_expr, op, .. } => { + let left = left_expr.unpack_val()?; + let right = right_expr.unpack_val()?; + + binary_op(&left, &right, op).ok() + .map(Arc::new) + } + _ => None + } + } + + fn unpack_col(&self) -> Option { + match self { + ScalarExpression::ColumnRef(col) => Some(col.clone()), + ScalarExpression::Alias { expr, .. } => expr.unpack_col(), + ScalarExpression::Unary { expr, .. } => expr.unpack_col(), + _ => None + } + } + + pub fn simplify(&mut self) -> Result<(), TypeError> { + self._simplify(&mut None) + } + + // Tips: Indirect expressions like `ScalarExpression::Alias` will be lost + fn _simplify(&mut self, fix_option: &mut Option) -> Result<(), TypeError> { + match self { + ScalarExpression::Binary { left_expr, right_expr, op, ty } => { + Self::fix_expr(fix_option, left_expr, right_expr, op)?; + + // `(c1 - 1) and (c1 + 2)` cannot fix! + Self::fix_expr(fix_option, right_expr, left_expr, op)?; + + if matches!(op, BinaryOperator::Plus | BinaryOperator::Divide + | BinaryOperator::Minus | BinaryOperator::Multiply) + { + match (left_expr.unpack_col(), right_expr.unpack_col()) { + (Some(_), Some(_)) => (), + (Some(col), None) => { + fix_option.replace(Replace::Binary(ReplaceBinary{ + column_expr: ScalarExpression::ColumnRef(col), + val_expr: right_expr.as_ref().clone(), + op: *op, + ty: *ty, + is_column_left: true, + })); + } + (None, Some(col)) => { + fix_option.replace(Replace::Binary(ReplaceBinary{ + column_expr: ScalarExpression::ColumnRef(col), + val_expr: left_expr.as_ref().clone(), + op: *op, + ty: *ty, + is_column_left: false, + })); + } + _ => () + } + } + } + ScalarExpression::Alias { expr, .. } => expr._simplify(fix_option)?, + ScalarExpression::TypeCast { expr, .. } => { + if let Some(val) = expr.unpack_val() { + let _ = mem::replace(self, ScalarExpression::Constant(val)); + } + }, + ScalarExpression::IsNull { expr, .. } => { + if let Some(val) = expr.unpack_val() { + let _ = mem::replace(self, ScalarExpression::Constant( + Arc::new(DataValue::Boolean(Some(val.is_null()))) + )); + } + }, + ScalarExpression::Unary { expr, op, ty } => { + if let Some(val) = expr.unpack_val() { + let new_expr = ScalarExpression::Constant( + Arc::new(unary_op(&val, op)?) + ); + let _ = mem::replace(self, new_expr); + } else { + let _ = fix_option.replace(Replace::Unary( + ReplaceUnary { + child_expr: expr.as_ref().clone(), + op: *op, + ty: *ty, + } + )); + } + }, + _ => () + } + + Ok(()) + } + + fn fix_expr( + fix_option: &mut Option, + left_expr: &mut Box, + right_expr: &mut Box, + op: &mut BinaryOperator, + ) -> Result<(), TypeError> { + left_expr._simplify(fix_option)?; + + if let Some(replace) = fix_option.take() { + match replace { + Replace::Binary(binary) => Self::fix_binary(binary, left_expr, right_expr, op), + Replace::Unary(unary) => { + Self::fix_unary(unary, left_expr, right_expr, op); + Self::fix_expr(fix_option, left_expr, right_expr, op)?; + }, + } + } + Ok(()) + } + + fn fix_unary( + replace_unary: ReplaceUnary, + col_expr: &mut Box, + val_expr: &mut Box, + op: &mut BinaryOperator + ) { + let ReplaceUnary { child_expr, op: fix_op, ty: fix_ty } = replace_unary; + let _ = mem::replace(col_expr, Box::new(child_expr)); + let _ = mem::replace(val_expr, Box::new(ScalarExpression::Unary { + op: fix_op, + expr: val_expr.clone(), + ty: fix_ty, + })); + let _ = mem::replace(op, match fix_op { + UnaryOperator::Plus => *op, + UnaryOperator::Minus => { + match *op { + BinaryOperator::Plus => BinaryOperator::Minus, + BinaryOperator::Minus => BinaryOperator::Plus, + BinaryOperator::Multiply => BinaryOperator::Divide, + BinaryOperator::Divide => BinaryOperator::Multiply, + BinaryOperator::Gt => BinaryOperator::Lt, + BinaryOperator::Lt => BinaryOperator::Gt, + BinaryOperator::GtEq => BinaryOperator::LtEq, + BinaryOperator::LtEq => BinaryOperator::GtEq, + source_op => source_op + } + } + UnaryOperator::Not => { + match *op { + BinaryOperator::Gt => BinaryOperator::Lt, + BinaryOperator::Lt => BinaryOperator::Gt, + BinaryOperator::GtEq => BinaryOperator::LtEq, + BinaryOperator::LtEq => BinaryOperator::GtEq, + source_op => source_op + } + } + }); + } + + fn fix_binary( + replace_binary: ReplaceBinary, + left_expr: &mut Box, + right_expr: &mut Box, + op: &mut BinaryOperator + ) { + let ReplaceBinary { column_expr, val_expr, op: fix_op, ty: fix_ty, is_column_left } = replace_binary; + let op_flip = |op: BinaryOperator| { + match op { + BinaryOperator::Plus => BinaryOperator::Minus, + BinaryOperator::Minus => BinaryOperator::Plus, + BinaryOperator::Multiply => BinaryOperator::Divide, + BinaryOperator::Divide => BinaryOperator::Multiply, + _ => unreachable!() + } + }; + let comparison_flip = |op: BinaryOperator| { + match op { + BinaryOperator::Gt => BinaryOperator::Lt, + BinaryOperator::GtEq => BinaryOperator::LtEq, + BinaryOperator::Lt => BinaryOperator::Gt, + BinaryOperator::LtEq => BinaryOperator::GtEq, + source_op => source_op + } + }; + let (fixed_op, fixed_left_expr, fixed_right_expr) = if is_column_left { + (op_flip(fix_op), right_expr.clone(), Box::new(val_expr)) + } else { + if matches!(fix_op, BinaryOperator::Minus | BinaryOperator::Multiply) { + let _ = mem::replace(op, comparison_flip(*op)); + } + (fix_op, Box::new(val_expr), right_expr.clone()) + }; + + let _ = mem::replace(left_expr, Box::new(column_expr)); + let _ = mem::replace(right_expr, Box::new(ScalarExpression::Binary { + op: fixed_op, + left_expr: fixed_left_expr, + right_expr: fixed_right_expr, + ty: fix_ty, + })); + } + + /// The definition of Or is not the Or in the Where condition. + /// The And and Or of ConstantBinary are concerned with the data range that needs to be aggregated. + /// - `ConstantBinary::And`: Aggregate the minimum range of all conditions in and + /// - `ConstantBinary::Or`: Rearrange and sort the range of each OR data + pub fn convert_binary(&self, col_id: &ColumnId) -> Result, TypeError> { + match self { + ScalarExpression::Binary { left_expr, right_expr, op, .. } => { + match (left_expr.convert_binary(col_id)?, right_expr.convert_binary(col_id)?) { + (Some(left_binary), Some(right_binary)) => { + match (left_binary, right_binary) { + (ConstantBinary::And(mut left), ConstantBinary::And(mut right)) + | (ConstantBinary::Or(mut left), ConstantBinary::Or(mut right)) => { + left.append(&mut right); + + Ok(Some(ConstantBinary::And(left))) + } + (ConstantBinary::And(mut left), ConstantBinary::Or(mut right)) => { + right.append(&mut left); + + Ok(Some(ConstantBinary::Or(right))) + } + (ConstantBinary::Or(mut left), ConstantBinary::And(mut right)) => { + left.append(&mut right); + + Ok(Some(ConstantBinary::Or(left))) + } + (ConstantBinary::And(mut binaries), binary) + | (binary, ConstantBinary::And(mut binaries)) => { + binaries.push(binary); + + Ok(Some(ConstantBinary::And(binaries))) + } + (ConstantBinary::Or(mut binaries), binary) + | (binary, ConstantBinary::Or(mut binaries)) => { + binaries.push(binary); + + Ok(Some(ConstantBinary::Or(binaries))) + } + (left, right) => { + match op { + BinaryOperator::And => { + Ok(Some(ConstantBinary::And(vec![left, right]))) + } + BinaryOperator::Or => { + Ok(Some(ConstantBinary::Or(vec![left, right]))) + } + BinaryOperator::Xor => todo!(), + _ => Ok(None) + } + } + } + }, + (None, None) => { + if let (Some(col), Some(val)) = + (left_expr.unpack_col(), right_expr.unpack_val()) + { + return Ok(Self::new_binary(col_id, *op, col, val, false)); + } + if let (Some(val), Some(col)) = + (left_expr.unpack_val(), right_expr.unpack_col()) + { + return Ok(Self::new_binary(col_id, *op, col, val, true)); + } + + return Ok(None); + } + (Some(binary), None) => Ok(Self::check_or(col_id, right_expr, op, binary)), + (None, Some(binary)) => Ok(Self::check_or(col_id, left_expr, op, binary)), + } + }, + ScalarExpression::Alias { expr, .. } => expr.convert_binary(col_id), + ScalarExpression::TypeCast { expr, .. } => expr.convert_binary(col_id), + ScalarExpression::IsNull { expr } => expr.convert_binary(col_id), + ScalarExpression::Unary { expr, .. } => expr.convert_binary(col_id), + _ => Ok(None), + } + } + + /// check if: c1 > c2 or c1 > 1 + /// this case it makes no sense to just extract c1 > 1 + fn check_or( + col_id: &ColumnId, + right_expr: &Box, + op: &BinaryOperator, + binary: ConstantBinary + ) -> Option { + if matches!(op, BinaryOperator::Or) && right_expr.exist_column(col_id) { + return None + } + + Some(binary) + } + + fn new_binary(col_id: &ColumnId, mut op: BinaryOperator, col: ColumnRef, val: ValueRef, is_flip: bool) -> Option { + if col.id.unwrap() != *col_id { + return None; + } + + if is_flip { + op = match op { + BinaryOperator::Gt => BinaryOperator::Lt, + BinaryOperator::Lt => BinaryOperator::Gt, + BinaryOperator::GtEq => BinaryOperator::LtEq, + BinaryOperator::LtEq => BinaryOperator::GtEq, + source_op => source_op + }; + } + + match op { + BinaryOperator::Gt => { + Some(ConstantBinary::Scope { + min: Bound::Excluded(val.clone()), + max: Bound::Unbounded + }) + } + BinaryOperator::Lt => { + Some(ConstantBinary::Scope { + min: Bound::Unbounded, + max: Bound::Excluded(val.clone()), + }) + } + BinaryOperator::GtEq => { + Some(ConstantBinary::Scope { + min: Bound::Included(val.clone()), + max: Bound::Unbounded + }) + } + BinaryOperator::LtEq => { + Some(ConstantBinary::Scope { + min: Bound::Unbounded, + max: Bound::Included(val.clone()), + }) + } + BinaryOperator::Eq | BinaryOperator::Spaceship => { + Some(ConstantBinary::Eq(val.clone())) + }, + BinaryOperator::NotEq => { + Some(ConstantBinary::NotEq(val.clone())) + }, + _ => None + } + } +} + +#[cfg(test)] +mod test { + use std::collections::Bound; + use std::sync::Arc; + use crate::catalog::{ColumnCatalog, ColumnDesc}; + use crate::expression::{BinaryOperator, ScalarExpression}; + use crate::expression::simplify::ConstantBinary; + use crate::types::errors::TypeError; + use crate::types::LogicalType; + use crate::types::value::DataValue; + + #[test] + fn test_convert_binary_simple() -> Result<(), TypeError> { + let col_1 = Arc::new(ColumnCatalog { + id: Some(0), + name: "c1".to_string(), + table_name: None, + nullable: false, + desc: ColumnDesc { + column_datatype: LogicalType::Integer, + is_primary: false, + is_unique: false, + }, + }); + let val_1 = Arc::new(DataValue::Int32(Some(1))); + + let binary_eq = ScalarExpression::Binary { + op: BinaryOperator::Eq, + left_expr: Box::new(ScalarExpression::Constant(val_1.clone())), + right_expr: Box::new(ScalarExpression::ColumnRef(col_1.clone())), + ty: LogicalType::Boolean, + }.convert_binary(&0)?.unwrap(); + + assert_eq!(binary_eq, ConstantBinary::Eq(val_1.clone())); + + let binary_not_eq = ScalarExpression::Binary { + op: BinaryOperator::NotEq, + left_expr: Box::new(ScalarExpression::Constant(val_1.clone())), + right_expr: Box::new(ScalarExpression::ColumnRef(col_1.clone())), + ty: LogicalType::Boolean, + }.convert_binary(&0)?.unwrap(); + + assert_eq!(binary_not_eq, ConstantBinary::NotEq(val_1.clone())); + + let binary_lt = ScalarExpression::Binary { + op: BinaryOperator::Lt, + left_expr: Box::new(ScalarExpression::ColumnRef(col_1.clone())), + right_expr: Box::new(ScalarExpression::Constant(val_1.clone())), + ty: LogicalType::Boolean, + }.convert_binary(&0)?.unwrap(); + + assert_eq!(binary_lt, ConstantBinary::Scope { + min: Bound::Unbounded, + max: Bound::Excluded(val_1.clone()) + }); + + let binary_lteq = ScalarExpression::Binary { + op: BinaryOperator::LtEq, + left_expr: Box::new(ScalarExpression::ColumnRef(col_1.clone())), + right_expr: Box::new(ScalarExpression::Constant(val_1.clone())), + ty: LogicalType::Boolean, + }.convert_binary(&0)?.unwrap(); + + assert_eq!(binary_lteq, ConstantBinary::Scope { + min: Bound::Unbounded, + max: Bound::Included(val_1.clone()) + }); + + let binary_gt = ScalarExpression::Binary { + op: BinaryOperator::Gt, + left_expr: Box::new(ScalarExpression::ColumnRef(col_1.clone())), + right_expr: Box::new(ScalarExpression::Constant(val_1.clone())), + ty: LogicalType::Boolean, + }.convert_binary(&0)?.unwrap(); + + assert_eq!(binary_gt, ConstantBinary::Scope { + min: Bound::Excluded(val_1.clone()), + max: Bound::Unbounded + }); + + let binary_gteq = ScalarExpression::Binary { + op: BinaryOperator::GtEq, + left_expr: Box::new(ScalarExpression::ColumnRef(col_1.clone())), + right_expr: Box::new(ScalarExpression::Constant(val_1.clone())), + ty: LogicalType::Boolean, + }.convert_binary(&0)?.unwrap(); + + assert_eq!(binary_gteq, ConstantBinary::Scope { + min: Bound::Included(val_1.clone()), + max: Bound::Unbounded + }); + + Ok(()) + } + + #[test] + fn test_scope_aggregation_eq_noteq() -> Result<(), TypeError> { + let val_0 = Arc::new(DataValue::Int32(Some(0))); + let val_1 = Arc::new(DataValue::Int32(Some(1))); + let val_2 = Arc::new(DataValue::Int32(Some(2))); + let val_3 = Arc::new(DataValue::Int32(Some(3))); + + let mut binary = ConstantBinary::And(vec![ + ConstantBinary::Eq(val_0.clone()), + ConstantBinary::NotEq(val_1.clone()), + ConstantBinary::Eq(val_2.clone()), + ConstantBinary::NotEq(val_3.clone()), + ]); + + binary.scope_aggregation()?; + + assert_eq!( + binary, + ConstantBinary::Eq(val_0) + ); + + Ok(()) + } + + #[test] + fn test_scope_aggregation_eq_noteq_cover() -> Result<(), TypeError> { + let val_0 = Arc::new(DataValue::Int32(Some(0))); + let val_1 = Arc::new(DataValue::Int32(Some(1))); + let val_2 = Arc::new(DataValue::Int32(Some(2))); + let val_3 = Arc::new(DataValue::Int32(Some(3))); + + let mut binary = ConstantBinary::And(vec![ + ConstantBinary::Eq(val_0.clone()), + ConstantBinary::NotEq(val_1.clone()), + ConstantBinary::Eq(val_2.clone()), + ConstantBinary::NotEq(val_3.clone()), + + ConstantBinary::NotEq(val_0.clone()), + ConstantBinary::NotEq(val_1.clone()), + ConstantBinary::NotEq(val_2.clone()), + ConstantBinary::NotEq(val_3.clone()), + ]); + + binary.scope_aggregation()?; + + assert_eq!( + binary, + ConstantBinary::And(vec![]) + ); + + Ok(()) + } + + #[test] + fn test_scope_aggregation_scope() -> Result<(), TypeError> { + let val_0 = Arc::new(DataValue::Int32(Some(0))); + let val_1 = Arc::new(DataValue::Int32(Some(1))); + let val_2 = Arc::new(DataValue::Int32(Some(2))); + let val_3 = Arc::new(DataValue::Int32(Some(3))); + + let mut binary = ConstantBinary::And(vec![ + ConstantBinary::Scope { + min: Bound::Excluded(val_0.clone()), + max: Bound::Included(val_3.clone()) + }, + ConstantBinary::Scope { + min: Bound::Included(val_1.clone()), + max: Bound::Excluded(val_2.clone()) + }, + ConstantBinary::Scope { + min: Bound::Excluded(val_1.clone()), + max: Bound::Included(val_2.clone()) + }, + ConstantBinary::Scope { + min: Bound::Included(val_0.clone()), + max: Bound::Excluded(val_3.clone()) + }, + ConstantBinary::Scope { + min: Bound::Unbounded, + max: Bound::Unbounded, + }, + ]); + + binary.scope_aggregation()?; + + assert_eq!( + binary, + ConstantBinary::Scope { + min: Bound::Excluded(val_1.clone()), + max: Bound::Excluded(val_2.clone()), + } + ); + + Ok(()) + } + + #[test] + fn test_scope_aggregation_mixed() -> Result<(), TypeError> { + let val_0 = Arc::new(DataValue::Int32(Some(0))); + let val_1 = Arc::new(DataValue::Int32(Some(1))); + let val_2 = Arc::new(DataValue::Int32(Some(2))); + let val_3 = Arc::new(DataValue::Int32(Some(3))); + + let mut binary = ConstantBinary::And(vec![ + ConstantBinary::Scope { + min: Bound::Excluded(val_0.clone()), + max: Bound::Included(val_3.clone()) + }, + ConstantBinary::Scope { + min: Bound::Included(val_1.clone()), + max: Bound::Excluded(val_2.clone()) + }, + ConstantBinary::Scope { + min: Bound::Excluded(val_1.clone()), + max: Bound::Included(val_2.clone()) + }, + ConstantBinary::Scope { + min: Bound::Included(val_0.clone()), + max: Bound::Excluded(val_3.clone()) + }, + ConstantBinary::Scope { + min: Bound::Unbounded, + max: Bound::Unbounded, + }, + ConstantBinary::Eq(val_1.clone()), + ConstantBinary::Eq(val_0.clone()), + ConstantBinary::NotEq(val_1.clone()), + ]); + + binary.scope_aggregation()?; + + assert_eq!( + binary, + ConstantBinary::Eq(val_0.clone()) + ); + + Ok(()) + } + + #[test] + fn test_rearrange() -> Result<(), TypeError> { + let val_0 = Arc::new(DataValue::Int32(Some(0))); + let val_1 = Arc::new(DataValue::Int32(Some(1))); + let val_2 = Arc::new(DataValue::Int32(Some(2))); + let val_3 = Arc::new(DataValue::Int32(Some(3))); + + let val_5 = Arc::new(DataValue::Int32(Some(5))); + + let val_6 = Arc::new(DataValue::Int32(Some(6))); + let val_7 = Arc::new(DataValue::Int32(Some(7))); + let val_8 = Arc::new(DataValue::Int32(Some(8))); + + let val_10 = Arc::new(DataValue::Int32(Some(10))); + + let binary = ConstantBinary::Or(vec![ + ConstantBinary::Scope { + min: Bound::Excluded(val_6.clone()), + max: Bound::Included(val_10.clone()) + }, + ConstantBinary::Scope { + min: Bound::Excluded(val_0.clone()), + max: Bound::Included(val_3.clone()) + }, + ConstantBinary::Scope { + min: Bound::Included(val_1.clone()), + max: Bound::Excluded(val_2.clone()) + }, + ConstantBinary::Scope { + min: Bound::Excluded(val_1.clone()), + max: Bound::Included(val_2.clone()) + }, + ConstantBinary::Scope { + min: Bound::Included(val_0.clone()), + max: Bound::Excluded(val_3.clone()) + }, + ConstantBinary::Scope { + min: Bound::Included(val_6.clone()), + max: Bound::Included(val_7.clone()) + }, + ConstantBinary::Scope { + min: Bound::Unbounded, + max: Bound::Unbounded, + }, + ConstantBinary::NotEq(val_8.clone()), + ConstantBinary::Eq(val_5.clone()), + ConstantBinary::Eq(val_0.clone()), + ConstantBinary::Eq(val_1.clone()), + ]); + + assert_eq!( + binary.rearrange()?, + vec![ + ConstantBinary::Scope { + min: Bound::Included(val_0.clone()), + max: Bound::Included(val_3.clone()), + }, + ConstantBinary::Eq(val_5.clone()), + ConstantBinary::Scope { + min: Bound::Included(val_6.clone()), + max: Bound::Included(val_10.clone()), + } + ] + ); + + Ok(()) + } +} \ No newline at end of file diff --git a/src/expression/value_compute.rs b/src/expression/value_compute.rs index 3c0b2b1d..3a2f6dfc 100644 --- a/src/expression/value_compute.rs +++ b/src/expression/value_compute.rs @@ -1184,7 +1184,7 @@ mod test { } #[test] - fn test_binary_op_Utf8_compare()->Result<(),TypeError>{ + fn test_binary_op_utf8_compare()->Result<(),TypeError>{ assert_eq!(binary_op(&DataValue::Utf8(Some("a".to_string())), &DataValue::Utf8(Some("b".to_string())), &BinaryOperator::Gt)?, DataValue::Boolean(Some(false))); assert_eq!(binary_op(&DataValue::Utf8(Some("a".to_string())), &DataValue::Utf8(Some("b".to_string())), &BinaryOperator::Lt)?, DataValue::Boolean(Some(true))); assert_eq!(binary_op(&DataValue::Utf8(Some("a".to_string())), &DataValue::Utf8(Some("a".to_string())), &BinaryOperator::GtEq)?, DataValue::Boolean(Some(true))); diff --git a/src/lib.rs b/src/lib.rs index 44cc5c24..8635388f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,6 +4,7 @@ #![feature(generators)] #![feature(iterator_try_collect)] #![feature(slice_pattern)] +#![feature(bound_map)] extern crate core; pub mod binder; diff --git a/src/optimizer/core/rule.rs b/src/optimizer/core/rule.rs index 063c84de..37bf04cd 100644 --- a/src/optimizer/core/rule.rs +++ b/src/optimizer/core/rule.rs @@ -1,10 +1,11 @@ use crate::optimizer::core::pattern::Pattern; use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId}; +use crate::optimizer::OptimizerError; /// A rule is to transform logically equivalent expression pub trait Rule { /// The pattern to determine whether the rule can be applied. fn pattern(&self) -> &Pattern; - fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph); + fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), OptimizerError>; } \ No newline at end of file diff --git a/src/optimizer/heuristic/optimizer.rs b/src/optimizer/heuristic/optimizer.rs index ddbd0579..081c313d 100644 --- a/src/optimizer/heuristic/optimizer.rs +++ b/src/optimizer/heuristic/optimizer.rs @@ -3,6 +3,7 @@ use crate::optimizer::core::rule::Rule; use crate::optimizer::heuristic::batch::{HepBatch, HepBatchStrategy}; use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId}; use crate::optimizer::heuristic::matcher::HepMatcher; +use crate::optimizer::OptimizerError; use crate::optimizer::rule::RuleImpl; use crate::planner::LogicalPlan; @@ -24,7 +25,7 @@ impl HepOptimizer { self } - pub fn find_best(&mut self) -> LogicalPlan { + pub fn find_best(&mut self) -> Result { let batches = self.batches.clone(); for batch in batches { @@ -32,7 +33,7 @@ impl HepOptimizer { let mut iteration = 1usize; while iteration <= batch.strategy.max_iteration && !batch_over { - if self.apply_batch(&batch) { + if self.apply_batch(&batch)? { iteration += 1; } else { batch_over = true @@ -40,31 +41,31 @@ impl HepOptimizer { } } - self.graph.to_plan() + Ok(self.graph.to_plan()) } - fn apply_batch(&mut self, HepBatch{ rules, strategy, .. }: &HepBatch) -> bool { + fn apply_batch(&mut self, HepBatch{ rules, strategy, .. }: &HepBatch) -> Result { let start_ver = self.graph.version; for rule in rules { for node_id in self.graph.nodes_iter(strategy.match_order, None) { - if self.apply_rule(rule, node_id) { + if self.apply_rule(rule, node_id)? { break; } } } - start_ver != self.graph.version + Ok(start_ver != self.graph.version) } - fn apply_rule(&mut self, rule: &RuleImpl, node_id: HepNodeId) -> bool { + fn apply_rule(&mut self, rule: &RuleImpl, node_id: HepNodeId) -> Result { let after_version = self.graph.version; if HepMatcher::new(rule.pattern(), node_id, &self.graph).match_opt_expr() { - rule.apply(node_id, &mut self.graph); + rule.apply(node_id, &mut self.graph)?; } - after_version != self.graph.version + Ok(after_version != self.graph.version) } } \ No newline at end of file diff --git a/src/optimizer/mod.rs b/src/optimizer/mod.rs index 6d35cdf9..8f937373 100644 --- a/src/optimizer/mod.rs +++ b/src/optimizer/mod.rs @@ -1,6 +1,18 @@ +use crate::types::errors::TypeError; + /// The architecture and some components, /// such as (/core) are referenced from sqlrs mod core; pub mod heuristic; -pub mod rule; \ No newline at end of file +pub mod rule; + +#[derive(thiserror::Error, Debug)] +pub enum OptimizerError { + #[error("type error")] + TypeError( + #[source] + #[from] + TypeError + ) +} \ No newline at end of file diff --git a/src/optimizer/rule/column_pruning.rs b/src/optimizer/rule/column_pruning.rs index e9a022ce..ca404ca5 100644 --- a/src/optimizer/rule/column_pruning.rs +++ b/src/optimizer/rule/column_pruning.rs @@ -1,14 +1,15 @@ -use std::collections::HashSet; use itertools::Itertools; use lazy_static::lazy_static; +use crate::catalog::ColumnRef; use crate::expression::ScalarExpression; use crate::optimizer::core::opt_expr::OptExprNode; use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate}; use crate::optimizer::core::rule::Rule; use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId}; +use crate::optimizer::OptimizerError; +use crate::planner::operator::aggregate::AggregateOperator; use crate::planner::operator::Operator; use crate::planner::operator::project::ProjectOperator; -use crate::types::ColumnId; lazy_static! { static ref PUSH_PROJECT_INTO_SCAN_RULE: Pattern = { @@ -43,7 +44,7 @@ impl Rule for PushProjectIntoScan { &PUSH_PROJECT_INTO_SCAN_RULE } - fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) { + fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), OptimizerError> { if let Operator::Project(project_op) = graph.operator(node_id) { let child_index = graph.children_at(node_id)[0]; if let Operator::Scan(scan_op) = graph.operator(child_index) { @@ -62,6 +63,8 @@ impl Rule for PushProjectIntoScan { ); } } + + Ok(()) } } @@ -73,76 +76,111 @@ impl Rule for PushProjectThroughChild { &PUSH_PROJECT_THROUGH_CHILD_RULE } - fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) { + fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), OptimizerError> { let node_operator = graph.operator(node_id); - let input_refs = node_operator.project_input_refs(); if let Operator::Project(_) = node_operator { let child_index = graph.children_at(node_id)[0]; - let mut node_referenced_columns = node_operator.referenced_columns(); + let node_referenced_columns = node_operator.referenced_columns(); let child_operator = graph.operator(child_index); let child_referenced_columns = child_operator.referenced_columns(); - let is_child_agg = matches!(child_operator, Operator::Aggregate(_)); - - // When the aggregate function is a child node, - // the pushdown will lose the corresponding ColumnRef due to `InputRef`. - // Therefore, it is necessary to map the InputRef to the corresponding ColumnRef - // and push it down. - if is_child_agg && !input_refs.is_empty() { - node_referenced_columns.append( - &mut child_operator.agg_mapping_col_refs(&input_refs) - ) - } + let op = |col: &ColumnRef| format!("{:?}.{:?}", col.table_name, col.id); - let intersection_columns_ids = child_referenced_columns - .iter() - .map(|col| col.id) - .chain( - node_referenced_columns + match child_operator { + // When the aggregate function is a child node, + // the pushdown will lose the corresponding ColumnRef due to `InputRef`. + // Therefore, it is necessary to map the InputRef to the corresponding ColumnRef + // and push it down. + Operator::Aggregate(AggregateOperator { agg_calls, .. }) => { + let grandson_id = graph.children_at(child_index)[0]; + let columns = node_operator + .project_input_refs() .iter() - .map(|col| col.id) - ) - .collect::>(); + .filter_map(|expr| { + if let ScalarExpression::InputRef { index, .. } = expr { + Some(agg_calls[*index].clone()) + } else { + None + } + }) + .map(|expr| expr.referenced_columns()) + .flatten() + .chain(node_referenced_columns.into_iter()) + .chain(child_referenced_columns.into_iter()) + .unique_by(op) + .map(|col| ScalarExpression::ColumnRef(col)) + .collect_vec(); - if intersection_columns_ids.is_empty() { - return; - } + Self::add_project_node(graph, child_index, columns, grandson_id); + } + Operator::Join(_) => { + let parent_referenced_columns = node_referenced_columns + .into_iter() + .chain(child_referenced_columns.into_iter()) + .unique_by(op) + .collect_vec(); - for grandson_id in graph.children_at(child_index) { - let mut columns = graph.operator(grandson_id) - .referenced_columns() - .into_iter() - .unique_by(|col| col.id) - .filter(|u| intersection_columns_ids.contains(&u.id)) - .map(|col| ScalarExpression::ColumnRef(col)) - .collect_vec(); + for grandson_id in graph.children_at(child_index) { + let grandson_referenced_column = graph + .operator(grandson_id) + .referenced_columns(); - if !is_child_agg && !input_refs.is_empty() { - // Tips: Aggregation InputRefs fields take precedence - columns = input_refs.iter() - .cloned() - .chain(columns) - .collect_vec(); + // for PushLimitThroughJoin + if grandson_referenced_column.is_empty() { + return Ok(()) + } + let grandson_table_name = grandson_referenced_column[0] + .table_name + .clone(); + let columns = parent_referenced_columns.iter() + .filter(|col| col.table_name == grandson_table_name) + .cloned() + .map(|col| ScalarExpression::ColumnRef(col)) + .collect_vec(); + + Self::add_project_node(graph, child_index, columns, grandson_id); + } } + _ => { + let grandson_id = graph.children_at(child_index)[0]; + let mut columns = node_operator.project_input_refs(); + + let mut referenced_columns = node_referenced_columns + .into_iter() + .chain(child_referenced_columns.into_iter()) + .unique_by(op) + .map(|col| ScalarExpression::ColumnRef(col)) + .collect_vec(); - if !columns.is_empty() { - graph.add_node( - child_index, - Some(grandson_id), - OptExprNode::OperatorRef( - Operator::Project(ProjectOperator { columns }) - ) - ); + columns.append(&mut referenced_columns); + + Self::add_project_node(graph, child_index, columns, grandson_id); } } } + + Ok(()) + } +} + +impl PushProjectThroughChild { + fn add_project_node(graph: &mut HepGraph, child_index: HepNodeId, columns: Vec, grandson_id: HepNodeId) { + if !columns.is_empty() { + graph.add_node( + child_index, + Some(grandson_id), + OptExprNode::OperatorRef( + Operator::Project(ProjectOperator { columns }) + ) + ); + } } } #[cfg(test)] mod tests { use crate::binder::test::select_sql_run; - use crate::execution::ExecutorError; + use crate::db::DatabaseError; use crate::optimizer::heuristic::batch::{HepBatchStrategy}; use crate::optimizer::heuristic::optimizer::HepOptimizer; use crate::optimizer::rule::RuleImpl; @@ -150,7 +188,7 @@ mod tests { use crate::planner::operator::Operator; #[tokio::test] - async fn test_project_into_table_scan() -> Result<(), ExecutorError> { + async fn test_project_into_table_scan() -> Result<(), DatabaseError> { let plan = select_sql_run("select * from t1").await?; let best_plan = HepOptimizer::new(plan.clone()) @@ -159,7 +197,7 @@ mod tests { HepBatchStrategy::once_topdown(), vec![RuleImpl::PushProjectIntoScan] ) - .find_best(); + .find_best()?; assert_eq!(best_plan.childrens.len(), 0); match best_plan.operator { @@ -173,7 +211,7 @@ mod tests { } #[tokio::test] - async fn test_project_through_child_on_join() -> Result<(), ExecutorError> { + async fn test_project_through_child_on_join() -> Result<(), DatabaseError> { let plan = select_sql_run("select c1, c3 from t1 left join t2 on c1 = c3").await?; let best_plan = HepOptimizer::new(plan.clone()) @@ -184,7 +222,7 @@ mod tests { RuleImpl::PushProjectThroughChild, RuleImpl::PushProjectIntoScan ] - ).find_best(); + ).find_best()?; assert_eq!(best_plan.childrens.len(), 1); match best_plan.operator { diff --git a/src/optimizer/rule/combine_operators.rs b/src/optimizer/rule/combine_operators.rs index 52fc80bd..bf22a66c 100644 --- a/src/optimizer/rule/combine_operators.rs +++ b/src/optimizer/rule/combine_operators.rs @@ -4,6 +4,7 @@ use crate::optimizer::core::opt_expr::OptExprNode; use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate}; use crate::optimizer::core::rule::Rule; use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId}; +use crate::optimizer::OptimizerError; use crate::optimizer::rule::is_subset_exprs; use crate::planner::operator::filter::FilterOperator; use crate::planner::operator::Operator; @@ -38,7 +39,7 @@ impl Rule for CollapseProject { &COLLAPSE_PROJECT_RULE } - fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) { + fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), OptimizerError> { if let Operator::Project(op) = graph.operator(node_id) { let child_id = graph.children_at(node_id)[0]; if let Operator::Project(child_op) = graph.operator(child_id) { @@ -47,6 +48,8 @@ impl Rule for CollapseProject { } } } + + Ok(()) } } @@ -58,7 +61,7 @@ impl Rule for CombineFilter { &COMBINE_FILTERS_RULE } - fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) { + fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), OptimizerError> { if let Operator::Filter(op) = graph.operator(node_id) { let child_id = graph.children_at(node_id)[0]; if let Operator::Filter(child_op) = graph.operator(child_id) { @@ -78,6 +81,8 @@ impl Rule for CombineFilter { graph.remove_node(child_id, false); } } + + Ok(()) } } @@ -85,7 +90,7 @@ impl Rule for CombineFilter { mod tests { use std::sync::Arc; use crate::binder::test::select_sql_run; - use crate::execution::ExecutorError; + use crate::db::DatabaseError; use crate::expression::{BinaryOperator, ScalarExpression}; use crate::expression::ScalarExpression::Constant; use crate::optimizer::core::opt_expr::OptExprNode; @@ -98,7 +103,7 @@ mod tests { use crate::types::value::DataValue; #[tokio::test] - async fn test_collapse_project() -> Result<(), ExecutorError> { + async fn test_collapse_project() -> Result<(), DatabaseError> { let plan = select_sql_run("select c1, c2 from t1").await?; let mut optimizer = HepOptimizer::new(plan.clone()) @@ -120,7 +125,7 @@ mod tests { optimizer.graph.add_root(OptExprNode::OperatorRef(new_project_op)); - let best_plan = optimizer.find_best(); + let best_plan = optimizer.find_best()?; if let Operator::Project(op) = &best_plan.operator { assert_eq!(op.columns.len(), 1); @@ -138,7 +143,7 @@ mod tests { } #[tokio::test] - async fn test_combine_filter() -> Result<(), ExecutorError> { + async fn test_combine_filter() -> Result<(), DatabaseError> { let plan = select_sql_run("select * from t1 where c1 > 1").await?; let mut optimizer = HepOptimizer::new(plan.clone()) @@ -169,7 +174,7 @@ mod tests { OptExprNode::OperatorRef(new_filter_op) ); - let best_plan = optimizer.find_best(); + let best_plan = optimizer.find_best()?; if let Operator::Filter(op) = &best_plan.childrens[0].operator { if let ScalarExpression::Binary { op, .. } = &op.predicate { diff --git a/src/optimizer/rule/mod.rs b/src/optimizer/rule/mod.rs index f93dec56..bca3f4c6 100644 --- a/src/optimizer/rule/mod.rs +++ b/src/optimizer/rule/mod.rs @@ -2,15 +2,19 @@ use crate::expression::ScalarExpression; use crate::optimizer::core::pattern::Pattern; use crate::optimizer::core::rule::Rule; use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId}; +use crate::optimizer::OptimizerError; use crate::optimizer::rule::column_pruning::{PushProjectIntoScan, PushProjectThroughChild}; use crate::optimizer::rule::combine_operators::{CollapseProject, CombineFilter}; use crate::optimizer::rule::pushdown_limit::{LimitProjectTranspose, EliminateLimits, PushLimitThroughJoin, PushLimitIntoScan}; use crate::optimizer::rule::pushdown_predicates::PushPredicateThroughJoin; +use crate::optimizer::rule::pushdown_predicates::PushPredicateIntoScan; +use crate::optimizer::rule::simplification::SimplifyFilter; mod column_pruning; mod combine_operators; mod pushdown_limit; mod pushdown_predicates; +mod simplification; #[derive(Debug, Copy, Clone)] pub enum RuleImpl { @@ -26,7 +30,11 @@ pub enum RuleImpl { PushLimitThroughJoin, PushLimitIntoTableScan, // PushDown predicates - PushPredicateThroughJoin + PushPredicateThroughJoin, + // Tips: need to be used with `SimplifyFilter` + PushPredicateIntoScan, + // Simplification + SimplifyFilter } impl Rule for RuleImpl { @@ -41,10 +49,12 @@ impl Rule for RuleImpl { RuleImpl::PushLimitThroughJoin => PushLimitThroughJoin {}.pattern(), RuleImpl::PushLimitIntoTableScan => PushLimitIntoScan {}.pattern(), RuleImpl::PushPredicateThroughJoin => PushPredicateThroughJoin {}.pattern(), + RuleImpl::PushPredicateIntoScan => PushPredicateIntoScan {}.pattern(), + RuleImpl::SimplifyFilter => SimplifyFilter {}.pattern(), } } - fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) { + fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), OptimizerError>{ match self { RuleImpl::PushProjectIntoScan => PushProjectIntoScan {}.apply(node_id, graph), RuleImpl::PushProjectThroughChild => PushProjectThroughChild {}.apply(node_id, graph), @@ -55,6 +65,8 @@ impl Rule for RuleImpl { RuleImpl::PushLimitThroughJoin => PushLimitThroughJoin {}.apply(node_id, graph), RuleImpl::PushLimitIntoTableScan => PushLimitIntoScan {}.apply(node_id, graph), RuleImpl::PushPredicateThroughJoin => PushPredicateThroughJoin {}.apply(node_id, graph), + RuleImpl::SimplifyFilter => SimplifyFilter {}.apply(node_id, graph), + RuleImpl::PushPredicateIntoScan => PushPredicateIntoScan {}.apply(node_id, graph) } } } diff --git a/src/optimizer/rule/pushdown_limit.rs b/src/optimizer/rule/pushdown_limit.rs index 69e8d94e..8fd161f0 100644 --- a/src/optimizer/rule/pushdown_limit.rs +++ b/src/optimizer/rule/pushdown_limit.rs @@ -5,6 +5,7 @@ use crate::optimizer::core::pattern::PatternChildrenPredicate; use crate::optimizer::core::pattern::Pattern; use crate::optimizer::core::rule::Rule; use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId}; +use crate::optimizer::OptimizerError; use crate::planner::operator::join::JoinType; use crate::planner::operator::limit::LimitOperator; use crate::planner::operator::Operator; @@ -54,11 +55,13 @@ impl Rule for LimitProjectTranspose { &LIMIT_PROJECT_TRANSPOSE_RULE } - fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) { + fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), OptimizerError> { graph.swap_node( node_id, graph.children_at(node_id)[0] ); + + Ok(()) } } @@ -71,7 +74,7 @@ impl Rule for EliminateLimits { &ELIMINATE_LIMITS_RULE } - fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) { + fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), OptimizerError> { if let Operator::Limit(op) = graph.operator(node_id) { let child_id = graph.children_at(node_id)[0]; if let Operator::Limit(child_op) = graph.operator(child_id) { @@ -89,6 +92,8 @@ impl Rule for EliminateLimits { ); } } + + Ok(()) } } @@ -105,29 +110,31 @@ impl Rule for PushLimitThroughJoin { &PUSH_LIMIT_THROUGH_JOIN_RULE } - fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) { - let child_id = graph.children_at(node_id)[0]; - let join_type = if let Operator::Join(op) = graph.operator(child_id) { - Some(op.join_type) - } else { - None - }; - - if let Some(ty) = join_type { - if let Some(grandson_id) = match ty { - JoinType::Left => Some(graph.children_at(child_id)[0]), - JoinType::Right => Some(graph.children_at(child_id)[1]), - _ => None - } { - let limit_node = graph.remove_node(node_id, false).unwrap(); - - graph.add_node( - child_id, - Some(grandson_id), - limit_node - ); + fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), OptimizerError> { + if let Operator::Limit(op) = graph.operator(node_id) { + let child_id = graph.children_at(node_id)[0]; + let join_type = if let Operator::Join(op) = graph.operator(child_id) { + Some(op.join_type) + } else { + None + }; + + if let Some(ty) = join_type { + if let Some(grandson_id) = match ty { + JoinType::Left => Some(graph.children_at(child_id)[0]), + JoinType::Right => Some(graph.children_at(child_id)[1]), + _ => None + } { + graph.add_node( + child_id, + Some(grandson_id), + OptExprNode::OperatorRef(Operator::Limit(op.clone())) + ); + } } } + + Ok(()) } } @@ -139,7 +146,7 @@ impl Rule for PushLimitIntoScan { &PUSH_LIMIT_INTO_TABLE_SCAN_RULE } - fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) { + fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), OptimizerError> { if let Operator::Limit(limit_op) = graph.operator(node_id) { let child_index = graph.children_at(node_id)[0]; if let Operator::Scan(scan_op) = graph.operator(child_index) { @@ -154,13 +161,15 @@ impl Rule for PushLimitIntoScan { ); } } + + Ok(()) } } #[cfg(test)] mod tests { use crate::binder::test::select_sql_run; - use crate::execution::ExecutorError; + use crate::db::DatabaseError; use crate::optimizer::core::opt_expr::OptExprNode; use crate::optimizer::heuristic::batch::HepBatchStrategy; use crate::optimizer::heuristic::optimizer::HepOptimizer; @@ -169,7 +178,7 @@ mod tests { use crate::planner::operator::Operator; #[tokio::test] - async fn test_limit_project_transpose() -> Result<(), ExecutorError> { + async fn test_limit_project_transpose() -> Result<(), DatabaseError> { let plan = select_sql_run("select c1, c2 from t1 limit 1").await?; let best_plan = HepOptimizer::new(plan.clone()) @@ -178,7 +187,7 @@ mod tests { HepBatchStrategy::once_topdown(), vec![RuleImpl::LimitProjectTranspose] ) - .find_best(); + .find_best()?; if let Operator::Project(_) = &best_plan.operator { @@ -196,7 +205,7 @@ mod tests { } #[tokio::test] - async fn test_eliminate_limits() -> Result<(), ExecutorError> { + async fn test_eliminate_limits() -> Result<(), DatabaseError> { let plan = select_sql_run("select c1, c2 from t1 limit 1 offset 1").await?; let mut optimizer = HepOptimizer::new(plan.clone()) @@ -215,7 +224,7 @@ mod tests { OptExprNode::OperatorRef(Operator::Limit(new_limit_op)) ); - let best_plan = optimizer.find_best(); + let best_plan = optimizer.find_best()?; if let Operator::Limit(op) = &best_plan.operator { assert_eq!(op.limit, 1); @@ -232,7 +241,7 @@ mod tests { } #[tokio::test] - async fn test_push_limit_through_join() -> Result<(), ExecutorError> { + async fn test_push_limit_through_join() -> Result<(), DatabaseError> { let plan = select_sql_run("select * from t1 left join t2 on c1 = c3 limit 1").await?; let best_plan = HepOptimizer::new(plan.clone()) @@ -244,24 +253,24 @@ mod tests { RuleImpl::PushLimitThroughJoin ] ) - .find_best(); + .find_best()?; - if let Operator::Join(_) = &best_plan.childrens[0].operator { + if let Operator::Join(_) = &best_plan.childrens[0].childrens[0].operator { } else { - unreachable!("Should be a project operator") + unreachable!("Should be a join operator") } - if let Operator::Limit(op) = &best_plan.childrens[0].childrens[0].operator { + if let Operator::Limit(op) = &best_plan.childrens[0].childrens[0].childrens[0].operator { assert_eq!(op.limit, 1); } else { - unreachable!("Should be a project operator") + unreachable!("Should be a limit operator") } Ok(()) } #[tokio::test] - async fn test_push_limit_into_table_scan() -> Result<(), ExecutorError> { + async fn test_push_limit_into_table_scan() -> Result<(), DatabaseError> { let plan = select_sql_run("select * from t1 limit 1 offset 1").await?; let best_plan = HepOptimizer::new(plan.clone()) @@ -273,7 +282,7 @@ mod tests { RuleImpl::PushLimitIntoTableScan ] ) - .find_best(); + .find_best()?; if let Operator::Scan(op) = &best_plan.childrens[0].operator { assert_eq!(op.limit, (Some(1), Some(1))) diff --git a/src/optimizer/rule/pushdown_predicates.rs b/src/optimizer/rule/pushdown_predicates.rs index c2242964..1cef28e8 100644 --- a/src/optimizer/rule/pushdown_predicates.rs +++ b/src/optimizer/rule/pushdown_predicates.rs @@ -7,6 +7,7 @@ use crate::optimizer::core::pattern::Pattern; use crate::optimizer::core::rule::Rule; use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId}; use crate::optimizer::core::pattern::PatternChildrenPredicate; +use crate::optimizer::OptimizerError; use crate::planner::operator::filter::FilterOperator; use crate::planner::operator::join::JoinType; use crate::planner::operator::Operator; @@ -23,6 +24,16 @@ lazy_static! { } }; + static ref PUSH_PREDICATE_INTO_SCAN: Pattern = { + Pattern { + predicate: |op| matches!(op, Operator::Filter(_)), + children: PatternChildrenPredicate::Predicate(vec![Pattern { + predicate: |op| matches!(op, Operator::Scan(_)), + children: PatternChildrenPredicate::None, + }]), + } + }; + // TODO static ref PUSH_PREDICATE_THROUGH_NON_JOIN: Pattern = { Pattern { @@ -92,11 +103,11 @@ impl Rule for PushPredicateThroughJoin { } // TODO: pushdown_predicates need to consider output columns - fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) { + fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), OptimizerError> { let child_id = graph.children_at(node_id)[0]; if let Operator::Join(child_op) = graph.operator(child_id) { if !matches!(child_op.join_type, JoinType::Inner | JoinType::Left | JoinType::Right) { - return ; + return Ok(()); } let join_childs = graph.children_at(child_id); @@ -194,22 +205,109 @@ impl Rule for PushPredicateThroughJoin { graph.remove_node(node_id, false); } } + + Ok(()) + } +} + +pub struct PushPredicateIntoScan { + +} + +impl Rule for PushPredicateIntoScan { + fn pattern(&self) -> &Pattern { + &PUSH_PREDICATE_INTO_SCAN + } + + fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), OptimizerError> { + if let Operator::Filter(op) = graph.operator(node_id) { + let child_id = graph.children_at(node_id)[0]; + if let Operator::Scan(child_op) = graph.operator(child_id) { + if child_op.index_by.is_some() { + return Ok(()) + } + + //FIXME: now only support unique + for meta in &child_op.index_metas { + let mut option = op.predicate.convert_binary(&meta.column_ids[0])?; + + if let Some(mut binary) = option.take() { + binary.scope_aggregation()?; + let rearrange_binaries = binary.rearrange()?; + + if !rearrange_binaries.is_empty() { + let mut scan_by_index = child_op.clone(); + scan_by_index.index_by = Some((meta.clone(), rearrange_binaries)); + + // The constant expression extracted in prewhere is used to + // reduce the data scanning range and cannot replace the role of Filter. + graph.replace_node( + child_id, + OptExprNode::OperatorRef( + Operator::Scan(scan_by_index) + ) + ); + + return Ok(()) + } + } + } + } + } + + Ok(()) } } #[cfg(test)] mod tests { + use std::collections::Bound; + use std::sync::Arc; use crate::binder::test::select_sql_run; - use crate::execution::ExecutorError; + use crate::db::DatabaseError; use crate::expression::{BinaryOperator, ScalarExpression}; + use crate::expression::simplify::ConstantBinary::Scope; use crate::optimizer::heuristic::batch::HepBatchStrategy; use crate::optimizer::heuristic::optimizer::HepOptimizer; use crate::optimizer::rule::RuleImpl; use crate::planner::operator::Operator; use crate::types::LogicalType; + use crate::types::value::DataValue; + + #[tokio::test] + async fn test_push_predicate_into_scan() -> Result<(), DatabaseError> { + // 1 - c2 < 0 => c2 > 1 + let plan = select_sql_run("select * from t1 where -(1 - c2) > 0").await?; + + let best_plan = HepOptimizer::new(plan) + .batch( + "simplify_filter".to_string(), + HepBatchStrategy::once_topdown(), + vec![RuleImpl::SimplifyFilter] + ) + .batch( + "test_push_predicate_into_scan".to_string(), + HepBatchStrategy::once_topdown(), + vec![RuleImpl::PushPredicateIntoScan] + ) + .find_best()?; + + if let Operator::Scan(op) = &best_plan.childrens[0].childrens[0].operator { + let mock_binaries = vec![Scope { + min: Bound::Excluded(Arc::new(DataValue::Int32(Some(1)))), + max: Bound::Unbounded + }]; + + assert_eq!(op.index_by.clone().unwrap().1, mock_binaries); + } else { + unreachable!("Should be a filter operator") + } + + Ok(()) + } #[tokio::test] - async fn test_push_predicate_through_join_in_left_join() -> Result<(), ExecutorError> { + async fn test_push_predicate_through_join_in_left_join() -> Result<(), DatabaseError> { let plan = select_sql_run("select * from t1 left join t2 on c1 = c3 where c1 > 1 and c3 < 2").await?; let best_plan = HepOptimizer::new(plan) @@ -218,7 +316,7 @@ mod tests { HepBatchStrategy::once_topdown(), vec![RuleImpl::PushPredicateThroughJoin] ) - .find_best(); + .find_best()?; if let Operator::Filter(op) = &best_plan.childrens[0].operator { match op.predicate { @@ -250,7 +348,7 @@ mod tests { } #[tokio::test] - async fn test_push_predicate_through_join_in_right_join() -> Result<(), ExecutorError> { + async fn test_push_predicate_through_join_in_right_join() -> Result<(), DatabaseError> { let plan = select_sql_run("select * from t1 right join t2 on c1 = c3 where c1 > 1 and c3 < 2").await?; let best_plan = HepOptimizer::new(plan) @@ -259,7 +357,7 @@ mod tests { HepBatchStrategy::once_topdown(), vec![RuleImpl::PushPredicateThroughJoin] ) - .find_best(); + .find_best()?; if let Operator::Filter(op) = &best_plan.childrens[0].operator { match op.predicate { @@ -291,7 +389,7 @@ mod tests { } #[tokio::test] - async fn test_push_predicate_through_join_in_inner_join() -> Result<(), ExecutorError> { + async fn test_push_predicate_through_join_in_inner_join() -> Result<(), DatabaseError> { let plan = select_sql_run("select * from t1 inner join t2 on c1 = c3 where c1 > 1 and c3 < 2").await?; let best_plan = HepOptimizer::new(plan) @@ -300,7 +398,7 @@ mod tests { HepBatchStrategy::once_topdown(), vec![RuleImpl::PushPredicateThroughJoin] ) - .find_best(); + .find_best()?; if let Operator::Join(_) = &best_plan.childrens[0].operator { diff --git a/src/optimizer/rule/simplification.rs b/src/optimizer/rule/simplification.rs new file mode 100644 index 00000000..796668b7 --- /dev/null +++ b/src/optimizer/rule/simplification.rs @@ -0,0 +1,289 @@ +use lazy_static::lazy_static; +use crate::optimizer::core::opt_expr::OptExprNode; +use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate}; +use crate::optimizer::core::rule::Rule; +use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId}; +use crate::optimizer::OptimizerError; +use crate::planner::operator::Operator; +lazy_static! { + static ref SIMPLIFY_FILTER_RULE: Pattern = { + Pattern { + predicate: |op| matches!(op, Operator::Filter(_)), + children: PatternChildrenPredicate::None, + } + }; +} + +#[derive(Copy, Clone)] +pub struct SimplifyFilter; + +impl Rule for SimplifyFilter { + fn pattern(&self) -> &Pattern { + &SIMPLIFY_FILTER_RULE + } + + fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), OptimizerError> { + if let Operator::Filter(mut filter_op) = graph.operator(node_id).clone() { + filter_op.predicate.simplify()?; + + graph.replace_node( + node_id, + OptExprNode::OperatorRef(Operator::Filter(filter_op)) + ) + } + + Ok(()) + } +} + +#[cfg(test)] +mod test { + use std::collections::Bound; + use std::sync::Arc; + use crate::binder::test::select_sql_run; + use crate::catalog::{ColumnCatalog, ColumnDesc}; + use crate::db::DatabaseError; + use crate::expression::{BinaryOperator, ScalarExpression}; + use crate::expression::simplify::ConstantBinary; + use crate::optimizer::heuristic::batch::HepBatchStrategy; + use crate::optimizer::heuristic::optimizer::HepOptimizer; + use crate::optimizer::rule::RuleImpl; + use crate::planner::LogicalPlan; + use crate::planner::operator::filter::FilterOperator; + use crate::planner::operator::Operator; + use crate::types::LogicalType; + use crate::types::value::DataValue; + + #[tokio::test] + async fn test_simplify_filter_single_column() -> Result<(), DatabaseError> { + // c1 + 1 < -1 => c1 < -2 + let plan_1 = select_sql_run("select * from t1 where -(c1 + 1) > 1").await?; + // 1 - c1 < -1 => c1 > 2 + let plan_2 = select_sql_run("select * from t1 where -(1 - c1) > 1").await?; + // c1 < -1 + let plan_3 = select_sql_run("select * from t1 where -c1 > 1").await?; + // c1 > 0 + let plan_4 = select_sql_run("select * from t1 where c1 + 1 > 1").await?; + + // c1 + 1 < -1 => c1 < -2 + let plan_5 = select_sql_run("select * from t1 where 1 < -(c1 + 1)").await?; + // 1 - c1 < -1 => c1 > 2 + let plan_6 = select_sql_run("select * from t1 where 1 < -(1 - c1)").await?; + // c1 < -1 + let plan_7 = select_sql_run("select * from t1 where 1 < -c1").await?; + // c1 > 0 + let plan_8 = select_sql_run("select * from t1 where 1 < c1 + 1").await?; + + let op = |plan: LogicalPlan, expr: &str| -> Result, DatabaseError> { + let best_plan = HepOptimizer::new(plan.clone()) + .batch( + "test_simplify_filter".to_string(), + HepBatchStrategy::once_topdown(), + vec![RuleImpl::SimplifyFilter] + ) + .find_best()?; + if let Operator::Filter(filter_op) = best_plan.childrens[0].clone().operator { + println!("{expr}: {:#?}", filter_op.predicate.convert_binary(&0).unwrap()); + + Ok(filter_op.predicate.convert_binary(&0).unwrap()) + } else { + Ok(None) + } + }; + + assert_eq!(op(plan_1, "-(c1 + 1) > 1")?, op(plan_5, "1 < -(c1 + 1)")?); + assert_eq!(op(plan_2, "-(1 - c1) > 1")?, op(plan_6, "1 < -(1 - c1)")?); + assert_eq!(op(plan_3, "-c1 > 1")?, op(plan_7, "1 < -c1")?); + assert_eq!(op(plan_4, "c1 + 1 > 1")?, op(plan_8, "1 < c1 + 1")?); + + Ok(()) + } + + #[tokio::test] + async fn test_simplify_filter_repeating_column() -> Result<(), DatabaseError> { + let plan = select_sql_run("select * from t1 where -(c1 + 1) > c2").await?; + + let best_plan = HepOptimizer::new(plan.clone()) + .batch( + "test_simplify_filter".to_string(), + HepBatchStrategy::once_topdown(), + vec![RuleImpl::SimplifyFilter] + ) + .find_best()?; + if let Operator::Filter(filter_op) = best_plan.childrens[0].clone().operator { + let c1_col = ColumnCatalog { + id: Some( + 0, + ), + name: "c1".to_string(), + table_name: Some( + Arc::new("t1".to_string()), + ), + nullable: false, + desc: ColumnDesc { + column_datatype: LogicalType::Integer, + is_primary: true, + is_unique: false, + }, + }; + let c2_col = ColumnCatalog { + id: Some( + 1, + ), + name: "c2".to_string(), + table_name: Some( + Arc::new("t1".to_string()), + ), + nullable: false, + desc: ColumnDesc { + column_datatype: LogicalType::Integer, + is_primary: false, + is_unique: true, + }, + }; + + // -(c1 + 1) > c2 => c1 < -c2 - 1 + assert_eq!( + filter_op.predicate, + ScalarExpression::Binary { + op: BinaryOperator::Lt, + left_expr: Box::new(ScalarExpression::Binary { + op: BinaryOperator::Minus, + left_expr: Box::new(ScalarExpression::ColumnRef(Arc::new(c1_col))), + right_expr: Box::new(ScalarExpression::Constant(Arc::new(DataValue::Int32(Some(-1))))), + ty: LogicalType::Integer, + }), + right_expr: Box::new(ScalarExpression::ColumnRef(Arc::new(c2_col))), + ty: LogicalType::Boolean, + } + ) + } else { + unreachable!() + } + + Ok(()) + } + + #[tokio::test] + async fn test_simplify_filter_multiple_column() -> Result<(), DatabaseError> { + // c1 + 1 < -1 => c1 < -2 + let plan_1 = select_sql_run("select * from t1 where -(c1 + 1) > 1 and -(1 - c2) > 1").await?; + // 1 - c1 < -1 => c1 > 2 + let plan_2 = select_sql_run("select * from t1 where -(1 - c1) > 1 and -(c2 + 1) > 1").await?; + // c1 < -1 + let plan_3 = select_sql_run("select * from t1 where -c1 > 1 and c2 + 1 > 1").await?; + // c1 > 0 + let plan_4 = select_sql_run("select * from t1 where c1 + 1 > 1 and -c2 > 1").await?; + + let op = |plan: LogicalPlan, expr: &str| -> Result, DatabaseError> { + let best_plan = HepOptimizer::new(plan.clone()) + .batch( + "test_simplify_filter".to_string(), + HepBatchStrategy::once_topdown(), + vec![RuleImpl::SimplifyFilter] + ) + .find_best()?; + if let Operator::Filter(filter_op) = best_plan.childrens[0].clone().operator { + println!("{expr}: {:#?}", filter_op); + + Ok(Some(filter_op)) + } else { + Ok(None) + } + }; + + let op_1 = op(plan_1, "-(c1 + 1) > 1 and -(1 - c2) > 1")?.unwrap(); + let op_2 = op(plan_2, "-(1 - c1) > 1 and -(c2 + 1) > 1")?.unwrap(); + let op_3 = op(plan_3, "-c1 > 1 and c2 + 1 > 1")?.unwrap(); + let op_4 = op(plan_4, "c1 + 1 > 1 and -c2 > 1")?.unwrap(); + + let cb_1_c1 = op_1.predicate.convert_binary(&0).unwrap(); + println!("op_1 => c1: {:#?}", cb_1_c1); + assert_eq!(cb_1_c1, Some(ConstantBinary::Scope { + min: Bound::Unbounded, + max: Bound::Excluded(Arc::new(DataValue::Int32(Some(-2)))) + })); + + let cb_1_c2 = op_1.predicate.convert_binary(&1).unwrap(); + println!("op_1 => c2: {:#?}", cb_1_c2); + assert_eq!(cb_1_c2, Some(ConstantBinary::Scope { + min: Bound::Excluded(Arc::new(DataValue::Int32(Some(2)))), + max: Bound::Unbounded + })); + + let cb_2_c1 = op_2.predicate.convert_binary(&0).unwrap(); + println!("op_2 => c1: {:#?}", cb_2_c1); + assert_eq!(cb_2_c1, Some(ConstantBinary::Scope { + min: Bound::Excluded(Arc::new(DataValue::Int32(Some(2)))), + max: Bound::Unbounded + })); + + let cb_2_c2 = op_2.predicate.convert_binary(&1).unwrap(); + println!("op_2 => c2: {:#?}", cb_2_c2); + assert_eq!(cb_1_c1, Some(ConstantBinary::Scope { + min: Bound::Unbounded, + max: Bound::Excluded(Arc::new(DataValue::Int32(Some(-2)))) + })); + + let cb_3_c1 = op_3.predicate.convert_binary(&0).unwrap(); + println!("op_3 => c1: {:#?}", cb_3_c1); + assert_eq!(cb_3_c1, Some(ConstantBinary::Scope { + min: Bound::Unbounded, + max: Bound::Excluded(Arc::new(DataValue::Int32(Some(-1)))) + })); + + let cb_3_c2 = op_3.predicate.convert_binary(&1).unwrap(); + println!("op_3 => c2: {:#?}", cb_3_c2); + assert_eq!(cb_3_c2, Some(ConstantBinary::Scope { + min: Bound::Excluded(Arc::new(DataValue::Int32(Some(0)))), + max: Bound::Unbounded + })); + + let cb_4_c1 = op_4.predicate.convert_binary(&0).unwrap(); + println!("op_4 => c1: {:#?}", cb_4_c1); + assert_eq!(cb_4_c1, Some(ConstantBinary::Scope { + min: Bound::Excluded(Arc::new(DataValue::Int32(Some(0)))), + max: Bound::Unbounded + })); + + let cb_4_c2 = op_4.predicate.convert_binary(&1).unwrap(); + println!("op_4 => c2: {:#?}", cb_4_c2); + assert_eq!(cb_4_c2, Some(ConstantBinary::Scope { + min: Bound::Unbounded, + max: Bound::Excluded(Arc::new(DataValue::Int32(Some(-1)))) + })); + + Ok(()) + } + + #[tokio::test] + async fn test_simplify_filter_multiple_column_in_or() -> Result<(), DatabaseError> { + // c1 + 1 < -1 => c1 < -2 + let plan_1 = select_sql_run("select * from t1 where c1 > c2 or c1 > 1").await?; + + let op = |plan: LogicalPlan, expr: &str| -> Result, DatabaseError> { + let best_plan = HepOptimizer::new(plan.clone()) + .batch( + "test_simplify_filter".to_string(), + HepBatchStrategy::once_topdown(), + vec![RuleImpl::SimplifyFilter] + ) + .find_best()?; + if let Operator::Filter(filter_op) = best_plan.childrens[0].clone().operator { + println!("{expr}: {:#?}", filter_op); + + Ok(Some(filter_op)) + } else { + Ok(None) + } + }; + + let op_1 = op(plan_1, "c1 > c2 or c1 > 1")?.unwrap(); + + let cb_1_c1 = op_1.predicate.convert_binary(&0).unwrap(); + println!("op_1 => c1: {:#?}", cb_1_c1); + assert_eq!(cb_1_c1, None); + + Ok(()) + } +} \ No newline at end of file diff --git a/src/planner/operator/mod.rs b/src/planner/operator/mod.rs index 39cd1e18..3e5560bb 100644 --- a/src/planner/operator/mod.rs +++ b/src/planner/operator/mod.rs @@ -125,11 +125,7 @@ impl Operator { .collect_vec() } Operator::Scan(op) => { - op.sort_fields - .iter() - .map(|field| &field.expr) - .chain(op.columns.iter()) - .chain(op.pre_where.iter()) + op.columns.iter() .flat_map(|expr| expr.referenced_columns()) .collect_vec() } diff --git a/src/planner/operator/scan.rs b/src/planner/operator/scan.rs index b6666234..d40d2d1f 100644 --- a/src/planner/operator/scan.rs +++ b/src/planner/operator/scan.rs @@ -1,24 +1,26 @@ use itertools::Itertools; use crate::catalog::{TableCatalog, TableName}; use crate::expression::ScalarExpression; +use crate::expression::simplify::ConstantBinary; use crate::planner::LogicalPlan; use crate::storage::Bounds; +use crate::types::index::IndexMetaRef; -use super::{sort::SortField, Operator}; +use super::Operator; #[derive(Debug, PartialEq, Clone)] pub struct ScanOperator { + pub index_metas: Vec, + pub table_name: TableName, pub columns: Vec, // Support push down limit. pub limit: Bounds, - // IndexScan only - pub sort_fields: Vec, // IndexScan only // Support push down predicate. // If pre_where is simple predicate, for example: a > 1 then can calculate directly when read data. - pub pre_where: Vec, + pub index_by: Option<(IndexMetaRef, Vec)>, } impl ScanOperator { pub fn new(table_name: TableName, table_catalog: &TableCatalog) -> LogicalPlan { @@ -31,11 +33,12 @@ impl ScanOperator { LogicalPlan { operator: Operator::Scan(ScanOperator { + index_metas: table_catalog.indexes.clone(), table_name, columns, - sort_fields: vec![], - pre_where: vec![], + limit: (None, None), + index_by: None, }), childrens: vec![], } diff --git a/src/storage/kip.rs b/src/storage/kip.rs index 22490d39..bb08d65c 100644 --- a/src/storage/kip.rs +++ b/src/storage/kip.rs @@ -1,18 +1,24 @@ -use std::collections::Bound; +use std::collections::{Bound, VecDeque}; use std::collections::hash_map::RandomState; +use std::mem; +use std::ops::SubAssign; use std::path::PathBuf; use std::sync::Arc; use async_trait::async_trait; use kip_db::kernel::lsm::mvcc::TransactionIter; use kip_db::kernel::lsm::{mvcc, storage}; -use kip_db::kernel::lsm::iterator::Iter; +use kip_db::kernel::lsm::iterator::Iter as KipDBIter; use kip_db::kernel::lsm::storage::Config; -use kip_db::kernel::Storage as Kip_Storage; +use kip_db::kernel::Storage as KipDBStorage; use kip_db::kernel::utils::lru_cache::ShardingLruCache; use crate::catalog::{ColumnCatalog, TableCatalog, TableName}; -use crate::storage::{Bounds, Projections, Storage, StorageError, Table, Transaction}; +use crate::expression::simplify::ConstantBinary; +use crate::storage::{Bounds, Projections, Storage, StorageError, Transaction, Iter, tuple_projection, IndexIter}; use crate::storage::table_codec::TableCodec; +use crate::types::errors::TypeError; +use crate::types::index::{Index, IndexMeta, IndexMetaRef}; use crate::types::tuple::{Tuple, TupleId}; +use crate::types::value::ValueRef; #[derive(Clone)] pub struct KipStorage { @@ -27,34 +33,118 @@ impl KipStorage { Ok(KipStorage { cache: Arc::new(ShardingLruCache::new( - 128, + 32, 16, RandomState::default(), )?), inner: Arc::new(storage), }) } + + fn column_collect(name: &String, tx: &mvcc::Transaction) -> Result<(Vec, Option), StorageError> { + let (column_min, column_max) = TableCodec::columns_bound(name); + let mut column_iter = tx.iter(Bound::Included(&column_min), Bound::Included(&column_max))?; + + let mut columns = vec![]; + let mut name_option = None; + + while let Some((_, value_option)) = column_iter.try_next().ok().flatten() { + if let Some(value) = value_option { + let (table_name, column) = TableCodec::decode_column(&value)?; + + if name != table_name.as_str() { + return Ok((vec![], None)); + } + let _ = name_option.insert(table_name); + + columns.push(column); + } + } + + Ok((columns, name_option)) + } + + fn index_meta_collect(name: &String, tx: &mvcc::Transaction) -> Option> { + let (index_min, index_max) = TableCodec::index_meta_bound(name); + let mut index_metas = vec![]; + let mut index_iter = tx.iter(Bound::Included(&index_min), Bound::Included(&index_max)).ok()?; + + while let Some((_, value_option)) = index_iter.try_next().ok().flatten() { + if let Some(value) = value_option { + if let Some(index_meta) = TableCodec::decode_index_meta(&value).ok() { + index_metas.push(Arc::new(index_meta)); + } + } + } + + Some(index_metas) + } + + fn _drop_data(table: &mut KipTransaction, min: &[u8], max: &[u8]) -> Result<(), StorageError> { + let mut iter = table.tx.iter(Bound::Included(&min), Bound::Included(&max))?; + let mut data_keys = vec![]; + + while let Some((key, value_option)) = iter.try_next()? { + if value_option.is_some() { + data_keys.push(key); + } + } + drop(iter); + + for key in data_keys { + table.tx.remove(&key)? + } + + Ok(()) + } + + fn create_index_meta_for_table( + tx: &mut mvcc::Transaction, + table: &mut TableCatalog + ) -> Result<(), StorageError> { + let table_name = table.name.clone(); + + for col in table.all_columns() + .into_iter() + .filter(|col| col.desc.is_unique) + { + if let Some(col_id) = col.id { + let meta = IndexMeta { + id: 0, + column_ids: vec![col_id], + name: format!("uk_{}", col.name), + is_unique: true, + }; + let meta_ref = table.add_index_meta(meta); + let (key, value) = TableCodec::encode_index_meta(&table_name, meta_ref)?; + + tx.set(key, value); + } + } + Ok(()) + } } #[async_trait] impl Storage for KipStorage { - type TableType = KipTable; + type TransactionType = KipTransaction; async fn create_table(&self, table_name: TableName, columns: Vec) -> Result { - let table = TableCatalog::new(table_name.clone(), columns)?; + let mut tx = self.inner.new_transaction().await; + let mut table_catalog = TableCatalog::new(table_name.clone(), columns)?; - for (key, value) in table.columns - .iter() - .filter_map(|(_, col)| TableCodec::encode_column(col)) - { - self.inner.set(key, value).await?; + Self::create_index_meta_for_table(&mut tx, &mut table_catalog)?; + + for (_, column) in &table_catalog.columns { + let (key, value) = TableCodec::encode_column(column)?; + tx.set(key, value); } - let (k, v)= TableCodec::encode_root_table(table_name.as_str(), table.columns.len()) - .ok_or(StorageError::Serialization)?; + let (k, v)= TableCodec::encode_root_table(&table_name)?; self.inner.set(k, v).await?; - self.cache.put(table_name.to_string(), table); + tx.commit().await?; + self.cache.put(table_name.to_string(), table_catalog); Ok(table_name) } @@ -77,9 +167,7 @@ impl Storage for KipStorage { for col_key in col_keys { tx.remove(&col_key)? } - let (k, _) = TableCodec::encode_root_table(name.as_str(),0) - .ok_or(StorageError::Serialization)?; - tx.remove(&k)?; + tx.remove(&TableCodec::encode_root_table_key(name))?; tx.commit().await?; let _ = self.cache.remove(name); @@ -88,59 +176,41 @@ impl Storage for KipStorage { } async fn drop_data(&self, name: &String) -> Result<(), StorageError> { - if let Some(mut table) = self.table(name).await { - let (min, max) = table.table_codec.tuple_bound(); - let mut iter = table.tx.iter(Bound::Included(&min), Bound::Included(&max))?; - let mut data_keys = vec![]; - - while let Some((key, value_option)) = iter.try_next()? { - if value_option.is_some() { - data_keys.push(key); - } - } - drop(iter); + if let Some(mut transaction) = self.transaction(name).await { - for col_key in data_keys { - table.tx.remove(&col_key)? - } - table.tx.commit().await?; + let (tuple_min, tuple_max) = transaction.table_codec.tuple_bound(); + Self::_drop_data(&mut transaction, &tuple_min, &tuple_max)?; + + let (index_min, index_max) = transaction.table_codec.all_index_bound(); + Self::_drop_data(&mut transaction, &index_min, &index_max)?; + + transaction.tx.commit().await?; } Ok(()) } - async fn table(&self, name: &String) -> Option { - let table_codec = self.table_catalog(name) + async fn transaction(&self, name: &String) -> Option { + let table_codec = self.table(name) .await .map(|catalog| TableCodec { table: catalog.clone() })?; let tx = self.inner.new_transaction().await; - Some(KipTable { table_codec, tx, }) + Some(KipTransaction { table_codec, tx, }) } - async fn table_catalog(&self, name: &String) -> Option<&TableCatalog> { + async fn table(&self, name: &String) -> Option<&TableCatalog> { let mut option = self.cache.get(name); if option.is_none() { - let (min, max) = TableCodec::columns_bound(name); let tx = self.inner.new_transaction().await; - let mut iter = tx.iter(Bound::Included(&min), Bound::Included(&max)).ok()?; - - let mut columns = vec![]; - let mut name_option = None; - - while let Some((key, value_option)) = iter.try_next().ok().flatten() { - if let Some(value) = value_option { - if let Some((table_name, column)) = TableCodec::decode_column(&key, &value) { - if name != table_name.as_str() { return None; } - let _ = name_option.insert(table_name); + // TODO: unify the data into a `Meta` prefix and use one iteration to collect all data + let (columns, name_option) = Self::column_collect(name, &tx).ok()?; + let indexes = Self::index_meta_collect(name, &tx)?; - columns.push(column); - } - } - } - - if let Some(catalog) = name_option.and_then(|table_name| TableCatalog::new(table_name, columns).ok()) { + if let Some(catalog) = name_option + .and_then(|table_name| TableCatalog::new_with_indexes(table_name, columns, indexes).ok()) + { option = self.cache.get_or_insert(name.to_string(), |_| Ok(catalog)).ok(); } } @@ -148,39 +218,39 @@ impl Storage for KipStorage { option } - async fn show_tables(&self) -> Option> { + async fn show_tables(&self) -> Result, StorageError> { let mut tables = vec![]; let (min, max) = TableCodec::root_table_bound(); let tx = self.inner.new_transaction().await; - let mut iter = tx.iter(Bound::Included(&min), Bound::Included(&max)).ok()?; + let mut iter = tx.iter(Bound::Included(&min), Bound::Included(&max))?; - while let Some((key, value_option)) = iter.try_next().ok().flatten() { + while let Some((_, value_option)) = iter.try_next().ok().flatten() { if let Some(value) = value_option { - if let Some((table_name, column_count)) = TableCodec::decode_root_table(&key, &value) { - tables.push((table_name,column_count)); - } + let table_name = TableCodec::decode_root_table(&value)?; + + tables.push(table_name); } } - Some(tables) + Ok(tables) } } -pub struct KipTable { +pub struct KipTransaction { table_codec: TableCodec, tx: mvcc::Transaction } #[async_trait] -impl Table for KipTable { - type TransactionType<'a> = KipTraction<'a>; +impl Transaction for KipTransaction { + type IterType<'a> = KipIter<'a>; - fn read(&self, bounds: Bounds, projections: Projections) -> Result, StorageError> { + fn read(&self, bounds: Bounds, projections: Projections) -> Result, StorageError> { let (min, max) = self.table_codec.tuple_bound(); let iter = self.tx.iter(Bound::Included(&min), Bound::Included(&max))?; - Ok(KipTraction { + Ok(KipIter { offset: bounds.0.unwrap_or(0), limit: bounds.1, projections, @@ -189,6 +259,94 @@ impl Table for KipTable { }) } + fn read_by_index( + &self, + (offset_option, mut limit_option): Bounds, + projections: Projections, + index_meta: IndexMetaRef, + binaries: Vec + ) -> Result, StorageError> { + let mut tuple_ids = Vec::new(); + let mut offset = offset_option.unwrap_or(0); + + for binary in binaries { + if matches!(limit_option, Some(0)) { + break; + } + + match binary { + ConstantBinary::Scope { min, max } => { + let mut iter = self.scope_to_iter(&index_meta, min, max)?; + + while let Some((_, value_option)) = iter.try_next()? { + if let Some(value) = value_option { + for id in TableCodec::decode_index(&value)? { + if Self::offset_move(&mut offset) { continue; } + + tuple_ids.push(id); + + if Self::limit_move(&mut limit_option) { break; } + } + } + + if matches!(limit_option, Some(0)) { + break; + } + } + } + ConstantBinary::Eq(val) => { + if Self::offset_move(&mut offset) { continue; } + + let key = self.val_to_key(&index_meta, val)?; + + if let Some(bytes) = self.tx.get(&key)? { + tuple_ids.append(&mut TableCodec::decode_index(&bytes)?) + } + + let _ = Self::limit_move(&mut limit_option); + } + _ => () + } + } + + Ok(IndexIter { + projections, + table_codec: &self.table_codec, + tuple_ids: VecDeque::from(tuple_ids), + tx: &self.tx, + }) + } + + fn add_index(&mut self, index: Index, tuple_ids: Vec, is_unique: bool) -> Result<(), StorageError> { + let (key, value) = self.table_codec.encode_index(&index, &tuple_ids)?; + + if let Some(bytes) = self.tx.get(&key)? { + if is_unique { + let old_tuple_ids = TableCodec::decode_index(&bytes)?; + + if old_tuple_ids[0] != tuple_ids[0] { + return Err(StorageError::DuplicateUniqueValue); + } else { + return Ok(()); + } + } else { + todo!("联合索引") + } + } + + self.tx.set(key, value); + + Ok(()) + } + + fn del_index(&mut self, index: &Index) -> Result<(), StorageError> { + let key = self.table_codec.encode_index_key(&index)?; + + self.tx.remove(&key)?; + + Ok(()) + } + fn append(&mut self, tuple: Tuple, is_overwrite: bool) -> Result<(), StorageError> { let (key, value) = self.table_codec.encode_tuple(&tuple)?; @@ -214,7 +372,71 @@ impl Table for KipTable { } } -pub struct KipTraction<'a> { +impl KipTransaction { + fn val_to_key(&self, index_meta: &IndexMetaRef, val: ValueRef) -> Result, TypeError> { + let index = Index::new(index_meta.id, vec![val]); + + self.table_codec.encode_index_key(&index) + } + + fn scope_to_iter( + &self, + index_meta: &IndexMetaRef, + min: Bound, + max: Bound + ) -> Result { + let bound_encode = |bound: Bound| -> Result<_, StorageError> { + match bound { + Bound::Included(val) => { + Ok(Bound::Included(self.val_to_key(&index_meta, val)?)) + }, + Bound::Excluded(val) => { + Ok(Bound::Excluded(self.val_to_key(&index_meta, val)?)) + } + Bound::Unbounded => Ok(Bound::Unbounded) + } + }; + let check_bound = |value: &mut Bound>, bound: Vec| { + if matches!(value, Bound::Unbounded) { + let _ = mem::replace(value, Bound::Included(bound)); + } + }; + let (bound_min, bound_max) = self.table_codec.index_bound(&index_meta.id); + + let mut encode_min = bound_encode(min)?; + check_bound(&mut encode_min, bound_min); + + let mut encode_max = bound_encode(max)?; + check_bound(&mut encode_max, bound_max); + + Ok(self.tx.iter( + encode_min.as_ref().map(Vec::as_slice), + encode_max.as_ref().map(Vec::as_slice), + )?) + } + + fn offset_move(offset: &mut usize) -> bool { + if *offset > 0 { + offset.sub_assign(1); + + true + } else { + false + } + } + + fn limit_move(limit_option: &mut Option) -> bool { + if let Some(limit) = limit_option { + limit.sub_assign(1); + + return *limit == 0; + } + + false + } +} + +pub struct KipIter<'a> { offset: usize, limit: Option, projections: Projections, @@ -222,7 +444,7 @@ pub struct KipTraction<'a> { iter: TransactionIter<'a> } -impl Transaction for KipTraction<'_> { +impl Iter for KipIter<'_> { fn next_tuple(&mut self) -> Result, StorageError> { while self.offset > 0 { let _ = self.iter.try_next()?; @@ -237,25 +459,13 @@ impl Transaction for KipTraction<'_> { while let Some(item) = self.iter.try_next()? { if let (_, Some(value)) = item { - let tuple = self.table_codec.decode_tuple(&value); - - let projection_len = self.projections.len(); - - let mut columns = Vec::with_capacity(projection_len); - let mut values = Vec::with_capacity(projection_len); - - for expr in self.projections.iter() { - values.push(expr.eval_column(&tuple)?); - columns.push(expr.output_columns(&tuple)); - } - - self.limit = self.limit.map(|num| num - 1); + let tuple = tuple_projection( + &mut self.limit, + &self.projections, + self.table_codec.decode_tuple(&value) + )?; - return Ok(Some(Tuple { - id: tuple.id, - columns, - values, - })) + return Ok(Some(tuple)) } } @@ -265,14 +475,18 @@ impl Transaction for KipTraction<'_> { #[cfg(test)] mod test { + use std::collections::{Bound, VecDeque}; use std::sync::Arc; use itertools::Itertools; use tempfile::TempDir; use crate::catalog::{ColumnCatalog, ColumnDesc}; + use crate::db::{Database, DatabaseError}; use crate::expression::ScalarExpression; + use crate::expression::simplify::ConstantBinary; use crate::storage::kip::KipStorage; - use crate::storage::{Storage, StorageError, Transaction, Table}; + use crate::storage::{Storage, StorageError, Iter, Transaction, IndexIter}; use crate::storage::memory::test::data_filling; + use crate::storage::table_codec::TableCodec; use crate::types::LogicalType; use crate::types::value::DataValue; @@ -284,12 +498,12 @@ mod test { Arc::new(ColumnCatalog::new( "c1".to_string(), false, - ColumnDesc::new(LogicalType::Integer, true) + ColumnDesc::new(LogicalType::Integer, true, false) )), Arc::new(ColumnCatalog::new( "c2".to_string(), false, - ColumnDesc::new(LogicalType::Boolean, false) + ColumnDesc::new(LogicalType::Boolean, false, false) )), ]; @@ -298,24 +512,100 @@ mod test { .collect_vec(); let table_id = storage.create_table(Arc::new("test".to_string()), source_columns).await?; - let table_catalog = storage.table_catalog(&"test".to_string()).await; + let table_catalog = storage.table(&"test".to_string()).await; assert!(table_catalog.is_some()); assert!(table_catalog.unwrap().get_column_id_by_name(&"c1".to_string()).is_some()); - let mut table = storage.table(&table_id).await.unwrap(); - data_filling(columns, &mut table)?; + let mut transaction = storage.transaction(&table_id).await.unwrap(); + data_filling(columns, &mut transaction)?; - let mut tx = table.read( + let mut iter = transaction.read( (Some(1), Some(1)), vec![ScalarExpression::InputRef { index: 0, ty: LogicalType::Integer }] )?; - let option_1 = tx.next_tuple()?; + let option_1 = iter.next_tuple()?; assert_eq!(option_1.unwrap().id, Some(Arc::new(DataValue::Int32(Some(2))))); - let option_2 = tx.next_tuple()?; + let option_2 = iter.next_tuple()?; assert_eq!(option_2, None); Ok(()) } + + #[tokio::test] + async fn test_index_iter() -> Result<(), DatabaseError> { + let temp_dir = TempDir::new().expect("unable to create temporary working directory"); + let kipsql = Database::with_kipdb(temp_dir.path()).await?; + + let _ = kipsql.run("create table t1 (a int primary key)").await?; + let _ = kipsql.run("insert into t1 (a) values (0), (1), (2)").await?; + + let table = kipsql.storage.table(&"t1".to_string()).await.unwrap().clone(); + let projections = table.all_columns() + .into_iter() + .map(|col| ScalarExpression::ColumnRef(col)) + .collect_vec(); + let codec = TableCodec { + table, + }; + let tx = kipsql.storage.transaction(&"t1".to_string()).await.unwrap(); + let tuple_ids = vec![ + Arc::new(DataValue::Int32(Some(0))), + Arc::new(DataValue::Int32(Some(1))), + Arc::new(DataValue::Int32(Some(2))), + ]; + let mut iter = IndexIter { + projections, + table_codec: &codec, + tuple_ids: VecDeque::from(tuple_ids.clone()), + tx: &tx.tx, + }; + let mut result = Vec::new(); + + while let Some(tuple) = iter.next_tuple()? { + result.push(tuple.id.unwrap()); + } + + assert_eq!(result, tuple_ids); + + Ok(()) + } + + #[tokio::test] + async fn test_read_by_index() -> Result<(), DatabaseError> { + let temp_dir = TempDir::new().expect("unable to create temporary working directory"); + let kipsql = Database::with_kipdb(temp_dir.path()).await?; + + let _ = kipsql.run("create table t1 (a int primary key, b int unique)").await?; + let _ = kipsql.run("insert into t1 (a, b) values (0, 0), (1, 1), (2, 2)").await?; + + let table = kipsql.storage.table(&"t1".to_string()).await.unwrap().clone(); + let projections = table.all_columns() + .into_iter() + .map(|col| ScalarExpression::ColumnRef(col)) + .collect_vec(); + let transaction = kipsql.storage.transaction(&"t1".to_string()).await.unwrap(); + let mut iter = transaction.read_by_index( + (Some(0), Some(1)), + projections, + table.indexes[0].clone(), + vec![ + ConstantBinary::Scope { + min: Bound::Excluded(Arc::new(DataValue::Int32(Some(0)))), + max: Bound::Unbounded + } + ] + ).unwrap(); + + while let Some(tuple) = iter.next_tuple()? { + assert_eq!(tuple.id, Some(Arc::new(DataValue::Int32(Some(1))))); + assert_eq!(tuple.values, vec![ + Arc::new(DataValue::Int32(Some(1))), + Arc::new(DataValue::Int32(Some(1))) + ]) + } + + Ok(()) + } } \ No newline at end of file diff --git a/src/storage/memory.rs b/src/storage/memory.rs index 6549e32a..aaf241a8 100644 --- a/src/storage/memory.rs +++ b/src/storage/memory.rs @@ -4,7 +4,9 @@ use std::slice; use std::sync::Arc; use async_trait::async_trait; use crate::catalog::{ColumnCatalog, RootCatalog, TableCatalog, TableName}; -use crate::storage::{Bounds, Projections, Storage, StorageError, Table, Transaction}; +use crate::expression::simplify::ConstantBinary; +use crate::storage::{Bounds, Projections, Storage, StorageError, Transaction, Iter, tuple_projection, IndexIter}; +use crate::types::index::{Index, IndexMetaRef}; use crate::types::tuple::{Tuple, TupleId}; // WARRING: Only single-threaded and tested using @@ -51,7 +53,7 @@ struct StorageInner { #[async_trait] impl Storage for MemStorage { - type TableType = MemTable; + type TransactionType = MemTable; async fn create_table(&self, table_name: TableName, columns: Vec) -> Result { let new_table = MemTable { @@ -91,7 +93,7 @@ impl Storage for MemStorage { Ok(()) } - async fn table(&self, name: &String) -> Option { + async fn transaction(&self, name: &String) -> Option { unsafe { self.inner .as_ptr() @@ -104,7 +106,7 @@ impl Storage for MemStorage { } } - async fn table_catalog(&self, name: &String) -> Option<&TableCatalog> { + async fn table(&self, name: &String) -> Option<&TableCatalog> { unsafe { self.inner .as_ptr() @@ -115,7 +117,7 @@ impl Storage for MemStorage { } } - async fn show_tables(&self) -> Option> { + async fn show_tables(&self) -> Result, StorageError> { todo!() } } @@ -144,10 +146,10 @@ impl Debug for MemTable { } #[async_trait] -impl Table for MemTable { - type TransactionType<'a> = MemTraction<'a>; +impl Transaction for MemTable { + type IterType<'a> = MemTraction<'a>; - fn read(&self, bounds: Bounds, projection: Projections) -> Result, StorageError> { + fn read(&self, bounds: Bounds, projection: Projections) -> Result, StorageError> { unsafe { Ok( MemTraction { @@ -160,6 +162,18 @@ impl Table for MemTable { } } + fn read_by_index(&self, bounds: Bounds, projection: Projections, index_meta: IndexMetaRef, binaries: Vec) -> Result, StorageError> { + todo!() + } + + fn add_index(&mut self, index: Index, tuple_ids: Vec, is_unique: bool) -> Result<(), StorageError> { + todo!() + } + + fn del_index(&mut self, _index: &Index) -> Result<(), StorageError> { + todo!() + } + fn append(&mut self, tuple: Tuple, is_overwrite: bool) -> Result<(), StorageError> { let tuples = unsafe { self.tuples @@ -203,7 +217,7 @@ pub struct MemTraction<'a> { iter: slice::Iter<'a, Tuple> } -impl Transaction for MemTraction<'_> { +impl Iter for MemTraction<'_> { fn next_tuple(&mut self) -> Result, StorageError> { while self.offset > 0 { let _ = self.iter.next(); @@ -219,25 +233,7 @@ impl Transaction for MemTraction<'_> { self.iter .next() .cloned() - .map(|tuple| { - let projection_len = self.projections.len(); - - let mut columns = Vec::with_capacity(projection_len); - let mut values = Vec::with_capacity(projection_len); - - for expr in self.projections.iter() { - values.push(expr.eval_column(&tuple)?); - columns.push(expr.output_columns(&tuple)); - } - - self.limit = self.limit.map(|num| num - 1); - - Ok(Tuple { - id: tuple.id, - columns, - values, - }) - }) + .map(|tuple| tuple_projection(&mut self.limit, &self.projections, tuple)) .transpose() } } @@ -249,12 +245,12 @@ pub(crate) mod test { use crate::catalog::{ColumnCatalog, ColumnDesc, ColumnRef}; use crate::expression::ScalarExpression; use crate::storage::memory::MemStorage; - use crate::storage::{Storage, StorageError, Table, Transaction}; + use crate::storage::{Storage, StorageError, Transaction, Iter}; use crate::types::LogicalType; use crate::types::tuple::Tuple; use crate::types::value::DataValue; - pub fn data_filling(columns: Vec, table: &mut impl Table) -> Result<(), StorageError> { + pub fn data_filling(columns: Vec, table: &mut impl Transaction) -> Result<(), StorageError> { table.append(Tuple { id: Some(Arc::new(DataValue::Int32(Some(1)))), columns: columns.clone(), @@ -282,12 +278,12 @@ pub(crate) mod test { Arc::new(ColumnCatalog::new( "c1".to_string(), false, - ColumnDesc::new(LogicalType::Integer, true) + ColumnDesc::new(LogicalType::Integer, true, false) )), Arc::new(ColumnCatalog::new( "c2".to_string(), false, - ColumnDesc::new(LogicalType::Boolean, false) + ColumnDesc::new(LogicalType::Boolean, false, false) )), ]; @@ -297,22 +293,22 @@ pub(crate) mod test { let table_id = storage.create_table(Arc::new("test".to_string()), source_columns).await?; - let table_catalog = storage.table_catalog(&"test".to_string()).await; + let table_catalog = storage.table(&"test".to_string()).await; assert!(table_catalog.is_some()); assert!(table_catalog.unwrap().get_column_id_by_name(&"c1".to_string()).is_some()); - let mut table = storage.table(&table_id).await.unwrap(); - data_filling(columns, &mut table)?; + let mut transaction = storage.transaction(&table_id).await.unwrap(); + data_filling(columns, &mut transaction)?; - let mut tx = table.read( + let mut iter = transaction.read( (Some(1), Some(1)), vec![ScalarExpression::InputRef { index: 0, ty: LogicalType::Integer }] )?; - let option_1 = tx.next_tuple()?; + let option_1 = iter.next_tuple()?; assert_eq!(option_1.unwrap().id, Some(Arc::new(DataValue::Int32(Some(2))))); - let option_2 = tx.next_tuple()?; + let option_2 = iter.next_tuple()?; assert_eq!(option_2, None); Ok(()) diff --git a/src/storage/mod.rs b/src/storage/mod.rs index 98f2ccdc..553e9662 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -2,17 +2,23 @@ pub mod memory; mod table_codec; pub mod kip; +use std::collections::VecDeque; +use std::ops::SubAssign; use async_trait::async_trait; use kip_db::error::CacheError; +use kip_db::kernel::lsm::mvcc; use kip_db::KernelError; use crate::catalog::{CatalogError, ColumnCatalog, TableCatalog, TableName}; use crate::expression::ScalarExpression; +use crate::expression::simplify::ConstantBinary; +use crate::storage::table_codec::TableCodec; use crate::types::errors::TypeError; +use crate::types::index::{Index, IndexMetaRef}; use crate::types::tuple::{Tuple, TupleId}; #[async_trait] pub trait Storage: Sync + Send + Clone + 'static { - type TableType: Table; + type TransactionType: Transaction; async fn create_table( &self, @@ -23,10 +29,10 @@ pub trait Storage: Sync + Send + Clone + 'static { async fn drop_table(&self, name: &String) -> Result<(), StorageError>; async fn drop_data(&self, name: &String) -> Result<(), StorageError>; - async fn table(&self, name: &String) -> Option; - async fn table_catalog(&self, name: &String) -> Option<&TableCatalog>; + async fn transaction(&self, name: &String) -> Option; + async fn table(&self, name: &String) -> Option<&TableCatalog>; - async fn show_tables(&self) -> Option>; + async fn show_tables(&self) -> Result, StorageError>; } /// Optional bounds of the reader, of the form (offset, limit). @@ -34,8 +40,8 @@ pub(crate) type Bounds = (Option, Option); type Projections = Vec; #[async_trait] -pub trait Table: Sync + Send + 'static { - type TransactionType<'a>: Transaction; +pub trait Transaction: Sync + Send + 'static { + type IterType<'a>: Iter; /// The bounds is applied to the whole data batches, not per batch. /// @@ -44,7 +50,19 @@ pub trait Table: Sync + Send + 'static { &self, bounds: Bounds, projection: Projections, - ) -> Result, StorageError>; + ) -> Result, StorageError>; + + fn read_by_index( + &self, + bounds: Bounds, + projection: Projections, + index_meta: IndexMetaRef, + binaries: Vec + ) -> Result, StorageError>; + + fn add_index(&mut self, index: Index, tuple_ids: Vec, is_unique: bool) -> Result<(), StorageError>; + + fn del_index(&mut self, index: &Index) -> Result<(), StorageError>; fn append(&mut self, tuple: Tuple, is_overwrite: bool) -> Result<(), StorageError>; @@ -53,10 +71,61 @@ pub trait Table: Sync + Send + 'static { async fn commit(self) -> Result<(), StorageError>; } -pub trait Transaction: Sync + Send { +// TODO: Table return optimization +pub struct IndexIter<'a> { + projections: Projections, + table_codec: &'a TableCodec, + tuple_ids: VecDeque, + tx: &'a mvcc::Transaction +} + +impl Iter for IndexIter<'_> { + fn next_tuple(&mut self) -> Result, StorageError> { + if let Some(tuple_id) = self.tuple_ids.pop_front() { + let key = self.table_codec.encode_tuple_key(&tuple_id)?; + + Ok(self.tx.get(&key)? + .map(|bytes| tuple_projection( + &mut None, + &self.projections, + self.table_codec.decode_tuple(&bytes) + )) + .transpose()?) + } else { + Ok(None) + } + } +} + +pub trait Iter: Sync + Send { fn next_tuple(&mut self) -> Result, StorageError>; } +pub(crate) fn tuple_projection( + limit: &mut Option, + projections: &Projections, + tuple: Tuple +) -> Result { + let projection_len = projections.len(); + let mut columns = Vec::with_capacity(projection_len); + let mut values = Vec::with_capacity(projection_len); + + for expr in projections.iter() { + values.push(expr.eval_column(&tuple)?); + columns.push(expr.output_columns(&tuple)); + } + + if let Some(num) = limit { + num.sub_assign(1); + } + + Ok(Tuple { + id: tuple.id, + columns, + values, + }) +} + #[derive(thiserror::Error, Debug)] pub enum StorageError { #[error("catalog error")] @@ -74,9 +143,8 @@ pub enum StorageError { #[error("The same primary key data already exists")] DuplicatePrimaryKey, - #[error("Serialization error")] - Serialization, - + #[error("The column has been declared unique and the value already exists")] + DuplicateUniqueValue, } impl From for StorageError { diff --git a/src/storage/table_codec.rs b/src/storage/table_codec.rs index aab27fef..e8afa7fa 100644 --- a/src/storage/table_codec.rs +++ b/src/storage/table_codec.rs @@ -1,45 +1,135 @@ -use std::sync::Arc; use bytes::Bytes; -use crate::catalog::{ColumnCatalog, ColumnRef, TableCatalog, TableName}; +use lazy_static::lazy_static; +use crate::catalog::{ColumnCatalog, TableCatalog, TableName}; use crate::types::errors::TypeError; +use crate::types::index::{Index, IndexId, IndexMeta}; use crate::types::tuple::{Tuple, TupleId}; const BOUND_MIN_TAG: u8 = 0; const BOUND_MAX_TAG: u8 = 1; - -const COLUMNS_ID_LEN: usize = 10; +lazy_static! { + static ref ROOT_BYTES: Vec = { + b"Root".to_vec() + }; +} #[derive(Clone)] pub struct TableCodec { pub table: TableCatalog } +#[derive(Copy, Clone)] +enum CodecType { + Column, + IndexMeta, + Index, + Tuple, + Root, +} + impl TableCodec { + /// TableName + Type + /// + /// Tips: Root full key = key_prefix + fn key_prefix(ty: CodecType, table_name: &String) -> Vec { + let mut table_bytes = table_name + .clone() + .into_bytes(); + + match ty { + CodecType::Column => { + table_bytes.push(b'0'); + } + CodecType::IndexMeta => { + table_bytes.push(b'1'); + } + CodecType::Index => { + table_bytes.push(b'2'); + } + CodecType::Tuple => { + table_bytes.push(b'3'); + } + CodecType::Root => { + let mut bytes = ROOT_BYTES.clone(); + bytes.push(BOUND_MIN_TAG); + bytes.append(&mut table_bytes); + + table_bytes = bytes + } + } + + table_bytes + } + pub fn tuple_bound(&self) -> (Vec, Vec) { let op = |bound_id| { - format!( - "{}_Data_{}", - self.table.name, - bound_id - ) + let mut key_prefix = Self::key_prefix(CodecType::Tuple, &self.table.name); + + key_prefix.push(bound_id); + key_prefix + }; + + (op(BOUND_MIN_TAG), op(BOUND_MAX_TAG)) + } + + pub fn index_meta_bound(name: &String) -> (Vec, Vec) { + let op = |bound_id| { + let mut key_prefix = Self::key_prefix(CodecType::IndexMeta, name); + + key_prefix.push(bound_id); + key_prefix + }; + + (op(BOUND_MIN_TAG), op(BOUND_MAX_TAG)) + } + + pub fn index_bound(&self, index_id: &IndexId) -> (Vec, Vec) { + let op = |bound_id| { + let mut key_prefix = Self::key_prefix(CodecType::Index, &self.table.name); + + key_prefix.push(BOUND_MIN_TAG); + key_prefix.append(&mut index_id.to_be_bytes().to_vec()); + key_prefix.push(bound_id); + key_prefix + }; + + (op(BOUND_MIN_TAG), op(BOUND_MAX_TAG)) + } + + pub fn all_index_bound(&self) -> (Vec, Vec) { + let op = |bound_id| { + let mut key_prefix = Self::key_prefix(CodecType::Index, &self.table.name); + + key_prefix.push(bound_id); + key_prefix + }; + + (op(BOUND_MIN_TAG), op(BOUND_MAX_TAG)) + } + + pub fn root_table_bound() -> (Vec, Vec) { + let op = |bound_id| { + let mut key_prefix = ROOT_BYTES.clone(); + + key_prefix.push(bound_id); + key_prefix }; - (op(BOUND_MIN_TAG).into_bytes(), op(BOUND_MAX_TAG).into_bytes()) + (op(BOUND_MIN_TAG), op(BOUND_MAX_TAG)) } pub fn columns_bound(name: &String) -> (Vec, Vec) { let op = |bound_id| { - format!( - "{}_Catalog_{}", - name, - bound_id - ) + let mut key_prefix = Self::key_prefix(CodecType::Column, &name); + + key_prefix.push(bound_id); + key_prefix }; - (op(BOUND_MIN_TAG).into_bytes(), op(BOUND_MAX_TAG).into_bytes()) + (op(BOUND_MIN_TAG), op(BOUND_MAX_TAG)) } - /// Key: TableName_Data_0_RowID(Sorted) + /// Key: TableName_Tuple_0_RowID(Sorted) /// Value: Tuple pub fn encode_tuple(&self, tuple: &Tuple) -> Result<(Bytes, Bytes), TypeError> { let tuple_id = tuple @@ -52,91 +142,100 @@ impl TableCodec { } pub fn encode_tuple_key(&self, tuple_id: &TupleId) -> Result, TypeError> { - let string_key = format!( - "{}_Data_0_{}", - self.table.name, - tuple_id.to_primary_key()?, - ); + let mut key_prefix = Self::key_prefix(CodecType::Tuple, &self.table.name); + key_prefix.push(BOUND_MIN_TAG); + + tuple_id.to_primary_key(&mut key_prefix)?; - Ok(string_key.into_bytes()) + Ok(key_prefix) } pub fn decode_tuple(&self, bytes: &[u8]) -> Tuple { Tuple::deserialize_from(self.table.all_columns(), bytes) } + /// Key: TableName_IndexMeta_0_IndexID + /// Value: IndexMeta + pub fn encode_index_meta(name: &String, index_meta: &IndexMeta) -> Result<(Bytes, Bytes), TypeError> { + let mut key_prefix = Self::key_prefix(CodecType::IndexMeta, &name); + key_prefix.push(BOUND_MIN_TAG); + key_prefix.append(&mut index_meta.id.to_be_bytes().to_vec()); + + Ok((Bytes::from(key_prefix), Bytes::from(bincode::serialize(&index_meta)?))) + } + + pub fn decode_index_meta(bytes: &[u8]) -> Result { + Ok(bincode::deserialize(bytes)?) + } + + /// NonUnique Index: + /// Key: TableName_Index_0_IndexID_0_DataValue1_DataValue2 .. + /// Value: TupleIDs + /// + /// Unique Index: + /// Key: TableName_Index_0_IndexID_0_DataValue + /// Value: TupleIDs + /// + /// Tips: The unique index has only one ColumnID and one corresponding DataValue, + /// so it can be positioned directly. + pub fn encode_index(&self, index: &Index, tuple_ids: &[TupleId]) -> Result<(Bytes, Bytes), TypeError> { + let key = self.encode_index_key(index)?; + + Ok((Bytes::from(key), Bytes::from(bincode::serialize(tuple_ids)?))) + } + + pub fn encode_index_key(&self, index: &Index) -> Result, TypeError> { + let mut key_prefix = Self::key_prefix(CodecType::Index, &self.table.name); + key_prefix.push(BOUND_MIN_TAG); + key_prefix.append(&mut index.id.to_be_bytes().to_vec()); + key_prefix.push(BOUND_MIN_TAG); + + for col_v in &index.column_values { + col_v.to_index_key(&mut key_prefix)?; + } + + Ok(key_prefix) + } + + pub fn decode_index(bytes: &[u8]) -> Result, TypeError> { + Ok(bincode::deserialize(bytes)?) + } + /// Key: TableName_Catalog_0_ColumnName_ColumnId /// Value: ColumnCatalog /// /// Tips: the `0` for bound range - pub fn encode_column(col: &ColumnRef) -> Option<(Bytes, Bytes)> { - let table_name = col.table_name.as_ref()?; + pub fn encode_column(col: &ColumnCatalog) -> Result<(Bytes, Bytes), TypeError> { + let bytes = bincode::serialize(col)?; + let mut key_prefix = Self::key_prefix(CodecType::Column, col.table_name.as_ref().unwrap()); - bincode::serialize(&col).ok() - .map(|bytes| { - let key = format!( - "{}_Catalog_{}_{}_{:0width$}", - table_name, - BOUND_MIN_TAG, - col.name, - col.id, - width = COLUMNS_ID_LEN - ); + key_prefix.push(BOUND_MIN_TAG); + key_prefix.append(&mut col.id.unwrap().to_be_bytes().to_vec()); - (Bytes::from(key.into_bytes()), Bytes::from(bytes)) - }) + Ok((Bytes::from(key_prefix), Bytes::from(bytes))) } - pub fn decode_column(key: &[u8], bytes: &[u8]) -> Option<(TableName, ColumnCatalog)> { - String::from_utf8(key.to_owned()).ok()? - .split("_") - .nth(0) - .and_then(|table_name| { - bincode::deserialize::(bytes).ok() - .and_then(|col| { - Some((Arc::new(table_name.to_string()), col)) - }) - }) + pub fn decode_column(bytes: &[u8]) -> Result<(TableName, ColumnCatalog), TypeError> { + let column = bincode::deserialize::(bytes)?; + + Ok((column.table_name.clone().unwrap(), column)) } /// Key: RootCatalog_0_TableName - /// Value: ColumnCount - pub fn encode_root_table(table_name: &str,column_count:usize) -> Option<(Bytes, Bytes)> { - let key = format!( - "RootCatalog_{}_{}", - BOUND_MIN_TAG, - table_name, - ); + /// Value: TableName + pub fn encode_root_table(table_name: &String) -> Result<(Bytes, Bytes), TypeError> { + let key = Self::encode_root_table_key(table_name); - bincode::serialize(&column_count).ok() - .map(|bytes| { - (Bytes::from(key.into_bytes()), Bytes::from(bytes)) - }) + Ok((Bytes::from(key), Bytes::from(table_name.clone().into_bytes()))) } - // TODO: value is reserved for saving meta-information - pub fn decode_root_table(key: &[u8], bytes: &[u8]) -> Option<(String,usize)> { - String::from_utf8(key.to_owned()).ok()? - .split("_") - .nth(2) - .and_then(|table_name| { - bincode::deserialize::(bytes).ok() - .and_then(|name| { - Some((table_name.to_string(), name)) - }) - }) + pub fn encode_root_table_key(table_name: &String) -> Vec { + Self::key_prefix(CodecType::Root, &table_name) } - pub fn root_table_bound() -> (Vec, Vec) { - let op = |bound_id| { - format!( - "RootCatalog_{}", - bound_id, - ) - }; - - (op(BOUND_MIN_TAG).into_bytes(), op(BOUND_MAX_TAG).into_bytes()) + pub fn decode_root_table(bytes: &[u8]) -> Result { + Ok(String::from_utf8(bytes.to_vec())?) } } @@ -145,11 +244,13 @@ mod tests { use std::collections::BTreeSet; use std::ops::Bound; use std::sync::Arc; + use bytes::Bytes; use itertools::Itertools; use rust_decimal::Decimal; use crate::catalog::{ColumnCatalog, ColumnDesc, TableCatalog}; - use crate::storage::table_codec::{COLUMNS_ID_LEN, TableCodec}; + use crate::storage::table_codec::TableCodec; use crate::types::errors::TypeError; + use crate::types::index::{Index, IndexMeta}; use crate::types::LogicalType; use crate::types::tuple::Tuple; use crate::types::value::DataValue; @@ -159,12 +260,12 @@ mod tests { ColumnCatalog::new( "c1".into(), false, - ColumnDesc::new(LogicalType::Integer, true) + ColumnDesc::new(LogicalType::Integer, true, false) ), ColumnCatalog::new( "c2".into(), false, - ColumnDesc::new(LogicalType::Decimal(None,None), false) + ColumnDesc::new(LogicalType::Decimal(None,None), false, false) ), ]; let table_catalog = TableCatalog::new(Arc::new("t1".to_string()), columns).unwrap(); @@ -184,17 +285,8 @@ mod tests { Arc::new(DataValue::Decimal(Some(Decimal::new(1, 0)))), ] }; + let (_, bytes) = codec.encode_tuple(&tuple)?; - let (key, bytes) = codec.encode_tuple(&tuple)?; - - assert_eq!( - String::from_utf8(key.to_vec()).ok().unwrap(), - format!( - "{}_Data_0_{}", - table_catalog.name, - tuple.id.clone().unwrap().to_primary_key()?, - ) - ); assert_eq!(codec.decode_tuple(&bytes), tuple); Ok(()) @@ -203,40 +295,51 @@ mod tests { #[test] fn test_root_catalog() { let (table_catalog, _) = build_table_codec(); - let (key, bytes) = TableCodec::encode_root_table(&table_catalog.name,2).unwrap(); - - assert_eq!( - String::from_utf8(key.to_vec()).ok().unwrap(), - format!( - "RootCatalog_0_{}", - table_catalog.name, - ) - ); + let (_, bytes) = TableCodec::encode_root_table(&table_catalog.name).unwrap(); - let (table_name, column_count) = TableCodec::decode_root_table(&key, &bytes).unwrap(); + let table_name = TableCodec::decode_root_table(&bytes).unwrap(); assert_eq!(table_name, table_catalog.name.as_str()); - assert_eq!(column_count, 2); + } + + #[test] + fn test_table_codec_index_meta() -> Result<(), TypeError> { + let index_meta = IndexMeta { + id: 0, + column_ids: vec![0], + name: "index_1".to_string(), + is_unique: false, + }; + let (_, bytes) = TableCodec::encode_index_meta(&"T1".to_string(), &index_meta)?; + + assert_eq!(TableCodec::decode_index_meta(&bytes)?, index_meta); + + Ok(()) + } + + #[test] + fn test_table_codec_index() -> Result<(), TypeError> { + let (_, codec) = build_table_codec(); + + let index = Index { + id: 0, + column_values: vec![Arc::new(DataValue::Int32(Some(0)))], + }; + let tuple_ids = vec![Arc::new(DataValue::Int32(Some(0)))]; + let (_, bytes) = codec.encode_index(&index, &tuple_ids)?; + + assert_eq!(TableCodec::decode_index(&bytes)?, tuple_ids); + + Ok(()) } #[test] fn test_table_codec_column() { let (table_catalog, _) = build_table_codec(); let col = table_catalog.all_columns()[0].clone(); - let (key, bytes) = TableCodec::encode_column(&col).unwrap(); - - assert_eq!( - String::from_utf8(key.to_vec()).ok().unwrap(), - format!( - "{}_Catalog_0_{}_{:0width$}", - table_catalog.name, - col.name, - col.id, - width = COLUMNS_ID_LEN - ) - ); + let (_, bytes) = TableCodec::encode_column(&col).unwrap(); - let (table_name, decode_col) = TableCodec::decode_column(&key, &bytes).unwrap(); + let (table_name, decode_col) = TableCodec::decode_column(&bytes).unwrap(); assert_eq!(&decode_col, col.as_ref()); assert_eq!(table_name, table_catalog.name); @@ -245,53 +348,205 @@ mod tests { #[test] fn test_table_codec_column_bound() { let mut set = BTreeSet::new(); - let op = |str: &str| { - str.to_string().into_bytes() + let op = |col_id: usize, table_name: &str| { + let mut col = ColumnCatalog::new( + "".to_string(), + false, + ColumnDesc { + column_datatype: LogicalType::Invalid, + is_primary: false, + is_unique: false, + } + ); + + col.table_name = Some(Arc::new(table_name.to_string())); + col.id = Some(col_id as u32); + + let (key, _) = TableCodec::encode_column(&col).unwrap(); + key }; - set.insert(op("T0_Catalog_0_C0_0")); - set.insert(op("T0_Catalog_0_C1_1")); - set.insert(op("T0_Catalog_0_C2_2")); + set.insert(op(0, "T0")); + set.insert(op(1, "T0")); + set.insert(op(2, "T0")); - set.insert(op("T1_Catalog_0_C0_0")); - set.insert(op("T1_Catalog_0_C1_1")); - set.insert(op("T1_Catalog_0_C2_2")); + set.insert(op(0, "T1")); + set.insert(op(1, "T1")); + set.insert(op(2, "T1")); - set.insert(op("T2_Catalog_0_C0_0")); - set.insert(op("T2_Catalog_0_C1_1")); - set.insert(op("T2_Catalog_0_C2_2")); + set.insert(op(0, "T2")); + set.insert(op(0, "T2")); + set.insert(op(0, "T2")); let (min, max) = TableCodec::columns_bound( &Arc::new("T1".to_string()) ); + let vec = set + .range::, Bound<&Bytes>)>(( + Bound::Included(&Bytes::from(min)), + Bound::Included(&Bytes::from(max)) + )) + .collect_vec(); + + assert_eq!(vec.len(), 3); + + assert_eq!(vec[0], &op(0, "T1")); + assert_eq!(vec[1], &op(1, "T1")); + assert_eq!(vec[2], &op(2, "T1")); + } + + #[test] + fn test_table_codec_index_meta_bound() { + let mut set = BTreeSet::new(); + let op = |index_id: usize, table_name: &str| { + let index_meta = IndexMeta { + id: index_id as u32, + column_ids: vec![], + name: "".to_string(), + is_unique: false, + }; + + let (key, _) = TableCodec::encode_index_meta(&table_name.to_string(), &index_meta).unwrap(); + key + }; + + set.insert(op(0, "T0")); + set.insert(op(1, "T0")); + set.insert(op(2, "T0")); + + set.insert(op(0, "T1")); + set.insert(op(1, "T1")); + set.insert(op(2, "T1")); + + set.insert(op(0, "T2")); + set.insert(op(1, "T2")); + set.insert(op(2, "T2")); + + let (min, max) = TableCodec::index_meta_bound(&"T1".to_string()); + + let vec = set + .range::, Bound<&Bytes>)>(( + Bound::Included(&Bytes::from(min)), + Bound::Included(&Bytes::from(max)) + )) + .collect_vec(); + + assert_eq!(vec.len(), 3); + + assert_eq!(vec[0], &op(0, "T1")); + assert_eq!(vec[1], &op(1, "T1")); + assert_eq!(vec[2], &op(2, "T1")); + } + + #[test] + fn test_table_codec_index_bound() { + let mut set = BTreeSet::new(); + let table_codec = TableCodec { + table: TableCatalog::new(Arc::new("T0".to_string()), vec![]).unwrap(), + }; + + let op = |value: DataValue, index_id: usize, table_codec: &TableCodec| { + let index = Index { + id: index_id as u32, + column_values: vec![Arc::new(value)], + }; + + table_codec.encode_index_key(&index).unwrap() + }; + + set.insert(op(DataValue::Int32(Some(0)), 0, &table_codec)); + set.insert(op(DataValue::Int32(Some(1)), 0, &table_codec)); + set.insert(op(DataValue::Int32(Some(2)), 0, &table_codec)); + + set.insert(op(DataValue::Int32(Some(0)), 1, &table_codec)); + set.insert(op(DataValue::Int32(Some(1)), 1, &table_codec)); + set.insert(op(DataValue::Int32(Some(2)), 1, &table_codec)); + + set.insert(op(DataValue::Int32(Some(0)), 2, &table_codec)); + set.insert(op(DataValue::Int32(Some(1)), 2, &table_codec)); + set.insert(op(DataValue::Int32(Some(2)), 2, &table_codec)); + + println!("{:#?}", set); + + let (min, max) = table_codec.index_bound(&1); + + println!("{:?}", min); + println!("{:?}", max); + let vec = set .range::, (Bound<&Vec>, Bound<&Vec>)>((Bound::Included(&min), Bound::Included(&max))) .collect_vec(); - assert_eq!(String::from_utf8(vec[0].clone()).unwrap(), "T1_Catalog_0_C0_0"); - assert_eq!(String::from_utf8(vec[1].clone()).unwrap(), "T1_Catalog_0_C1_1"); - assert_eq!(String::from_utf8(vec[2].clone()).unwrap(), "T1_Catalog_0_C2_2"); + assert_eq!(vec.len(), 3); + + assert_eq!(vec[0], &op(DataValue::Int32(Some(0)), 1, &table_codec)); + assert_eq!(vec[1], &op(DataValue::Int32(Some(1)), 1, &table_codec)); + assert_eq!(vec[2], &op(DataValue::Int32(Some(2)), 1, &table_codec)); + } + + #[test] + fn test_table_codec_index_all_bound() { + let mut set = BTreeSet::new(); + let op = |value: DataValue, index_id: usize, table_name: &str| { + let index = Index { + id: index_id as u32, + column_values: vec![Arc::new(value)], + }; + + TableCodec { + table: TableCatalog::new(Arc::new(table_name.to_string()), vec![]).unwrap() + }.encode_index_key(&index).unwrap() + }; + + set.insert(op(DataValue::Int32(Some(0)), 0, "T0")); + set.insert(op(DataValue::Int32(Some(1)), 0, "T0")); + set.insert(op(DataValue::Int32(Some(2)), 0, "T0")); + + set.insert(op(DataValue::Int32(Some(0)), 0, "T1")); + set.insert(op(DataValue::Int32(Some(1)), 0, "T1")); + set.insert(op(DataValue::Int32(Some(2)), 0, "T1")); + + set.insert(op(DataValue::Int32(Some(0)), 0, "T2")); + set.insert(op(DataValue::Int32(Some(1)), 0, "T2")); + set.insert(op(DataValue::Int32(Some(2)), 0, "T2")); + + let table_codec = TableCodec { + table: TableCatalog::new(Arc::new("T1".to_string()), vec![]).unwrap(), + }; + let (min, max) = table_codec.all_index_bound(); + + let vec = set + .range::, (Bound<&Vec>, Bound<&Vec>)>((Bound::Included(&min), Bound::Included(&max))) + .collect_vec(); + + assert_eq!(vec.len(), 3); + + assert_eq!(vec[0], &op(DataValue::Int32(Some(0)), 0, "T1")); + assert_eq!(vec[1], &op(DataValue::Int32(Some(1)), 0, "T1")); + assert_eq!(vec[2], &op(DataValue::Int32(Some(2)), 0, "T1")); } #[test] fn test_table_codec_tuple_bound() { let mut set = BTreeSet::new(); - let op = |str: &str| { - str.to_string().into_bytes() + let op = |tuple_id: DataValue, table_name: &str| { + TableCodec { + table: TableCatalog::new(Arc::new(table_name.to_string()), vec![]).unwrap() + }.encode_tuple_key(&Arc::new(tuple_id)).unwrap() }; - set.insert(op("T0_Data_0_0000000000000000000")); - set.insert(op("T0_Data_0_0000000000000000001")); - set.insert(op("T0_Data_0_0000000000000000002")); + set.insert(op(DataValue::Int32(Some(0)), "T0")); + set.insert(op(DataValue::Int32(Some(1)), "T0")); + set.insert(op(DataValue::Int32(Some(2)), "T0")); - set.insert(op("T1_Data_0_0000000000000000000")); - set.insert(op("T1_Data_0_0000000000000000001")); - set.insert(op("T1_Data_0_0000000000000000002")); + set.insert(op(DataValue::Int32(Some(0)), "T1")); + set.insert(op(DataValue::Int32(Some(1)), "T1")); + set.insert(op(DataValue::Int32(Some(2)), "T1")); - set.insert(op("T2_Data_0_0000000000000000000")); - set.insert(op("T2_Data_0_0000000000000000001")); - set.insert(op("T2_Data_0_0000000000000000002")); + set.insert(op(DataValue::Int32(Some(0)), "T2")); + set.insert(op(DataValue::Int32(Some(1)), "T2")); + set.insert(op(DataValue::Int32(Some(2)), "T2")); let table_codec = TableCodec { table: TableCatalog::new(Arc::new("T1".to_string()), vec![]).unwrap(), @@ -302,21 +557,27 @@ mod tests { .range::, (Bound<&Vec>, Bound<&Vec>)>((Bound::Included(&min), Bound::Included(&max))) .collect_vec(); - assert_eq!(String::from_utf8(vec[0].clone()).unwrap(), "T1_Data_0_0000000000000000000"); - assert_eq!(String::from_utf8(vec[1].clone()).unwrap(), "T1_Data_0_0000000000000000001"); - assert_eq!(String::from_utf8(vec[2].clone()).unwrap(), "T1_Data_0_0000000000000000002"); + assert_eq!(vec.len(), 3); + + assert_eq!(vec[0], &op(DataValue::Int32(Some(0)), "T1")); + assert_eq!(vec[1], &op(DataValue::Int32(Some(1)), "T1")); + assert_eq!(vec[2], &op(DataValue::Int32(Some(2)), "T1")); } #[test] fn test_root_codec_name_bound(){ let mut set = BTreeSet::new(); - let op = |str: &str| { - str.to_string().into_bytes() + let op = |table_name: &str| { + TableCodec::encode_root_table_key(&table_name.to_string()) }; - set.insert(op("RootCatalog_0_T0")); - set.insert(op("RootCatalog_0_T1")); - set.insert(op("RootCatalog_0_T2")); + set.insert(b"A".to_vec()); + + set.insert(op("T0")); + set.insert(op("T1")); + set.insert(op("T2")); + + set.insert(b"Z".to_vec()); let (min, max) = TableCodec::root_table_bound(); @@ -324,9 +585,9 @@ mod tests { .range::, (Bound<&Vec>, Bound<&Vec>)>((Bound::Included(&min), Bound::Included(&max))) .collect_vec(); - assert_eq!(String::from_utf8(vec[0].clone()).unwrap(), "RootCatalog_0_T0"); - assert_eq!(String::from_utf8(vec[1].clone()).unwrap(), "RootCatalog_0_T1"); - assert_eq!(String::from_utf8(vec[2].clone()).unwrap(), "RootCatalog_0_T2"); + assert_eq!(vec[0], &op("T0")); + assert_eq!(vec[1], &op("T1")); + assert_eq!(vec[2], &op("T2")); } } \ No newline at end of file diff --git a/src/types/errors.rs b/src/types/errors.rs index c9476f62..9ec307a1 100644 --- a/src/types/errors.rs +++ b/src/types/errors.rs @@ -1,5 +1,6 @@ use std::num::{ParseFloatError, ParseIntError, TryFromIntError}; use std::str::ParseBoolError; +use std::string::FromUtf8Error; use chrono::ParseError; #[derive(thiserror::Error, Debug)] @@ -46,10 +47,22 @@ pub enum TypeError { #[from] ParseError, ), + #[error("bindcode")] + Bincode( + #[source] + #[from] + Box + ), #[error("try from decimal")] TryFromDecimal( #[source] #[from] rust_decimal::Error, ), + #[error("from utf8")] + FromUtf8Error( + #[source] + #[from] + FromUtf8Error, + ) } diff --git a/src/types/index.rs b/src/types/index.rs new file mode 100644 index 00000000..59c89a05 --- /dev/null +++ b/src/types/index.rs @@ -0,0 +1,29 @@ +use std::sync::Arc; +use serde::{Deserialize, Serialize}; +use crate::types::ColumnId; +use crate::types::value::ValueRef; + +pub type IndexId = u32; +pub type IndexMetaRef = Arc; + +#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)] +pub struct IndexMeta { + pub id: IndexId, + pub column_ids: Vec, + pub name: String, + pub is_unique:bool +} + +pub struct Index { + pub id: IndexId, + pub column_values: Vec, +} + +impl Index { + pub fn new(id: IndexId, column_values: Vec) -> Self { + Index { + id, + column_values, + } + } +} \ No newline at end of file diff --git a/src/types/mod.rs b/src/types/mod.rs index de67aed9..b2fbd65e 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -1,41 +1,15 @@ pub mod errors; pub mod value; pub mod tuple; +pub mod index; -use std::sync::atomic::AtomicU32; -use std::sync::atomic::Ordering::{Acquire, Release}; use serde::{Deserialize, Serialize}; -use integer_encoding::FixedInt; use sqlparser::ast::ExactNumberInfo; use strum_macros::AsRefStr; use crate::types::errors::TypeError; -static ID_BUF: AtomicU32 = AtomicU32::new(0); - -pub(crate) struct IdGenerator { } - -impl IdGenerator { - pub(crate) fn encode_to_raw() -> Vec { - ID_BUF - .load(Acquire) - .encode_fixed_vec() - } - - pub(crate) fn from_raw(buf: &[u8]) { - Self::init(u32::decode_fixed(buf)) - } - - pub(crate) fn init(init_value: u32) { - ID_BUF.store(init_value, Release) - } - - pub(crate) fn build() -> u32 { - ID_BUF.fetch_add(1, Release) - } -} - pub type ColumnId = u32; /// Sqlrs type conversion: @@ -78,7 +52,7 @@ impl LogicalType { LogicalType::UBigint => Some(8), LogicalType::Float => Some(4), LogicalType::Double => Some(8), - /// Note: The non-fixed length type's raw_len is None e.g. Varchar and Decimal + /// Note: The non-fixed length type's raw_len is None e.g. Varchar LogicalType::Varchar(_) => None, LogicalType::Decimal(_, _) => Some(16), LogicalType::Date => Some(4), @@ -320,32 +294,3 @@ impl std::fmt::Display for LogicalType { write!(f, "{}", self.as_ref()) } } - -#[cfg(test)] -mod test { - use std::sync::atomic::Ordering::Release; - - use crate::types::{IdGenerator, ID_BUF}; - - /// Tips: 由于IdGenerator为static全局性质生成的id,因此需要单独测试避免其他测试方法干扰 - #[test] - #[ignore] - fn test_id_generator() { - assert_eq!(IdGenerator::build(), 0); - assert_eq!(IdGenerator::build(), 1); - - let buf = IdGenerator::encode_to_raw(); - test_id_generator_reset(); - - assert_eq!(IdGenerator::build(), 0); - - IdGenerator::from_raw(&buf); - - assert_eq!(IdGenerator::build(), 2); - assert_eq!(IdGenerator::build(), 3); - } - - fn test_id_generator_reset() { - ID_BUF.store(0, Release) - } -} diff --git a/src/types/tuple.rs b/src/types/tuple.rs index bf569c38..dc6db97c 100644 --- a/src/types/tuple.rs +++ b/src/types/tuple.rs @@ -18,7 +18,7 @@ pub struct Tuple { impl Tuple { pub fn deserialize_from(columns: Vec, bytes: &[u8]) -> Self { - fn bit_index(bits: u8, i: usize) -> bool { + fn is_none(bits: u8, i: usize) -> bool { bits & (1 << (7 - i)) > 0 } @@ -32,7 +32,7 @@ impl Tuple { for (i, col) in columns.iter().enumerate() { let logic_type = col.datatype(); - if bit_index(bytes[i / BITS_MAX_INDEX], i % BITS_MAX_INDEX) { + if is_none(bytes[i / BITS_MAX_INDEX], i % BITS_MAX_INDEX) { values.push(Arc::new(DataValue::none(logic_type))); } else if let Some(len) = logic_type.raw_len() { /// fixed length (e.g.: int) @@ -125,62 +125,62 @@ mod tests { Arc::new(ColumnCatalog::new( "c1".to_string(), false, - ColumnDesc::new(LogicalType::Integer, true) + ColumnDesc::new(LogicalType::Integer, true, false) )), Arc::new(ColumnCatalog::new( "c2".to_string(), false, - ColumnDesc::new(LogicalType::UInteger, false) + ColumnDesc::new(LogicalType::UInteger, false, false) )), Arc::new(ColumnCatalog::new( "c3".to_string(), false, - ColumnDesc::new(LogicalType::Varchar(Some(2)), false) + ColumnDesc::new(LogicalType::Varchar(Some(2)), false, false) )), Arc::new(ColumnCatalog::new( "c4".to_string(), false, - ColumnDesc::new(LogicalType::Smallint, false) + ColumnDesc::new(LogicalType::Smallint, false, false) )), Arc::new(ColumnCatalog::new( "c5".to_string(), false, - ColumnDesc::new(LogicalType::USmallint, false) + ColumnDesc::new(LogicalType::USmallint, false, false) )), Arc::new(ColumnCatalog::new( "c6".to_string(), false, - ColumnDesc::new(LogicalType::Float, false) + ColumnDesc::new(LogicalType::Float, false, false) )), Arc::new(ColumnCatalog::new( "c7".to_string(), false, - ColumnDesc::new(LogicalType::Double, false) + ColumnDesc::new(LogicalType::Double, false, false) )), Arc::new(ColumnCatalog::new( "c8".to_string(), false, - ColumnDesc::new(LogicalType::Tinyint, false) + ColumnDesc::new(LogicalType::Tinyint, false, false) )), Arc::new(ColumnCatalog::new( "c9".to_string(), false, - ColumnDesc::new(LogicalType::UTinyint, false) + ColumnDesc::new(LogicalType::UTinyint, false, false) )), Arc::new(ColumnCatalog::new( "c10".to_string(), false, - ColumnDesc::new(LogicalType::Boolean, false) + ColumnDesc::new(LogicalType::Boolean, false, false) )), Arc::new(ColumnCatalog::new( "c11".to_string(), false, - ColumnDesc::new(LogicalType::DateTime, false) + ColumnDesc::new(LogicalType::DateTime, false, false) )), Arc::new(ColumnCatalog::new( "c12".to_string(), false, - ColumnDesc::new(LogicalType::Date, false) + ColumnDesc::new(LogicalType::Date, false, false) )), ]; diff --git a/src/types/value.rs b/src/types/value.rs index 94b856f3..5f097d68 100644 --- a/src/types/value.rs +++ b/src/types/value.rs @@ -11,6 +11,7 @@ use lazy_static::lazy_static; use rust_decimal::Decimal; use ordered_float::OrderedFloat; +use serde::{Deserialize, Serialize}; use rust_decimal::prelude::FromPrimitive; use crate::types::errors::TypeError; @@ -25,9 +26,12 @@ lazy_static! { pub const DATE_FMT: &str = "%Y-%m-%d"; pub const DATE_TIME_FMT: &str = "%Y-%m-%d %H:%M:%S"; +const ENCODE_GROUP_SIZE: usize = 8; +const ENCODE_MARKER: u8 = 0xFF; + pub type ValueRef = Arc; -#[derive(Clone)] +#[derive(Clone, Serialize, Deserialize)] pub enum DataValue { Null, Boolean(Option), @@ -145,14 +149,10 @@ impl PartialOrd for DataValue { } } -macro_rules! signed_to_primary_key { - ($ty:ty, $EXPR:expr) => {{ - if $EXPR.is_negative() { - $EXPR ^ (-1 ^ <$ty>::MIN) - } else { - $EXPR - } - }}; +macro_rules! encode_u { + ($b:ident, $u:expr) => { + $b.extend_from_slice(&$u.to_be_bytes()) + }; } impl Eq for DataValue {} @@ -209,12 +209,16 @@ impl DataValue { } (LogicalType::Decimal(full_len, scale_len), DataValue::Decimal(Some(val))) => { if let Some(len) = full_len { - val.mantissa().ilog10() + 1 > *len as u32 - } else if let Some(len) = scale_len { - val.scale() > *len as u32 - } else { - false + if val.mantissa().ilog10() + 1 > *len as u32 { + return Err(TypeError::TooLong) + } + } + if let Some(len) = scale_len { + if val.scale() > *len as u32 { + return Err(TypeError::TooLong) + } } + false } _ => false }; @@ -382,19 +386,126 @@ impl DataValue { } } - pub fn to_primary_key(&self) -> Result { + // EncodeBytes guarantees the encoded value is in ascending order for comparison, + // encoding with the following rule: + // + // [group1][marker1]...[groupN][markerN] + // group is 8 bytes slice which is padding with 0. + // marker is `0xFF - padding 0 count` + // + // For example: + // + // [] -> [0, 0, 0, 0, 0, 0, 0, 0, 247] + // [1, 2, 3] -> [1, 2, 3, 0, 0, 0, 0, 0, 250] + // [1, 2, 3, 0] -> [1, 2, 3, 0, 0, 0, 0, 0, 251] + // [1, 2, 3, 4, 5, 6, 7, 8] -> [1, 2, 3, 4, 5, 6, 7, 8, 255, 0, 0, 0, 0, 0, 0, 0, 0, 247] + // + // Refer: https://github.com/facebook/mysql-5.6/wiki/MyRocks-record-format#memcomparable-format + fn encode_bytes(b: &mut Vec, data: &[u8]) { + let d_len = data.len(); + let realloc_size = (d_len / ENCODE_GROUP_SIZE + 1) * (ENCODE_GROUP_SIZE + 1); + Self::realloc_bytes(b, realloc_size); + + let mut idx = 0; + while idx <= d_len { + let remain = d_len - idx; + let pad_count: usize; + + if remain >= ENCODE_GROUP_SIZE { + b.extend_from_slice(&data[idx..idx + ENCODE_GROUP_SIZE]); + pad_count = 0; + } else { + pad_count = ENCODE_GROUP_SIZE - remain; + b.extend_from_slice(&data[idx..]); + b.extend_from_slice(&vec![0; pad_count]); + } + + b.push(ENCODE_MARKER - pad_count as u8); + idx += ENCODE_GROUP_SIZE; + } + } + + fn realloc_bytes(b: &mut Vec, size: usize) { + let len = b.len(); + + if size > len { + b.reserve(size - len); + b.resize(size, 0); + } + } + + pub fn to_primary_key(&self, b: &mut Vec) -> Result<(), TypeError> { match self { - DataValue::Int8(option) => option.map(|v| format!("{:0width$}", signed_to_primary_key!(i8, v), width = 4)), - DataValue::Int16(option) => option.map(|v| format!("{:0width$}", signed_to_primary_key!(i16, v), width = 6)), - DataValue::Int32(option) => option.map(|v| format!("{:0width$}", signed_to_primary_key!(i32, v), width = 11)), - DataValue::Int64(option) => option.map(|v| format!("{:0width$}", signed_to_primary_key!(i64, v), width = 20)), - DataValue::UInt8(option) => option.map(|v| format!("{:0width$}", v, width = 3)), - DataValue::UInt16(option) => option.map(|v| format!("{:0width$}", v, width = 5)), - DataValue::UInt32(option) => option.map(|v| format!("{:0width$}", v, width = 10)), - DataValue::UInt64(option) => option.map(|v| format!("{:0width$}", v, width = 20)), - DataValue::Utf8(option) => option.clone(), - _ => return Err(TypeError::InvalidType), - }.ok_or(TypeError::NotNull) + DataValue::Int8(Some(v)) => encode_u!(b, *v as u8 ^ 0x80_u8), + DataValue::Int16(Some(v)) => encode_u!(b, *v as u16 ^ 0x8000_u16), + DataValue::Int32(Some(v)) => encode_u!(b, *v as u32 ^ 0x80000000_u32), + DataValue::Int64(Some(v)) => encode_u!(b, *v as u64 ^ 0x8000000000000000_u64), + DataValue::UInt8(Some(v)) => encode_u!(b, v), + DataValue::UInt16(Some(v)) => encode_u!(b, v), + DataValue::UInt32(Some(v)) => encode_u!(b, v), + DataValue::UInt64(Some(v)) => encode_u!(b, v), + DataValue::Utf8(Some(v)) => Self::encode_bytes(b, v.as_bytes()), + value => { + return if value.is_null() { + Err(TypeError::NotNull) + } else { + Err(TypeError::InvalidType) + } + } + } + + Ok(()) + } + + pub fn to_index_key(&self, b: &mut Vec) -> Result<(), TypeError> { + match self { + DataValue::Int8(Some(v)) => encode_u!(b, *v as u8 ^ 0x80_u8), + DataValue::Int16(Some(v)) => encode_u!(b, *v as u16 ^ 0x8000_u16), + DataValue::Int32(Some(v)) | DataValue::Date32(Some(v)) => { + encode_u!(b, *v as u32 ^ 0x80000000_u32) + }, + DataValue::Int64(Some(v)) | DataValue::Date64(Some(v)) => { + encode_u!(b, *v as u64 ^ 0x8000000000000000_u64) + }, + DataValue::UInt8(Some(v)) => encode_u!(b, v), + DataValue::UInt16(Some(v)) => encode_u!(b, v), + DataValue::UInt32(Some(v)) => encode_u!(b, v), + DataValue::UInt64(Some(v)) => encode_u!(b, v), + DataValue::Utf8(Some(v)) => Self::encode_bytes(b, v.as_bytes()), + DataValue::Boolean(Some(v)) => b.push(if *v { b'1' } else { b'0' }), + DataValue::Float32(Some(f)) => { + let mut u = f.to_bits(); + + if *f >= 0_f32 { + u |= 0x80000000_u32; + } else { + u = !u; + } + + encode_u!(b, u); + }, + DataValue::Float64(Some(f)) => { + let mut u = f.to_bits(); + + if *f >= 0_f64 { + u |= 0x8000000000000000_u64; + } else { + u = !u; + } + + encode_u!(b, u); + }, + DataValue::Decimal(Some(_v)) => todo!(), + value => { + return if value.is_null() { + todo!() + } else { + Err(TypeError::InvalidType) + } + }, + } + + Ok(()) } pub fn cast(self, to: &LogicalType) -> Result { @@ -870,42 +981,89 @@ mod test { #[test] fn test_to_primary_key() -> Result<(), TypeError> { - let key_i8_1 = DataValue::Int8(Some(i8::MIN)).to_primary_key()?; - let key_i8_2 = DataValue::Int8(Some(-1_i8)).to_primary_key()?; - let key_i8_3 = DataValue::Int8(Some(i8::MAX)).to_primary_key()?; + let mut key_i8_1 = Vec::new(); + let mut key_i8_2 = Vec::new(); + let mut key_i8_3 = Vec::new(); + + DataValue::Int8(Some(i8::MIN)).to_primary_key(&mut key_i8_1)?; + DataValue::Int8(Some(-1_i8)).to_primary_key(&mut key_i8_2)?; + DataValue::Int8(Some(i8::MAX)).to_primary_key(&mut key_i8_3)?; - println!("{} < {}", key_i8_1, key_i8_2); - println!("{} < {}", key_i8_2, key_i8_3); + println!("{:?} < {:?}", key_i8_1, key_i8_2); + println!("{:?} < {:?}", key_i8_2, key_i8_3); assert!(key_i8_1 < key_i8_2); assert!(key_i8_2 < key_i8_3); - let key_i16_1 = DataValue::Int16(Some(i16::MIN)).to_primary_key()?; - let key_i16_2 = DataValue::Int16(Some(-1_i16)).to_primary_key()?; - let key_i16_3 = DataValue::Int16(Some(i16::MAX)).to_primary_key()?; + let mut key_i16_1 = Vec::new(); + let mut key_i16_2 = Vec::new(); + let mut key_i16_3 = Vec::new(); - println!("{} < {}", key_i16_1, key_i16_2); - println!("{} < {}", key_i16_2, key_i16_3); + DataValue::Int16(Some(i16::MIN)).to_primary_key(&mut key_i16_1)?; + DataValue::Int16(Some(-1_i16)).to_primary_key(&mut key_i16_2)?; + DataValue::Int16(Some(i16::MAX)).to_primary_key(&mut key_i16_3)?; + + println!("{:?} < {:?}", key_i16_1, key_i16_2); + println!("{:?} < {:?}", key_i16_2, key_i16_3); assert!(key_i16_1 < key_i16_2); assert!(key_i16_2 < key_i16_3); - let key_i32_1 = DataValue::Int32(Some(i32::MIN)).to_primary_key()?; - let key_i32_2 = DataValue::Int32(Some(-1_i32)).to_primary_key()?; - let key_i32_3 = DataValue::Int32(Some(i32::MAX)).to_primary_key()?; + let mut key_i32_1 = Vec::new(); + let mut key_i32_2 = Vec::new(); + let mut key_i32_3 = Vec::new(); + + DataValue::Int32(Some(i32::MIN)).to_primary_key(&mut key_i32_1)?; + DataValue::Int32(Some(-1_i32)).to_primary_key(&mut key_i32_2)?; + DataValue::Int32(Some(i32::MAX)).to_primary_key(&mut key_i32_3)?; - println!("{} < {}", key_i32_1, key_i32_2); - println!("{} < {}", key_i32_2, key_i32_3); + println!("{:?} < {:?}", key_i32_1, key_i32_2); + println!("{:?} < {:?}", key_i32_2, key_i32_3); assert!(key_i32_1 < key_i32_2); assert!(key_i32_2 < key_i32_3); - let key_i64_1 = DataValue::Int64(Some(i64::MIN)).to_primary_key()?; - let key_i64_2 = DataValue::Int64(Some(-1_i64)).to_primary_key()?; - let key_i64_3 = DataValue::Int64(Some(i64::MAX)).to_primary_key()?; + let mut key_i64_1 = Vec::new(); + let mut key_i64_2 = Vec::new(); + let mut key_i64_3 = Vec::new(); - println!("{} < {}", key_i64_1, key_i64_2); - println!("{} < {}", key_i64_2, key_i64_3); + DataValue::Int64(Some(i64::MIN)).to_primary_key(&mut key_i64_1)?; + DataValue::Int64(Some(-1_i64)).to_primary_key(&mut key_i64_2)?; + DataValue::Int64(Some(i64::MAX)).to_primary_key(&mut key_i64_3)?; + + println!("{:?} < {:?}", key_i64_1, key_i64_2); + println!("{:?} < {:?}", key_i64_2, key_i64_3); assert!(key_i64_1 < key_i64_2); assert!(key_i64_2 < key_i64_3); Ok(()) } + + #[test] + fn test_to_index_key_f() -> Result<(), TypeError> { + let mut key_f32_1 = Vec::new(); + let mut key_f32_2 = Vec::new(); + let mut key_f32_3 = Vec::new(); + + DataValue::Float32(Some(f32::MIN)).to_index_key(&mut key_f32_1)?; + DataValue::Float32(Some(-1_f32)).to_index_key(&mut key_f32_2)?; + DataValue::Float32(Some(f32::MAX)).to_index_key(&mut key_f32_3)?; + + println!("{:?} < {:?}", key_f32_1, key_f32_2); + println!("{:?} < {:?}", key_f32_2, key_f32_3); + assert!(key_f32_1 < key_f32_2); + assert!(key_f32_2 < key_f32_3); + + let mut key_f64_1 = Vec::new(); + let mut key_f64_2 = Vec::new(); + let mut key_f64_3 = Vec::new(); + + DataValue::Float64(Some(f64::MIN)).to_index_key(&mut key_f64_1)?; + DataValue::Float64(Some(-1_f64)).to_index_key(&mut key_f64_2)?; + DataValue::Float64(Some(f64::MAX)).to_index_key(&mut key_f64_3)?; + + println!("{:?} < {:?}", key_f64_1, key_f64_2); + println!("{:?} < {:?}", key_f64_2, key_f64_3); + assert!(key_f64_1 < key_f64_2); + assert!(key_f64_2 < key_f64_3); + + Ok(()) + } }