From 52a9d93cf5eca14e63baaac435ff3205667edf77 Mon Sep 17 00:00:00 2001 From: Kould <2435992353@qq.com> Date: Mon, 25 Mar 2024 15:55:00 +0800 Subject: [PATCH] feat: support 'Expression' on `Default` --- Cargo.toml | 1 + src/binder/create_table.rs | 18 +++++++++++------- src/binder/insert.rs | 9 ++++----- src/binder/update.rs | 9 ++++----- src/catalog/column.rs | 15 +++++++++++---- src/db.rs | 12 +++++++----- src/errors.rs | 2 ++ src/execution/volcano/ddl/add_column.rs | 2 +- src/execution/volcano/dml/insert.rs | 12 ++++++++---- src/execution/volcano/dql/describe.rs | 10 +++++++--- src/expression/function.rs | 6 ++++-- src/expression/mod.rs | 4 ++-- src/marcos/mod.rs | 9 ++++++--- src/storage/kip.rs | 2 +- src/types/tuple.rs | 10 ++++++++++ tests/slt/create.slt | 5 ++++- 16 files changed, 83 insertions(+), 43 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 3d089da5..6d8125f0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -63,6 +63,7 @@ strum_macros = { version = "0.26.2" } thiserror = { version = "1.0.58" } tokio = { version = "1.36.0", features = ["full"] } tracing = { version = "0.1.40" } +typetag = { version = "0.2" } [dev-dependencies] cargo-tarpaulin = { version = "0.27.1" } diff --git a/src/binder/create_table.rs b/src/binder/create_table.rs index 5fabf9ec..70717fcc 100644 --- a/src/binder/create_table.rs +++ b/src/binder/create_table.rs @@ -12,7 +12,6 @@ use crate::planner::operator::create_table::CreateTableOperator; use crate::planner::operator::Operator; use crate::planner::LogicalPlan; use crate::storage::Transaction; -use crate::types::value::DataValue; use crate::types::LogicalType; impl<'a, T: Transaction> Binder<'a, T> { @@ -116,15 +115,20 @@ impl<'a, T: Transaction> Binder<'a, T> { } } ColumnOption::Default(expr) => { - if let ScalarExpression::Constant(value) = self.bind_expr(expr)? { - let cast_value = - DataValue::clone(&value).cast(&column_desc.column_datatype)?; - column_desc.default = Some(Arc::new(cast_value)); - } else { + let mut expr = self.bind_expr(expr)?; + + if !expr.referenced_columns(true).is_empty() { return Err(DatabaseError::UnsupportedStmt( - "'DEFAULT' only with constant now".to_string(), + "column is not allowed to exist in `default`".to_string(), )); } + if expr.return_type() != column_desc.column_datatype { + expr = ScalarExpression::TypeCast { + expr: Box::new(expr), + ty: column_desc.column_datatype, + } + } + column_desc.default = Some(expr); } _ => todo!(), } diff --git a/src/binder/insert.rs b/src/binder/insert.rs index 65f5ca79..699a2aa5 100644 --- a/src/binder/insert.rs +++ b/src/binder/insert.rs @@ -78,11 +78,10 @@ impl<'a, T: Transaction> Binder<'a, T> { row.push(value); } ScalarExpression::Empty => { - row.push(schema_ref[i].default_value().ok_or_else(|| { - DatabaseError::InvalidColumn( - "column does not exist default".to_string(), - ) - })?); + let default_value = schema_ref[i] + .default_value()? + .ok_or(DatabaseError::DefaultNotExist)?; + row.push(default_value); } _ => return Err(DatabaseError::UnsupportedStmt(expr.to_string())), } diff --git a/src/binder/update.rs b/src/binder/update.rs index 7a0b2fe6..33253d1a 100644 --- a/src/binder/update.rs +++ b/src/binder/update.rs @@ -53,11 +53,10 @@ impl<'a, T: Transaction> Binder<'a, T> { row.push(value.clone()); } ScalarExpression::Empty => { - row.push(column.default_value().ok_or_else(|| { - DatabaseError::InvalidColumn( - "column does not exist default".to_string(), - ) - })?); + let default_value = column + .default_value()? + .ok_or(DatabaseError::DefaultNotExist)?; + row.push(default_value); } _ => return Err(DatabaseError::UnsupportedStmt(value.to_string())), } diff --git a/src/catalog/column.rs b/src/catalog/column.rs index 8533180d..32d7653f 100644 --- a/src/catalog/column.rs +++ b/src/catalog/column.rs @@ -1,8 +1,11 @@ use crate::catalog::TableName; +use crate::errors::DatabaseError; +use crate::expression::ScalarExpression; use serde::{Deserialize, Serialize}; use std::hash::Hash; use std::sync::Arc; +use crate::types::tuple::EMPTY_TUPLE; use crate::types::value::ValueRef; use crate::types::{ColumnId, LogicalType}; @@ -82,8 +85,12 @@ impl ColumnCatalog { &self.desc.column_datatype } - pub(crate) fn default_value(&self) -> Option { - self.desc.default.clone() + pub(crate) fn default_value(&self) -> Result, DatabaseError> { + self.desc + .default + .as_ref() + .map(|expr| expr.eval(&EMPTY_TUPLE, &[])) + .transpose() } #[allow(dead_code)] @@ -98,7 +105,7 @@ pub struct ColumnDesc { pub(crate) column_datatype: LogicalType, pub(crate) is_primary: bool, pub(crate) is_unique: bool, - pub(crate) default: Option, + pub(crate) default: Option, } impl ColumnDesc { @@ -106,7 +113,7 @@ impl ColumnDesc { column_datatype: LogicalType, is_primary: bool, is_unique: bool, - default: Option, + default: Option, ) -> ColumnDesc { ColumnDesc { column_datatype, diff --git a/src/db.rs b/src/db.rs index 98a9f8b0..b76e54fb 100644 --- a/src/db.rs +++ b/src/db.rs @@ -263,6 +263,8 @@ mod test { use crate::types::tuple::{create_table, Tuple}; use crate::types::value::{DataValue, ValueRef}; use crate::types::LogicalType; + use serde::Deserialize; + use serde::Serialize; use std::sync::Arc; use tempfile::TempDir; @@ -298,10 +300,10 @@ mod test { Ok(()) } - function!(TestFunction::test(LogicalType::Integer, LogicalType::Integer) -> LogicalType::Integer => |v1: ValueRef, v2: ValueRef| { + function!(TestFunction::test(LogicalType::Integer, LogicalType::Integer) -> LogicalType::Integer => (|v1: ValueRef, v2: ValueRef| { let value = DataValue::binary_op(&v1, &v2, &BinaryOperator::Plus)?; DataValue::unary_op(&value, &UnaryOperator::Minus) - }); + })); #[tokio::test] async fn test_udf() -> Result<(), DatabaseError> { @@ -311,12 +313,12 @@ mod test { .build() .await?; let _ = fnck_sql - .run("CREATE TABLE test (id int primary key, c1 int, c2 int);") + .run("CREATE TABLE test (id int primary key, c1 int, c2 int default test(1, 2));") .await?; let _ = fnck_sql - .run("INSERT INTO test VALUES (1, 2, 2), (0, 1, 1), (2, 1, 1), (3, 3, 3);") + .run("INSERT INTO test VALUES (1, 2, 2), (0, 1, 1), (2, 1, 1), (3, 3, default);") .await?; - let (schema, tuples) = fnck_sql.run("select test(c1, 1) from test").await?; + let (schema, tuples) = fnck_sql.run("select test(c1, 1), c2 from test").await?; println!("{}", create_table(&schema, &tuples)); Ok(()) diff --git a/src/errors.rs b/src/errors.rs index 51cb4964..fc06f38f 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -29,6 +29,8 @@ pub enum DatabaseError { #[source] csv::Error, ), + #[error("default does not exist")] + DefaultNotExist, #[error("column: {0} already exists")] DuplicateColumn(String), #[error("index: {0} already exists")] diff --git a/src/execution/volcano/ddl/add_column.rs b/src/execution/volcano/ddl/add_column.rs index 01d7c22a..3c228035 100644 --- a/src/execution/volcano/ddl/add_column.rs +++ b/src/execution/volcano/ddl/add_column.rs @@ -50,7 +50,7 @@ impl AddColumn { for tuple in build_read(self.input, transaction) { let mut tuple: Tuple = tuple?; - if let Some(value) = column.default_value() { + if let Some(value) = column.default_value()? { if let Some(unique_values) = &mut unique_values { unique_values.push((tuple.id.clone().unwrap(), value.clone())); } diff --git a/src/execution/volcano/dml/insert.rs b/src/execution/volcano/dml/insert.rs index 70145974..31cdee23 100644 --- a/src/execution/volcano/dml/insert.rs +++ b/src/execution/volcano/dml/insert.rs @@ -76,10 +76,14 @@ impl Insert { let mut values = Vec::with_capacity(table_catalog.columns_len()); for col in table_catalog.columns() { - let value = tuple_map - .remove(&col.id()) - .or_else(|| col.default_value()) - .unwrap_or_else(|| Arc::new(DataValue::none(col.datatype()))); + let value = { + let mut value = tuple_map.remove(&col.id()); + + if value.is_none() { + value = col.default_value()?; + } + value.unwrap_or_else(|| Arc::new(DataValue::none(col.datatype()))) + }; if value.is_null() && !col.nullable { return Err(DatabaseError::NotNull); } diff --git a/src/execution/volcano/dql/describe.rs b/src/execution/volcano/dql/describe.rs index c24c60b2..247e449c 100644 --- a/src/execution/volcano/dql/describe.rs +++ b/src/execution/volcano/dql/describe.rs @@ -52,6 +52,12 @@ impl Describe { for column in table.columns() { let datatype = column.datatype(); + let default = column + .desc + .default + .as_ref() + .map(|expr| format!("{}", expr)) + .unwrap_or_else(|| "null".to_string()); let values = vec![ Arc::new(DataValue::Utf8(Some(column.name().to_string()))), Arc::new(DataValue::Utf8(Some(datatype.to_string()))), @@ -63,9 +69,7 @@ impl Describe { ))), Arc::new(DataValue::Utf8(Some(column.nullable.to_string()))), key_fn(column), - column - .default_value() - .unwrap_or_else(|| Arc::new(DataValue::none(datatype))), + Arc::new(DataValue::Utf8(Some(default))), ]; yield Tuple { id: None, values }; } diff --git a/src/expression/function.rs b/src/expression/function.rs index 1922583d..3236d132 100644 --- a/src/expression/function.rs +++ b/src/expression/function.rs @@ -4,6 +4,7 @@ use crate::expression::ScalarExpression; use crate::types::tuple::Tuple; use crate::types::value::DataValue; use crate::types::LogicalType; +use serde::{Deserialize, Serialize}; use std::fmt::Debug; use std::hash::{Hash, Hasher}; use std::sync::Arc; @@ -14,7 +15,7 @@ use std::sync::Arc; /// - `Some(false)` monotonically decreasing pub type FuncMonotonicity = Vec>; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct ScalarFunction { pub(crate) args: Vec, pub(crate) inner: Arc, @@ -34,12 +35,13 @@ impl Hash for ScalarFunction { } } -#[derive(Debug, Eq, PartialEq, Hash, Clone)] +#[derive(Debug, Eq, PartialEq, Hash, Clone, Serialize, Deserialize)] pub struct FunctionSummary { pub(crate) name: String, pub(crate) arg_types: Vec, } +#[typetag::serde(tag = "type")] pub trait ScalarFunctionImpl: Debug + Send + Sync { fn eval( &self, diff --git a/src/expression/mod.rs b/src/expression/mod.rs index 78672a4c..ab97c43c 100644 --- a/src/expression/mod.rs +++ b/src/expression/mod.rs @@ -20,7 +20,7 @@ pub mod range_detacher; pub mod simplify; pub mod value_compute; -#[derive(Debug, PartialEq, Eq, Clone, Hash)] +#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] pub enum AliasType { Name(String), Expr(Box), @@ -30,7 +30,7 @@ pub enum AliasType { /// SELECT a+1, b FROM t1. /// a+1 -> ScalarExpression::Unary(a + 1) /// b -> ScalarExpression::ColumnRef() -#[derive(Debug, PartialEq, Eq, Clone, Hash)] +#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] pub enum ScalarExpression { Constant(ValueRef), ColumnRef(ColumnRef), diff --git a/src/marcos/mod.rs b/src/marcos/mod.rs index 9b48cee3..98379465 100644 --- a/src/marcos/mod.rs +++ b/src/marcos/mod.rs @@ -68,7 +68,7 @@ macro_rules! implement_from_tuple { #[macro_export] macro_rules! function { ($struct_name:ident::$function_name:ident($($arg_ty:expr),*) -> $return_ty:expr => $closure:expr) => { - #[derive(Debug)] + #[derive(Debug, Serialize, Deserialize)] pub(crate) struct $struct_name { summary: FunctionSummary } @@ -91,6 +91,7 @@ macro_rules! function { } } + #[typetag::serde] impl ScalarFunctionImpl for $struct_name { fn eval(&self, args: &[ScalarExpression], tuple: &Tuple, schema: &[ColumnRef]) -> Result { let mut _index = 0; @@ -132,6 +133,8 @@ mod test { use crate::types::tuple::{SchemaRef, Tuple}; use crate::types::value::{DataValue, ValueRef}; use crate::types::LogicalType; + use serde::Deserialize; + use serde::Serialize; use std::sync::Arc; fn build_tuple() -> (Tuple, SchemaRef) { @@ -187,9 +190,9 @@ mod test { assert_eq!(my_struct.c2, "LOL"); } - function!(MyFunction::sum(LogicalType::Integer, LogicalType::Integer) -> LogicalType::Integer => |v1: ValueRef, v2: ValueRef| { + function!(MyFunction::sum(LogicalType::Integer, LogicalType::Integer) -> LogicalType::Integer => (|v1: ValueRef, v2: ValueRef| { DataValue::binary_op(&v1, &v2, &BinaryOperator::Plus) - }); + })); #[test] fn test_function() -> Result<(), DatabaseError> { diff --git a/src/storage/kip.rs b/src/storage/kip.rs index d71ee85e..50d27239 100644 --- a/src/storage/kip.rs +++ b/src/storage/kip.rs @@ -232,7 +232,7 @@ impl Transaction for KipTransaction { if_not_exists: bool, ) -> Result { if let Some(mut table) = self.table(table_name.clone()).cloned() { - if !column.nullable && column.default_value().is_none() { + if !column.nullable && column.default_value()?.is_none() { return Err(DatabaseError::NeedNullAbleOrDefault); } diff --git a/src/types/tuple.rs b/src/types/tuple.rs index ff484a22..0cced1ae 100644 --- a/src/types/tuple.rs +++ b/src/types/tuple.rs @@ -4,8 +4,18 @@ use crate::types::LogicalType; use comfy_table::{Cell, Table}; use integer_encoding::FixedInt; use itertools::Itertools; +use lazy_static::lazy_static; use std::sync::Arc; +lazy_static! { + pub static ref EMPTY_TUPLE: Tuple = { + Tuple { + id: None, + values: vec![], + } + }; +} + const BITS_MAX_INDEX: usize = 8; pub type TupleId = ValueRef; diff --git a/tests/slt/create.slt b/tests/slt/create.slt index bf017236..046c3031 100644 --- a/tests/slt/create.slt +++ b/tests/slt/create.slt @@ -20,4 +20,7 @@ statement ok create table if not exists t(id int primary key, v1 int, v2 int, v3 int) statement ok -create table if not exists t(id int primary key, v1 int, v2 int, v3 int) \ No newline at end of file +create table if not exists t(id int primary key, v1 int, v2 int, v3 int) + +statement error +create table test_default_expr(id int primary key, v1 int, v2 int, v3 int default (v1 + 1)) \ No newline at end of file