From b326b56861179a06ea10149b76ad337548a7f104 Mon Sep 17 00:00:00 2001 From: Veeupup Date: Tue, 17 May 2022 15:36:55 +0800 Subject: [PATCH 01/13] add join test data Signed-off-by: Veeupup --- README.md | 2 +- data/department.csv | 4 ++++ data/employee.csv | 6 ++++++ data/rank.csv | 4 ++++ test_data.csv => data/test_data.csv | 0 src/datasource/csv.rs | 4 ++-- src/main.rs | 2 +- src/physical_plan/limit.rs | 2 +- src/physical_plan/projection.rs | 2 +- src/physical_plan/scan.rs | 2 +- src/physical_plan/selection.rs | 2 +- src/planner/mod.rs | 2 +- src/sql/planner.rs | 2 +- 13 files changed, 24 insertions(+), 10 deletions(-) create mode 100644 data/department.csv create mode 100644 data/employee.csv create mode 100644 data/rank.csv rename test_data.csv => data/test_data.csv (100%) diff --git a/README.md b/README.md index 9e796cf..0578c82 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ use naive_db::Result; fn main() -> Result<()> { let mut db = NaiveDB::default(); - db.create_csv_table("t1", "test_data.csv")?; + db.create_csv_table("t1", "data/test_data.csv")?; let ret = db.run_sql("select id, name, age + 100 from t1 where id < 6 limit 3")?; diff --git a/data/department.csv b/data/department.csv new file mode 100644 index 0000000..0002fc6 --- /dev/null +++ b/data/department.csv @@ -0,0 +1,4 @@ +id, info +1, IT +2, Marketing +3, Human Resource \ No newline at end of file diff --git a/data/employee.csv b/data/employee.csv new file mode 100644 index 0000000..15005d8 --- /dev/null +++ b/data/employee.csv @@ -0,0 +1,6 @@ +id, name, department_id, rank +1, vee, 1, 0 +2, lynne, 1, 0 +3, Alex, 2, 1 +4, jack, 2, 1 +5, mike, 3, 2 \ No newline at end of file diff --git a/data/rank.csv b/data/rank.csv new file mode 100644 index 0000000..3edf643 --- /dev/null +++ b/data/rank.csv @@ -0,0 +1,4 @@ +id, rank_name +0, master +1, diamond +2, grandmaster \ No newline at end of file diff --git a/test_data.csv b/data/test_data.csv similarity index 100% rename from test_data.csv rename to data/test_data.csv diff --git a/src/datasource/csv.rs b/src/datasource/csv.rs index 8bc2814..f81e363 100644 --- a/src/datasource/csv.rs +++ b/src/datasource/csv.rs @@ -103,7 +103,7 @@ mod tests { #[test] fn test_infer_schema() -> Result<()> { - let table = CsvTable::try_create("test_data.csv", CsvConfig::default())?; + let table = CsvTable::try_create("data/test_data.csv", CsvConfig::default())?; let schema = table.schema(); let excepted = Arc::new(Schema::new(vec![ @@ -127,7 +127,7 @@ mod tests { #[test] fn test_read_from_csv() -> Result<()> { - let table = CsvTable::try_create("test_data.csv", CsvConfig::default())?; + let table = CsvTable::try_create("data/test_data.csv", CsvConfig::default())?; let batches = table.scan(None)?; diff --git a/src/main.rs b/src/main.rs index c15f9cd..69c51da 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,7 +5,7 @@ use naive_db::Result; fn main() -> Result<()> { let mut db = NaiveDB::default(); - db.create_csv_table("t1", "test_data.csv")?; + db.create_csv_table("t1", "data/test_data.csv")?; let ret = db.run_sql("select id, name, age + 100 from t1 where id < 6 limit 3")?; diff --git a/src/physical_plan/limit.rs b/src/physical_plan/limit.rs index 0a5cca5..6ad6db7 100644 --- a/src/physical_plan/limit.rs +++ b/src/physical_plan/limit.rs @@ -64,7 +64,7 @@ mod tests { #[test] fn test_physical_scan() -> Result<()> { - let source = CsvTable::try_create("test_data.csv", CsvConfig::default())?; + let source = CsvTable::try_create("data/test_data.csv", CsvConfig::default())?; let scan_plan = ScanPlan::create(source, None); let limit_plan = PhysicalLimitPlan::create(scan_plan, 2); diff --git a/src/physical_plan/projection.rs b/src/physical_plan/projection.rs index 9bdcdbb..6e3a728 100644 --- a/src/physical_plan/projection.rs +++ b/src/physical_plan/projection.rs @@ -78,7 +78,7 @@ mod tests { #[test] fn test_projection() -> Result<()> { - let source = CsvTable::try_create("test_data.csv", CsvConfig::default())?; + let source = CsvTable::try_create("data/test_data.csv", CsvConfig::default())?; let schema = Arc::new(Schema::new(vec![ source.schema().field(0).clone(), source.schema().field(1).clone(), diff --git a/src/physical_plan/scan.rs b/src/physical_plan/scan.rs index e2e81e8..2b68600 100644 --- a/src/physical_plan/scan.rs +++ b/src/physical_plan/scan.rs @@ -48,7 +48,7 @@ mod tests { #[test] fn test_physical_scan() -> Result<()> { - let source = CsvTable::try_create("test_data.csv", CsvConfig::default())?; + let source = CsvTable::try_create("data/test_data.csv", CsvConfig::default())?; let scan_plan = ScanPlan::create(source, None); diff --git a/src/physical_plan/selection.rs b/src/physical_plan/selection.rs index 664113e..d7f7766 100644 --- a/src/physical_plan/selection.rs +++ b/src/physical_plan/selection.rs @@ -125,7 +125,7 @@ mod tests { #[test] fn test_selection() -> Result<()> { - let source = CsvTable::try_create("test_data.csv", CsvConfig::default())?; + let source = CsvTable::try_create("data/test_data.csv", CsvConfig::default())?; let schema = Arc::new(Schema::new(vec![ source.schema().field(0).clone(), source.schema().field(1).clone(), diff --git a/src/planner/mod.rs b/src/planner/mod.rs index 6dab6fb..6828578 100644 --- a/src/planner/mod.rs +++ b/src/planner/mod.rs @@ -117,7 +117,7 @@ mod tests { fn test_scan_projection() -> Result<()> { // construct let mut catalog = Catalog::default(); - catalog.add_csv_table("t1", "test_data.csv")?; + catalog.add_csv_table("t1", "data/test_data.csv")?; let source = catalog.get_table_df("t1")?; let exprs = vec![ LogicalExpr::column("id".to_string()), diff --git a/src/sql/planner.rs b/src/sql/planner.rs index 491ad72..d1ae473 100644 --- a/src/sql/planner.rs +++ b/src/sql/planner.rs @@ -211,7 +211,7 @@ mod tests { #[test] fn select_with_projection_filter() -> Result<()> { let mut db = NaiveDB::default(); - db.create_csv_table("t1", "test_data.csv")?; + db.create_csv_table("t1", "data/test_data.csv")?; { let ret = db.run_sql("select id, name from t1")?; From 114aec98ad40b5304769ed4c013fdbb4c61fec67 Mon Sep 17 00:00:00 2001 From: Veeupup Date: Tue, 17 May 2022 17:06:24 +0800 Subject: [PATCH 02/13] column add table name Signed-off-by: Veeupup --- src/logical_plan/dataframe.rs | 12 ++++++------ src/logical_plan/expression.rs | 13 ++++++++----- src/logical_plan/plan.rs | 12 ++++++------ src/planner/mod.rs | 8 ++++---- src/sql/planner.rs | 2 +- 5 files changed, 25 insertions(+), 22 deletions(-) diff --git a/src/logical_plan/dataframe.rs b/src/logical_plan/dataframe.rs index 7230e68..e236da3 100644 --- a/src/logical_plan/dataframe.rs +++ b/src/logical_plan/dataframe.rs @@ -107,18 +107,18 @@ mod tests { let _plan = catalog .get_table_df("empty")? .filter(LogicalExpr::BinaryExpr(BinaryExpr { - left: Box::new(LogicalExpr::column("state".to_string())), + left: Box::new(LogicalExpr::column(None, "state".to_string())), op: Operator::Eq, right: Box::new(LogicalExpr::Literal(ScalarValue::Utf8(Some( "CO".to_string(), )))), })) .project(vec![ - LogicalExpr::column("id".to_string()), - LogicalExpr::column("first_name".to_string()), - LogicalExpr::column("last_name".to_string()), - LogicalExpr::column("state".to_string()), - LogicalExpr::column("salary".to_string()), + LogicalExpr::column(None, "id".to_string()), + LogicalExpr::column(None, "first_name".to_string()), + LogicalExpr::column(None, "last_name".to_string()), + LogicalExpr::column(None, "state".to_string()), + LogicalExpr::column(None, "salary".to_string()), ]); Ok(()) diff --git a/src/logical_plan/expression.rs b/src/logical_plan/expression.rs index 8a8d67a..70142d6 100644 --- a/src/logical_plan/expression.rs +++ b/src/logical_plan/expression.rs @@ -47,8 +47,8 @@ pub enum LogicalExpr { } impl LogicalExpr { - pub fn column(name: String) -> LogicalExpr { - LogicalExpr::Column(Column(name)) + pub fn column(table: Option, name: String) -> LogicalExpr { + LogicalExpr::Column(Column { table, name }) } /// TODO(veeupup): consider return Vec @@ -62,9 +62,9 @@ impl LogicalExpr { field.is_nullable(), )) } - LogicalExpr::Column(col) => { + LogicalExpr::Column(Column { name, .. }) => { for field in input.schema().fields() { - if field.name() == col.0.as_str() { + if field.name() == name.as_str() { return Ok(field.clone()); } } @@ -91,7 +91,10 @@ impl LogicalExpr { /// A named reference to a qualified field in a schema. #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -pub struct Column(pub String); +pub struct Column { + pub table: Option, + pub name: String, +} #[derive(Debug, Clone)] diff --git a/src/logical_plan/plan.rs b/src/logical_plan/plan.rs index 2dc96e4..2dbce5b 100644 --- a/src/logical_plan/plan.rs +++ b/src/logical_plan/plan.rs @@ -156,7 +156,7 @@ mod tests { }); let filter_expr = LogicalExpr::BinaryExpr(BinaryExpr { - left: Box::new(LogicalExpr::column("state".to_string())), + left: Box::new(LogicalExpr::column(None, "state".to_string())), op: Operator::Eq, right: Box::new(LogicalExpr::Literal(ScalarValue::Utf8(Some( "CO".to_string(), @@ -169,11 +169,11 @@ mod tests { }); let _projection = vec![ - LogicalExpr::column("id".to_string()), - LogicalExpr::column("first_name".to_string()), - LogicalExpr::column("last_name".to_string()), - LogicalExpr::column("state".to_string()), - LogicalExpr::column("salary".to_string()), + LogicalExpr::column(None, "id".to_string()), + LogicalExpr::column(None, "first_name".to_string()), + LogicalExpr::column(None, "last_name".to_string()), + LogicalExpr::column(None, "state".to_string()), + LogicalExpr::column(None, "salary".to_string()), ]; Ok(()) diff --git a/src/planner/mod.rs b/src/planner/mod.rs index 6828578..dcfc276 100644 --- a/src/planner/mod.rs +++ b/src/planner/mod.rs @@ -74,7 +74,7 @@ impl QueryPlanner { ) -> Result { match expr { LogicalExpr::Alias(_, _) => todo!(), - LogicalExpr::Column(Column(name)) => { + LogicalExpr::Column(Column { name, .. }) => { for (idx, field) in input.schema().fields().iter().enumerate() { if field.name() == name { return ColumnExpr::try_create(None, Some(idx)); @@ -120,9 +120,9 @@ mod tests { catalog.add_csv_table("t1", "data/test_data.csv")?; let source = catalog.get_table_df("t1")?; let exprs = vec![ - LogicalExpr::column("id".to_string()), - LogicalExpr::column("name".to_string()), - LogicalExpr::column("age".to_string()), + LogicalExpr::column(None, "id".to_string()), + LogicalExpr::column(None, "name".to_string()), + LogicalExpr::column(None, "age".to_string()), ]; let logical_plan = source.project(exprs).logical_plan(); let physical_plan = QueryPlanner::create_physical_plan(&logical_plan)?; diff --git a/src/sql/planner.rs b/src/sql/planner.rs index d1ae473..860cf24 100644 --- a/src/sql/planner.rs +++ b/src/sql/planner.rs @@ -156,7 +156,7 @@ impl<'a> SQLPlanner<'a> { } Expr::Value(Value::SingleQuotedString(ref s)) => Ok(lit(s.clone())), Expr::Value(Value::Null) => Ok(LogicalExpr::Literal(ScalarValue::Null)), - Expr::Identifier(id) => Ok(LogicalExpr::column(normalize_ident(id))), + Expr::Identifier(id) => Ok(LogicalExpr::column(None, normalize_ident(id))), // TODO(veeupup): cast func Expr::BinaryOp { left, op, right } => self.parse_sql_binary_op(left, op, right), _ => todo!(), From f9cbc2721589211e53c4eaae9da0c4a0f95db4e8 Mon Sep 17 00:00:00 2001 From: Veeupup Date: Tue, 17 May 2022 20:18:42 +0800 Subject: [PATCH 03/13] inner join planner Signed-off-by: Veeupup --- src/error.rs | 1 + src/logical_plan/dataframe.rs | 17 +++- src/logical_plan/expression.rs | 13 +++ src/logical_plan/plan.rs | 2 +- src/sql/planner.rs | 170 +++++++++++++++++++++++++++++---- 5 files changed, 183 insertions(+), 20 deletions(-) diff --git a/src/error.rs b/src/error.rs index 1addd83..d9a1ee4 100644 --- a/src/error.rs +++ b/src/error.rs @@ -30,6 +30,7 @@ pub enum ErrorCode { PlanError(String), + NotImplemented, #[allow(unused)] Others, } diff --git a/src/logical_plan/dataframe.rs b/src/logical_plan/dataframe.rs index e236da3..65e1f00 100644 --- a/src/logical_plan/dataframe.rs +++ b/src/logical_plan/dataframe.rs @@ -11,7 +11,9 @@ use arrow::datatypes::{Schema, SchemaRef}; use crate::logical_plan::expression::LogicalExpr; use crate::logical_plan::plan::{Aggregate, Filter, LogicalPlan, Projection}; -use super::plan::Limit; +use super::expression::Column; +use super::plan::{JoinType, Limit}; +use crate::error::Result; #[derive(Clone)] pub struct DataFrame { @@ -19,6 +21,10 @@ pub struct DataFrame { } impl DataFrame { + pub fn new(plan: LogicalPlan) -> Self { + Self { plan } + } + pub fn project(self, exprs: Vec) -> Self { let fields = exprs .iter() @@ -73,6 +79,15 @@ impl DataFrame { } } + pub fn join( + &self, + _right: &LogicalPlan, + _join_type: JoinType, + _join_keys: (Vec, Vec), + ) -> Result { + todo!() + } + pub fn schema(&self) -> SchemaRef { self.plan.schema() } diff --git a/src/logical_plan/expression.rs b/src/logical_plan/expression.rs index 70142d6..ebcf99d 100644 --- a/src/logical_plan/expression.rs +++ b/src/logical_plan/expression.rs @@ -87,6 +87,19 @@ impl LogicalExpr { // LogicalExpression::Wildcard => , } } + + pub fn and(self, other: LogicalExpr) -> LogicalExpr { + binary_expr(self, Operator::And, other) + } +} + +/// return a new expression l r +pub fn binary_expr(l: LogicalExpr, op: Operator, r: LogicalExpr) -> LogicalExpr { + LogicalExpr::BinaryExpr(BinaryExpr { + left: Box::new(l), + op, + right: Box::new(r), + }) } /// A named reference to a qualified field in a schema. diff --git a/src/logical_plan/plan.rs b/src/logical_plan/plan.rs index 2dbce5b..a4426eb 100644 --- a/src/logical_plan/plan.rs +++ b/src/logical_plan/plan.rs @@ -103,7 +103,7 @@ pub struct Aggregate { pub schema: SchemaRef, } -#[derive(Clone)] +#[derive(Clone, PartialEq)] pub enum JoinType { Inner, Left, diff --git a/src/sql/planner.rs b/src/sql/planner.rs index 860cf24..b0b9220 100644 --- a/src/sql/planner.rs +++ b/src/sql/planner.rs @@ -7,14 +7,19 @@ * */ -use log::debug; -use sqlparser::ast::{BinaryOperator, Expr, OrderByExpr, SetExpr, Statement, TableWithJoins}; + + + +use sqlparser::ast::{ + BinaryOperator, Expr, Join, JoinConstraint, JoinOperator, OrderByExpr, SetExpr, Statement, + TableWithJoins, +}; use sqlparser::ast::{Ident, ObjectName, SelectItem, TableFactor, Value}; use crate::error::ErrorCode; -use crate::logical_plan::expression::{BinaryExpr, LogicalExpr, Operator, ScalarValue}; +use crate::logical_plan::expression::{BinaryExpr, Column, LogicalExpr, Operator, ScalarValue}; use crate::logical_plan::literal::lit; -use crate::logical_plan::plan::TableScan; +use crate::logical_plan::plan::{JoinType, TableScan}; use crate::{ catalog::Catalog, error::Result, @@ -45,9 +50,9 @@ impl<'a> SQLPlanner<'a> { fn set_expr_to_plan(&self, set_expr: SetExpr) -> Result { match set_expr { SetExpr::Select(select) => { - let df = self.plan_from_tables(select.from)?; + let plans = self.plan_from_tables(select.from)?; - let df = self.plan_selection(select.selection, df)?; + let df = self.plan_selection(select.selection, plans)?; // process the SELECT expressions, with wildcards expanded let df = self.plan_from_projection(df, select.projection)?; @@ -82,21 +87,97 @@ impl<'a> SQLPlanner<'a> { } } - fn plan_from_tables(&self, from: Vec) -> Result { - // TODO(veeupup): support select with no from - // TODO(veeupup): support select with join, multi table - debug_assert!(!from.is_empty()); - match &from[0].relation { - TableFactor::Table { name, alias: _, .. } => { + fn plan_from_tables(&self, from: Vec) -> Result> { + match from.len() { + 0 => todo!("support select with no from"), + _ => from + .iter() + .map(|t| self.plan_table_with_joins(t)) + .collect::>>(), + } + } + + fn plan_table_with_joins(&self, t: &TableWithJoins) -> Result { + let left = self.parse_table(&t.relation)?; + match t.joins.len() { + 0 => Ok(left), + n => { + let mut left = self.parse_table_join(left, &t.joins[0])?; + for i in 1..n { + left = self.parse_table_join(left, &t.joins[i])?; + } + Ok(left) + } + } + } + + fn parse_table_join(&self, left: LogicalPlan, join: &Join) -> Result { + let right = self.parse_table(&join.relation)?; + match &join.join_operator { + JoinOperator::LeftOuter(constraint) => { + self.parse_join(left, right, constraint, JoinType::Left) + } + JoinOperator::RightOuter(constraint) => { + self.parse_join(left, right, constraint, JoinType::Right) + } + JoinOperator::Inner(constraint) => { + self.parse_join(left, right, constraint, JoinType::Inner) + } + // TODO(veeupup): cross join + _other => Err(ErrorCode::NotImplemented), + } + } + + fn parse_join( + &self, + left: LogicalPlan, + right: LogicalPlan, + constraint: &JoinConstraint, + join_type: JoinType, + ) -> Result { + match constraint { + JoinConstraint::On(sql_expr) => { + let mut keys: Vec<(Column, Column)> = vec![]; + let expr = self.sql_to_expr(sql_expr)?; + + let mut filters = vec![]; + extract_join_keys(&expr, &mut keys, &mut filters); + + let (left_keys, right_keys): (Vec, Vec) = keys.into_iter().unzip(); + + if filters.is_empty() { + let join = + DataFrame::new(left).join(&right, join_type, (left_keys, right_keys))?; + Ok(join.logical_plan()) + } else if join_type == JoinType::Inner { + let join = + DataFrame::new(left).join(&right, join_type, (left_keys, right_keys))?; + let join = join.filter( + filters + .iter() + .skip(1) + .fold(filters[0].clone(), |acc, e| acc.and(e.clone())), + ); + Ok(join.logical_plan()) + } else { + Err(ErrorCode::NotImplemented) + } + } + _ => Err(ErrorCode::NotImplemented), + } + } + + fn parse_table(&self, relation: &TableFactor) -> Result { + match &relation { + TableFactor::Table { name, .. } => { let table_name = Self::normalize_sql_object_name(name); let source = self.catalog.get_table(&table_name)?; - let plan = LogicalPlan::TableScan(TableScan { + Ok(LogicalPlan::TableScan(TableScan { source, projection: None, - }); - Ok(DataFrame { plan }) + })) } - _ => todo!(), + _ => unimplemented!(), } } @@ -119,11 +200,18 @@ impl<'a> SQLPlanner<'a> { Err(err) => Err(err), }) .collect::>(); - debug!("projection: {:?}", proj); Ok(df.project(proj)) } - fn plan_selection(&self, selection: Option, df: DataFrame) -> Result { + fn plan_selection( + &self, + selection: Option, + plans: Vec, + ) -> Result { + // TODO(veeupup): handle joins + let df = DataFrame { + plan: plans[0].clone(), + }; match selection { Some(predicate_expr) => { let filter_expr = self.sql_to_expr(&predicate_expr)?; @@ -201,6 +289,52 @@ fn normalize_ident(id: &Ident) -> String { } } +/// Come from apache/arrow-datafusion +/// Extracts equijoin ON condition be a single Eq or multiple conjunctive Eqs +/// Filters matching this pattern are added to `accum` +/// Filters that don't match this pattern are added to `accum_filter` +/// Examples: +/// +/// foo = bar => accum=[(foo, bar)] accum_filter=[] +/// foo = bar AND bar = baz => accum=[(foo, bar), (bar, baz)] accum_filter=[] +/// foo = bar AND baz > 1 => accum=[(foo, bar)] accum_filter=[baz > 1] +/// +fn extract_join_keys( + expr: &LogicalExpr, + accum: &mut Vec<(Column, Column)>, + accum_filter: &mut Vec, +) { + match expr { + LogicalExpr::BinaryExpr(BinaryExpr { left, op, right }) => match op { + Operator::Eq => match (left.as_ref(), right.as_ref()) { + (LogicalExpr::Column(l), LogicalExpr::Column(r)) => { + accum.push((l.clone(), r.clone())); + } + _other => { + accum_filter.push(expr.clone()); + } + }, + Operator::And => { + extract_join_keys(left, accum, accum_filter); + extract_join_keys(right, accum, accum_filter); + } + _other + if matches!(**left, LogicalExpr::Column(_)) + || matches!(**right, LogicalExpr::Column(_)) => + { + accum_filter.push(expr.clone()); + } + _other => { + extract_join_keys(left, accum, accum_filter); + extract_join_keys(right, accum, accum_filter); + } + }, + _other => { + accum_filter.push(expr.clone()); + } + } +} + #[cfg(test)] mod tests { use crate::db::NaiveDB; From f78feb076c0f90c6c37b215b1f550221a517a481 Mon Sep 17 00:00:00 2001 From: Veeupup Date: Tue, 17 May 2022 21:07:39 +0800 Subject: [PATCH 04/13] join planner save Signed-off-by: Veeupup --- src/logical_plan/dataframe.rs | 50 ++++++++++++++++++++++++++++++---- src/logical_plan/expression.rs | 2 +- src/main.rs | 13 +++++++-- src/sql/planner.rs | 27 +++++++++++++++++- 4 files changed, 81 insertions(+), 11 deletions(-) diff --git a/src/logical_plan/dataframe.rs b/src/logical_plan/dataframe.rs index 65e1f00..6ea8f27 100644 --- a/src/logical_plan/dataframe.rs +++ b/src/logical_plan/dataframe.rs @@ -12,8 +12,8 @@ use crate::logical_plan::expression::LogicalExpr; use crate::logical_plan::plan::{Aggregate, Filter, LogicalPlan, Projection}; use super::expression::Column; -use super::plan::{JoinType, Limit}; -use crate::error::Result; +use super::plan::{JoinType, Limit, Join}; +use crate::error::{Result, ErrorCode}; #[derive(Clone)] pub struct DataFrame { @@ -81,11 +81,49 @@ impl DataFrame { pub fn join( &self, - _right: &LogicalPlan, - _join_type: JoinType, - _join_keys: (Vec, Vec), + right: &LogicalPlan, + join_type: JoinType, + join_keys: (Vec, Vec), ) -> Result { - todo!() + if join_keys.0.len() != join_keys.1.len() { + return Err(ErrorCode::PlanError("left_keys length must be equal to right_keys length".to_string())); + } + + // TODO(veeupup): we need judge which side os conditions on + // let (left_keys, right_keys) = + // join_keys + // .0 + // .into_iter() + // .zip(join_keys.1.into_iter()) + // .map(|(l, r)| { + // match (&l.table, &r.table) { + // (Some(l), Some(r)) => { + // (Ok(l), Ok(r)) + // }, + // _ => unimplemented!() + // } + // }).collect::>(); + // let left_keys = left_keys.into_iter().collect::>>()?; + // let right_keys = right_keys.into_iter().collect::>>()?; + let left_keys = join_keys.0.clone(); + let right_keys = join_keys.1.clone(); + + let on: Vec<(_, _)> = left_keys.into_iter().zip(right_keys.into_iter()).collect(); + // join schema + let left_schema = self.plan.schema(); + let left_fields = left_schema.fields().iter(); + let right_schema = right.schema(); + let right_fields = right_schema.fields().iter(); + let fields = left_fields.chain(right_fields).cloned().collect(); + let join_schema = Arc::new(Schema::new(fields)); + + Ok(Self::new(LogicalPlan::Join(Join { + left: Arc::new(self.plan.clone()), + right: Arc::new(right.clone()), + on, + join_type, + schema: join_schema, + }))) } pub fn schema(&self) -> SchemaRef { diff --git a/src/logical_plan/expression.rs b/src/logical_plan/expression.rs index ebcf99d..a9d1a08 100644 --- a/src/logical_plan/expression.rs +++ b/src/logical_plan/expression.rs @@ -9,7 +9,7 @@ use std::iter::repeat; use arrow::array::StringArray; use arrow::array::{new_null_array, ArrayRef, BooleanArray, Float64Array, Int64Array, UInt64Array}; -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::{DataType, Field, SchemaRef}; use std::sync::Arc; use crate::error::ErrorCode; diff --git a/src/main.rs b/src/main.rs index 69c51da..69e7fee 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,11 +5,18 @@ use naive_db::Result; fn main() -> Result<()> { let mut db = NaiveDB::default(); - db.create_csv_table("t1", "data/test_data.csv")?; + // db.create_csv_table("t1", "data/test_data.csv")?; - let ret = db.run_sql("select id, name, age + 100 from t1 where id < 6 limit 3")?; + // let ret = db.run_sql("select id, name, age + 100 from t1 where id < 6 limit 3")?; - print_result(&ret)?; + // print_result(&ret)?; + // Ok(()) + db.create_csv_table("employee", "data/employee.csv"); + db.create_csv_table("rank", "data/rank.csv"); + + let ret = db.run_sql("select id, name from employee innner join rank on employee.id = rank.id"); + + print_result(&ret?); Ok(()) } diff --git a/src/sql/planner.rs b/src/sql/planner.rs index b0b9220..fe00bbb 100644 --- a/src/sql/planner.rs +++ b/src/sql/planner.rs @@ -247,6 +247,22 @@ impl<'a> SQLPlanner<'a> { Expr::Identifier(id) => Ok(LogicalExpr::column(None, normalize_ident(id))), // TODO(veeupup): cast func Expr::BinaryOp { left, op, right } => self.parse_sql_binary_op(left, op, right), + Expr::CompoundIdentifier(ids) => { + let mut var_names = + ids.iter().map(|id| id.value.clone()).collect::>(); + + match (var_names.pop(), var_names.pop()) { + (Some(name), Some(table)) if var_names.is_empty() => { + // table.column identifier + Ok(LogicalExpr::Column(Column { + table: Some(table), + name, + })) + } + _ => Err(ErrorCode::NotImplemented), + } + + } _ => todo!(), } } @@ -337,7 +353,7 @@ fn extract_join_keys( #[cfg(test)] mod tests { - use crate::db::NaiveDB; + use crate::{db::NaiveDB, print_result}; use crate::error::Result; use arrow::array::{Array, ArrayRef, Int64Array, StringArray}; use std::sync::Arc; @@ -380,6 +396,15 @@ mod tests { assert_eq!(batch.column(2), &age_excepted); } + { + db.create_csv_table("employee", "data/employee.csv"); + db.create_csv_table("rank", "data/rank.csv"); + + let ret = db.run_sql("select id, name from employee innner join rank on employee.id = rank.id"); + + print_result(&ret?); + } + Ok(()) } } From 5d430236d073cf85a83b1eaddc002b551f0b20ec Mon Sep 17 00:00:00 2001 From: Veeupup Date: Tue, 17 May 2022 21:30:54 +0800 Subject: [PATCH 05/13] planner join with select Signed-off-by: Veeupup --- src/logical_plan/dataframe.rs | 10 +- src/logical_plan/expression.rs | 2 +- src/main.rs | 14 +-- src/sql/planner.rs | 174 ++++++++++++++++++++++++++++----- 4 files changed, 163 insertions(+), 37 deletions(-) diff --git a/src/logical_plan/dataframe.rs b/src/logical_plan/dataframe.rs index 6ea8f27..6e6a486 100644 --- a/src/logical_plan/dataframe.rs +++ b/src/logical_plan/dataframe.rs @@ -12,8 +12,8 @@ use crate::logical_plan::expression::LogicalExpr; use crate::logical_plan::plan::{Aggregate, Filter, LogicalPlan, Projection}; use super::expression::Column; -use super::plan::{JoinType, Limit, Join}; -use crate::error::{Result, ErrorCode}; +use super::plan::{Join, JoinType, Limit}; +use crate::error::{ErrorCode, Result}; #[derive(Clone)] pub struct DataFrame { @@ -86,9 +86,11 @@ impl DataFrame { join_keys: (Vec, Vec), ) -> Result { if join_keys.0.len() != join_keys.1.len() { - return Err(ErrorCode::PlanError("left_keys length must be equal to right_keys length".to_string())); + return Err(ErrorCode::PlanError( + "left_keys length must be equal to right_keys length".to_string(), + )); } - + // TODO(veeupup): we need judge which side os conditions on // let (left_keys, right_keys) = // join_keys diff --git a/src/logical_plan/expression.rs b/src/logical_plan/expression.rs index a9d1a08..ebcf99d 100644 --- a/src/logical_plan/expression.rs +++ b/src/logical_plan/expression.rs @@ -9,7 +9,7 @@ use std::iter::repeat; use arrow::array::StringArray; use arrow::array::{new_null_array, ArrayRef, BooleanArray, Float64Array, Int64Array, UInt64Array}; -use arrow::datatypes::{DataType, Field, SchemaRef}; +use arrow::datatypes::{DataType, Field}; use std::sync::Arc; use crate::error::ErrorCode; diff --git a/src/main.rs b/src/main.rs index 69e7fee..418f963 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,18 +5,18 @@ use naive_db::Result; fn main() -> Result<()> { let mut db = NaiveDB::default(); - // db.create_csv_table("t1", "data/test_data.csv")?; + db.create_csv_table("t1", "data/test_data.csv")?; - // let ret = db.run_sql("select id, name, age + 100 from t1 where id < 6 limit 3")?; + let ret = db.run_sql("select id, name, age + 100 from t1 where id < 6 limit 3")?; - // print_result(&ret)?; + print_result(&ret)?; // Ok(()) - db.create_csv_table("employee", "data/employee.csv"); - db.create_csv_table("rank", "data/rank.csv"); + db.create_csv_table("employee", "data/employee.csv")?; + db.create_csv_table("rank", "data/rank.csv")?; - let ret = db.run_sql("select id, name from employee innner join rank on employee.id = rank.id"); + // let ret = db.run_sql("select id, name from employee innner join rank on employee.id = rank.id"); - print_result(&ret?); + // print_result(&ret?); Ok(()) } diff --git a/src/sql/planner.rs b/src/sql/planner.rs index fe00bbb..efd9316 100644 --- a/src/sql/planner.rs +++ b/src/sql/planner.rs @@ -7,9 +7,9 @@ * */ +use std::collections::HashSet; - - +use arrow::datatypes::SchemaRef; use sqlparser::ast::{ BinaryOperator, Expr, Join, JoinConstraint, JoinOperator, OrderByExpr, SetExpr, Statement, TableWithJoins, @@ -209,16 +209,75 @@ impl<'a> SQLPlanner<'a> { plans: Vec, ) -> Result { // TODO(veeupup): handle joins - let df = DataFrame { - plan: plans[0].clone(), - }; + // let df = DataFrame { + // plan: plans[0].clone(), + // }; + // match selection { + // Some(predicate_expr) => { + // let filter_expr = self.sql_to_expr(&predicate_expr)?; + // let df = df.filter(filter_expr); + // Ok(df) + // } + // None => Ok(df), + // } match selection { - Some(predicate_expr) => { - let filter_expr = self.sql_to_expr(&predicate_expr)?; - let df = df.filter(filter_expr); - Ok(df) + Some(expr) => { + let mut fields = vec![]; + for plan in &plans { + fields.extend_from_slice(plan.schema().fields()); + } + let filter_expr = self.sql_to_expr(&expr)?; + + // look for expressions of the form ` = ` + let mut possible_join_keys = vec![]; + extract_possible_join_keys(&filter_expr, &mut possible_join_keys)?; + + let mut all_join_keys = HashSet::new(); + let mut left = plans[0].clone(); + for right in plans.iter().skip(1) { + let left_schema = left.schema(); + let right_schema = right.schema(); + let mut join_keys = vec![]; + for (l, r) in &possible_join_keys { + if find_column_in_schema(&left_schema, l).is_ok() + && find_column_in_schema(&right_schema, r).is_ok() + { + join_keys.push((l.clone(), r.clone())); + } else if find_column_in_schema(&left_schema, r).is_ok() + && find_column_in_schema(&right_schema, l).is_ok() + { + join_keys.push((r.clone(), l.clone())); + } + } + if !join_keys.is_empty() { + let left_keys: Vec = + join_keys.iter().map(|(l, _)| l.clone()).collect(); + let right_keys: Vec = + join_keys.iter().map(|(_, r)| r.clone()).collect(); + let df = DataFrame::new(left); + left = df + .join(right, JoinType::Inner, (left_keys, right_keys))? + .logical_plan(); + } else { + return Err(ErrorCode::NotImplemented); + } + + all_join_keys.extend(join_keys); + } + // remove join expressions from filter + match remove_join_expressions(&filter_expr, &all_join_keys)? { + Some(filter_expr) => Ok(DataFrame::new(left).filter(filter_expr)), + _ => Ok(DataFrame::new(left)), + } + } + None => { + if plans.len() == 1 { + Ok(DataFrame::new(plans[0].clone())) + } else { + // CROSS JOIN NOT SUPPORTED YET + Err(ErrorCode::NotImplemented) + } } - None => Ok(df), } } @@ -248,20 +307,18 @@ impl<'a> SQLPlanner<'a> { // TODO(veeupup): cast func Expr::BinaryOp { left, op, right } => self.parse_sql_binary_op(left, op, right), Expr::CompoundIdentifier(ids) => { - let mut var_names = - ids.iter().map(|id| id.value.clone()).collect::>(); - - match (var_names.pop(), var_names.pop()) { - (Some(name), Some(table)) if var_names.is_empty() => { - // table.column identifier - Ok(LogicalExpr::Column(Column { - table: Some(table), - name, - })) - } - _ => Err(ErrorCode::NotImplemented), + let mut var_names = ids.iter().map(|id| id.value.clone()).collect::>(); + + match (var_names.pop(), var_names.pop()) { + (Some(name), Some(table)) if var_names.is_empty() => { + // table.column identifier + Ok(LogicalExpr::Column(Column { + table: Some(table), + name, + })) } - + _ => Err(ErrorCode::NotImplemented), + } } _ => todo!(), } @@ -351,10 +408,76 @@ fn extract_join_keys( } } +/// Extract join keys from a WHERE clause +fn extract_possible_join_keys(expr: &LogicalExpr, accum: &mut Vec<(Column, Column)>) -> Result<()> { + match expr { + LogicalExpr::BinaryExpr(BinaryExpr { left, op, right }) => match op { + Operator::Eq => match (left.as_ref(), right.as_ref()) { + (LogicalExpr::Column(l), LogicalExpr::Column(r)) => { + accum.push((l.clone(), r.clone())); + Ok(()) + } + _ => Ok(()), + }, + Operator::And => { + extract_possible_join_keys(left, accum)?; + extract_possible_join_keys(right, accum) + } + _ => Ok(()), + }, + _ => Ok(()), + } +} + +fn find_column_in_schema(schema: &SchemaRef, col: &Column) -> Result<()> { + let fields = schema.fields(); + for field in fields { + if field.name() == &col.name { + return Ok(()); + } + } + Err(ErrorCode::NoSuchField) +} + +/// Remove join expressions from a filter expression +fn remove_join_expressions( + expr: &LogicalExpr, + join_columns: &HashSet<(Column, Column)>, +) -> Result> { + match expr { + LogicalExpr::BinaryExpr(BinaryExpr { left, op, right }) => match op { + Operator::Eq => match (left.as_ref(), right.as_ref()) { + (LogicalExpr::Column(l), LogicalExpr::Column(r)) => { + if join_columns.contains(&(l.clone(), r.clone())) + || join_columns.contains(&(r.clone(), l.clone())) + { + Ok(None) + } else { + Ok(Some(expr.clone())) + } + } + _ => Ok(Some(expr.clone())), + }, + Operator::And => { + let l = remove_join_expressions(left, join_columns)?; + let r = remove_join_expressions(right, join_columns)?; + match (l, r) { + (Some(ll), Some(rr)) => Ok(Some(LogicalExpr::and(ll, rr))), + (Some(ll), _) => Ok(Some(ll)), + (_, Some(rr)) => Ok(Some(rr)), + _ => Ok(None), + } + } + _ => Ok(Some(expr.clone())), + }, + _ => Ok(Some(expr.clone())), + } +} + #[cfg(test)] mod tests { - use crate::{db::NaiveDB, print_result}; use crate::error::Result; + use crate::{db::NaiveDB, print_result}; use arrow::array::{Array, ArrayRef, Int64Array, StringArray}; use std::sync::Arc; @@ -400,7 +523,8 @@ mod tests { db.create_csv_table("employee", "data/employee.csv"); db.create_csv_table("rank", "data/rank.csv"); - let ret = db.run_sql("select id, name from employee innner join rank on employee.id = rank.id"); + let ret = db + .run_sql("select id, name from employee innner join rank on employee.id = rank.id"); print_result(&ret?); } From 08ad9f03fcfc19b637b4442d089236153d62a1ce Mon Sep 17 00:00:00 2001 From: Veeupup Date: Tue, 17 May 2022 21:55:39 +0800 Subject: [PATCH 06/13] join planner Signed-off-by: Veeupup --- data/department.csv | 8 ++++---- data/employee.csv | 12 ++++++------ data/rank.csv | 8 ++++---- src/logical_plan/dataframe.rs | 1 - src/logical_plan/plan.rs | 16 ++++++++-------- src/main.rs | 9 +++++---- src/planner/mod.rs | 1 - 7 files changed, 27 insertions(+), 28 deletions(-) diff --git a/data/department.csv b/data/department.csv index 0002fc6..2c6c7e0 100644 --- a/data/department.csv +++ b/data/department.csv @@ -1,4 +1,4 @@ -id, info -1, IT -2, Marketing -3, Human Resource \ No newline at end of file +id,info +1,IT +2,Marketing +3,Human Resource \ No newline at end of file diff --git a/data/employee.csv b/data/employee.csv index 15005d8..5acf7b0 100644 --- a/data/employee.csv +++ b/data/employee.csv @@ -1,6 +1,6 @@ -id, name, department_id, rank -1, vee, 1, 0 -2, lynne, 1, 0 -3, Alex, 2, 1 -4, jack, 2, 1 -5, mike, 3, 2 \ No newline at end of file +id,name,department_id,rank +1,vee,1,0 +2,lynne,1,0 +3,Alex,2,1 +4,jack,2,1 +5,mike,3,2 \ No newline at end of file diff --git a/data/rank.csv b/data/rank.csv index 3edf643..7976157 100644 --- a/data/rank.csv +++ b/data/rank.csv @@ -1,4 +1,4 @@ -id, rank_name -0, master -1, diamond -2, grandmaster \ No newline at end of file +id,rank_name +0,master +1,diamond +2,grandmaster \ No newline at end of file diff --git a/src/logical_plan/dataframe.rs b/src/logical_plan/dataframe.rs index 6e6a486..aa626b4 100644 --- a/src/logical_plan/dataframe.rs +++ b/src/logical_plan/dataframe.rs @@ -118,7 +118,6 @@ impl DataFrame { let right_fields = right_schema.fields().iter(); let fields = left_fields.chain(right_fields).cloned().collect(); let join_schema = Arc::new(Schema::new(fields)); - Ok(Self::new(LogicalPlan::Join(Join { left: Arc::new(self.plan.clone()), right: Arc::new(right.clone()), diff --git a/src/logical_plan/plan.rs b/src/logical_plan/plan.rs index a4426eb..32e218e 100644 --- a/src/logical_plan/plan.rs +++ b/src/logical_plan/plan.rs @@ -9,7 +9,7 @@ use crate::logical_plan::expression::{Column, LogicalExpr}; use arrow::datatypes::SchemaRef; use std::sync::Arc; -#[derive(Clone)] +#[derive(Debug, Clone)] pub enum LogicalPlan { /// Evaluates an arbitrary list of expressions (essentially a /// SELECT with an expression list) on its input. @@ -63,7 +63,7 @@ impl LogicalPlan { } } -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct Projection { /// The list of expressions pub exprs: Vec, @@ -73,7 +73,7 @@ pub struct Projection { pub schema: SchemaRef, } -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct Filter { /// The predicate expression, which must have Boolean type. pub predicate: LogicalExpr, @@ -81,7 +81,7 @@ pub struct Filter { pub input: Arc, } -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct TableScan { /// The source of the table pub source: TableRef, @@ -91,7 +91,7 @@ pub struct TableScan { /// Aggregates its input based on a set of grouping and aggregate /// expressions (e.g. SUM). -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct Aggregate { /// The incoming logical plan pub input: Arc, @@ -103,7 +103,7 @@ pub struct Aggregate { pub schema: SchemaRef, } -#[derive(Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq)] pub enum JoinType { Inner, Left, @@ -111,7 +111,7 @@ pub enum JoinType { } /// Join two logical plans on one or more join columns -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct Join { /// Left input pub left: Arc, @@ -125,7 +125,7 @@ pub struct Join { pub schema: SchemaRef, } -#[derive(Clone)] +#[derive(Debug, Clone)] /// Produces the first `n` tuples from its input and discards the rest. pub struct Limit { diff --git a/src/main.rs b/src/main.rs index 418f963..631afd2 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,17 +5,18 @@ use naive_db::Result; fn main() -> Result<()> { let mut db = NaiveDB::default(); - db.create_csv_table("t1", "data/test_data.csv")?; + // db.create_csv_table("t1", "data/test_data.csv")?; - let ret = db.run_sql("select id, name, age + 100 from t1 where id < 6 limit 3")?; + // let ret = db.run_sql("select id, name, age + 100 from t1 where id < 6 limit 3")?; - print_result(&ret)?; + // print_result(&ret)?; // Ok(()) db.create_csv_table("employee", "data/employee.csv")?; db.create_csv_table("rank", "data/rank.csv")?; - // let ret = db.run_sql("select id, name from employee innner join rank on employee.id = rank.id"); + let ret = + db.run_sql("select id, name from employee innner join rank on employee.id = rank.id"); // print_result(&ret?); Ok(()) diff --git a/src/planner/mod.rs b/src/planner/mod.rs index dcfc276..9d82d43 100644 --- a/src/planner/mod.rs +++ b/src/planner/mod.rs @@ -128,7 +128,6 @@ mod tests { let physical_plan = QueryPlanner::create_physical_plan(&logical_plan)?; let batches = physical_plan.execute()?; - println!("{:?}", batches); // test assert_eq!(batches.len(), 1); let batch = &batches[0]; From e261bd1c48d587d0d7d030ef6e75956c1f3a5c36 Mon Sep 17 00:00:00 2001 From: Veeupup Date: Wed, 18 May 2022 14:10:44 +0800 Subject: [PATCH 07/13] add NaiveSchema & NaiveField Signed-off-by: Veeupup --- src/logical_plan/mod.rs | 1 + src/logical_plan/schema.rs | 175 +++++++++++++++++++++++++++++++++++++ 2 files changed, 176 insertions(+) create mode 100644 src/logical_plan/schema.rs diff --git a/src/logical_plan/mod.rs b/src/logical_plan/mod.rs index 1ed95ae..e437ced 100644 --- a/src/logical_plan/mod.rs +++ b/src/logical_plan/mod.rs @@ -13,5 +13,6 @@ mod dataframe; pub mod expression; pub mod literal; pub mod plan; +pub mod schema; pub use dataframe::DataFrame; diff --git a/src/logical_plan/schema.rs b/src/logical_plan/schema.rs new file mode 100644 index 0000000..a30acc5 --- /dev/null +++ b/src/logical_plan/schema.rs @@ -0,0 +1,175 @@ +/* + * @Author: Veeupup + * @Date: 2022-05-18 13:45:10 + * @Last Modified by: Veeupup + * @Last Modified time: 2022-05-18 14:10:09 + * + * Arrow Field does not have table/relation name as its proroties + * So we need a Schema to define inner schema with table name + * + * Code Ideas come from https://github.com/apache/arrow-datafusion/ + * + */ + +use std::ptr::null; + +use arrow::datatypes::DataType; +use arrow::datatypes::{Field, Schema}; + +use crate::error::ErrorCode; +use crate::error::Result; + +pub struct NaiveSchema { + pub fields: Vec, +} + +impl NaiveSchema { + pub fn empty() -> Self { + Self { fields: vec![] } + } + + pub fn new(fields: Vec) -> Self { + // TODO(veeupup): check if we have duplicated name field + Self { fields } + } + + pub fn from_qualified(qualifier: &str, schema: &Schema) -> Self { + Self::new( + schema + .fields() + .iter() + .map(|field| NaiveField { + field: field.clone(), + qualifier: Some(qualifier.to_owned()), + }) + .collect(), + ) + } + + /// join two schema + pub fn join(&self, schema: &NaiveSchema) -> Self { + let mut fields = self.fields.clone(); + fields.extend_from_slice(schema.fields().as_slice()); + Self::new(fields) + } + + pub fn fields(&self) -> &Vec { + &self.fields + } + + pub fn field(&self, i: usize) -> &NaiveField { + &self.fields[i] + } + + pub fn index_of(&self, name: &str) -> Result { + for i in 0..self.fields().len() { + if self.fields[i].name() == name { + return Ok(i); + } + } + Err(ErrorCode::NoSuchField) + } + + /// Find the field with the given name + pub fn field_with_name(&self, relation_name: Option<&str>, name: &str) -> Result { + if let Some(relation_name) = relation_name { + self.field_with_qualified_name(relation_name, name) + } else { + self.field_with_unqualified_name(name) + } + } + + pub fn field_with_unqualified_name(&self, name: &str) -> Result { + let matches = self + .fields + .iter() + .filter(|field| field.name() == name) + .collect::>(); + match matches.len() { + 0 => Err(ErrorCode::PlanError(format!("No field named '{}'", name))), + 1 => Ok(matches[0].to_owned()), + _ => Err(ErrorCode::PlanError(format!( + "Ambiguous reference to field named '{}'", + name + ))), + } + } + + pub fn field_with_qualified_name(&self, relation_name: &str, name: &str) -> Result { + let matches = self + .fields + .iter() + .filter(|field| { + field.qualifier == Some(relation_name.to_owned()) && field.name() == name + }) + .collect::>(); + match matches.len() { + 0 => Err(ErrorCode::PlanError(format!("No field named '{}'", name))), + 1 => Ok(matches[0].to_owned()), + _ => Err(ErrorCode::PlanError(format!( + "Ambiguous reference to field named '{}'", + name + ))), + } + } +} + +/// DFField wraps an Arrow field and adds an optional qualifier +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct NaiveField { + /// Optional qualifier (usually a table or relation name) + qualifier: Option, + /// Arrow field definition + field: Field, +} + +impl NaiveField { + pub fn new(qualifier: Option<&str>, name: &str, data_type: DataType, nullable: bool) -> Self { + Self { + qualifier: qualifier.map(|s| s.to_owned()), + field: Field::new(name, data_type, nullable), + } + } + + pub fn from(field: Field) -> Self { + Self { + qualifier: None, + field, + } + } + + pub fn from_qualified(qualifier: &str, field: Field) -> Self { + Self { + qualifier: Some(qualifier.to_owned()), + field, + } + } + + pub fn name(&self) -> &String { + self.field.name() + } + + /// Returns an immutable reference to the `DFField`'s data-type + pub fn data_type(&self) -> &DataType { + &self.field.data_type() + } + + /// Indicates whether this `DFField` supports null values + pub fn is_nullable(&self) -> bool { + self.field.is_nullable() + } + + /// Returns a reference to the `DFField`'s qualified name + pub fn qualified_name(&self) -> String { + if let Some(relation_name) = &self.qualifier { + format!("{}.{}", relation_name, self.field.name()) + } else { + self.field.name().to_owned() + } + } + + /// Get the optional qualifier + pub fn qualifier(&self) -> Option<&String> { + self.qualifier.as_ref() + } +} From 6190f73c52600dcf308281be420861f00ba803eb Mon Sep 17 00:00:00 2001 From: Veeupup Date: Wed, 18 May 2022 15:10:34 +0800 Subject: [PATCH 08/13] use NaiveSchema & NaiveField Signed-off-by: Veeupup --- src/catalog.rs | 5 +- src/datasource/csv.rs | 17 +++--- src/datasource/empty.rs | 14 +++-- src/datasource/memory.rs | 12 ++-- src/datasource/mod.rs | 3 +- src/logical_plan/dataframe.rs | 9 +-- src/logical_plan/expression.rs | 103 +++++++++++++++++++------------- src/logical_plan/plan.rs | 16 ++--- src/logical_plan/schema.rs | 71 ++++++++++++++++++++-- src/main.rs | 14 ++--- src/physical_plan/limit.rs | 5 +- src/physical_plan/plan.rs | 4 +- src/physical_plan/projection.rs | 12 ++-- src/physical_plan/scan.rs | 3 +- src/physical_plan/selection.rs | 6 +- src/planner/mod.rs | 3 +- src/sql/planner.rs | 3 +- 17 files changed, 200 insertions(+), 100 deletions(-) diff --git a/src/catalog.rs b/src/catalog.rs index 6ce58cc..c30f690 100644 --- a/src/catalog.rs +++ b/src/catalog.rs @@ -11,6 +11,7 @@ use arrow::datatypes::SchemaRef; use crate::datasource::{EmptyTable, MemTable}; use crate::error::ErrorCode; use crate::logical_plan::plan::{LogicalPlan, TableScan}; +use crate::logical_plan::schema::NaiveSchema; use crate::logical_plan::DataFrame; use crate::{ datasource::{CsvConfig, CsvTable, TableRef}, @@ -35,7 +36,7 @@ impl Catalog { pub fn add_memory_table( &mut self, table: &str, - schema: SchemaRef, + schema: NaiveSchema, batches: Vec, ) -> Result<()> { let source = MemTable::try_create(schema, batches)?; @@ -44,7 +45,7 @@ impl Catalog { } /// add empty table - pub fn add_empty_table(&mut self, table: &str, schema: SchemaRef) -> Result<()> { + pub fn add_empty_table(&mut self, table: &str, schema: NaiveSchema) -> Result<()> { let source = EmptyTable::try_create(schema)?; self.tables.insert(table.to_string(), source); Ok(()) diff --git a/src/datasource/csv.rs b/src/datasource/csv.rs index f81e363..ba19fe6 100644 --- a/src/datasource/csv.rs +++ b/src/datasource/csv.rs @@ -11,8 +11,10 @@ use std::path::Path; use std::sync::Arc; use crate::error::Result; +use crate::logical_plan::schema::NaiveSchema; use arrow::csv; +use arrow::datatypes::Schema; use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; use super::TableSource; @@ -42,19 +44,20 @@ impl Default for CsvConfig { #[derive(Debug, Clone)] pub struct CsvTable { - schema: SchemaRef, + schema: NaiveSchema, batches: Vec, } impl CsvTable { #[allow(unused, clippy::iter_next_loop)] pub fn try_create(filename: &str, csv_config: CsvConfig) -> Result { - let schema = Self::infer_schema_from_csv(filename, &csv_config)?; + let orig_schema = Self::infer_schema_from_csv(filename, &csv_config)?; + let schema = NaiveSchema::from_unqualified(&orig_schema); let mut file = File::open(env::current_dir()?.join(Path::new(filename)))?; let mut reader = csv::Reader::new( file, - Arc::clone(&schema), + Arc::new(orig_schema), csv_config.has_header, Some(csv_config.delimiter), csv_config.batch_size, @@ -71,7 +74,7 @@ impl CsvTable { Ok(Arc::new(Self { schema, batches })) } - fn infer_schema_from_csv(filename: &str, csv_config: &CsvConfig) -> Result { + fn infer_schema_from_csv(filename: &str, csv_config: &CsvConfig) -> Result { let mut file = File::open(env::current_dir()?.join(Path::new(filename)))?; let (schema, _) = arrow::csv::reader::infer_reader_schema( &mut file, @@ -79,13 +82,13 @@ impl CsvTable { csv_config.max_read_records, csv_config.has_header, )?; - Ok(Arc::new(schema)) + Ok(schema) } } impl TableSource for CsvTable { - fn schema(&self) -> SchemaRef { - self.schema.clone() + fn schema(&self) -> &NaiveSchema { + &self.schema } fn scan(&self, _projection: Option>) -> Result> { diff --git a/src/datasource/empty.rs b/src/datasource/empty.rs index e282bba..1069530 100644 --- a/src/datasource/empty.rs +++ b/src/datasource/empty.rs @@ -7,6 +7,7 @@ use super::TableSource; use crate::datasource::TableRef; use crate::error::Result; +use crate::logical_plan::schema::NaiveSchema; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use std::sync::Arc; @@ -14,19 +15,19 @@ use std::sync::Arc; /// Empty Table with schema but no data #[derive(Debug, Clone)] pub struct EmptyTable { - schema: SchemaRef, + schema: NaiveSchema, } impl EmptyTable { #[allow(unused)] - pub fn try_create(schema: SchemaRef) -> Result { + pub fn try_create(schema: NaiveSchema) -> Result { Ok(Arc::new(Self { schema })) } } impl TableSource for EmptyTable { - fn schema(&self) -> SchemaRef { - self.schema.clone() + fn schema(&self) -> &NaiveSchema { + &self.schema } fn scan(&self, _projection: Option>) -> Result> { @@ -42,10 +43,11 @@ mod tests { #[test] fn test_empty_table() -> Result<()> { - let schema = Arc::new(Schema::new(vec![ + let schema = Schema::new(vec![ Field::new("a", DataType::Int32, false), Field::new("b", DataType::Int32, false), - ])); + ]); + let schema = NaiveSchema::from_qualified("t1", &schema); let table = EmptyTable::try_create(schema)?; let batches = table.scan(None)?; diff --git a/src/datasource/memory.rs b/src/datasource/memory.rs index f631d01..ded1077 100644 --- a/src/datasource/memory.rs +++ b/src/datasource/memory.rs @@ -9,24 +9,24 @@ use arrow::record_batch::RecordBatch; use std::sync::Arc; use super::{TableRef, TableSource}; -use crate::error::Result; +use crate::{error::Result, logical_plan::schema::NaiveSchema}; #[derive(Debug, Clone)] pub struct MemTable { - schema: SchemaRef, + schema: NaiveSchema, batches: Vec, } impl MemTable { #[allow(unused)] - pub fn try_create(schema: SchemaRef, batches: Vec) -> Result { + pub fn try_create(schema: NaiveSchema, batches: Vec) -> Result { Ok(Arc::new(Self { schema, batches })) } } impl TableSource for MemTable { - fn schema(&self) -> SchemaRef { - self.schema.clone() + fn schema(&self) -> &NaiveSchema { + &self.schema } fn scan(&self, projection: Option>) -> Result> { @@ -47,6 +47,7 @@ mod tests { use super::MemTable; use crate::datasource::TableSource; use crate::error::Result; + use crate::logical_plan::schema::NaiveSchema; use arrow::array::Int32Array; use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; @@ -60,6 +61,7 @@ mod tests { Field::new("c", DataType::Int32, false), Field::new("d", DataType::Int32, true), ])); + let schema = NaiveSchema::from_qualified("t1", &schema); let batch = RecordBatch::try_new( schema.clone(), diff --git a/src/datasource/mod.rs b/src/datasource/mod.rs index 4840426..d3fa81b 100644 --- a/src/datasource/mod.rs +++ b/src/datasource/mod.rs @@ -12,12 +12,13 @@ use std::fmt::Debug; use std::sync::Arc; use crate::error::Result; +use crate::logical_plan::schema::NaiveSchema; use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; pub type TableRef = Arc; pub trait TableSource: Debug { - fn schema(&self) -> SchemaRef; + fn schema(&self) -> &NaiveSchema; // TODO(veeupup): return async stream record batch /// for scan diff --git a/src/logical_plan/dataframe.rs b/src/logical_plan/dataframe.rs index aa626b4..b0ceeff 100644 --- a/src/logical_plan/dataframe.rs +++ b/src/logical_plan/dataframe.rs @@ -13,6 +13,7 @@ use crate::logical_plan::plan::{Aggregate, Filter, LogicalPlan, Projection}; use super::expression::Column; use super::plan::{Join, JoinType, Limit}; +use super::schema::NaiveSchema; use crate::error::{ErrorCode, Result}; #[derive(Clone)] @@ -30,7 +31,7 @@ impl DataFrame { .iter() .map(|expr| expr.data_field(&self.plan).unwrap()) .collect::>(); - let schema = Arc::new(Schema::new(fields)); + let schema = NaiveSchema::new(fields); Self { plan: LogicalPlan::Projection(Projection { input: Arc::new(self.plan), @@ -59,7 +60,7 @@ impl DataFrame { .map(|expr| expr.data_field(&self.plan).unwrap()) .collect::>(); group_fields.append(&mut aggr_fields); - let schema = Arc::new(Schema::new(group_fields)); + let schema = NaiveSchema::new(group_fields); Self { plan: LogicalPlan::Aggregate(Aggregate { input: Arc::new(self.plan), @@ -117,7 +118,7 @@ impl DataFrame { let right_schema = right.schema(); let right_fields = right_schema.fields().iter(); let fields = left_fields.chain(right_fields).cloned().collect(); - let join_schema = Arc::new(Schema::new(fields)); + let join_schema = NaiveSchema::new(fields); Ok(Self::new(LogicalPlan::Join(Join { left: Arc::new(self.plan.clone()), right: Arc::new(right.clone()), @@ -127,7 +128,7 @@ impl DataFrame { }))) } - pub fn schema(&self) -> SchemaRef { + pub fn schema(&self) -> &NaiveSchema { self.plan.schema() } diff --git a/src/logical_plan/expression.rs b/src/logical_plan/expression.rs index ebcf99d..c9caf04 100644 --- a/src/logical_plan/expression.rs +++ b/src/logical_plan/expression.rs @@ -17,6 +17,8 @@ use crate::error::Result; use crate::logical_plan::plan::LogicalPlan; +use super::schema::{self, NaiveField, NaiveSchema}; + #[derive(Clone, Debug)] pub enum LogicalExpr { /// An expression with a specific name. @@ -52,32 +54,30 @@ impl LogicalExpr { } /// TODO(veeupup): consider return Vec - pub fn data_field(&self, input: &LogicalPlan) -> Result { + pub fn data_field(&self, input: &LogicalPlan) -> Result { match self { LogicalExpr::Alias(expr, alias) => { let field = expr.data_field(input)?; - Ok(Field::new( + Ok(NaiveField::new( + None, alias, field.data_type().clone(), field.is_nullable(), )) } LogicalExpr::Column(Column { name, .. }) => { - for field in input.schema().fields() { - if field.name() == name.as_str() { - return Ok(field.clone()); - } - } - Err(ErrorCode::NoSuchField) + input.schema().field_with_unqualified_name(name) } LogicalExpr::Literal(scalar_val) => Ok(scalar_val.data_field()), LogicalExpr::BinaryExpr(expr) => expr.data_field(input), - LogicalExpr::Not(expr) => Ok(Field::new( + LogicalExpr::Not(expr) => Ok(NaiveField::new( + None, format!("Not {}", expr.data_field(input)?.name()).as_str(), DataType::Boolean, true, )), - LogicalExpr::Cast { expr, data_type } => Ok(Field::new( + LogicalExpr::Cast { expr, data_type } => Ok(NaiveField::new( + None, expr.data_field(input)?.name(), data_type.clone(), true, @@ -136,14 +136,14 @@ macro_rules! build_array_from_option { } impl ScalarValue { - pub fn data_field(&self) -> Field { + pub fn data_field(&self) -> NaiveField { match self { - ScalarValue::Null => Field::new("Null", DataType::Null, true), - ScalarValue::Boolean(_) => Field::new("bool", DataType::Boolean, true), - ScalarValue::Float64(_) => Field::new("f64", DataType::Float64, true), - ScalarValue::Int64(_) => Field::new("i64", DataType::Int64, true), - ScalarValue::UInt64(_) => Field::new("u64", DataType::UInt64, true), - ScalarValue::Utf8(_) => Field::new("string", DataType::Utf8, true), + ScalarValue::Null => NaiveField::new(None, "Null", DataType::Null, true), + ScalarValue::Boolean(_) => NaiveField::new(None, "bool", DataType::Boolean, true), + ScalarValue::Float64(_) => NaiveField::new(None, "f64", DataType::Float64, true), + ScalarValue::Int64(_) => NaiveField::new(None, "i64", DataType::Int64, true), + ScalarValue::UInt64(_) => NaiveField::new(None, "u64", DataType::UInt64, true), + ScalarValue::Utf8(_) => NaiveField::new(None, "string", DataType::Utf8, true), } } @@ -173,7 +173,7 @@ pub struct BinaryExpr { } impl BinaryExpr { - pub fn data_field(&self, input: &LogicalPlan) -> Result { + pub fn data_field(&self, input: &LogicalPlan) -> Result { let left = self.left.data_field(input)?; let left = left.name(); let right = match &*self.right { @@ -188,67 +188,80 @@ impl BinaryExpr { _ => self.right.data_field(input)?.name().clone(), }; let field = match self.op { - Operator::Eq => Field::new( + Operator::Eq => NaiveField::new( + None, format!("{} = {}", left, right).as_str(), DataType::Boolean, true, ), - Operator::NotEq => Field::new( + Operator::NotEq => NaiveField::new( + None, format!("{} != {}", left, right).as_str(), DataType::Boolean, true, ), - Operator::Lt => Field::new( + Operator::Lt => NaiveField::new( + None, format!("{} < {}", left, right).as_str(), DataType::Boolean, true, ), - Operator::LtEq => Field::new( + Operator::LtEq => NaiveField::new( + None, format!("{} <= {}", left, right).as_str(), DataType::Boolean, true, ), - Operator::Gt => Field::new( + Operator::Gt => NaiveField::new( + None, format!("{} > {}", left, right).as_str(), DataType::Boolean, true, ), - Operator::GtEq => Field::new( + Operator::GtEq => NaiveField::new( + None, format!("{} >= {}", left, right).as_str(), DataType::Boolean, true, ), - Operator::Plus => Field::new( + Operator::Plus => NaiveField::new( + None, format!("{} + {}", left, right).as_str(), self.left.data_field(input)?.data_type().clone(), true, ), - Operator::Minus => Field::new( + Operator::Minus => NaiveField::new( + None, format!("{} - {}", left, right).as_str(), self.left.data_field(input)?.data_type().clone(), true, ), - Operator::Multiply => Field::new( + Operator::Multiply => NaiveField::new( + None, format!("{} * {}", left, right).as_str(), self.left.data_field(input)?.data_type().clone(), true, ), - Operator::Divide => Field::new( + Operator::Divide => NaiveField::new( + None, format!("{} / {}", left, right).as_str(), self.left.data_field(input)?.data_type().clone(), true, ), - Operator::Modulo => Field::new( + Operator::Modulo => NaiveField::new( + None, format!("{} % {}", left, right).as_str(), self.left.data_field(input)?.data_type().clone(), true, ), - Operator::And => Field::new( + Operator::And => NaiveField::new( + None, format!("{} and {}", left, right).as_str(), DataType::Boolean, true, ), - Operator::Or => Field::new( + Operator::Or => NaiveField::new( + None, format!("{} or {}", left, right).as_str(), DataType::Boolean, true, @@ -298,21 +311,24 @@ pub struct ScalarFunction { } impl ScalarFunction { - pub fn data_field(&self, input: &LogicalPlan) -> Result { + pub fn data_field(&self, input: &LogicalPlan) -> Result { // TODO(veeupup): we should make scalar func more specific and should check if valid before creating them let field = self.args[0].data_field(input)?; let field = match self.fun { - ScalarFunc::Abs => Field::new( + ScalarFunc::Abs => NaiveField::new( + None, format!("abs({})", field.name()).as_str(), DataType::Int64, true, ), - ScalarFunc::Add => Field::new( + ScalarFunc::Add => NaiveField::new( + None, format!("add({})", field.name()).as_str(), DataType::Int64, true, ), - ScalarFunc::Sub => Field::new( + ScalarFunc::Sub => NaiveField::new( + None, format!("sub({})", field.name()).as_str(), DataType::Int64, true, @@ -339,30 +355,35 @@ pub struct AggregateFunction { } impl AggregateFunction { - pub fn data_field(&self, input: &LogicalPlan) -> Result { + pub fn data_field(&self, input: &LogicalPlan) -> Result { let dt = self.args.data_field(input)?; let field = match self.fun { - AggregateFunc::Count => Field::new( + AggregateFunc::Count => NaiveField::new( + None, format!("count({})", dt.name()).as_str(), dt.data_type().clone(), true, ), - AggregateFunc::Sum => Field::new( + AggregateFunc::Sum => NaiveField::new( + None, format!("sum({})", dt.name()).as_str(), dt.data_type().clone(), true, ), - AggregateFunc::Min => Field::new( + AggregateFunc::Min => NaiveField::new( + None, format!("min({})", dt.name()).as_str(), dt.data_type().clone(), true, ), - AggregateFunc::Max => Field::new( + AggregateFunc::Max => NaiveField::new( + None, format!("max({})", dt.name()).as_str(), dt.data_type().clone(), true, ), - AggregateFunc::Avg => Field::new( + AggregateFunc::Avg => NaiveField::new( + None, format!("avg({})", dt.name()).as_str(), dt.data_type().clone(), true, diff --git a/src/logical_plan/plan.rs b/src/logical_plan/plan.rs index 32e218e..cdcbbb1 100644 --- a/src/logical_plan/plan.rs +++ b/src/logical_plan/plan.rs @@ -9,6 +9,8 @@ use crate::logical_plan::expression::{Column, LogicalExpr}; use arrow::datatypes::SchemaRef; use std::sync::Arc; +use super::schema::NaiveSchema; + #[derive(Debug, Clone)] pub enum LogicalPlan { /// Evaluates an arbitrary list of expressions (essentially a @@ -40,12 +42,12 @@ pub enum LogicalPlan { } impl LogicalPlan { - pub fn schema(&self) -> SchemaRef { + pub fn schema(&self) -> &NaiveSchema { match self { - LogicalPlan::Projection(Projection { schema, .. }) => schema.clone(), + LogicalPlan::Projection(Projection { schema, .. }) => schema, LogicalPlan::Filter(Filter { input, .. }) => input.schema(), - LogicalPlan::Aggregate(Aggregate { schema, .. }) => schema.clone(), - LogicalPlan::Join(Join { schema, .. }) => schema.clone(), + LogicalPlan::Aggregate(Aggregate { schema, .. }) => schema, + LogicalPlan::Join(Join { schema, .. }) => schema, LogicalPlan::Limit(Limit { input, .. }) => input.schema(), LogicalPlan::TableScan(TableScan { source, .. }) => source.schema(), } @@ -70,7 +72,7 @@ pub struct Projection { /// The incoming logical plan pub input: Arc, /// The schema description of the output - pub schema: SchemaRef, + pub schema: NaiveSchema, } #[derive(Debug, Clone)] @@ -100,7 +102,7 @@ pub struct Aggregate { /// Aggregate expressions pub aggr_expr: Vec, /// The schema description of the aggregate output - pub schema: SchemaRef, + pub schema: NaiveSchema, } #[derive(Debug, Clone, PartialEq)] @@ -122,7 +124,7 @@ pub struct Join { /// Join type pub join_type: JoinType, /// The output schema, containing fields from the left and right inputs - pub schema: SchemaRef, + pub schema: NaiveSchema, } #[derive(Debug, Clone)] diff --git a/src/logical_plan/schema.rs b/src/logical_plan/schema.rs index a30acc5..af5b5e0 100644 --- a/src/logical_plan/schema.rs +++ b/src/logical_plan/schema.rs @@ -1,8 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + /* * @Author: Veeupup * @Date: 2022-05-18 13:45:10 * @Last Modified by: Veeupup - * @Last Modified time: 2022-05-18 14:10:09 + * @Last Modified time: 2022-05-18 14:54:46 * * Arrow Field does not have table/relation name as its proroties * So we need a Schema to define inner schema with table name @@ -13,12 +30,13 @@ use std::ptr::null; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, SchemaRef}; use arrow::datatypes::{Field, Schema}; use crate::error::ErrorCode; use crate::error::Result; +#[derive(Debug, Clone)] pub struct NaiveSchema { pub fields: Vec, } @@ -46,6 +64,19 @@ impl NaiveSchema { ) } + pub fn from_unqualified(schema: &Schema) -> Self { + Self::new( + schema + .fields() + .iter() + .map(|field| NaiveField { + field: field.clone(), + qualifier: None, + }) + .collect(), + ) + } + /// join two schema pub fn join(&self, schema: &NaiveSchema) -> Self { let mut fields = self.fields.clone(); @@ -114,7 +145,35 @@ impl NaiveSchema { } } -/// DFField wraps an Arrow field and adds an optional qualifier +impl Into for NaiveSchema { + /// Convert a schema into a DFSchema + fn into(self) -> Schema { + Schema::new( + self.fields + .into_iter() + .map(|f| { + if f.qualifier().is_some() { + Field::new( + f.qualified_name().as_str(), + f.data_type().to_owned(), + f.is_nullable(), + ) + } else { + f.field + } + }) + .collect(), + ) + } +} + +impl Into for NaiveSchema { + fn into(self) -> SchemaRef { + SchemaRef::new(self.into()) + } +} + +/// NaiveField wraps an Arrow field and adds an optional qualifier #[derive(Debug, Clone, PartialEq, Eq)] pub struct NaiveField { /// Optional qualifier (usually a table or relation name) @@ -149,17 +208,17 @@ impl NaiveField { self.field.name() } - /// Returns an immutable reference to the `DFField`'s data-type + /// Returns an immutable reference to the `NaiveField`'s data-type pub fn data_type(&self) -> &DataType { &self.field.data_type() } - /// Indicates whether this `DFField` supports null values + /// Indicates whether this `NaiveField` supports null values pub fn is_nullable(&self) -> bool { self.field.is_nullable() } - /// Returns a reference to the `DFField`'s qualified name + /// Returns a reference to the `NaiveField`'s qualified name pub fn qualified_name(&self) -> String { if let Some(relation_name) = &self.qualifier { format!("{}.{}", relation_name, self.field.name()) diff --git a/src/main.rs b/src/main.rs index 631afd2..9d8f152 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,18 +5,18 @@ use naive_db::Result; fn main() -> Result<()> { let mut db = NaiveDB::default(); - // db.create_csv_table("t1", "data/test_data.csv")?; + db.create_csv_table("t1", "data/test_data.csv")?; - // let ret = db.run_sql("select id, name, age + 100 from t1 where id < 6 limit 3")?; + let ret = db.run_sql("select id, name, age + 100 from t1 where id < 6 limit 3")?; - // print_result(&ret)?; + print_result(&ret)?; // Ok(()) - db.create_csv_table("employee", "data/employee.csv")?; - db.create_csv_table("rank", "data/rank.csv")?; + // db.create_csv_table("employee", "data/employee.csv")?; + // db.create_csv_table("rank", "data/rank.csv")?; - let ret = - db.run_sql("select id, name from employee innner join rank on employee.id = rank.id"); + // let ret = + // db.run_sql("select id, name from employee innner join rank on employee.id = rank.id"); // print_result(&ret?); Ok(()) diff --git a/src/physical_plan/limit.rs b/src/physical_plan/limit.rs index 6ad6db7..e41fa48 100644 --- a/src/physical_plan/limit.rs +++ b/src/physical_plan/limit.rs @@ -2,11 +2,12 @@ * @Author: Veeupup * @Date: 2022-05-17 11:27:29 * @Last Modified by: Veeupup - * @Last Modified time: 2022-05-17 11:57:13 + * @Last Modified time: 2022-05-18 14:45:03 */ use super::{PhysicalPlan, PhysicalPlanRef}; use crate::error::Result; +use crate::logical_plan::schema::NaiveSchema; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use std::sync::Arc; @@ -24,7 +25,7 @@ impl PhysicalLimitPlan { } impl PhysicalPlan for PhysicalLimitPlan { - fn schema(&self) -> SchemaRef { + fn schema(&self) -> &NaiveSchema { self.input.schema() } diff --git a/src/physical_plan/plan.rs b/src/physical_plan/plan.rs index aa916ac..2af3852 100644 --- a/src/physical_plan/plan.rs +++ b/src/physical_plan/plan.rs @@ -9,10 +9,10 @@ use std::sync::Arc; use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; -use crate::error::Result; +use crate::{error::Result, logical_plan::schema::NaiveSchema}; pub trait PhysicalPlan: Debug { - fn schema(&self) -> SchemaRef; + fn schema(&self) -> &NaiveSchema; // TODO(veeupup): return by using streaming mode fn execute(&self) -> Result>; diff --git a/src/physical_plan/projection.rs b/src/physical_plan/projection.rs index 6e3a728..5859c07 100644 --- a/src/physical_plan/projection.rs +++ b/src/physical_plan/projection.rs @@ -9,20 +9,21 @@ use std::sync::Arc; use super::{expression::PhysicalExpr, plan::PhysicalPlan}; use crate::error::Result; +use crate::logical_plan::schema::NaiveSchema; use crate::physical_plan::PhysicalExprRef; use crate::physical_plan::PhysicalPlanRef; use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; #[derive(Debug, Clone)] pub struct ProjectionPlan { input: PhysicalPlanRef, - schema: SchemaRef, + schema: NaiveSchema, expr: Vec, } impl ProjectionPlan { pub fn create( input: PhysicalPlanRef, - schema: SchemaRef, + schema: NaiveSchema, expr: Vec, ) -> PhysicalPlanRef { Arc::new(Self { @@ -34,8 +35,8 @@ impl ProjectionPlan { } impl PhysicalPlan for ProjectionPlan { - fn schema(&self) -> SchemaRef { - self.schema.clone() + fn schema(&self) -> &NaiveSchema { + &self.schema } fn execute(&self) -> Result> { @@ -54,7 +55,8 @@ impl PhysicalPlan for ProjectionPlan { .map(|column| column.clone().into_array()) .collect::>(); // TODO(veeupup): remove unwrap - RecordBatch::try_new(self.schema.clone(), columns).unwrap() + // let projection_schema = self.schema.into(); + RecordBatch::try_new(self.schema.clone().into(), columns).unwrap() }) .collect::>(); Ok(batches) diff --git a/src/physical_plan/scan.rs b/src/physical_plan/scan.rs index 2b68600..f7d7250 100644 --- a/src/physical_plan/scan.rs +++ b/src/physical_plan/scan.rs @@ -8,6 +8,7 @@ use std::sync::Arc; use crate::datasource::TableRef; use crate::error::Result; +use crate::logical_plan::schema::NaiveSchema; use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; use crate::physical_plan::PhysicalPlan; @@ -26,7 +27,7 @@ impl ScanPlan { } impl PhysicalPlan for ScanPlan { - fn schema(&self) -> SchemaRef { + fn schema(&self) -> &NaiveSchema { self.source.schema() } diff --git a/src/physical_plan/selection.rs b/src/physical_plan/selection.rs index d7f7766..6b3a865 100644 --- a/src/physical_plan/selection.rs +++ b/src/physical_plan/selection.rs @@ -7,6 +7,7 @@ use std::sync::Arc; use super::{PhysicalExpr, PhysicalExprRef, PhysicalPlan, PhysicalPlanRef}; +use crate::logical_plan::schema::NaiveSchema; use crate::Result; use arrow::array::{ Float64Array, Float64Builder, Int64Array, Int64Builder, StringArray, StringBuilder, @@ -50,7 +51,7 @@ macro_rules! build_array_by_predicate { } impl PhysicalPlan for SelectionPlan { - fn schema(&self) -> SchemaRef { + fn schema(&self) -> &NaiveSchema { self.input.schema() } @@ -98,7 +99,8 @@ impl PhysicalPlan for SelectionPlan { }; columns.push(column); } - let record_batch = RecordBatch::try_new(self.schema(), columns)?; + let record_batch = + RecordBatch::try_new(Arc::new(self.schema().clone().into()), columns)?; batches.push(record_batch); } Ok(batches) diff --git a/src/planner/mod.rs b/src/planner/mod.rs index 9d82d43..4ff3209 100644 --- a/src/planner/mod.rs +++ b/src/planner/mod.rs @@ -11,6 +11,7 @@ use std::sync::Arc; use arrow::datatypes::Schema; +use crate::logical_plan::schema::NaiveSchema; use crate::physical_plan::PhysicalBinaryExpr; use crate::physical_plan::PhysicalExprRef; use crate::physical_plan::PhysicalLimitPlan; @@ -47,7 +48,7 @@ impl QueryPlanner { .iter() .map(|expr| expr.data_field(proj.input.as_ref()).unwrap()) .collect::>(); - let proj_schema = Arc::new(Schema::new(fields)); + let proj_schema = NaiveSchema::new(fields); Ok(ProjectionPlan::create(input, proj_schema, proj_expr)) } LogicalPlan::Limit(limit) => { diff --git a/src/sql/planner.rs b/src/sql/planner.rs index efd9316..1b3bfa6 100644 --- a/src/sql/planner.rs +++ b/src/sql/planner.rs @@ -20,6 +20,7 @@ use crate::error::ErrorCode; use crate::logical_plan::expression::{BinaryExpr, Column, LogicalExpr, Operator, ScalarValue}; use crate::logical_plan::literal::lit; use crate::logical_plan::plan::{JoinType, TableScan}; +use crate::logical_plan::schema::NaiveSchema; use crate::{ catalog::Catalog, error::Result, @@ -429,7 +430,7 @@ fn extract_possible_join_keys(expr: &LogicalExpr, accum: &mut Vec<(Column, Colum } } -fn find_column_in_schema(schema: &SchemaRef, col: &Column) -> Result<()> { +fn find_column_in_schema(schema: &NaiveSchema, col: &Column) -> Result<()> { let fields = schema.fields(); for field in fields { if field.name() == &col.name { From b3bcf3cddbdf2ad9854e0d5686783676ad643196 Mon Sep 17 00:00:00 2001 From: Veeupup Date: Wed, 18 May 2022 15:30:24 +0800 Subject: [PATCH 09/13] save Signed-off-by: Veeupup --- src/logical_plan/dataframe.rs | 2 ++ src/logical_plan/expression.rs | 7 +++++-- src/logical_plan/schema.rs | 2 +- src/main.rs | 8 ++++---- src/sql/planner.rs | 29 ++++------------------------- 5 files changed, 16 insertions(+), 32 deletions(-) diff --git a/src/logical_plan/dataframe.rs b/src/logical_plan/dataframe.rs index b0ceeff..d2c0560 100644 --- a/src/logical_plan/dataframe.rs +++ b/src/logical_plan/dataframe.rs @@ -27,6 +27,8 @@ impl DataFrame { } pub fn project(self, exprs: Vec) -> Self { + // TODO(veeupup): Ambiguous reference of field + println!("plan: {:?}", &self.plan.schema()); let fields = exprs .iter() .map(|expr| expr.data_field(&self.plan).unwrap()) diff --git a/src/logical_plan/expression.rs b/src/logical_plan/expression.rs index c9caf04..f3b1083 100644 --- a/src/logical_plan/expression.rs +++ b/src/logical_plan/expression.rs @@ -65,8 +65,11 @@ impl LogicalExpr { field.is_nullable(), )) } - LogicalExpr::Column(Column { name, .. }) => { - input.schema().field_with_unqualified_name(name) + LogicalExpr::Column(Column { name, table }) => { + match table { + Some(table) => input.schema().field_with_qualified_name(table, name), + None => input.schema().field_with_unqualified_name(name) + } } LogicalExpr::Literal(scalar_val) => Ok(scalar_val.data_field()), LogicalExpr::BinaryExpr(expr) => expr.data_field(input), diff --git a/src/logical_plan/schema.rs b/src/logical_plan/schema.rs index af5b5e0..a92a54f 100644 --- a/src/logical_plan/schema.rs +++ b/src/logical_plan/schema.rs @@ -19,7 +19,7 @@ * @Author: Veeupup * @Date: 2022-05-18 13:45:10 * @Last Modified by: Veeupup - * @Last Modified time: 2022-05-18 14:54:46 + * @Last Modified time: 2022-05-18 15:28:57 * * Arrow Field does not have table/relation name as its proroties * So we need a Schema to define inner schema with table name diff --git a/src/main.rs b/src/main.rs index 9d8f152..92b8ec3 100644 --- a/src/main.rs +++ b/src/main.rs @@ -12,11 +12,11 @@ fn main() -> Result<()> { print_result(&ret)?; // Ok(()) - // db.create_csv_table("employee", "data/employee.csv")?; - // db.create_csv_table("rank", "data/rank.csv")?; + db.create_csv_table("employee", "data/employee.csv")?; + db.create_csv_table("rank", "data/rank.csv")?; - // let ret = - // db.run_sql("select id, name from employee innner join rank on employee.id = rank.id"); + let ret = + db.run_sql("select name, rank from employee innner join rank on employee.id = rank.id"); // print_result(&ret?); Ok(()) diff --git a/src/sql/planner.rs b/src/sql/planner.rs index 1b3bfa6..3928bc3 100644 --- a/src/sql/planner.rs +++ b/src/sql/planner.rs @@ -210,17 +210,6 @@ impl<'a> SQLPlanner<'a> { plans: Vec, ) -> Result { // TODO(veeupup): handle joins - // let df = DataFrame { - // plan: plans[0].clone(), - // }; - // match selection { - // Some(predicate_expr) => { - // let filter_expr = self.sql_to_expr(&predicate_expr)?; - // let df = df.filter(filter_expr); - // Ok(df) - // } - // None => Ok(df), - // } match selection { Some(expr) => { let mut fields = vec![]; @@ -240,12 +229,12 @@ impl<'a> SQLPlanner<'a> { let right_schema = right.schema(); let mut join_keys = vec![]; for (l, r) in &possible_join_keys { - if find_column_in_schema(&left_schema, l).is_ok() - && find_column_in_schema(&right_schema, r).is_ok() + if left_schema.field_with_unqualified_name(l.name.as_str()).is_ok() + && right_schema.field_with_unqualified_name(r.name.as_str()).is_ok() { join_keys.push((l.clone(), r.clone())); - } else if find_column_in_schema(&left_schema, r).is_ok() - && find_column_in_schema(&right_schema, l).is_ok() + } else if left_schema.field_with_unqualified_name(r.name.as_str()).is_ok() + && right_schema.field_with_unqualified_name(l.name.as_str()).is_ok() { join_keys.push((r.clone(), l.clone())); } @@ -430,16 +419,6 @@ fn extract_possible_join_keys(expr: &LogicalExpr, accum: &mut Vec<(Column, Colum } } -fn find_column_in_schema(schema: &NaiveSchema, col: &Column) -> Result<()> { - let fields = schema.fields(); - for field in fields { - if field.name() == &col.name { - return Ok(()); - } - } - Err(ErrorCode::NoSuchField) -} - /// Remove join expressions from a filter expression fn remove_join_expressions( expr: &LogicalExpr, From c9597f8b82ec1dad10c970fd15aee7972613e083 Mon Sep 17 00:00:00 2001 From: Veeupup Date: Wed, 18 May 2022 15:59:43 +0800 Subject: [PATCH 10/13] save Signed-off-by: Veeupup --- src/logical_plan/dataframe.rs | 30 ++++-------------------------- src/logical_plan/schema.rs | 3 ++- src/sql/planner.rs | 5 ++++- 3 files changed, 10 insertions(+), 28 deletions(-) diff --git a/src/logical_plan/dataframe.rs b/src/logical_plan/dataframe.rs index d2c0560..e5947c7 100644 --- a/src/logical_plan/dataframe.rs +++ b/src/logical_plan/dataframe.rs @@ -28,7 +28,6 @@ impl DataFrame { pub fn project(self, exprs: Vec) -> Self { // TODO(veeupup): Ambiguous reference of field - println!("plan: {:?}", &self.plan.schema()); let fields = exprs .iter() .map(|expr| expr.data_field(&self.plan).unwrap()) @@ -94,33 +93,12 @@ impl DataFrame { )); } - // TODO(veeupup): we need judge which side os conditions on - // let (left_keys, right_keys) = - // join_keys - // .0 - // .into_iter() - // .zip(join_keys.1.into_iter()) - // .map(|(l, r)| { - // match (&l.table, &r.table) { - // (Some(l), Some(r)) => { - // (Ok(l), Ok(r)) - // }, - // _ => unimplemented!() - // } - // }).collect::>(); - // let left_keys = left_keys.into_iter().collect::>>()?; - // let right_keys = right_keys.into_iter().collect::>>()?; - let left_keys = join_keys.0.clone(); - let right_keys = join_keys.1.clone(); - + let (left_keys, right_keys) = join_keys; let on: Vec<(_, _)> = left_keys.into_iter().zip(right_keys.into_iter()).collect(); - // join schema + let left_schema = self.plan.schema(); - let left_fields = left_schema.fields().iter(); - let right_schema = right.schema(); - let right_fields = right_schema.fields().iter(); - let fields = left_fields.chain(right_fields).cloned().collect(); - let join_schema = NaiveSchema::new(fields); + let join_schema = left_schema.join(right.schema()); + Ok(Self::new(LogicalPlan::Join(Join { left: Arc::new(self.plan.clone()), right: Arc::new(right.clone()), diff --git a/src/logical_plan/schema.rs b/src/logical_plan/schema.rs index a92a54f..38b055e 100644 --- a/src/logical_plan/schema.rs +++ b/src/logical_plan/schema.rs @@ -19,7 +19,7 @@ * @Author: Veeupup * @Date: 2022-05-18 13:45:10 * @Last Modified by: Veeupup - * @Last Modified time: 2022-05-18 15:28:57 + * @Last Modified time: 2022-05-18 15:57:32 * * Arrow Field does not have table/relation name as its proroties * So we need a Schema to define inner schema with table name @@ -28,6 +28,7 @@ * */ +use std::fmt::Display; use std::ptr::null; use arrow::datatypes::{DataType, SchemaRef}; diff --git a/src/sql/planner.rs b/src/sql/planner.rs index 3928bc3..76a0bfc 100644 --- a/src/sql/planner.rs +++ b/src/sql/planner.rs @@ -144,7 +144,10 @@ impl<'a> SQLPlanner<'a> { let mut filters = vec![]; extract_join_keys(&expr, &mut keys, &mut filters); - let (left_keys, right_keys): (Vec, Vec) = keys.into_iter().unzip(); + let left_keys = + keys.iter().map(|pair| pair.0.clone()).collect(); + let right_keys = + keys.iter().map(|pair| pair.1.clone()).collect(); if filters.is_empty() { let join = From 68b2c5a2f54df02063b428ccfe63072ad1abc1cd Mon Sep 17 00:00:00 2001 From: Veeupup Date: Wed, 18 May 2022 17:31:46 +0800 Subject: [PATCH 11/13] Join Nested Loop Join finished Signed-off-by: Veeupup --- data/employee.csv | 6 +- src/catalog.rs | 2 +- src/datasource/csv.rs | 2 +- src/datasource/empty.rs | 2 +- src/datasource/memory.rs | 2 +- src/datasource/mod.rs | 2 +- src/logical_plan/dataframe.rs | 2 +- src/logical_plan/expression.rs | 16 ++- src/logical_plan/plan.rs | 4 +- src/logical_plan/schema.rs | 28 ++--- src/main.rs | 7 +- src/physical_plan/limit.rs | 2 +- src/physical_plan/mod.rs | 2 + src/physical_plan/nested_loop_join.rs | 147 ++++++++++++++++++++++++++ src/physical_plan/plan.rs | 2 +- src/physical_plan/projection.rs | 2 +- src/physical_plan/scan.rs | 2 +- src/physical_plan/selection.rs | 2 +- src/planner/mod.rs | 17 ++- src/sql/planner.rs | 26 +++-- 20 files changed, 220 insertions(+), 55 deletions(-) create mode 100644 src/physical_plan/nested_loop_join.rs diff --git a/data/employee.csv b/data/employee.csv index 5acf7b0..119e883 100644 --- a/data/employee.csv +++ b/data/employee.csv @@ -1,6 +1,6 @@ id,name,department_id,rank -1,vee,1,0 -2,lynne,1,0 -3,Alex,2,1 +1,vee,1,1 +2,lynne,1,0 +3,Alex,2,0 4,jack,2,1 5,mike,3,2 \ No newline at end of file diff --git a/src/catalog.rs b/src/catalog.rs index c30f690..bdfa74f 100644 --- a/src/catalog.rs +++ b/src/catalog.rs @@ -6,7 +6,7 @@ use std::collections::HashMap; -use arrow::datatypes::SchemaRef; + use crate::datasource::{EmptyTable, MemTable}; use crate::error::ErrorCode; diff --git a/src/datasource/csv.rs b/src/datasource/csv.rs index ba19fe6..9e2a8a7 100644 --- a/src/datasource/csv.rs +++ b/src/datasource/csv.rs @@ -15,7 +15,7 @@ use crate::logical_plan::schema::NaiveSchema; use arrow::csv; use arrow::datatypes::Schema; -use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; +use arrow::{record_batch::RecordBatch}; use super::TableSource; use crate::datasource::TableRef; diff --git a/src/datasource/empty.rs b/src/datasource/empty.rs index 1069530..7106d87 100644 --- a/src/datasource/empty.rs +++ b/src/datasource/empty.rs @@ -8,7 +8,7 @@ use super::TableSource; use crate::datasource::TableRef; use crate::error::Result; use crate::logical_plan::schema::NaiveSchema; -use arrow::datatypes::SchemaRef; + use arrow::record_batch::RecordBatch; use std::sync::Arc; diff --git a/src/datasource/memory.rs b/src/datasource/memory.rs index ded1077..686ff85 100644 --- a/src/datasource/memory.rs +++ b/src/datasource/memory.rs @@ -4,7 +4,7 @@ * @Email: code@tanweime.com */ -use arrow::datatypes::SchemaRef; + use arrow::record_batch::RecordBatch; use std::sync::Arc; diff --git a/src/datasource/mod.rs b/src/datasource/mod.rs index d3fa81b..8dbcfd5 100644 --- a/src/datasource/mod.rs +++ b/src/datasource/mod.rs @@ -13,7 +13,7 @@ use std::sync::Arc; use crate::error::Result; use crate::logical_plan::schema::NaiveSchema; -use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; +use arrow::{record_batch::RecordBatch}; pub type TableRef = Arc; diff --git a/src/logical_plan/dataframe.rs b/src/logical_plan/dataframe.rs index e5947c7..3c66362 100644 --- a/src/logical_plan/dataframe.rs +++ b/src/logical_plan/dataframe.rs @@ -6,7 +6,7 @@ use std::sync::Arc; -use arrow::datatypes::{Schema, SchemaRef}; + use crate::logical_plan::expression::LogicalExpr; use crate::logical_plan::plan::{Aggregate, Filter, LogicalPlan, Projection}; diff --git a/src/logical_plan/expression.rs b/src/logical_plan/expression.rs index f3b1083..aeebf5c 100644 --- a/src/logical_plan/expression.rs +++ b/src/logical_plan/expression.rs @@ -9,15 +9,15 @@ use std::iter::repeat; use arrow::array::StringArray; use arrow::array::{new_null_array, ArrayRef, BooleanArray, Float64Array, Int64Array, UInt64Array}; -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::{DataType}; use std::sync::Arc; -use crate::error::ErrorCode; + use crate::error::Result; use crate::logical_plan::plan::LogicalPlan; -use super::schema::{self, NaiveField, NaiveSchema}; +use super::schema::{NaiveField}; #[derive(Clone, Debug)] pub enum LogicalExpr { @@ -65,12 +65,10 @@ impl LogicalExpr { field.is_nullable(), )) } - LogicalExpr::Column(Column { name, table }) => { - match table { - Some(table) => input.schema().field_with_qualified_name(table, name), - None => input.schema().field_with_unqualified_name(name) - } - } + LogicalExpr::Column(Column { name, table }) => match table { + Some(table) => input.schema().field_with_qualified_name(table, name), + None => input.schema().field_with_unqualified_name(name), + }, LogicalExpr::Literal(scalar_val) => Ok(scalar_val.data_field()), LogicalExpr::BinaryExpr(expr) => expr.data_field(input), LogicalExpr::Not(expr) => Ok(NaiveField::new( diff --git a/src/logical_plan/plan.rs b/src/logical_plan/plan.rs index cdcbbb1..fcd7d35 100644 --- a/src/logical_plan/plan.rs +++ b/src/logical_plan/plan.rs @@ -6,7 +6,7 @@ use crate::datasource::TableRef; use crate::logical_plan::expression::{Column, LogicalExpr}; -use arrow::datatypes::SchemaRef; + use std::sync::Arc; use super::schema::NaiveSchema; @@ -105,7 +105,7 @@ pub struct Aggregate { pub schema: NaiveSchema, } -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, Copy, PartialEq)] pub enum JoinType { Inner, Left, diff --git a/src/logical_plan/schema.rs b/src/logical_plan/schema.rs index 38b055e..18ff7d2 100644 --- a/src/logical_plan/schema.rs +++ b/src/logical_plan/schema.rs @@ -19,7 +19,7 @@ * @Author: Veeupup * @Date: 2022-05-18 13:45:10 * @Last Modified by: Veeupup - * @Last Modified time: 2022-05-18 15:57:32 + * @Last Modified time: 2022-05-18 17:30:21 * * Arrow Field does not have table/relation name as its proroties * So we need a Schema to define inner schema with table name @@ -28,8 +28,8 @@ * */ -use std::fmt::Display; -use std::ptr::null; + + use arrow::datatypes::{DataType, SchemaRef}; use arrow::datatypes::{Field, Schema}; @@ -119,11 +119,12 @@ impl NaiveSchema { .collect::>(); match matches.len() { 0 => Err(ErrorCode::PlanError(format!("No field named '{}'", name))), - 1 => Ok(matches[0].to_owned()), - _ => Err(ErrorCode::PlanError(format!( - "Ambiguous reference to field named '{}'", - name - ))), + _ => Ok(matches[0].to_owned()), + // TODO(veeupup): multi same name, and we need to return Error + // _ => Err(ErrorCode::PlanError(format!( + // "Ambiguous reference to field named '{}'", + // name + // ))), } } @@ -137,11 +138,12 @@ impl NaiveSchema { .collect::>(); match matches.len() { 0 => Err(ErrorCode::PlanError(format!("No field named '{}'", name))), - 1 => Ok(matches[0].to_owned()), - _ => Err(ErrorCode::PlanError(format!( - "Ambiguous reference to field named '{}'", - name - ))), + _ => Ok(matches[0].to_owned()), + // TODO(veeupup): multi same name, and we need to return Error + // _ => Err(ErrorCode::PlanError(format!( + // "Ambiguous reference to field named '{}'", + // name + // ))), } } } diff --git a/src/main.rs b/src/main.rs index 92b8ec3..196e11f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -15,9 +15,10 @@ fn main() -> Result<()> { db.create_csv_table("employee", "data/employee.csv")?; db.create_csv_table("rank", "data/rank.csv")?; - let ret = - db.run_sql("select name, rank from employee innner join rank on employee.id = rank.id"); + let ret = db.run_sql( + "select id, name, rank_name from employee innner join rank on employee.rank = rank.id", + )?; - // print_result(&ret?); + print_result(&ret); Ok(()) } diff --git a/src/physical_plan/limit.rs b/src/physical_plan/limit.rs index e41fa48..893ad87 100644 --- a/src/physical_plan/limit.rs +++ b/src/physical_plan/limit.rs @@ -8,7 +8,7 @@ use super::{PhysicalPlan, PhysicalPlanRef}; use crate::error::Result; use crate::logical_plan::schema::NaiveSchema; -use arrow::datatypes::SchemaRef; + use arrow::record_batch::RecordBatch; use std::sync::Arc; diff --git a/src/physical_plan/mod.rs b/src/physical_plan/mod.rs index 5a27a6c..ac0c537 100644 --- a/src/physical_plan/mod.rs +++ b/src/physical_plan/mod.rs @@ -8,12 +8,14 @@ mod expression; mod plan; mod limit; +mod nested_loop_join; mod projection; mod scan; mod selection; pub use expression::*; pub use limit::*; +pub use nested_loop_join::*; pub use plan::*; pub use projection::*; pub use scan::*; diff --git a/src/physical_plan/nested_loop_join.rs b/src/physical_plan/nested_loop_join.rs new file mode 100644 index 0000000..276bcd2 --- /dev/null +++ b/src/physical_plan/nested_loop_join.rs @@ -0,0 +1,147 @@ +/* + * @Author: Veeupup + * @Date: 2022-05-18 16:00:13 + * @Last Modified by: Veeupup + * @Last Modified time: 2022-05-18 17:28:58 + */ +use super::PhysicalPlan; +use super::PhysicalPlanRef; +use crate::error::ErrorCode; +use crate::logical_plan::expression::Column; +use crate::logical_plan::plan::JoinType; +use crate::logical_plan::schema::NaiveSchema; +use crate::physical_plan::ColumnExpr; +use crate::physical_plan::PhysicalExpr; + +use crate::Result; +use std::sync::Arc; + +use arrow::array::Array; +use arrow::array::Int64Builder; +use arrow::array::PrimitiveArray; +use arrow::compute; +use arrow::datatypes::DataType; +use arrow::datatypes::Int64Type; +use arrow::record_batch::RecordBatch; + +#[derive(Debug, Clone)] +pub struct NestedLoopJoin { + left: PhysicalPlanRef, + right: PhysicalPlanRef, + on: Vec<(Column, Column)>, + join_type: JoinType, + schema: NaiveSchema, +} + +impl NestedLoopJoin { + pub fn new( + left: PhysicalPlanRef, + right: PhysicalPlanRef, + on: Vec<(Column, Column)>, + join_type: JoinType, + schema: NaiveSchema, + ) -> PhysicalPlanRef { + Arc::new(Self { + left, + right, + on, + join_type, + schema, + }) + } +} + +impl PhysicalPlan for NestedLoopJoin { + fn schema(&self) -> &NaiveSchema { + &self.schema + } + + fn execute(&self) -> Result> { + let outer_table = self.left.execute()?; + let inner_table = self.right.execute()?; + + let mut batches: Vec = vec![]; + // TODO(veeupup): support multi on conditions + for (left_col, right_col) in &self.on { + let left_col = ColumnExpr::try_create(Some(left_col.name.clone()), None)?; + let right_col = ColumnExpr::try_create(Some(right_col.name.clone()), None)?; + + for outer in &outer_table { + let left_col = left_col.evaluate(outer)?.into_array(); + + let dt = left_col.data_type(); + for inner in &inner_table { + let right_col = right_col.evaluate(inner)?.into_array(); + + // check if ok + if left_col.data_type() != right_col.data_type() { + return Err(ErrorCode::PlanError(format!( + "Join on left and right data type should be same: left: {:?}, right: {:?}", + left_col.data_type(), + right_col.data_type() + ))); + } + + let mut outer_pos = Int64Builder::new(left_col.len()); + let mut inner_pos = Int64Builder::new(right_col.len()); + match dt { + DataType::Int64 => { + let left_col = left_col + .as_any() + .downcast_ref::>() + .unwrap(); + let right_col = right_col + .as_any() + .downcast_ref::>() + .unwrap(); + + for (x_pos, x) in left_col.iter().enumerate() { + for (y_pos, y) in right_col.iter().enumerate() { + match (x, y) { + (Some(x), Some(y)) => { + if x == y { + // equal and we should + outer_pos.append_value(x_pos as i64); + inner_pos.append_value(y_pos as i64); + } + } + _ => {} + } + } + } + } + DataType::UInt64 => todo!(), + DataType::Float64 => todo!(), + DataType::Utf8 => todo!(), + _ => unimplemented!(), + } + let mut columns = vec![]; + + let outer_pos = outer_pos.finish(); + let inner_pos = inner_pos.finish(); + + // add left columns + for i in 0..self.left.schema().fields().len() { + let array = outer.column(i); + columns.push(compute::take(array.as_ref(), &outer_pos, None)?); + } + + // add right columns + for i in 0..self.right.schema().fields().len() { + let array = inner.column(i); + columns.push(compute::take(array.as_ref(), &inner_pos, None)?); + } + + let batch = RecordBatch::try_new(self.schema.clone().into(), columns)?; + batches.push(batch); + } + } + } + + return Ok(batches); + } + + fn children(&self) -> Result> { + Ok(vec![self.left.clone(), self.right.clone()]) + } +} diff --git a/src/physical_plan/plan.rs b/src/physical_plan/plan.rs index 2af3852..0bbfbf0 100644 --- a/src/physical_plan/plan.rs +++ b/src/physical_plan/plan.rs @@ -7,7 +7,7 @@ use std::fmt::Debug; use std::sync::Arc; -use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; +use arrow::{record_batch::RecordBatch}; use crate::{error::Result, logical_plan::schema::NaiveSchema}; diff --git a/src/physical_plan/projection.rs b/src/physical_plan/projection.rs index 5859c07..3b51075 100644 --- a/src/physical_plan/projection.rs +++ b/src/physical_plan/projection.rs @@ -12,7 +12,7 @@ use crate::error::Result; use crate::logical_plan::schema::NaiveSchema; use crate::physical_plan::PhysicalExprRef; use crate::physical_plan::PhysicalPlanRef; -use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; +use arrow::{record_batch::RecordBatch}; #[derive(Debug, Clone)] pub struct ProjectionPlan { input: PhysicalPlanRef, diff --git a/src/physical_plan/scan.rs b/src/physical_plan/scan.rs index f7d7250..9965c18 100644 --- a/src/physical_plan/scan.rs +++ b/src/physical_plan/scan.rs @@ -9,7 +9,7 @@ use std::sync::Arc; use crate::datasource::TableRef; use crate::error::Result; use crate::logical_plan::schema::NaiveSchema; -use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; +use arrow::{record_batch::RecordBatch}; use crate::physical_plan::PhysicalPlan; use crate::physical_plan::PhysicalPlanRef; diff --git a/src/physical_plan/selection.rs b/src/physical_plan/selection.rs index 6b3a865..a624691 100644 --- a/src/physical_plan/selection.rs +++ b/src/physical_plan/selection.rs @@ -16,7 +16,7 @@ use arrow::array::{ use arrow::record_batch::RecordBatch; use arrow::{ array::{Array, BooleanArray, BooleanBuilder}, - datatypes::{DataType, SchemaRef}, + datatypes::{DataType}, }; #[derive(Debug)] diff --git a/src/planner/mod.rs b/src/planner/mod.rs index 4ff3209..c13671c 100644 --- a/src/planner/mod.rs +++ b/src/planner/mod.rs @@ -7,11 +7,12 @@ * */ -use std::sync::Arc; -use arrow::datatypes::Schema; + + use crate::logical_plan::schema::NaiveSchema; +use crate::physical_plan::NestedLoopJoin; use crate::physical_plan::PhysicalBinaryExpr; use crate::physical_plan::PhysicalExprRef; use crate::physical_plan::PhysicalLimitPlan; @@ -55,8 +56,16 @@ impl QueryPlanner { let plan = Self::create_physical_plan(&limit.input)?; Ok(PhysicalLimitPlan::create(plan, limit.n)) } - LogicalPlan::Join(_join) => { - todo!() + LogicalPlan::Join(join) => { + let left = Self::create_physical_plan(&join.left)?; + let right = Self::create_physical_plan(&join.right)?; + Ok(NestedLoopJoin::new( + left, + right, + join.on.clone(), + join.join_type, + join.schema.clone(), + )) } LogicalPlan::Filter(filter) => { let predicate = Self::create_physical_expression(&filter.predicate, plan)?; diff --git a/src/sql/planner.rs b/src/sql/planner.rs index 76a0bfc..52468e9 100644 --- a/src/sql/planner.rs +++ b/src/sql/planner.rs @@ -9,7 +9,7 @@ use std::collections::HashSet; -use arrow::datatypes::SchemaRef; + use sqlparser::ast::{ BinaryOperator, Expr, Join, JoinConstraint, JoinOperator, OrderByExpr, SetExpr, Statement, TableWithJoins, @@ -20,7 +20,7 @@ use crate::error::ErrorCode; use crate::logical_plan::expression::{BinaryExpr, Column, LogicalExpr, Operator, ScalarValue}; use crate::logical_plan::literal::lit; use crate::logical_plan::plan::{JoinType, TableScan}; -use crate::logical_plan::schema::NaiveSchema; + use crate::{ catalog::Catalog, error::Result, @@ -144,10 +144,8 @@ impl<'a> SQLPlanner<'a> { let mut filters = vec![]; extract_join_keys(&expr, &mut keys, &mut filters); - let left_keys = - keys.iter().map(|pair| pair.0.clone()).collect(); - let right_keys = - keys.iter().map(|pair| pair.1.clone()).collect(); + let left_keys = keys.iter().map(|pair| pair.0.clone()).collect(); + let right_keys = keys.iter().map(|pair| pair.1.clone()).collect(); if filters.is_empty() { let join = @@ -232,12 +230,20 @@ impl<'a> SQLPlanner<'a> { let right_schema = right.schema(); let mut join_keys = vec![]; for (l, r) in &possible_join_keys { - if left_schema.field_with_unqualified_name(l.name.as_str()).is_ok() - && right_schema.field_with_unqualified_name(r.name.as_str()).is_ok() + if left_schema + .field_with_unqualified_name(l.name.as_str()) + .is_ok() + && right_schema + .field_with_unqualified_name(r.name.as_str()) + .is_ok() { join_keys.push((l.clone(), r.clone())); - } else if left_schema.field_with_unqualified_name(r.name.as_str()).is_ok() - && right_schema.field_with_unqualified_name(l.name.as_str()).is_ok() + } else if left_schema + .field_with_unqualified_name(r.name.as_str()) + .is_ok() + && right_schema + .field_with_unqualified_name(l.name.as_str()) + .is_ok() { join_keys.push((r.clone(), l.clone())); } From aa5b2d49fc2648894980fc8261148271b75c6b21 Mon Sep 17 00:00:00 2001 From: Veeupup Date: Wed, 18 May 2022 17:37:15 +0800 Subject: [PATCH 12/13] fix test Signed-off-by: Veeupup --- src/datasource/empty.rs | 2 +- src/datasource/memory.rs | 6 +++--- src/logical_plan/dataframe.rs | 18 +++++++++--------- src/logical_plan/plan.rs | 4 ++-- src/physical_plan/projection.rs | 5 ++--- src/physical_plan/selection.rs | 5 ++--- src/planner/mod.rs | 1 + 7 files changed, 20 insertions(+), 21 deletions(-) diff --git a/src/datasource/empty.rs b/src/datasource/empty.rs index 7106d87..30aff83 100644 --- a/src/datasource/empty.rs +++ b/src/datasource/empty.rs @@ -39,7 +39,7 @@ impl TableSource for EmptyTable { mod tests { use super::*; use arrow::datatypes::{DataType, Field, Schema}; - use std::sync::Arc; + #[test] fn test_empty_table() -> Result<()> { diff --git a/src/datasource/memory.rs b/src/datasource/memory.rs index 686ff85..a1a6060 100644 --- a/src/datasource/memory.rs +++ b/src/datasource/memory.rs @@ -64,7 +64,7 @@ mod tests { let schema = NaiveSchema::from_qualified("t1", &schema); let batch = RecordBatch::try_new( - schema.clone(), + schema.clone().into(), vec![ Arc::new(Int32Array::from(vec![1, 2, 3])), Arc::new(Int32Array::from(vec![4, 5, 6])), @@ -80,8 +80,8 @@ mod tests { let batch2 = &batches[0]; assert_eq!(2, batch2.schema().fields().len()); - assert_eq!("c", batch2.schema().field(0).name()); - assert_eq!("b", batch2.schema().field(1).name()); + assert_eq!("t1.c", batch2.schema().field(0).name()); + assert_eq!("t1.b", batch2.schema().field(1).name()); assert_eq!(2, batch2.num_columns()); Ok(()) diff --git a/src/logical_plan/dataframe.rs b/src/logical_plan/dataframe.rs index 3c66362..d085768 100644 --- a/src/logical_plan/dataframe.rs +++ b/src/logical_plan/dataframe.rs @@ -121,21 +121,21 @@ impl DataFrame { mod tests { use super::*; - use crate::catalog::Catalog; + use crate::{catalog::Catalog, logical_plan::schema::NaiveField}; use crate::error::Result; use crate::logical_plan::expression::*; - use arrow::datatypes::{DataType, Field, Schema}; + use arrow::datatypes::{DataType}; #[test] fn create_logical_plan() -> Result<()> { - let schema = Arc::new(Schema::new(vec![ - Field::new("state", DataType::Int64, true), - Field::new("id", DataType::Int64, true), - Field::new("first_name", DataType::Utf8, true), - Field::new("last_name", DataType::Utf8, true), - Field::new("salary", DataType::Int64, true), - ])); + let schema = NaiveSchema::new(vec![ + NaiveField::new(None, "state", DataType::Int64, true), + NaiveField::new(None, "id", DataType::Int64, true), + NaiveField::new(None, "first_name", DataType::Utf8, true), + NaiveField::new(None, "last_name", DataType::Utf8, true), + NaiveField::new(None, "salary", DataType::Int64, true), + ]); let mut catalog = Catalog::default(); catalog.add_empty_table("empty", schema)?; diff --git a/src/logical_plan/plan.rs b/src/logical_plan/plan.rs index fcd7d35..744d120 100644 --- a/src/logical_plan/plan.rs +++ b/src/logical_plan/plan.rs @@ -144,12 +144,12 @@ mod tests { use crate::error::Result; use crate::logical_plan::expression::*; - use arrow::datatypes::{DataType, Field, Schema}; + /// Create LogicalPlan #[test] fn create_logical_plan() -> Result<()> { - let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let schema = NaiveSchema::empty(); let source = EmptyTable::try_create(schema)?; let scan = LogicalPlan::TableScan(TableScan { diff --git a/src/physical_plan/projection.rs b/src/physical_plan/projection.rs index 3b51075..fd5b2da 100644 --- a/src/physical_plan/projection.rs +++ b/src/physical_plan/projection.rs @@ -75,16 +75,15 @@ mod tests { use crate::physical_plan::scan::ScanPlan; use arrow::{ array::{Array, ArrayRef, Int64Array, StringArray}, - datatypes::Schema, }; #[test] fn test_projection() -> Result<()> { let source = CsvTable::try_create("data/test_data.csv", CsvConfig::default())?; - let schema = Arc::new(Schema::new(vec![ + let schema = NaiveSchema::new(vec![ source.schema().field(0).clone(), source.schema().field(1).clone(), - ])); + ]); let scan_plan = ScanPlan::create(source, None); let expr = vec![ diff --git a/src/physical_plan/selection.rs b/src/physical_plan/selection.rs index a624691..53d9427 100644 --- a/src/physical_plan/selection.rs +++ b/src/physical_plan/selection.rs @@ -122,17 +122,16 @@ mod tests { use crate::print_result; use arrow::{ array::{Array, ArrayRef, Int64Array, StringArray}, - datatypes::Schema, }; #[test] fn test_selection() -> Result<()> { let source = CsvTable::try_create("data/test_data.csv", CsvConfig::default())?; - let schema = Arc::new(Schema::new(vec![ + let schema = NaiveSchema::new(vec![ source.schema().field(0).clone(), source.schema().field(1).clone(), source.schema().field(2).clone(), - ])); + ]); let scan_plan = ScanPlan::create(source, None); let expr = vec![ diff --git a/src/planner/mod.rs b/src/planner/mod.rs index c13671c..6a4040d 100644 --- a/src/planner/mod.rs +++ b/src/planner/mod.rs @@ -115,6 +115,7 @@ impl QueryPlanner { #[cfg(test)] mod tests { + use std::sync::Arc; use arrow::array::ArrayRef; use arrow::array::Int64Array; use arrow::array::StringArray; From b9a9b80406729103e65db67f2ab972aa3eec3dac Mon Sep 17 00:00:00 2001 From: Veeupup Date: Wed, 18 May 2022 17:42:28 +0800 Subject: [PATCH 13/13] update readme Signed-off-by: Veeupup --- README.md | 31 ++++++++++++++++++++++++++++--- src/main.rs | 2 +- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 0578c82..f21188d 100644 --- a/README.md +++ b/README.md @@ -26,8 +26,18 @@ fn main() -> Result<()> { print_result(&ret)?; + // Join + db.create_csv_table("employee", "data/employee.csv")?; + db.create_csv_table("rank", "data/rank.csv")?; + + let ret = db.run_sql( + "select id, name, rank_name from employee innner join rank on employee.rank = rank.id", + )?; + + print_result(&ret); Ok(()) } + ``` output will be: @@ -40,6 +50,15 @@ output will be: | 2 | alex | 120 | | 4 | lynne | 118 | +----+---------+-----------+ ++----+-------+-------------+ +| id | name | rank_name | ++----+-------+-------------+ +| 1 | vee | diamond | +| 2 | lynne | master | +| 3 | Alex | master | +| 4 | jack | diamond | +| 5 | mike | grandmaster | ++----+-------+-------------+ ``` ## architecture @@ -79,12 +98,16 @@ impl NaiveDB { - [x] filter - [x] aggregate - [x] limit - - [ ] join and more... + - [x] join - [x] physical plan & expressions - [x] physical scan - [x] physical projection - [x] physical filter - [x] physical limit + - [x] join + - [x] (dumb😊) nested loop join + - [ ] hash join + - [ ] sort-merge join - [ ] physical expression - [x] column expr - [x] binary operation expr(add/sub/mul/div/and/or...) @@ -93,8 +116,8 @@ impl NaiveDB { - [ ] query planner - [x] scan - [x] limit + - [x] join - [ ] aggregate - - [ ] join - [ ] ... - [ ] query optimization - [ ] more rules needed @@ -105,4 +128,6 @@ impl NaiveDB { - [x] projection - [x] selection - [x] limit - - [ ] join and more... + - [x] join + - [ ] aggregate + - [ ] scalar function diff --git a/src/main.rs b/src/main.rs index 196e11f..9d4e72f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -11,7 +11,7 @@ fn main() -> Result<()> { print_result(&ret)?; - // Ok(()) + // Join db.create_csv_table("employee", "data/employee.csv")?; db.create_csv_table("rank", "data/rank.csv")?;