From 0270f6b4372754beca1c1e68156cfab73c0b8789 Mon Sep 17 00:00:00 2001 From: Kould <2435992353@qq.com> Date: Wed, 14 Feb 2024 21:55:41 +0800 Subject: [PATCH] feat: support `Union` (#139) * feat: support `Union` * docs: add `DataValue::Tuple` and `DQL::Union` --- README.md | 2 + src/binder/expr.rs | 4 +- src/binder/mod.rs | 14 +++ src/binder/select.rs | 102 ++++++++++++++++-- src/catalog/column.rs | 7 ++ src/errors.rs | 2 +- src/execution/volcano/dql/mod.rs | 1 + src/execution/volcano/dql/union.rs | 45 ++++++++ src/execution/volcano/mod.rs | 7 ++ src/expression/mod.rs | 7 +- .../rule/normalization/column_pruning.rs | 6 +- .../rule/normalization/expression_remapper.rs | 3 +- src/planner/mod.rs | 12 ++- src/planner/operator/mod.rs | 44 +++++++- src/planner/operator/union.rs | 51 +++++++++ src/types/tuple.rs | 2 +- src/types/tuple_builder.rs | 10 +- tests/slt/sql_2016/E071_01.slt | 24 ++--- tests/slt/sql_2016/E071_02.slt | 13 ++- tests/slt/union | 55 ++++++++++ 20 files changed, 360 insertions(+), 51 deletions(-) create mode 100644 src/execution/volcano/dql/union.rs create mode 100644 src/planner/operator/union.rs create mode 100644 tests/slt/union diff --git a/README.md b/README.md index af9a3ed3..a5224a82 100755 --- a/README.md +++ b/README.md @@ -139,6 +139,7 @@ implement_from_tuple!( - [x] Show Tables - [x] Explain - [x] Describe + - [x] Union - DML - [x] Insert - [x] Insert Overwrite @@ -162,6 +163,7 @@ implement_from_tuple!( - Varchar - Date - DateTime + - Tuple ## Roadmap - SQL 2016 diff --git a/src/binder/expr.rs b/src/binder/expr.rs index 544686ae..51c0749d 100644 --- a/src/binder/expr.rs +++ b/src/binder/expr.rs @@ -96,8 +96,8 @@ impl<'a, T: Transaction> Binder<'a, T> { if sub_query_schema.len() != 1 { return Err(DatabaseError::MisMatch( - "expects only one expression to be returned".to_string(), - "the expression returned by the subquery".to_string(), + "expects only one expression to be returned", + "the expression returned by the subquery", )); } let column = sub_query_schema[0].clone(); diff --git a/src/binder/mod.rs b/src/binder/mod.rs index 33854dbc..59c34bb9 100644 --- a/src/binder/mod.rs +++ b/src/binder/mod.rs @@ -240,6 +240,20 @@ impl<'a, T: Transaction> Binder<'a, T> { }; Ok(plan) } + + pub fn bind_set_expr(&mut self, set_expr: &SetExpr) -> Result { + match set_expr { + SetExpr::Select(select) => self.bind_select(select, &[]), + SetExpr::Query(query) => self.bind_query(query), + SetExpr::SetOperation { + op, + set_quantifier, + left, + right, + } => self.bind_set_operation(op, set_quantifier, left, right), + _ => todo!(), + } + } } fn lower_ident(ident: &Ident) -> String { diff --git a/src/binder/select.rs b/src/binder/select.rs index 9a751b02..477a7651 100644 --- a/src/binder/select.rs +++ b/src/binder/select.rs @@ -22,14 +22,15 @@ use crate::execution::volcano::dql::join::joins_nullable; use crate::expression::{AliasType, BinaryOperator}; use crate::planner::operator::join::JoinCondition; use crate::planner::operator::sort::{SortField, SortOperator}; +use crate::planner::operator::union::UnionOperator; use crate::planner::LogicalPlan; use crate::storage::Transaction; -use crate::types::tuple::Schema; +use crate::types::tuple::{Schema, SchemaRef}; use crate::types::LogicalType; use itertools::Itertools; use sqlparser::ast::{ Distinct, Expr, Ident, Join, JoinConstraint, JoinOperator, Offset, OrderByExpr, Query, Select, - SelectItem, SetExpr, TableAlias, TableFactor, TableWithJoins, + SelectItem, SetExpr, SetOperator, SetQuantifier, TableAlias, TableFactor, TableWithJoins, }; impl<'a, T: Transaction> Binder<'a, T> { @@ -41,6 +42,12 @@ impl<'a, T: Transaction> Binder<'a, T> { let mut plan = match query.body.borrow() { SetExpr::Select(select) => self.bind_select(select, &query.order_by), SetExpr::Query(query) => self.bind_query(query), + SetExpr::SetOperation { + op, + set_quantifier, + left, + right, + } => self.bind_set_operation(op, set_quantifier, left, right), _ => unimplemented!(), }?; @@ -54,7 +61,7 @@ impl<'a, T: Transaction> Binder<'a, T> { Ok(plan) } - fn bind_select( + pub(crate) fn bind_select( &mut self, select: &Select, orderby: &[OrderByExpr], @@ -107,6 +114,90 @@ impl<'a, T: Transaction> Binder<'a, T> { Ok(plan) } + pub(crate) fn bind_set_operation( + &mut self, + op: &SetOperator, + set_quantifier: &SetQuantifier, + left: &SetExpr, + right: &SetExpr, + ) -> Result { + let is_all = match set_quantifier { + SetQuantifier::All => true, + SetQuantifier::Distinct | SetQuantifier::None => false, + }; + let mut left_plan = self.bind_set_expr(left)?; + let mut right_plan = self.bind_set_expr(right)?; + let fn_eq = |left_schema: &SchemaRef, right_schema: &SchemaRef| { + let left_len = left_schema.len(); + + if left_len != right_schema.len() { + return false; + } + for i in 0..left_len { + if left_schema[i].datatype() != right_schema[i].datatype() { + return false; + } + } + true + }; + match (op, is_all) { + (SetOperator::Union, true) => { + let left_schema = left_plan.output_schema(); + let right_schema = right_plan.output_schema(); + + if !fn_eq(left_schema, right_schema) { + return Err(DatabaseError::MisMatch( + "the output types on the left", + "the output types on the right", + )); + } + Ok(UnionOperator::build( + left_schema.clone(), + right_schema.clone(), + left_plan, + right_plan, + )) + } + (SetOperator::Union, false) => { + let left_schema = left_plan.output_schema(); + let right_schema = right_plan.output_schema(); + + if !fn_eq(left_schema, right_schema) { + return Err(DatabaseError::MisMatch( + "the output types on the left", + "the output types on the right", + )); + } + let union_op = Operator::Union(UnionOperator { + left_schema_ref: left_schema.clone(), + right_schema_ref: right_schema.clone(), + }); + let distinct_exprs = left_schema + .iter() + .cloned() + .map(ScalarExpression::ColumnRef) + .collect_vec(); + + Ok(self.bind_distinct( + LogicalPlan::new(union_op, vec![left_plan, right_plan]), + distinct_exprs, + )) + } + (SetOperator::Intersect, true) => { + todo!() + } + (SetOperator::Intersect, false) => { + todo!() + } + (SetOperator::Except, true) => { + todo!() + } + (SetOperator::Except, false) => { + todo!() + } + } + } + pub(crate) fn bind_table_ref( &mut self, from: &[TableWithJoins], @@ -192,10 +283,7 @@ impl<'a, T: Transaction> Binder<'a, T> { .ok_or(DatabaseError::TableNotFound)?; if alias_column.len() != table.columns_len() { - return Err(DatabaseError::MisMatch( - "Alias".to_string(), - "Columns".to_string(), - )); + return Err(DatabaseError::MisMatch("alias", "columns")); } let aliases_with_columns = alias_column .iter() diff --git a/src/catalog/column.rs b/src/catalog/column.rs index ee5a6e04..eb69ecee 100644 --- a/src/catalog/column.rs +++ b/src/catalog/column.rs @@ -63,6 +63,13 @@ impl ColumnCatalog { &self.summary.name } + pub fn full_name(&self) -> String { + if let Some(table_name) = self.table_name() { + return format!("{}.{}", table_name, self.name()); + } + self.name().to_string() + } + pub fn table_name(&self) -> Option<&TableName> { self.summary.table_name.as_ref() } diff --git a/src/errors.rs b/src/errors.rs index 9d895faa..afda9b63 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -80,7 +80,7 @@ pub enum DatabaseError { FromUtf8Error, ), #[error("{0} and {1} do not match")] - MisMatch(String, String), + MisMatch(&'static str, &'static str), #[error("io: {0}")] IO( #[source] diff --git a/src/execution/volcano/dql/mod.rs b/src/execution/volcano/dql/mod.rs index f4505959..b413dece 100644 --- a/src/execution/volcano/dql/mod.rs +++ b/src/execution/volcano/dql/mod.rs @@ -10,6 +10,7 @@ pub(crate) mod projection; pub(crate) mod seq_scan; pub(crate) mod show_table; pub(crate) mod sort; +pub(crate) mod union; pub(crate) mod values; #[cfg(test)] diff --git a/src/execution/volcano/dql/union.rs b/src/execution/volcano/dql/union.rs new file mode 100644 index 00000000..6f6cf763 --- /dev/null +++ b/src/execution/volcano/dql/union.rs @@ -0,0 +1,45 @@ +use crate::errors::DatabaseError; +use crate::execution::volcano::{build_read, BoxedExecutor, ReadExecutor}; +use crate::planner::LogicalPlan; +use crate::storage::Transaction; +use crate::types::tuple::Tuple; +use futures_async_stream::try_stream; + +pub struct Union { + left_input: LogicalPlan, + right_input: LogicalPlan, +} + +impl From<(LogicalPlan, LogicalPlan)> for Union { + fn from((left_input, right_input): (LogicalPlan, LogicalPlan)) -> Self { + Union { + left_input, + right_input, + } + } +} + +impl ReadExecutor for Union { + fn execute(self, transaction: &T) -> BoxedExecutor { + self._execute(transaction) + } +} + +impl Union { + #[try_stream(boxed, ok = Tuple, error = DatabaseError)] + pub async fn _execute(self, transaction: &T) { + let Union { + left_input, + right_input, + } = self; + + #[for_await] + for tuple in build_read(left_input, transaction) { + yield tuple?; + } + #[for_await] + for tuple in build_read(right_input, transaction) { + yield tuple?; + } + } +} diff --git a/src/execution/volcano/mod.rs b/src/execution/volcano/mod.rs index 89c6e9ed..c585bf37 100644 --- a/src/execution/volcano/mod.rs +++ b/src/execution/volcano/mod.rs @@ -25,6 +25,7 @@ use crate::execution::volcano::dql::projection::Projection; use crate::execution::volcano::dql::seq_scan::SeqScan; use crate::execution::volcano::dql::show_table::ShowTables; use crate::execution::volcano::dql::sort::Sort; +use crate::execution::volcano::dql::union::Union; use crate::execution::volcano::dql::values::Values; use crate::planner::operator::{Operator, PhysicalOption}; use crate::planner::LogicalPlan; @@ -109,6 +110,12 @@ pub fn build_read(plan: LogicalPlan, transaction: &T) -> BoxedEx Explain::from(input).execute(transaction) } Operator::Describe(op) => Describe::from(op).execute(transaction), + Operator::Union(_) => { + let left_input = childrens.remove(0); + let right_input = childrens.remove(0); + + Union::from((left_input, right_input)).execute(transaction) + } _ => unreachable!(), } } diff --git a/src/expression/mod.rs b/src/expression/mod.rs index e8a41115..13a61402 100644 --- a/src/expression/mod.rs +++ b/src/expression/mod.rs @@ -364,12 +364,7 @@ impl ScalarExpression { pub fn output_name(&self) -> String { match self { ScalarExpression::Constant(value) => format!("{}", value), - ScalarExpression::ColumnRef(col) => { - if let Some(table_name) = col.table_name() { - return format!("{}.{}", table_name, col.name()); - } - col.name().to_string() - } + ScalarExpression::ColumnRef(col) => col.full_name(), ScalarExpression::Alias { alias, expr } => match alias { AliasType::Name(alias) => alias.to_string(), AliasType::Expr(alias_expr) => { diff --git a/src/optimizer/rule/normalization/column_pruning.rs b/src/optimizer/rule/normalization/column_pruning.rs index c052b7c3..d3503775 100644 --- a/src/optimizer/rule/normalization/column_pruning.rs +++ b/src/optimizer/rule/normalization/column_pruning.rs @@ -102,7 +102,11 @@ impl ColumnPruning { .retain(|(_, column)| column_references.contains(column.summary())); } } - Operator::Sort(_) | Operator::Limit(_) | Operator::Join(_) | Operator::Filter(_) => { + Operator::Sort(_) + | Operator::Limit(_) + | Operator::Join(_) + | Operator::Filter(_) + | Operator::Union(_) => { let temp_columns = operator.referenced_columns(false); // why? let mut column_references = column_references; diff --git a/src/optimizer/rule/normalization/expression_remapper.rs b/src/optimizer/rule/normalization/expression_remapper.rs index 6d97730a..a9f759fa 100644 --- a/src/optimizer/rule/normalization/expression_remapper.rs +++ b/src/optimizer/rule/normalization/expression_remapper.rs @@ -92,7 +92,8 @@ impl ExpressionRemapper { | Operator::DropTable(_) | Operator::Truncate(_) | Operator::CopyFromFile(_) - | Operator::CopyToFile(_) => (), + | Operator::CopyToFile(_) + | Operator::Union(_) => (), } if let Some(exprs) = operator.output_exprs() { *output_exprs = exprs; diff --git a/src/planner/mod.rs b/src/planner/mod.rs index e0da3152..18cd62ab 100644 --- a/src/planner/mod.rs +++ b/src/planner/mod.rs @@ -1,6 +1,8 @@ pub mod operator; use crate::catalog::TableName; +use crate::planner::operator::union::UnionOperator; +use crate::planner::operator::values::ValuesOperator; use crate::planner::operator::{Operator, PhysicalOption}; use crate::types::tuple::SchemaRef; use itertools::Itertools; @@ -83,7 +85,15 @@ impl LogicalPlan { .collect_vec(); Arc::new(out_columns) } - Operator::Values(op) => op.schema_ref.clone(), + Operator::Values(ValuesOperator { schema_ref, .. }) => schema_ref.clone(), + Operator::Union(UnionOperator { + left_schema_ref, + right_schema_ref, + }) => { + let mut schema = Vec::clone(left_schema_ref); + schema.extend_from_slice(right_schema_ref.as_slice()); + Arc::new(schema) + } Operator::Dummy | Operator::Show | Operator::Explain diff --git a/src/planner/operator/mod.rs b/src/planner/operator/mod.rs index 03855fde..976fc165 100644 --- a/src/planner/operator/mod.rs +++ b/src/planner/operator/mod.rs @@ -15,6 +15,7 @@ pub mod project; pub mod scan; pub mod sort; pub mod truncate; +pub mod union; pub mod update; pub mod values; @@ -31,6 +32,7 @@ use crate::planner::operator::drop_table::DropTableOperator; use crate::planner::operator::insert::InsertOperator; use crate::planner::operator::join::JoinCondition; use crate::planner::operator::truncate::TruncateOperator; +use crate::planner::operator::union::UnionOperator; use crate::planner::operator::update::UpdateOperator; use crate::planner::operator::values::ValuesOperator; use crate::types::index::IndexInfo; @@ -59,6 +61,7 @@ pub enum Operator { Show, Explain, Describe(DescribeOperator), + Union(UnionOperator), // DML Insert(InsertOperator), Update(UpdateOperator), @@ -124,13 +127,24 @@ impl Operator { .collect_vec(), ), Operator::Sort(_) | Operator::Limit(_) => None, - Operator::Values(op) => Some( - op.schema_ref + Operator::Values(ValuesOperator { schema_ref, .. }) => Some( + schema_ref .iter() .cloned() .map(ScalarExpression::ColumnRef) .collect_vec(), ), + Operator::Union(UnionOperator { + left_schema_ref, + right_schema_ref, + }) => Some( + left_schema_ref + .iter() + .chain(right_schema_ref.iter()) + .cloned() + .map(ScalarExpression::ColumnRef) + .collect_vec(), + ), Operator::Show | Operator::Explain | Operator::Describe(_) @@ -189,10 +203,31 @@ impl Operator { .map(|field| &field.expr) .flat_map(|expr| expr.referenced_columns(only_column_ref)) .collect_vec(), - Operator::Values(op) => Vec::clone(&op.schema_ref), + Operator::Values(ValuesOperator { schema_ref, .. }) => Vec::clone(schema_ref), + Operator::Union(UnionOperator { + left_schema_ref, + right_schema_ref, + }) => { + let mut schema = Vec::clone(left_schema_ref); + schema.extend_from_slice(right_schema_ref.as_slice()); + schema + } Operator::Analyze(op) => op.columns.clone(), Operator::Delete(op) => vec![op.primary_key_column.clone()], - _ => vec![], + Operator::Dummy + | Operator::Limit(_) + | Operator::Show + | Operator::Explain + | Operator::Describe(_) + | Operator::Insert(_) + | Operator::Update(_) + | Operator::AddColumn(_) + | Operator::DropColumn(_) + | Operator::CreateTable(_) + | Operator::DropTable(_) + | Operator::Truncate(_) + | Operator::CopyFromFile(_) + | Operator::CopyToFile(_) => vec![], } } } @@ -223,6 +258,7 @@ impl fmt::Display for Operator { Operator::Truncate(op) => write!(f, "{}", op), Operator::CopyFromFile(op) => write!(f, "{}", op), Operator::CopyToFile(_) => todo!(), + Operator::Union(op) => write!(f, "{}", op), } } } diff --git a/src/planner/operator/union.rs b/src/planner/operator/union.rs new file mode 100644 index 00000000..071fa05f --- /dev/null +++ b/src/planner/operator/union.rs @@ -0,0 +1,51 @@ +use crate::planner::operator::Operator; +use crate::planner::LogicalPlan; +use crate::types::tuple::SchemaRef; +use itertools::Itertools; +use std::fmt; +use std::fmt::Formatter; +#[derive(Debug, PartialEq, Eq, Clone, Hash)] +pub struct UnionOperator { + pub left_schema_ref: SchemaRef, + pub right_schema_ref: SchemaRef, +} + +impl UnionOperator { + pub fn build( + left_schema_ref: SchemaRef, + right_schema_ref: SchemaRef, + left_plan: LogicalPlan, + right_plan: LogicalPlan, + ) -> LogicalPlan { + LogicalPlan::new( + Operator::Union(UnionOperator { + left_schema_ref, + right_schema_ref, + }), + vec![left_plan, right_plan], + ) + } +} + +impl fmt::Display for UnionOperator { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + let left_columns = self + .left_schema_ref + .iter() + .map(|column| column.name().to_string()) + .join(", "); + let right_columns = self + .right_schema_ref + .iter() + .map(|column| column.name().to_string()) + .join(", "); + + write!( + f, + "Union left: [{}], right: [{}]", + left_columns, right_columns + )?; + + Ok(()) + } +} diff --git a/src/types/tuple.rs b/src/types/tuple.rs index 692270e8..7d748c79 100644 --- a/src/types/tuple.rs +++ b/src/types/tuple.rs @@ -146,7 +146,7 @@ pub fn create_table(tuples: &[Tuple]) -> Table { let mut header = Vec::new(); for col in tuples[0].schema_ref.iter() { - header.push(Cell::new(col.name().to_string())); + header.push(Cell::new(col.full_name())); } table.set_header(header); diff --git a/src/types/tuple_builder.rs b/src/types/tuple_builder.rs index d3032e07..5bbb225a 100644 --- a/src/types/tuple_builder.rs +++ b/src/types/tuple_builder.rs @@ -30,10 +30,7 @@ impl<'a> TupleBuilder<'a> { values: Vec, ) -> Result { if values.len() != self.schema_ref.len() { - return Err(DatabaseError::MisMatch( - "types".to_string(), - "values".to_string(), - )); + return Err(DatabaseError::MisMatch("types", "values")); } Ok(Tuple { @@ -61,10 +58,7 @@ impl<'a> TupleBuilder<'a> { values.push(data_value); } if values.len() != self.schema_ref.len() { - return Err(DatabaseError::MisMatch( - "types".to_string(), - "values".to_string(), - )); + return Err(DatabaseError::MisMatch("types", "values")); } Ok(Tuple { diff --git a/tests/slt/sql_2016/E071_01.slt b/tests/slt/sql_2016/E071_01.slt index 9bc61efb..a2928ac7 100644 --- a/tests/slt/sql_2016/E071_01.slt +++ b/tests/slt/sql_2016/E071_01.slt @@ -1,19 +1,19 @@ # E071-01: UNION DISTINCT table operator -# TODO: Support `UNION\UNION DISTINCT` +statement ok +CREATE TABLE TABLE_E071_01_01_011 ( ID INT PRIMARY KEY, A INT ); -# statement ok -# CREATE TABLE TABLE_E071_01_01_011 ( ID INT PRIMARY KEY, A INT ); +statement ok +CREATE TABLE TABLE_E071_01_01_012 ( ID INT PRIMARY KEY, B INT ); -# statement ok -# CREATE TABLE TABLE_E071_01_01_012 ( ID INT PRIMARY KEY, B INT ); +query I +SELECT A FROM TABLE_E071_01_01_011 UNION DISTINCT SELECT B FROM TABLE_E071_01_01_012 -# SELECT A FROM TABLE_E071_01_01_011 UNION DISTINCT SELECT B FROM TABLE_E071_01_01_012 +statement ok +CREATE TABLE TABLE_E071_01_01_021 ( ID INT PRIMARY KEY, A INT ); -# statement ok -# CREATE TABLE TABLE_E071_01_01_021 ( ID INT PRIMARY KEY, A INT ); +statement ok +CREATE TABLE TABLE_E071_01_01_022 ( ID INT PRIMARY KEY, B INT ); -# statement ok -# CREATE TABLE TABLE_E071_01_01_022 ( ID INT PRIMARY KEY, B INT ); - -# SELECT A FROM TABLE_E071_01_01_021 UNION SELECT B FROM TABLE_E071_01_01_022 +query I +SELECT A FROM TABLE_E071_01_01_021 UNION SELECT B FROM TABLE_E071_01_01_022 diff --git a/tests/slt/sql_2016/E071_02.slt b/tests/slt/sql_2016/E071_02.slt index d53cf3ea..be33cf81 100644 --- a/tests/slt/sql_2016/E071_02.slt +++ b/tests/slt/sql_2016/E071_02.slt @@ -1,11 +1,10 @@ # E071-02: UNION ALL table operator -# TODO: Support `UNION ALL` +statement ok +CREATE TABLE TABLE_E071_02_01_011 ( ID INT PRIMARY KEY, A INT ); -# statement ok -# CREATE TABLE TABLE_E071_02_01_011 ( ID INT PRIMARY KEY, A INT ); +statement ok +CREATE TABLE TABLE_E071_02_01_012 ( ID INT PRIMARY KEY, B INT ); -# statement ok -# CREATE TABLE TABLE_E071_02_01_012 ( ID INT PRIMARY KEY, B INT ); - -# SELECT A FROM TABLE_E071_02_01_011 UNION ALL SELECT B FROM TABLE_E071_02_01_012 +query I +SELECT A FROM TABLE_E071_02_01_011 UNION ALL SELECT B FROM TABLE_E071_02_01_012 diff --git a/tests/slt/union b/tests/slt/union new file mode 100644 index 00000000..da561b6b --- /dev/null +++ b/tests/slt/union @@ -0,0 +1,55 @@ +query I rowsort +select 1 union select 2 +---- +1 +2 + +query I rowsort +select 1 union select 2 + 1 +---- +1 +3 + +query I rowsort +select 1 union select 1 +---- +1 + +query I rowsort +select 1 union all select 1 +---- +1 +1 + +query T +select (1, 2) union select (2, 1) union select (1, 2) +---- +(1, 2) +(2, 1) + +statement ok +create table t1(id int primary key, v1 int unique) + +statement ok +insert into t1 values (1,1), (2,2), (3,3), (4,4) + +query I +select v1 from t1 union select * from t1 +1 +2 +3 +4 + +query I rowsort +select v1 from t1 union all select * from t1 +1 +1 +2 +2 +3 +3 +4 +4 + +statement ok +drop t1 \ No newline at end of file