diff --git a/README.md b/README.md index 9e796cf..f21188d 100644 --- a/README.md +++ b/README.md @@ -20,14 +20,24 @@ use naive_db::Result; fn main() -> Result<()> { let mut db = NaiveDB::default(); - db.create_csv_table("t1", "test_data.csv")?; + db.create_csv_table("t1", "data/test_data.csv")?; let ret = db.run_sql("select id, name, age + 100 from t1 where id < 6 limit 3")?; print_result(&ret)?; + // Join + db.create_csv_table("employee", "data/employee.csv")?; + db.create_csv_table("rank", "data/rank.csv")?; + + let ret = db.run_sql( + "select id, name, rank_name from employee innner join rank on employee.rank = rank.id", + )?; + + print_result(&ret); Ok(()) } + ``` output will be: @@ -40,6 +50,15 @@ output will be: | 2 | alex | 120 | | 4 | lynne | 118 | +----+---------+-----------+ ++----+-------+-------------+ +| id | name | rank_name | ++----+-------+-------------+ +| 1 | vee | diamond | +| 2 | lynne | master | +| 3 | Alex | master | +| 4 | jack | diamond | +| 5 | mike | grandmaster | ++----+-------+-------------+ ``` ## architecture @@ -79,12 +98,16 @@ impl NaiveDB { - [x] filter - [x] aggregate - [x] limit - - [ ] join and more... + - [x] join - [x] physical plan & expressions - [x] physical scan - [x] physical projection - [x] physical filter - [x] physical limit + - [x] join + - [x] (dumb😊) nested loop join + - [ ] hash join + - [ ] sort-merge join - [ ] physical expression - [x] column expr - [x] binary operation expr(add/sub/mul/div/and/or...) @@ -93,8 +116,8 @@ impl NaiveDB { - [ ] query planner - [x] scan - [x] limit + - [x] join - [ ] aggregate - - [ ] join - [ ] ... - [ ] query optimization - [ ] more rules needed @@ -105,4 +128,6 @@ impl NaiveDB { - [x] projection - [x] selection - [x] limit - - [ ] join and more... + - [x] join + - [ ] aggregate + - [ ] scalar function diff --git a/data/department.csv b/data/department.csv new file mode 100644 index 0000000..2c6c7e0 --- /dev/null +++ b/data/department.csv @@ -0,0 +1,4 @@ +id,info +1,IT +2,Marketing +3,Human Resource \ No newline at end of file diff --git a/data/employee.csv b/data/employee.csv new file mode 100644 index 0000000..119e883 --- /dev/null +++ b/data/employee.csv @@ -0,0 +1,6 @@ +id,name,department_id,rank +1,vee,1,1 +2,lynne,1,0 +3,Alex,2,0 +4,jack,2,1 +5,mike,3,2 \ No newline at end of file diff --git a/data/rank.csv b/data/rank.csv new file mode 100644 index 0000000..7976157 --- /dev/null +++ b/data/rank.csv @@ -0,0 +1,4 @@ +id,rank_name +0,master +1,diamond +2,grandmaster \ No newline at end of file diff --git a/test_data.csv b/data/test_data.csv similarity index 100% rename from test_data.csv rename to data/test_data.csv diff --git a/src/catalog.rs b/src/catalog.rs index 6ce58cc..bdfa74f 100644 --- a/src/catalog.rs +++ b/src/catalog.rs @@ -6,11 +6,12 @@ use std::collections::HashMap; -use arrow::datatypes::SchemaRef; + use crate::datasource::{EmptyTable, MemTable}; use crate::error::ErrorCode; use crate::logical_plan::plan::{LogicalPlan, TableScan}; +use crate::logical_plan::schema::NaiveSchema; use crate::logical_plan::DataFrame; use crate::{ datasource::{CsvConfig, CsvTable, TableRef}, @@ -35,7 +36,7 @@ impl Catalog { pub fn add_memory_table( &mut self, table: &str, - schema: SchemaRef, + schema: NaiveSchema, batches: Vec, ) -> Result<()> { let source = MemTable::try_create(schema, batches)?; @@ -44,7 +45,7 @@ impl Catalog { } /// add empty table - pub fn add_empty_table(&mut self, table: &str, schema: SchemaRef) -> Result<()> { + pub fn add_empty_table(&mut self, table: &str, schema: NaiveSchema) -> Result<()> { let source = EmptyTable::try_create(schema)?; self.tables.insert(table.to_string(), source); Ok(()) diff --git a/src/datasource/csv.rs b/src/datasource/csv.rs index 8bc2814..9e2a8a7 100644 --- a/src/datasource/csv.rs +++ b/src/datasource/csv.rs @@ -11,9 +11,11 @@ use std::path::Path; use std::sync::Arc; use crate::error::Result; +use crate::logical_plan::schema::NaiveSchema; use arrow::csv; -use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; +use arrow::datatypes::Schema; +use arrow::{record_batch::RecordBatch}; use super::TableSource; use crate::datasource::TableRef; @@ -42,19 +44,20 @@ impl Default for CsvConfig { #[derive(Debug, Clone)] pub struct CsvTable { - schema: SchemaRef, + schema: NaiveSchema, batches: Vec, } impl CsvTable { #[allow(unused, clippy::iter_next_loop)] pub fn try_create(filename: &str, csv_config: CsvConfig) -> Result { - let schema = Self::infer_schema_from_csv(filename, &csv_config)?; + let orig_schema = Self::infer_schema_from_csv(filename, &csv_config)?; + let schema = NaiveSchema::from_unqualified(&orig_schema); let mut file = File::open(env::current_dir()?.join(Path::new(filename)))?; let mut reader = csv::Reader::new( file, - Arc::clone(&schema), + Arc::new(orig_schema), csv_config.has_header, Some(csv_config.delimiter), csv_config.batch_size, @@ -71,7 +74,7 @@ impl CsvTable { Ok(Arc::new(Self { schema, batches })) } - fn infer_schema_from_csv(filename: &str, csv_config: &CsvConfig) -> Result { + fn infer_schema_from_csv(filename: &str, csv_config: &CsvConfig) -> Result { let mut file = File::open(env::current_dir()?.join(Path::new(filename)))?; let (schema, _) = arrow::csv::reader::infer_reader_schema( &mut file, @@ -79,13 +82,13 @@ impl CsvTable { csv_config.max_read_records, csv_config.has_header, )?; - Ok(Arc::new(schema)) + Ok(schema) } } impl TableSource for CsvTable { - fn schema(&self) -> SchemaRef { - self.schema.clone() + fn schema(&self) -> &NaiveSchema { + &self.schema } fn scan(&self, _projection: Option>) -> Result> { @@ -103,7 +106,7 @@ mod tests { #[test] fn test_infer_schema() -> Result<()> { - let table = CsvTable::try_create("test_data.csv", CsvConfig::default())?; + let table = CsvTable::try_create("data/test_data.csv", CsvConfig::default())?; let schema = table.schema(); let excepted = Arc::new(Schema::new(vec![ @@ -127,7 +130,7 @@ mod tests { #[test] fn test_read_from_csv() -> Result<()> { - let table = CsvTable::try_create("test_data.csv", CsvConfig::default())?; + let table = CsvTable::try_create("data/test_data.csv", CsvConfig::default())?; let batches = table.scan(None)?; diff --git a/src/datasource/empty.rs b/src/datasource/empty.rs index e282bba..30aff83 100644 --- a/src/datasource/empty.rs +++ b/src/datasource/empty.rs @@ -7,26 +7,27 @@ use super::TableSource; use crate::datasource::TableRef; use crate::error::Result; -use arrow::datatypes::SchemaRef; +use crate::logical_plan::schema::NaiveSchema; + use arrow::record_batch::RecordBatch; use std::sync::Arc; /// Empty Table with schema but no data #[derive(Debug, Clone)] pub struct EmptyTable { - schema: SchemaRef, + schema: NaiveSchema, } impl EmptyTable { #[allow(unused)] - pub fn try_create(schema: SchemaRef) -> Result { + pub fn try_create(schema: NaiveSchema) -> Result { Ok(Arc::new(Self { schema })) } } impl TableSource for EmptyTable { - fn schema(&self) -> SchemaRef { - self.schema.clone() + fn schema(&self) -> &NaiveSchema { + &self.schema } fn scan(&self, _projection: Option>) -> Result> { @@ -38,14 +39,15 @@ impl TableSource for EmptyTable { mod tests { use super::*; use arrow::datatypes::{DataType, Field, Schema}; - use std::sync::Arc; + #[test] fn test_empty_table() -> Result<()> { - let schema = Arc::new(Schema::new(vec![ + let schema = Schema::new(vec![ Field::new("a", DataType::Int32, false), Field::new("b", DataType::Int32, false), - ])); + ]); + let schema = NaiveSchema::from_qualified("t1", &schema); let table = EmptyTable::try_create(schema)?; let batches = table.scan(None)?; diff --git a/src/datasource/memory.rs b/src/datasource/memory.rs index f631d01..a1a6060 100644 --- a/src/datasource/memory.rs +++ b/src/datasource/memory.rs @@ -4,29 +4,29 @@ * @Email: code@tanweime.com */ -use arrow::datatypes::SchemaRef; + use arrow::record_batch::RecordBatch; use std::sync::Arc; use super::{TableRef, TableSource}; -use crate::error::Result; +use crate::{error::Result, logical_plan::schema::NaiveSchema}; #[derive(Debug, Clone)] pub struct MemTable { - schema: SchemaRef, + schema: NaiveSchema, batches: Vec, } impl MemTable { #[allow(unused)] - pub fn try_create(schema: SchemaRef, batches: Vec) -> Result { + pub fn try_create(schema: NaiveSchema, batches: Vec) -> Result { Ok(Arc::new(Self { schema, batches })) } } impl TableSource for MemTable { - fn schema(&self) -> SchemaRef { - self.schema.clone() + fn schema(&self) -> &NaiveSchema { + &self.schema } fn scan(&self, projection: Option>) -> Result> { @@ -47,6 +47,7 @@ mod tests { use super::MemTable; use crate::datasource::TableSource; use crate::error::Result; + use crate::logical_plan::schema::NaiveSchema; use arrow::array::Int32Array; use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; @@ -60,9 +61,10 @@ mod tests { Field::new("c", DataType::Int32, false), Field::new("d", DataType::Int32, true), ])); + let schema = NaiveSchema::from_qualified("t1", &schema); let batch = RecordBatch::try_new( - schema.clone(), + schema.clone().into(), vec![ Arc::new(Int32Array::from(vec![1, 2, 3])), Arc::new(Int32Array::from(vec![4, 5, 6])), @@ -78,8 +80,8 @@ mod tests { let batch2 = &batches[0]; assert_eq!(2, batch2.schema().fields().len()); - assert_eq!("c", batch2.schema().field(0).name()); - assert_eq!("b", batch2.schema().field(1).name()); + assert_eq!("t1.c", batch2.schema().field(0).name()); + assert_eq!("t1.b", batch2.schema().field(1).name()); assert_eq!(2, batch2.num_columns()); Ok(()) diff --git a/src/datasource/mod.rs b/src/datasource/mod.rs index 4840426..8dbcfd5 100644 --- a/src/datasource/mod.rs +++ b/src/datasource/mod.rs @@ -12,12 +12,13 @@ use std::fmt::Debug; use std::sync::Arc; use crate::error::Result; -use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; +use crate::logical_plan::schema::NaiveSchema; +use arrow::{record_batch::RecordBatch}; pub type TableRef = Arc; pub trait TableSource: Debug { - fn schema(&self) -> SchemaRef; + fn schema(&self) -> &NaiveSchema; // TODO(veeupup): return async stream record batch /// for scan diff --git a/src/error.rs b/src/error.rs index 1addd83..d9a1ee4 100644 --- a/src/error.rs +++ b/src/error.rs @@ -30,6 +30,7 @@ pub enum ErrorCode { PlanError(String), + NotImplemented, #[allow(unused)] Others, } diff --git a/src/logical_plan/dataframe.rs b/src/logical_plan/dataframe.rs index 7230e68..d085768 100644 --- a/src/logical_plan/dataframe.rs +++ b/src/logical_plan/dataframe.rs @@ -6,12 +6,15 @@ use std::sync::Arc; -use arrow::datatypes::{Schema, SchemaRef}; + use crate::logical_plan::expression::LogicalExpr; use crate::logical_plan::plan::{Aggregate, Filter, LogicalPlan, Projection}; -use super::plan::Limit; +use super::expression::Column; +use super::plan::{Join, JoinType, Limit}; +use super::schema::NaiveSchema; +use crate::error::{ErrorCode, Result}; #[derive(Clone)] pub struct DataFrame { @@ -19,12 +22,17 @@ pub struct DataFrame { } impl DataFrame { + pub fn new(plan: LogicalPlan) -> Self { + Self { plan } + } + pub fn project(self, exprs: Vec) -> Self { + // TODO(veeupup): Ambiguous reference of field let fields = exprs .iter() .map(|expr| expr.data_field(&self.plan).unwrap()) .collect::>(); - let schema = Arc::new(Schema::new(fields)); + let schema = NaiveSchema::new(fields); Self { plan: LogicalPlan::Projection(Projection { input: Arc::new(self.plan), @@ -53,7 +61,7 @@ impl DataFrame { .map(|expr| expr.data_field(&self.plan).unwrap()) .collect::>(); group_fields.append(&mut aggr_fields); - let schema = Arc::new(Schema::new(group_fields)); + let schema = NaiveSchema::new(group_fields); Self { plan: LogicalPlan::Aggregate(Aggregate { input: Arc::new(self.plan), @@ -73,7 +81,34 @@ impl DataFrame { } } - pub fn schema(&self) -> SchemaRef { + pub fn join( + &self, + right: &LogicalPlan, + join_type: JoinType, + join_keys: (Vec, Vec), + ) -> Result { + if join_keys.0.len() != join_keys.1.len() { + return Err(ErrorCode::PlanError( + "left_keys length must be equal to right_keys length".to_string(), + )); + } + + let (left_keys, right_keys) = join_keys; + let on: Vec<(_, _)> = left_keys.into_iter().zip(right_keys.into_iter()).collect(); + + let left_schema = self.plan.schema(); + let join_schema = left_schema.join(right.schema()); + + Ok(Self::new(LogicalPlan::Join(Join { + left: Arc::new(self.plan.clone()), + right: Arc::new(right.clone()), + on, + join_type, + schema: join_schema, + }))) + } + + pub fn schema(&self) -> &NaiveSchema { self.plan.schema() } @@ -86,39 +121,39 @@ impl DataFrame { mod tests { use super::*; - use crate::catalog::Catalog; + use crate::{catalog::Catalog, logical_plan::schema::NaiveField}; use crate::error::Result; use crate::logical_plan::expression::*; - use arrow::datatypes::{DataType, Field, Schema}; + use arrow::datatypes::{DataType}; #[test] fn create_logical_plan() -> Result<()> { - let schema = Arc::new(Schema::new(vec![ - Field::new("state", DataType::Int64, true), - Field::new("id", DataType::Int64, true), - Field::new("first_name", DataType::Utf8, true), - Field::new("last_name", DataType::Utf8, true), - Field::new("salary", DataType::Int64, true), - ])); + let schema = NaiveSchema::new(vec![ + NaiveField::new(None, "state", DataType::Int64, true), + NaiveField::new(None, "id", DataType::Int64, true), + NaiveField::new(None, "first_name", DataType::Utf8, true), + NaiveField::new(None, "last_name", DataType::Utf8, true), + NaiveField::new(None, "salary", DataType::Int64, true), + ]); let mut catalog = Catalog::default(); catalog.add_empty_table("empty", schema)?; let _plan = catalog .get_table_df("empty")? .filter(LogicalExpr::BinaryExpr(BinaryExpr { - left: Box::new(LogicalExpr::column("state".to_string())), + left: Box::new(LogicalExpr::column(None, "state".to_string())), op: Operator::Eq, right: Box::new(LogicalExpr::Literal(ScalarValue::Utf8(Some( "CO".to_string(), )))), })) .project(vec![ - LogicalExpr::column("id".to_string()), - LogicalExpr::column("first_name".to_string()), - LogicalExpr::column("last_name".to_string()), - LogicalExpr::column("state".to_string()), - LogicalExpr::column("salary".to_string()), + LogicalExpr::column(None, "id".to_string()), + LogicalExpr::column(None, "first_name".to_string()), + LogicalExpr::column(None, "last_name".to_string()), + LogicalExpr::column(None, "state".to_string()), + LogicalExpr::column(None, "salary".to_string()), ]); Ok(()) diff --git a/src/logical_plan/expression.rs b/src/logical_plan/expression.rs index 8a8d67a..aeebf5c 100644 --- a/src/logical_plan/expression.rs +++ b/src/logical_plan/expression.rs @@ -9,14 +9,16 @@ use std::iter::repeat; use arrow::array::StringArray; use arrow::array::{new_null_array, ArrayRef, BooleanArray, Float64Array, Int64Array, UInt64Array}; -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::{DataType}; use std::sync::Arc; -use crate::error::ErrorCode; + use crate::error::Result; use crate::logical_plan::plan::LogicalPlan; +use super::schema::{NaiveField}; + #[derive(Clone, Debug)] pub enum LogicalExpr { /// An expression with a specific name. @@ -47,37 +49,36 @@ pub enum LogicalExpr { } impl LogicalExpr { - pub fn column(name: String) -> LogicalExpr { - LogicalExpr::Column(Column(name)) + pub fn column(table: Option, name: String) -> LogicalExpr { + LogicalExpr::Column(Column { table, name }) } /// TODO(veeupup): consider return Vec - pub fn data_field(&self, input: &LogicalPlan) -> Result { + pub fn data_field(&self, input: &LogicalPlan) -> Result { match self { LogicalExpr::Alias(expr, alias) => { let field = expr.data_field(input)?; - Ok(Field::new( + Ok(NaiveField::new( + None, alias, field.data_type().clone(), field.is_nullable(), )) } - LogicalExpr::Column(col) => { - for field in input.schema().fields() { - if field.name() == col.0.as_str() { - return Ok(field.clone()); - } - } - Err(ErrorCode::NoSuchField) - } + LogicalExpr::Column(Column { name, table }) => match table { + Some(table) => input.schema().field_with_qualified_name(table, name), + None => input.schema().field_with_unqualified_name(name), + }, LogicalExpr::Literal(scalar_val) => Ok(scalar_val.data_field()), LogicalExpr::BinaryExpr(expr) => expr.data_field(input), - LogicalExpr::Not(expr) => Ok(Field::new( + LogicalExpr::Not(expr) => Ok(NaiveField::new( + None, format!("Not {}", expr.data_field(input)?.name()).as_str(), DataType::Boolean, true, )), - LogicalExpr::Cast { expr, data_type } => Ok(Field::new( + LogicalExpr::Cast { expr, data_type } => Ok(NaiveField::new( + None, expr.data_field(input)?.name(), data_type.clone(), true, @@ -87,11 +88,27 @@ impl LogicalExpr { // LogicalExpression::Wildcard => , } } + + pub fn and(self, other: LogicalExpr) -> LogicalExpr { + binary_expr(self, Operator::And, other) + } +} + +/// return a new expression l r +pub fn binary_expr(l: LogicalExpr, op: Operator, r: LogicalExpr) -> LogicalExpr { + LogicalExpr::BinaryExpr(BinaryExpr { + left: Box::new(l), + op, + right: Box::new(r), + }) } /// A named reference to a qualified field in a schema. #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -pub struct Column(pub String); +pub struct Column { + pub table: Option, + pub name: String, +} #[derive(Debug, Clone)] @@ -120,14 +137,14 @@ macro_rules! build_array_from_option { } impl ScalarValue { - pub fn data_field(&self) -> Field { + pub fn data_field(&self) -> NaiveField { match self { - ScalarValue::Null => Field::new("Null", DataType::Null, true), - ScalarValue::Boolean(_) => Field::new("bool", DataType::Boolean, true), - ScalarValue::Float64(_) => Field::new("f64", DataType::Float64, true), - ScalarValue::Int64(_) => Field::new("i64", DataType::Int64, true), - ScalarValue::UInt64(_) => Field::new("u64", DataType::UInt64, true), - ScalarValue::Utf8(_) => Field::new("string", DataType::Utf8, true), + ScalarValue::Null => NaiveField::new(None, "Null", DataType::Null, true), + ScalarValue::Boolean(_) => NaiveField::new(None, "bool", DataType::Boolean, true), + ScalarValue::Float64(_) => NaiveField::new(None, "f64", DataType::Float64, true), + ScalarValue::Int64(_) => NaiveField::new(None, "i64", DataType::Int64, true), + ScalarValue::UInt64(_) => NaiveField::new(None, "u64", DataType::UInt64, true), + ScalarValue::Utf8(_) => NaiveField::new(None, "string", DataType::Utf8, true), } } @@ -157,7 +174,7 @@ pub struct BinaryExpr { } impl BinaryExpr { - pub fn data_field(&self, input: &LogicalPlan) -> Result { + pub fn data_field(&self, input: &LogicalPlan) -> Result { let left = self.left.data_field(input)?; let left = left.name(); let right = match &*self.right { @@ -172,67 +189,80 @@ impl BinaryExpr { _ => self.right.data_field(input)?.name().clone(), }; let field = match self.op { - Operator::Eq => Field::new( + Operator::Eq => NaiveField::new( + None, format!("{} = {}", left, right).as_str(), DataType::Boolean, true, ), - Operator::NotEq => Field::new( + Operator::NotEq => NaiveField::new( + None, format!("{} != {}", left, right).as_str(), DataType::Boolean, true, ), - Operator::Lt => Field::new( + Operator::Lt => NaiveField::new( + None, format!("{} < {}", left, right).as_str(), DataType::Boolean, true, ), - Operator::LtEq => Field::new( + Operator::LtEq => NaiveField::new( + None, format!("{} <= {}", left, right).as_str(), DataType::Boolean, true, ), - Operator::Gt => Field::new( + Operator::Gt => NaiveField::new( + None, format!("{} > {}", left, right).as_str(), DataType::Boolean, true, ), - Operator::GtEq => Field::new( + Operator::GtEq => NaiveField::new( + None, format!("{} >= {}", left, right).as_str(), DataType::Boolean, true, ), - Operator::Plus => Field::new( + Operator::Plus => NaiveField::new( + None, format!("{} + {}", left, right).as_str(), self.left.data_field(input)?.data_type().clone(), true, ), - Operator::Minus => Field::new( + Operator::Minus => NaiveField::new( + None, format!("{} - {}", left, right).as_str(), self.left.data_field(input)?.data_type().clone(), true, ), - Operator::Multiply => Field::new( + Operator::Multiply => NaiveField::new( + None, format!("{} * {}", left, right).as_str(), self.left.data_field(input)?.data_type().clone(), true, ), - Operator::Divide => Field::new( + Operator::Divide => NaiveField::new( + None, format!("{} / {}", left, right).as_str(), self.left.data_field(input)?.data_type().clone(), true, ), - Operator::Modulo => Field::new( + Operator::Modulo => NaiveField::new( + None, format!("{} % {}", left, right).as_str(), self.left.data_field(input)?.data_type().clone(), true, ), - Operator::And => Field::new( + Operator::And => NaiveField::new( + None, format!("{} and {}", left, right).as_str(), DataType::Boolean, true, ), - Operator::Or => Field::new( + Operator::Or => NaiveField::new( + None, format!("{} or {}", left, right).as_str(), DataType::Boolean, true, @@ -282,21 +312,24 @@ pub struct ScalarFunction { } impl ScalarFunction { - pub fn data_field(&self, input: &LogicalPlan) -> Result { + pub fn data_field(&self, input: &LogicalPlan) -> Result { // TODO(veeupup): we should make scalar func more specific and should check if valid before creating them let field = self.args[0].data_field(input)?; let field = match self.fun { - ScalarFunc::Abs => Field::new( + ScalarFunc::Abs => NaiveField::new( + None, format!("abs({})", field.name()).as_str(), DataType::Int64, true, ), - ScalarFunc::Add => Field::new( + ScalarFunc::Add => NaiveField::new( + None, format!("add({})", field.name()).as_str(), DataType::Int64, true, ), - ScalarFunc::Sub => Field::new( + ScalarFunc::Sub => NaiveField::new( + None, format!("sub({})", field.name()).as_str(), DataType::Int64, true, @@ -323,30 +356,35 @@ pub struct AggregateFunction { } impl AggregateFunction { - pub fn data_field(&self, input: &LogicalPlan) -> Result { + pub fn data_field(&self, input: &LogicalPlan) -> Result { let dt = self.args.data_field(input)?; let field = match self.fun { - AggregateFunc::Count => Field::new( + AggregateFunc::Count => NaiveField::new( + None, format!("count({})", dt.name()).as_str(), dt.data_type().clone(), true, ), - AggregateFunc::Sum => Field::new( + AggregateFunc::Sum => NaiveField::new( + None, format!("sum({})", dt.name()).as_str(), dt.data_type().clone(), true, ), - AggregateFunc::Min => Field::new( + AggregateFunc::Min => NaiveField::new( + None, format!("min({})", dt.name()).as_str(), dt.data_type().clone(), true, ), - AggregateFunc::Max => Field::new( + AggregateFunc::Max => NaiveField::new( + None, format!("max({})", dt.name()).as_str(), dt.data_type().clone(), true, ), - AggregateFunc::Avg => Field::new( + AggregateFunc::Avg => NaiveField::new( + None, format!("avg({})", dt.name()).as_str(), dt.data_type().clone(), true, diff --git a/src/logical_plan/mod.rs b/src/logical_plan/mod.rs index 1ed95ae..e437ced 100644 --- a/src/logical_plan/mod.rs +++ b/src/logical_plan/mod.rs @@ -13,5 +13,6 @@ mod dataframe; pub mod expression; pub mod literal; pub mod plan; +pub mod schema; pub use dataframe::DataFrame; diff --git a/src/logical_plan/plan.rs b/src/logical_plan/plan.rs index 2dc96e4..744d120 100644 --- a/src/logical_plan/plan.rs +++ b/src/logical_plan/plan.rs @@ -6,10 +6,12 @@ use crate::datasource::TableRef; use crate::logical_plan::expression::{Column, LogicalExpr}; -use arrow::datatypes::SchemaRef; + use std::sync::Arc; -#[derive(Clone)] +use super::schema::NaiveSchema; + +#[derive(Debug, Clone)] pub enum LogicalPlan { /// Evaluates an arbitrary list of expressions (essentially a /// SELECT with an expression list) on its input. @@ -40,12 +42,12 @@ pub enum LogicalPlan { } impl LogicalPlan { - pub fn schema(&self) -> SchemaRef { + pub fn schema(&self) -> &NaiveSchema { match self { - LogicalPlan::Projection(Projection { schema, .. }) => schema.clone(), + LogicalPlan::Projection(Projection { schema, .. }) => schema, LogicalPlan::Filter(Filter { input, .. }) => input.schema(), - LogicalPlan::Aggregate(Aggregate { schema, .. }) => schema.clone(), - LogicalPlan::Join(Join { schema, .. }) => schema.clone(), + LogicalPlan::Aggregate(Aggregate { schema, .. }) => schema, + LogicalPlan::Join(Join { schema, .. }) => schema, LogicalPlan::Limit(Limit { input, .. }) => input.schema(), LogicalPlan::TableScan(TableScan { source, .. }) => source.schema(), } @@ -63,17 +65,17 @@ impl LogicalPlan { } } -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct Projection { /// The list of expressions pub exprs: Vec, /// The incoming logical plan pub input: Arc, /// The schema description of the output - pub schema: SchemaRef, + pub schema: NaiveSchema, } -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct Filter { /// The predicate expression, which must have Boolean type. pub predicate: LogicalExpr, @@ -81,7 +83,7 @@ pub struct Filter { pub input: Arc, } -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct TableScan { /// The source of the table pub source: TableRef, @@ -91,7 +93,7 @@ pub struct TableScan { /// Aggregates its input based on a set of grouping and aggregate /// expressions (e.g. SUM). -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct Aggregate { /// The incoming logical plan pub input: Arc, @@ -100,10 +102,10 @@ pub struct Aggregate { /// Aggregate expressions pub aggr_expr: Vec, /// The schema description of the aggregate output - pub schema: SchemaRef, + pub schema: NaiveSchema, } -#[derive(Clone)] +#[derive(Debug, Clone, Copy, PartialEq)] pub enum JoinType { Inner, Left, @@ -111,7 +113,7 @@ pub enum JoinType { } /// Join two logical plans on one or more join columns -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct Join { /// Left input pub left: Arc, @@ -122,10 +124,10 @@ pub struct Join { /// Join type pub join_type: JoinType, /// The output schema, containing fields from the left and right inputs - pub schema: SchemaRef, + pub schema: NaiveSchema, } -#[derive(Clone)] +#[derive(Debug, Clone)] /// Produces the first `n` tuples from its input and discards the rest. pub struct Limit { @@ -142,12 +144,12 @@ mod tests { use crate::error::Result; use crate::logical_plan::expression::*; - use arrow::datatypes::{DataType, Field, Schema}; + /// Create LogicalPlan #[test] fn create_logical_plan() -> Result<()> { - let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let schema = NaiveSchema::empty(); let source = EmptyTable::try_create(schema)?; let scan = LogicalPlan::TableScan(TableScan { @@ -156,7 +158,7 @@ mod tests { }); let filter_expr = LogicalExpr::BinaryExpr(BinaryExpr { - left: Box::new(LogicalExpr::column("state".to_string())), + left: Box::new(LogicalExpr::column(None, "state".to_string())), op: Operator::Eq, right: Box::new(LogicalExpr::Literal(ScalarValue::Utf8(Some( "CO".to_string(), @@ -169,11 +171,11 @@ mod tests { }); let _projection = vec![ - LogicalExpr::column("id".to_string()), - LogicalExpr::column("first_name".to_string()), - LogicalExpr::column("last_name".to_string()), - LogicalExpr::column("state".to_string()), - LogicalExpr::column("salary".to_string()), + LogicalExpr::column(None, "id".to_string()), + LogicalExpr::column(None, "first_name".to_string()), + LogicalExpr::column(None, "last_name".to_string()), + LogicalExpr::column(None, "state".to_string()), + LogicalExpr::column(None, "salary".to_string()), ]; Ok(()) diff --git a/src/logical_plan/schema.rs b/src/logical_plan/schema.rs new file mode 100644 index 0000000..18ff7d2 --- /dev/null +++ b/src/logical_plan/schema.rs @@ -0,0 +1,237 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/* + * @Author: Veeupup + * @Date: 2022-05-18 13:45:10 + * @Last Modified by: Veeupup + * @Last Modified time: 2022-05-18 17:30:21 + * + * Arrow Field does not have table/relation name as its proroties + * So we need a Schema to define inner schema with table name + * + * Code Ideas come from https://github.com/apache/arrow-datafusion/ + * + */ + + + + +use arrow::datatypes::{DataType, SchemaRef}; +use arrow::datatypes::{Field, Schema}; + +use crate::error::ErrorCode; +use crate::error::Result; + +#[derive(Debug, Clone)] +pub struct NaiveSchema { + pub fields: Vec, +} + +impl NaiveSchema { + pub fn empty() -> Self { + Self { fields: vec![] } + } + + pub fn new(fields: Vec) -> Self { + // TODO(veeupup): check if we have duplicated name field + Self { fields } + } + + pub fn from_qualified(qualifier: &str, schema: &Schema) -> Self { + Self::new( + schema + .fields() + .iter() + .map(|field| NaiveField { + field: field.clone(), + qualifier: Some(qualifier.to_owned()), + }) + .collect(), + ) + } + + pub fn from_unqualified(schema: &Schema) -> Self { + Self::new( + schema + .fields() + .iter() + .map(|field| NaiveField { + field: field.clone(), + qualifier: None, + }) + .collect(), + ) + } + + /// join two schema + pub fn join(&self, schema: &NaiveSchema) -> Self { + let mut fields = self.fields.clone(); + fields.extend_from_slice(schema.fields().as_slice()); + Self::new(fields) + } + + pub fn fields(&self) -> &Vec { + &self.fields + } + + pub fn field(&self, i: usize) -> &NaiveField { + &self.fields[i] + } + + pub fn index_of(&self, name: &str) -> Result { + for i in 0..self.fields().len() { + if self.fields[i].name() == name { + return Ok(i); + } + } + Err(ErrorCode::NoSuchField) + } + + /// Find the field with the given name + pub fn field_with_name(&self, relation_name: Option<&str>, name: &str) -> Result { + if let Some(relation_name) = relation_name { + self.field_with_qualified_name(relation_name, name) + } else { + self.field_with_unqualified_name(name) + } + } + + pub fn field_with_unqualified_name(&self, name: &str) -> Result { + let matches = self + .fields + .iter() + .filter(|field| field.name() == name) + .collect::>(); + match matches.len() { + 0 => Err(ErrorCode::PlanError(format!("No field named '{}'", name))), + _ => Ok(matches[0].to_owned()), + // TODO(veeupup): multi same name, and we need to return Error + // _ => Err(ErrorCode::PlanError(format!( + // "Ambiguous reference to field named '{}'", + // name + // ))), + } + } + + pub fn field_with_qualified_name(&self, relation_name: &str, name: &str) -> Result { + let matches = self + .fields + .iter() + .filter(|field| { + field.qualifier == Some(relation_name.to_owned()) && field.name() == name + }) + .collect::>(); + match matches.len() { + 0 => Err(ErrorCode::PlanError(format!("No field named '{}'", name))), + _ => Ok(matches[0].to_owned()), + // TODO(veeupup): multi same name, and we need to return Error + // _ => Err(ErrorCode::PlanError(format!( + // "Ambiguous reference to field named '{}'", + // name + // ))), + } + } +} + +impl Into for NaiveSchema { + /// Convert a schema into a DFSchema + fn into(self) -> Schema { + Schema::new( + self.fields + .into_iter() + .map(|f| { + if f.qualifier().is_some() { + Field::new( + f.qualified_name().as_str(), + f.data_type().to_owned(), + f.is_nullable(), + ) + } else { + f.field + } + }) + .collect(), + ) + } +} + +impl Into for NaiveSchema { + fn into(self) -> SchemaRef { + SchemaRef::new(self.into()) + } +} + +/// NaiveField wraps an Arrow field and adds an optional qualifier +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct NaiveField { + /// Optional qualifier (usually a table or relation name) + qualifier: Option, + /// Arrow field definition + field: Field, +} + +impl NaiveField { + pub fn new(qualifier: Option<&str>, name: &str, data_type: DataType, nullable: bool) -> Self { + Self { + qualifier: qualifier.map(|s| s.to_owned()), + field: Field::new(name, data_type, nullable), + } + } + + pub fn from(field: Field) -> Self { + Self { + qualifier: None, + field, + } + } + + pub fn from_qualified(qualifier: &str, field: Field) -> Self { + Self { + qualifier: Some(qualifier.to_owned()), + field, + } + } + + pub fn name(&self) -> &String { + self.field.name() + } + + /// Returns an immutable reference to the `NaiveField`'s data-type + pub fn data_type(&self) -> &DataType { + &self.field.data_type() + } + + /// Indicates whether this `NaiveField` supports null values + pub fn is_nullable(&self) -> bool { + self.field.is_nullable() + } + + /// Returns a reference to the `NaiveField`'s qualified name + pub fn qualified_name(&self) -> String { + if let Some(relation_name) = &self.qualifier { + format!("{}.{}", relation_name, self.field.name()) + } else { + self.field.name().to_owned() + } + } + + /// Get the optional qualifier + pub fn qualifier(&self) -> Option<&String> { + self.qualifier.as_ref() + } +} diff --git a/src/main.rs b/src/main.rs index c15f9cd..9d4e72f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,11 +5,20 @@ use naive_db::Result; fn main() -> Result<()> { let mut db = NaiveDB::default(); - db.create_csv_table("t1", "test_data.csv")?; + db.create_csv_table("t1", "data/test_data.csv")?; let ret = db.run_sql("select id, name, age + 100 from t1 where id < 6 limit 3")?; print_result(&ret)?; + // Join + db.create_csv_table("employee", "data/employee.csv")?; + db.create_csv_table("rank", "data/rank.csv")?; + + let ret = db.run_sql( + "select id, name, rank_name from employee innner join rank on employee.rank = rank.id", + )?; + + print_result(&ret); Ok(()) } diff --git a/src/physical_plan/limit.rs b/src/physical_plan/limit.rs index 0a5cca5..893ad87 100644 --- a/src/physical_plan/limit.rs +++ b/src/physical_plan/limit.rs @@ -2,12 +2,13 @@ * @Author: Veeupup * @Date: 2022-05-17 11:27:29 * @Last Modified by: Veeupup - * @Last Modified time: 2022-05-17 11:57:13 + * @Last Modified time: 2022-05-18 14:45:03 */ use super::{PhysicalPlan, PhysicalPlanRef}; use crate::error::Result; -use arrow::datatypes::SchemaRef; +use crate::logical_plan::schema::NaiveSchema; + use arrow::record_batch::RecordBatch; use std::sync::Arc; @@ -24,7 +25,7 @@ impl PhysicalLimitPlan { } impl PhysicalPlan for PhysicalLimitPlan { - fn schema(&self) -> SchemaRef { + fn schema(&self) -> &NaiveSchema { self.input.schema() } @@ -64,7 +65,7 @@ mod tests { #[test] fn test_physical_scan() -> Result<()> { - let source = CsvTable::try_create("test_data.csv", CsvConfig::default())?; + let source = CsvTable::try_create("data/test_data.csv", CsvConfig::default())?; let scan_plan = ScanPlan::create(source, None); let limit_plan = PhysicalLimitPlan::create(scan_plan, 2); diff --git a/src/physical_plan/mod.rs b/src/physical_plan/mod.rs index 5a27a6c..ac0c537 100644 --- a/src/physical_plan/mod.rs +++ b/src/physical_plan/mod.rs @@ -8,12 +8,14 @@ mod expression; mod plan; mod limit; +mod nested_loop_join; mod projection; mod scan; mod selection; pub use expression::*; pub use limit::*; +pub use nested_loop_join::*; pub use plan::*; pub use projection::*; pub use scan::*; diff --git a/src/physical_plan/nested_loop_join.rs b/src/physical_plan/nested_loop_join.rs new file mode 100644 index 0000000..276bcd2 --- /dev/null +++ b/src/physical_plan/nested_loop_join.rs @@ -0,0 +1,147 @@ +/* + * @Author: Veeupup + * @Date: 2022-05-18 16:00:13 + * @Last Modified by: Veeupup + * @Last Modified time: 2022-05-18 17:28:58 + */ +use super::PhysicalPlan; +use super::PhysicalPlanRef; +use crate::error::ErrorCode; +use crate::logical_plan::expression::Column; +use crate::logical_plan::plan::JoinType; +use crate::logical_plan::schema::NaiveSchema; +use crate::physical_plan::ColumnExpr; +use crate::physical_plan::PhysicalExpr; + +use crate::Result; +use std::sync::Arc; + +use arrow::array::Array; +use arrow::array::Int64Builder; +use arrow::array::PrimitiveArray; +use arrow::compute; +use arrow::datatypes::DataType; +use arrow::datatypes::Int64Type; +use arrow::record_batch::RecordBatch; + +#[derive(Debug, Clone)] +pub struct NestedLoopJoin { + left: PhysicalPlanRef, + right: PhysicalPlanRef, + on: Vec<(Column, Column)>, + join_type: JoinType, + schema: NaiveSchema, +} + +impl NestedLoopJoin { + pub fn new( + left: PhysicalPlanRef, + right: PhysicalPlanRef, + on: Vec<(Column, Column)>, + join_type: JoinType, + schema: NaiveSchema, + ) -> PhysicalPlanRef { + Arc::new(Self { + left, + right, + on, + join_type, + schema, + }) + } +} + +impl PhysicalPlan for NestedLoopJoin { + fn schema(&self) -> &NaiveSchema { + &self.schema + } + + fn execute(&self) -> Result> { + let outer_table = self.left.execute()?; + let inner_table = self.right.execute()?; + + let mut batches: Vec = vec![]; + // TODO(veeupup): support multi on conditions + for (left_col, right_col) in &self.on { + let left_col = ColumnExpr::try_create(Some(left_col.name.clone()), None)?; + let right_col = ColumnExpr::try_create(Some(right_col.name.clone()), None)?; + + for outer in &outer_table { + let left_col = left_col.evaluate(outer)?.into_array(); + + let dt = left_col.data_type(); + for inner in &inner_table { + let right_col = right_col.evaluate(inner)?.into_array(); + + // check if ok + if left_col.data_type() != right_col.data_type() { + return Err(ErrorCode::PlanError(format!( + "Join on left and right data type should be same: left: {:?}, right: {:?}", + left_col.data_type(), + right_col.data_type() + ))); + } + + let mut outer_pos = Int64Builder::new(left_col.len()); + let mut inner_pos = Int64Builder::new(right_col.len()); + match dt { + DataType::Int64 => { + let left_col = left_col + .as_any() + .downcast_ref::>() + .unwrap(); + let right_col = right_col + .as_any() + .downcast_ref::>() + .unwrap(); + + for (x_pos, x) in left_col.iter().enumerate() { + for (y_pos, y) in right_col.iter().enumerate() { + match (x, y) { + (Some(x), Some(y)) => { + if x == y { + // equal and we should + outer_pos.append_value(x_pos as i64); + inner_pos.append_value(y_pos as i64); + } + } + _ => {} + } + } + } + } + DataType::UInt64 => todo!(), + DataType::Float64 => todo!(), + DataType::Utf8 => todo!(), + _ => unimplemented!(), + } + let mut columns = vec![]; + + let outer_pos = outer_pos.finish(); + let inner_pos = inner_pos.finish(); + + // add left columns + for i in 0..self.left.schema().fields().len() { + let array = outer.column(i); + columns.push(compute::take(array.as_ref(), &outer_pos, None)?); + } + + // add right columns + for i in 0..self.right.schema().fields().len() { + let array = inner.column(i); + columns.push(compute::take(array.as_ref(), &inner_pos, None)?); + } + + let batch = RecordBatch::try_new(self.schema.clone().into(), columns)?; + batches.push(batch); + } + } + } + + return Ok(batches); + } + + fn children(&self) -> Result> { + Ok(vec![self.left.clone(), self.right.clone()]) + } +} diff --git a/src/physical_plan/plan.rs b/src/physical_plan/plan.rs index aa916ac..0bbfbf0 100644 --- a/src/physical_plan/plan.rs +++ b/src/physical_plan/plan.rs @@ -7,12 +7,12 @@ use std::fmt::Debug; use std::sync::Arc; -use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; +use arrow::{record_batch::RecordBatch}; -use crate::error::Result; +use crate::{error::Result, logical_plan::schema::NaiveSchema}; pub trait PhysicalPlan: Debug { - fn schema(&self) -> SchemaRef; + fn schema(&self) -> &NaiveSchema; // TODO(veeupup): return by using streaming mode fn execute(&self) -> Result>; diff --git a/src/physical_plan/projection.rs b/src/physical_plan/projection.rs index 9bdcdbb..fd5b2da 100644 --- a/src/physical_plan/projection.rs +++ b/src/physical_plan/projection.rs @@ -9,20 +9,21 @@ use std::sync::Arc; use super::{expression::PhysicalExpr, plan::PhysicalPlan}; use crate::error::Result; +use crate::logical_plan::schema::NaiveSchema; use crate::physical_plan::PhysicalExprRef; use crate::physical_plan::PhysicalPlanRef; -use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; +use arrow::{record_batch::RecordBatch}; #[derive(Debug, Clone)] pub struct ProjectionPlan { input: PhysicalPlanRef, - schema: SchemaRef, + schema: NaiveSchema, expr: Vec, } impl ProjectionPlan { pub fn create( input: PhysicalPlanRef, - schema: SchemaRef, + schema: NaiveSchema, expr: Vec, ) -> PhysicalPlanRef { Arc::new(Self { @@ -34,8 +35,8 @@ impl ProjectionPlan { } impl PhysicalPlan for ProjectionPlan { - fn schema(&self) -> SchemaRef { - self.schema.clone() + fn schema(&self) -> &NaiveSchema { + &self.schema } fn execute(&self) -> Result> { @@ -54,7 +55,8 @@ impl PhysicalPlan for ProjectionPlan { .map(|column| column.clone().into_array()) .collect::>(); // TODO(veeupup): remove unwrap - RecordBatch::try_new(self.schema.clone(), columns).unwrap() + // let projection_schema = self.schema.into(); + RecordBatch::try_new(self.schema.clone().into(), columns).unwrap() }) .collect::>(); Ok(batches) @@ -73,16 +75,15 @@ mod tests { use crate::physical_plan::scan::ScanPlan; use arrow::{ array::{Array, ArrayRef, Int64Array, StringArray}, - datatypes::Schema, }; #[test] fn test_projection() -> Result<()> { - let source = CsvTable::try_create("test_data.csv", CsvConfig::default())?; - let schema = Arc::new(Schema::new(vec![ + let source = CsvTable::try_create("data/test_data.csv", CsvConfig::default())?; + let schema = NaiveSchema::new(vec![ source.schema().field(0).clone(), source.schema().field(1).clone(), - ])); + ]); let scan_plan = ScanPlan::create(source, None); let expr = vec![ diff --git a/src/physical_plan/scan.rs b/src/physical_plan/scan.rs index e2e81e8..9965c18 100644 --- a/src/physical_plan/scan.rs +++ b/src/physical_plan/scan.rs @@ -8,7 +8,8 @@ use std::sync::Arc; use crate::datasource::TableRef; use crate::error::Result; -use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; +use crate::logical_plan::schema::NaiveSchema; +use arrow::{record_batch::RecordBatch}; use crate::physical_plan::PhysicalPlan; use crate::physical_plan::PhysicalPlanRef; @@ -26,7 +27,7 @@ impl ScanPlan { } impl PhysicalPlan for ScanPlan { - fn schema(&self) -> SchemaRef { + fn schema(&self) -> &NaiveSchema { self.source.schema() } @@ -48,7 +49,7 @@ mod tests { #[test] fn test_physical_scan() -> Result<()> { - let source = CsvTable::try_create("test_data.csv", CsvConfig::default())?; + let source = CsvTable::try_create("data/test_data.csv", CsvConfig::default())?; let scan_plan = ScanPlan::create(source, None); diff --git a/src/physical_plan/selection.rs b/src/physical_plan/selection.rs index 664113e..53d9427 100644 --- a/src/physical_plan/selection.rs +++ b/src/physical_plan/selection.rs @@ -7,6 +7,7 @@ use std::sync::Arc; use super::{PhysicalExpr, PhysicalExprRef, PhysicalPlan, PhysicalPlanRef}; +use crate::logical_plan::schema::NaiveSchema; use crate::Result; use arrow::array::{ Float64Array, Float64Builder, Int64Array, Int64Builder, StringArray, StringBuilder, @@ -15,7 +16,7 @@ use arrow::array::{ use arrow::record_batch::RecordBatch; use arrow::{ array::{Array, BooleanArray, BooleanBuilder}, - datatypes::{DataType, SchemaRef}, + datatypes::{DataType}, }; #[derive(Debug)] @@ -50,7 +51,7 @@ macro_rules! build_array_by_predicate { } impl PhysicalPlan for SelectionPlan { - fn schema(&self) -> SchemaRef { + fn schema(&self) -> &NaiveSchema { self.input.schema() } @@ -98,7 +99,8 @@ impl PhysicalPlan for SelectionPlan { }; columns.push(column); } - let record_batch = RecordBatch::try_new(self.schema(), columns)?; + let record_batch = + RecordBatch::try_new(Arc::new(self.schema().clone().into()), columns)?; batches.push(record_batch); } Ok(batches) @@ -120,17 +122,16 @@ mod tests { use crate::print_result; use arrow::{ array::{Array, ArrayRef, Int64Array, StringArray}, - datatypes::Schema, }; #[test] fn test_selection() -> Result<()> { - let source = CsvTable::try_create("test_data.csv", CsvConfig::default())?; - let schema = Arc::new(Schema::new(vec![ + let source = CsvTable::try_create("data/test_data.csv", CsvConfig::default())?; + let schema = NaiveSchema::new(vec![ source.schema().field(0).clone(), source.schema().field(1).clone(), source.schema().field(2).clone(), - ])); + ]); let scan_plan = ScanPlan::create(source, None); let expr = vec![ diff --git a/src/planner/mod.rs b/src/planner/mod.rs index 6dab6fb..6a4040d 100644 --- a/src/planner/mod.rs +++ b/src/planner/mod.rs @@ -7,10 +7,12 @@ * */ -use std::sync::Arc; -use arrow::datatypes::Schema; + + +use crate::logical_plan::schema::NaiveSchema; +use crate::physical_plan::NestedLoopJoin; use crate::physical_plan::PhysicalBinaryExpr; use crate::physical_plan::PhysicalExprRef; use crate::physical_plan::PhysicalLimitPlan; @@ -47,15 +49,23 @@ impl QueryPlanner { .iter() .map(|expr| expr.data_field(proj.input.as_ref()).unwrap()) .collect::>(); - let proj_schema = Arc::new(Schema::new(fields)); + let proj_schema = NaiveSchema::new(fields); Ok(ProjectionPlan::create(input, proj_schema, proj_expr)) } LogicalPlan::Limit(limit) => { let plan = Self::create_physical_plan(&limit.input)?; Ok(PhysicalLimitPlan::create(plan, limit.n)) } - LogicalPlan::Join(_join) => { - todo!() + LogicalPlan::Join(join) => { + let left = Self::create_physical_plan(&join.left)?; + let right = Self::create_physical_plan(&join.right)?; + Ok(NestedLoopJoin::new( + left, + right, + join.on.clone(), + join.join_type, + join.schema.clone(), + )) } LogicalPlan::Filter(filter) => { let predicate = Self::create_physical_expression(&filter.predicate, plan)?; @@ -74,7 +84,7 @@ impl QueryPlanner { ) -> Result { match expr { LogicalExpr::Alias(_, _) => todo!(), - LogicalExpr::Column(Column(name)) => { + LogicalExpr::Column(Column { name, .. }) => { for (idx, field) in input.schema().fields().iter().enumerate() { if field.name() == name { return ColumnExpr::try_create(None, Some(idx)); @@ -105,6 +115,7 @@ impl QueryPlanner { #[cfg(test)] mod tests { + use std::sync::Arc; use arrow::array::ArrayRef; use arrow::array::Int64Array; use arrow::array::StringArray; @@ -117,18 +128,17 @@ mod tests { fn test_scan_projection() -> Result<()> { // construct let mut catalog = Catalog::default(); - catalog.add_csv_table("t1", "test_data.csv")?; + catalog.add_csv_table("t1", "data/test_data.csv")?; let source = catalog.get_table_df("t1")?; let exprs = vec![ - LogicalExpr::column("id".to_string()), - LogicalExpr::column("name".to_string()), - LogicalExpr::column("age".to_string()), + LogicalExpr::column(None, "id".to_string()), + LogicalExpr::column(None, "name".to_string()), + LogicalExpr::column(None, "age".to_string()), ]; let logical_plan = source.project(exprs).logical_plan(); let physical_plan = QueryPlanner::create_physical_plan(&logical_plan)?; let batches = physical_plan.execute()?; - println!("{:?}", batches); // test assert_eq!(batches.len(), 1); let batch = &batches[0]; diff --git a/src/sql/planner.rs b/src/sql/planner.rs index 491ad72..52468e9 100644 --- a/src/sql/planner.rs +++ b/src/sql/planner.rs @@ -7,14 +7,20 @@ * */ -use log::debug; -use sqlparser::ast::{BinaryOperator, Expr, OrderByExpr, SetExpr, Statement, TableWithJoins}; +use std::collections::HashSet; + + +use sqlparser::ast::{ + BinaryOperator, Expr, Join, JoinConstraint, JoinOperator, OrderByExpr, SetExpr, Statement, + TableWithJoins, +}; use sqlparser::ast::{Ident, ObjectName, SelectItem, TableFactor, Value}; use crate::error::ErrorCode; -use crate::logical_plan::expression::{BinaryExpr, LogicalExpr, Operator, ScalarValue}; +use crate::logical_plan::expression::{BinaryExpr, Column, LogicalExpr, Operator, ScalarValue}; use crate::logical_plan::literal::lit; -use crate::logical_plan::plan::TableScan; +use crate::logical_plan::plan::{JoinType, TableScan}; + use crate::{ catalog::Catalog, error::Result, @@ -45,9 +51,9 @@ impl<'a> SQLPlanner<'a> { fn set_expr_to_plan(&self, set_expr: SetExpr) -> Result { match set_expr { SetExpr::Select(select) => { - let df = self.plan_from_tables(select.from)?; + let plans = self.plan_from_tables(select.from)?; - let df = self.plan_selection(select.selection, df)?; + let df = self.plan_selection(select.selection, plans)?; // process the SELECT expressions, with wildcards expanded let df = self.plan_from_projection(df, select.projection)?; @@ -82,21 +88,98 @@ impl<'a> SQLPlanner<'a> { } } - fn plan_from_tables(&self, from: Vec) -> Result { - // TODO(veeupup): support select with no from - // TODO(veeupup): support select with join, multi table - debug_assert!(!from.is_empty()); - match &from[0].relation { - TableFactor::Table { name, alias: _, .. } => { + fn plan_from_tables(&self, from: Vec) -> Result> { + match from.len() { + 0 => todo!("support select with no from"), + _ => from + .iter() + .map(|t| self.plan_table_with_joins(t)) + .collect::>>(), + } + } + + fn plan_table_with_joins(&self, t: &TableWithJoins) -> Result { + let left = self.parse_table(&t.relation)?; + match t.joins.len() { + 0 => Ok(left), + n => { + let mut left = self.parse_table_join(left, &t.joins[0])?; + for i in 1..n { + left = self.parse_table_join(left, &t.joins[i])?; + } + Ok(left) + } + } + } + + fn parse_table_join(&self, left: LogicalPlan, join: &Join) -> Result { + let right = self.parse_table(&join.relation)?; + match &join.join_operator { + JoinOperator::LeftOuter(constraint) => { + self.parse_join(left, right, constraint, JoinType::Left) + } + JoinOperator::RightOuter(constraint) => { + self.parse_join(left, right, constraint, JoinType::Right) + } + JoinOperator::Inner(constraint) => { + self.parse_join(left, right, constraint, JoinType::Inner) + } + // TODO(veeupup): cross join + _other => Err(ErrorCode::NotImplemented), + } + } + + fn parse_join( + &self, + left: LogicalPlan, + right: LogicalPlan, + constraint: &JoinConstraint, + join_type: JoinType, + ) -> Result { + match constraint { + JoinConstraint::On(sql_expr) => { + let mut keys: Vec<(Column, Column)> = vec![]; + let expr = self.sql_to_expr(sql_expr)?; + + let mut filters = vec![]; + extract_join_keys(&expr, &mut keys, &mut filters); + + let left_keys = keys.iter().map(|pair| pair.0.clone()).collect(); + let right_keys = keys.iter().map(|pair| pair.1.clone()).collect(); + + if filters.is_empty() { + let join = + DataFrame::new(left).join(&right, join_type, (left_keys, right_keys))?; + Ok(join.logical_plan()) + } else if join_type == JoinType::Inner { + let join = + DataFrame::new(left).join(&right, join_type, (left_keys, right_keys))?; + let join = join.filter( + filters + .iter() + .skip(1) + .fold(filters[0].clone(), |acc, e| acc.and(e.clone())), + ); + Ok(join.logical_plan()) + } else { + Err(ErrorCode::NotImplemented) + } + } + _ => Err(ErrorCode::NotImplemented), + } + } + + fn parse_table(&self, relation: &TableFactor) -> Result { + match &relation { + TableFactor::Table { name, .. } => { let table_name = Self::normalize_sql_object_name(name); let source = self.catalog.get_table(&table_name)?; - let plan = LogicalPlan::TableScan(TableScan { + Ok(LogicalPlan::TableScan(TableScan { source, projection: None, - }); - Ok(DataFrame { plan }) + })) } - _ => todo!(), + _ => unimplemented!(), } } @@ -119,18 +202,81 @@ impl<'a> SQLPlanner<'a> { Err(err) => Err(err), }) .collect::>(); - debug!("projection: {:?}", proj); Ok(df.project(proj)) } - fn plan_selection(&self, selection: Option, df: DataFrame) -> Result { + fn plan_selection( + &self, + selection: Option, + plans: Vec, + ) -> Result { + // TODO(veeupup): handle joins match selection { - Some(predicate_expr) => { - let filter_expr = self.sql_to_expr(&predicate_expr)?; - let df = df.filter(filter_expr); - Ok(df) + Some(expr) => { + let mut fields = vec![]; + for plan in &plans { + fields.extend_from_slice(plan.schema().fields()); + } + let filter_expr = self.sql_to_expr(&expr)?; + + // look for expressions of the form ` = ` + let mut possible_join_keys = vec![]; + extract_possible_join_keys(&filter_expr, &mut possible_join_keys)?; + + let mut all_join_keys = HashSet::new(); + let mut left = plans[0].clone(); + for right in plans.iter().skip(1) { + let left_schema = left.schema(); + let right_schema = right.schema(); + let mut join_keys = vec![]; + for (l, r) in &possible_join_keys { + if left_schema + .field_with_unqualified_name(l.name.as_str()) + .is_ok() + && right_schema + .field_with_unqualified_name(r.name.as_str()) + .is_ok() + { + join_keys.push((l.clone(), r.clone())); + } else if left_schema + .field_with_unqualified_name(r.name.as_str()) + .is_ok() + && right_schema + .field_with_unqualified_name(l.name.as_str()) + .is_ok() + { + join_keys.push((r.clone(), l.clone())); + } + } + if !join_keys.is_empty() { + let left_keys: Vec = + join_keys.iter().map(|(l, _)| l.clone()).collect(); + let right_keys: Vec = + join_keys.iter().map(|(_, r)| r.clone()).collect(); + let df = DataFrame::new(left); + left = df + .join(right, JoinType::Inner, (left_keys, right_keys))? + .logical_plan(); + } else { + return Err(ErrorCode::NotImplemented); + } + + all_join_keys.extend(join_keys); + } + // remove join expressions from filter + match remove_join_expressions(&filter_expr, &all_join_keys)? { + Some(filter_expr) => Ok(DataFrame::new(left).filter(filter_expr)), + _ => Ok(DataFrame::new(left)), + } + } + None => { + if plans.len() == 1 { + Ok(DataFrame::new(plans[0].clone())) + } else { + // CROSS JOIN NOT SUPPORTED YET + Err(ErrorCode::NotImplemented) + } } - None => Ok(df), } } @@ -156,9 +302,23 @@ impl<'a> SQLPlanner<'a> { } Expr::Value(Value::SingleQuotedString(ref s)) => Ok(lit(s.clone())), Expr::Value(Value::Null) => Ok(LogicalExpr::Literal(ScalarValue::Null)), - Expr::Identifier(id) => Ok(LogicalExpr::column(normalize_ident(id))), + Expr::Identifier(id) => Ok(LogicalExpr::column(None, normalize_ident(id))), // TODO(veeupup): cast func Expr::BinaryOp { left, op, right } => self.parse_sql_binary_op(left, op, right), + Expr::CompoundIdentifier(ids) => { + let mut var_names = ids.iter().map(|id| id.value.clone()).collect::>(); + + match (var_names.pop(), var_names.pop()) { + (Some(name), Some(table)) if var_names.is_empty() => { + // table.column identifier + Ok(LogicalExpr::Column(Column { + table: Some(table), + name, + })) + } + _ => Err(ErrorCode::NotImplemented), + } + } _ => todo!(), } } @@ -201,17 +361,119 @@ fn normalize_ident(id: &Ident) -> String { } } +/// Come from apache/arrow-datafusion +/// Extracts equijoin ON condition be a single Eq or multiple conjunctive Eqs +/// Filters matching this pattern are added to `accum` +/// Filters that don't match this pattern are added to `accum_filter` +/// Examples: +/// +/// foo = bar => accum=[(foo, bar)] accum_filter=[] +/// foo = bar AND bar = baz => accum=[(foo, bar), (bar, baz)] accum_filter=[] +/// foo = bar AND baz > 1 => accum=[(foo, bar)] accum_filter=[baz > 1] +/// +fn extract_join_keys( + expr: &LogicalExpr, + accum: &mut Vec<(Column, Column)>, + accum_filter: &mut Vec, +) { + match expr { + LogicalExpr::BinaryExpr(BinaryExpr { left, op, right }) => match op { + Operator::Eq => match (left.as_ref(), right.as_ref()) { + (LogicalExpr::Column(l), LogicalExpr::Column(r)) => { + accum.push((l.clone(), r.clone())); + } + _other => { + accum_filter.push(expr.clone()); + } + }, + Operator::And => { + extract_join_keys(left, accum, accum_filter); + extract_join_keys(right, accum, accum_filter); + } + _other + if matches!(**left, LogicalExpr::Column(_)) + || matches!(**right, LogicalExpr::Column(_)) => + { + accum_filter.push(expr.clone()); + } + _other => { + extract_join_keys(left, accum, accum_filter); + extract_join_keys(right, accum, accum_filter); + } + }, + _other => { + accum_filter.push(expr.clone()); + } + } +} + +/// Extract join keys from a WHERE clause +fn extract_possible_join_keys(expr: &LogicalExpr, accum: &mut Vec<(Column, Column)>) -> Result<()> { + match expr { + LogicalExpr::BinaryExpr(BinaryExpr { left, op, right }) => match op { + Operator::Eq => match (left.as_ref(), right.as_ref()) { + (LogicalExpr::Column(l), LogicalExpr::Column(r)) => { + accum.push((l.clone(), r.clone())); + Ok(()) + } + _ => Ok(()), + }, + Operator::And => { + extract_possible_join_keys(left, accum)?; + extract_possible_join_keys(right, accum) + } + _ => Ok(()), + }, + _ => Ok(()), + } +} + +/// Remove join expressions from a filter expression +fn remove_join_expressions( + expr: &LogicalExpr, + join_columns: &HashSet<(Column, Column)>, +) -> Result> { + match expr { + LogicalExpr::BinaryExpr(BinaryExpr { left, op, right }) => match op { + Operator::Eq => match (left.as_ref(), right.as_ref()) { + (LogicalExpr::Column(l), LogicalExpr::Column(r)) => { + if join_columns.contains(&(l.clone(), r.clone())) + || join_columns.contains(&(r.clone(), l.clone())) + { + Ok(None) + } else { + Ok(Some(expr.clone())) + } + } + _ => Ok(Some(expr.clone())), + }, + Operator::And => { + let l = remove_join_expressions(left, join_columns)?; + let r = remove_join_expressions(right, join_columns)?; + match (l, r) { + (Some(ll), Some(rr)) => Ok(Some(LogicalExpr::and(ll, rr))), + (Some(ll), _) => Ok(Some(ll)), + (_, Some(rr)) => Ok(Some(rr)), + _ => Ok(None), + } + } + _ => Ok(Some(expr.clone())), + }, + _ => Ok(Some(expr.clone())), + } +} + #[cfg(test)] mod tests { - use crate::db::NaiveDB; use crate::error::Result; + use crate::{db::NaiveDB, print_result}; use arrow::array::{Array, ArrayRef, Int64Array, StringArray}; use std::sync::Arc; #[test] fn select_with_projection_filter() -> Result<()> { let mut db = NaiveDB::default(); - db.create_csv_table("t1", "test_data.csv")?; + db.create_csv_table("t1", "data/test_data.csv")?; { let ret = db.run_sql("select id, name from t1")?; @@ -246,6 +508,16 @@ mod tests { assert_eq!(batch.column(2), &age_excepted); } + { + db.create_csv_table("employee", "data/employee.csv"); + db.create_csv_table("rank", "data/rank.csv"); + + let ret = db + .run_sql("select id, name from employee innner join rank on employee.id = rank.id"); + + print_result(&ret?); + } + Ok(()) } }