From fe8832261f5365cdc137ce084688deede4b9989c Mon Sep 17 00:00:00 2001 From: Kould Date: Thu, 7 Nov 2024 01:22:17 +0800 Subject: [PATCH] feat: impl multiple primary keys --- src/binder/create_table.rs | 14 +-- src/binder/delete.rs | 2 +- src/binder/expr.rs | 2 +- src/catalog/column.rs | 20 +++- src/catalog/table.rs | 38 +++++-- src/errors.rs | 2 + src/execution/ddl/add_column.rs | 6 +- src/execution/ddl/drop_column.rs | 4 +- src/execution/dml/insert.rs | 28 +++-- src/execution/dml/update.rs | 24 +++- src/execution/dql/aggregate/min_max.rs | 2 +- src/execution/dql/aggregate/sum.rs | 2 +- src/execution/dql/describe.rs | 4 +- src/expression/evaluator.rs | 3 +- src/expression/mod.rs | 18 ++- src/expression/simplify.rs | 24 ++-- src/optimizer/core/histogram.rs | 2 +- .../rule/implementation/dql/table_scan.rs | 6 +- .../rule/normalization/simplification.rs | 14 +-- src/planner/operator/table_scan.rs | 18 ++- src/serdes/column.rs | 31 +++-- src/serdes/data_value.rs | 106 +++++++++++++----- src/storage/mod.rs | 21 ++-- src/storage/table_codec.rs | 45 ++++++-- src/types/evaluator/mod.rs | 2 +- src/types/mod.rs | 36 +++--- src/types/tuple.rs | 32 ++++-- src/types/tuple_builder.rs | 18 +-- src/types/value.rs | 39 +++++-- tests/slt/create.slt | 3 + tests/slt/delete.slt | 2 +- tests/slt/delete_multiple_primary_keys.slt | 28 +++++ tests/slt/describe.slt | 15 ++- tests/slt/insert.slt | 2 - tests/slt/insert_multiple_primary_keys.slt | 101 +++++++++++++++++ tests/slt/sql_2016/E141_03.slt | 10 +- tests/slt/sql_2016/E141_08.slt | 10 +- tests/slt/update.slt | 2 +- tests/slt/update_multiple_primary_keys.slt | 47 ++++++++ 39 files changed, 569 insertions(+), 214 deletions(-) create mode 100644 tests/slt/delete_multiple_primary_keys.slt create mode 100644 tests/slt/insert_multiple_primary_keys.slt create mode 100644 tests/slt/update_multiple_primary_keys.slt diff --git a/src/binder/create_table.rs b/src/binder/create_table.rs index 2ba98d13..29af3d5a 100644 --- a/src/binder/create_table.rs +++ b/src/binder/create_table.rs @@ -62,9 +62,9 @@ impl Binder<'_, '_, T> { .find(|column| column.name() == column_name) { if *is_primary { - column.desc_mut().is_primary = true; + column.desc_mut().set_primary(true); } else { - column.desc_mut().is_unique = true; + column.desc_mut().set_unique(true); } } } @@ -73,9 +73,9 @@ impl Binder<'_, '_, T> { } } - if columns.iter().filter(|col| col.desc().is_primary).count() != 1 { + if columns.iter().filter(|col| col.desc().is_primary()).count() == 0 { return Err(DatabaseError::InvalidTable( - "The primary key field must exist and have at least one".to_string(), + "the primary key field must exist and have at least one".to_string(), )); } @@ -106,12 +106,12 @@ impl Binder<'_, '_, T> { ColumnOption::NotNull => nullable = false, ColumnOption::Unique { is_primary, .. } => { if *is_primary { - column_desc.is_primary = true; + column_desc.set_primary(true); nullable = false; // Skip other options when using primary key break; } else { - column_desc.is_unique = true; + column_desc.set_unique(true); } } ColumnOption::Default(expr) => { @@ -125,7 +125,7 @@ impl Binder<'_, '_, T> { if expr.return_type() != column_desc.column_datatype { expr = ScalarExpression::TypeCast { expr: Box::new(expr), - ty: column_desc.column_datatype, + ty: column_desc.column_datatype.clone(), } } column_desc.default = Some(expr); diff --git a/src/binder/delete.rs b/src/binder/delete.rs index 192e21e6..4e1e1e94 100644 --- a/src/binder/delete.rs +++ b/src/binder/delete.rs @@ -30,7 +30,7 @@ impl Binder<'_, '_, T> { let schema_buf = self.table_schema_buf.entry(table_name.clone()).or_default(); let primary_key_column = source .columns(schema_buf) - .find(|column| column.desc().is_primary) + .find(|column| column.desc().is_primary()) .cloned() .unwrap(); let mut plan = match source { diff --git a/src/binder/expr.rs b/src/binder/expr.rs index e8f67f2e..590ca25d 100644 --- a/src/binder/expr.rs +++ b/src/binder/expr.rs @@ -186,7 +186,7 @@ impl<'a, T: Transaction> Binder<'a, '_, T> { if ty == &LogicalType::SqlNull { *ty = result_ty; } else if ty != &result_ty { - return Err(DatabaseError::Incomparable(*ty, result_ty)); + return Err(DatabaseError::Incomparable(ty.clone(), result_ty)); } } diff --git a/src/catalog/column.rs b/src/catalog/column.rs index e7c537ad..599ced6d 100644 --- a/src/catalog/column.rs +++ b/src/catalog/column.rs @@ -187,8 +187,8 @@ impl ColumnCatalog { #[derive(Debug, Clone, PartialEq, Eq, Hash, ReferenceSerialization)] pub struct ColumnDesc { pub(crate) column_datatype: LogicalType, - pub(crate) is_primary: bool, - pub(crate) is_unique: bool, + is_primary: bool, + is_unique: bool, pub(crate) default: Option, } @@ -212,4 +212,20 @@ impl ColumnDesc { default, }) } + + pub(crate) fn is_primary(&self) -> bool { + self.is_primary + } + + pub(crate) fn set_primary(&mut self, is_primary: bool) { + self.is_primary = is_primary + } + + pub(crate) fn is_unique(&self) -> bool { + self.is_unique + } + + pub(crate) fn set_unique(&mut self, is_unique: bool) { + self.is_unique = is_unique + } } diff --git a/src/catalog/table.rs b/src/catalog/table.rs index 49b695fd..5966fc51 100644 --- a/src/catalog/table.rs +++ b/src/catalog/table.rs @@ -21,6 +21,7 @@ pub struct TableCatalog { pub(crate) indexes: Vec, schema_ref: SchemaRef, + primary_keys: Vec<(usize, ColumnRef)>, } //TODO: can add some like Table description and other information as attributes @@ -73,17 +74,13 @@ impl TableCatalog { self.columns.len() } - pub(crate) fn primary_key(&self) -> Result<(usize, &ColumnRef), DatabaseError> { - self.schema_ref - .iter() - .enumerate() - .find(|(_, column)| column.desc().is_primary) - .ok_or(DatabaseError::PrimaryKeyNotFound) + pub(crate) fn primary_keys(&self) -> &[(usize, ColumnRef)] { + &self.primary_keys } pub(crate) fn types(&self) -> Vec { self.columns() - .map(|column| *column.datatype()) + .map(|column| column.datatype().clone()) .collect_vec() } @@ -128,7 +125,17 @@ impl TableCatalog { } let index_id = self.indexes.last().map(|index| index.id + 1).unwrap_or(0); - let pk_ty = *self.primary_key()?.1.datatype(); + let primary_keys = self.primary_keys(); + let pk_ty = if primary_keys.len() == 1 { + primary_keys[0].1.datatype().clone() + } else { + LogicalType::Tuple( + primary_keys + .iter() + .map(|(_, column)| column.datatype().clone()) + .collect_vec(), + ) + }; let index = IndexMeta { id: index_id, column_ids, @@ -154,6 +161,7 @@ impl TableCatalog { columns: BTreeMap::new(), indexes: vec![], schema_ref: Arc::new(vec![]), + primary_keys: vec![], }; let mut generator = Generator::new(); for col_catalog in columns.into_iter() { @@ -161,6 +169,13 @@ impl TableCatalog { .add_column(col_catalog, &mut generator) .unwrap(); } + table_catalog.primary_keys = table_catalog + .schema_ref + .iter() + .enumerate() + .filter(|&(_, column)| column.desc().is_primary()) + .map(|(i, column)| (i, column.clone())) + .collect_vec(); Ok(table_catalog) } @@ -182,6 +197,12 @@ impl TableCatalog { columns.insert(column_id, i); } let schema_ref = Arc::new(column_refs.clone()); + let primary_keys = schema_ref + .iter() + .enumerate() + .filter(|&(_, column)| column.desc().is_primary()) + .map(|(i, column)| (i, column.clone())) + .collect_vec(); Ok(TableCatalog { name, @@ -189,6 +210,7 @@ impl TableCatalog { columns, indexes, schema_ref, + primary_keys, }) } } diff --git a/src/errors.rs b/src/errors.rs index 1a784a68..27993afa 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -112,6 +112,8 @@ pub enum DatabaseError { ), #[error("must contain primary key!")] PrimaryKeyNotFound, + #[error("primaryKey only allows single or multiple values")] + PrimaryKeyTooManyLayers, #[error("rocksdb: {0}")] RocksDB( #[source] diff --git a/src/execution/ddl/add_column.rs b/src/execution/ddl/add_column.rs index 09bd0398..00933ad2 100644 --- a/src/execution/ddl/add_column.rs +++ b/src/execution/ddl/add_column.rs @@ -40,15 +40,15 @@ impl<'a, T: Transaction + 'a> WriteExecutor<'a, T> for AddColumn { if_not_exists, } = &self.op; - let mut unique_values = column.desc().is_unique.then(Vec::new); + let mut unique_values = column.desc().is_unique().then(Vec::new); let mut tuples = Vec::new(); let schema = self.input.output_schema(); let mut types = Vec::with_capacity(schema.len() + 1); for column_ref in schema.iter() { - types.push(*column_ref.datatype()); + types.push(column_ref.datatype().clone()); } - types.push(*column.datatype()); + types.push(column.datatype().clone()); let mut coroutine = build_read(self.input, cache, transaction); diff --git a/src/execution/ddl/drop_column.rs b/src/execution/ddl/drop_column.rs index 5aee0f16..d7f3ff0d 100644 --- a/src/execution/ddl/drop_column.rs +++ b/src/execution/ddl/drop_column.rs @@ -41,7 +41,7 @@ impl<'a, T: Transaction + 'a> WriteExecutor<'a, T> for DropColumn { .iter() .enumerate() .find(|(_, column)| column.name() == column_name) - .map(|(i, column)| (i, column.desc().is_primary)) + .map(|(i, column)| (i, column.desc().is_primary())) { if is_primary { throw!(Err(DatabaseError::InvalidColumn( @@ -55,7 +55,7 @@ impl<'a, T: Transaction + 'a> WriteExecutor<'a, T> for DropColumn { if i == column_index { continue; } - types.push(*column_ref.datatype()); + types.push(column_ref.datatype().clone()); } let mut coroutine = build_read(self.input, cache, transaction); diff --git a/src/execution/dml/insert.rs b/src/execution/dml/insert.rs index fe81b67b..1b2e673d 100644 --- a/src/execution/dml/insert.rs +++ b/src/execution/dml/insert.rs @@ -11,6 +11,7 @@ use crate::types::tuple::Tuple; use crate::types::tuple_builder::TupleBuilder; use crate::types::value::DataValue; use crate::types::ColumnId; +use itertools::Itertools; use std::collections::HashMap; use std::ops::Coroutine; use std::ops::CoroutineState; @@ -79,11 +80,14 @@ impl<'a, T: Transaction + 'a> WriteExecutor<'a, T> for Insert { let mut tuples = Vec::new(); let schema = input.output_schema().clone(); - let pk_key = throw!(schema + let primary_keys = schema .iter() - .find(|col| col.desc().is_primary) + .filter(|&col| col.desc().is_primary()) .map(|col| col.key(is_mapping_by_name)) - .ok_or(DatabaseError::NotNull)); + .collect_vec(); + if primary_keys.is_empty() { + throw!(Err(DatabaseError::NotNull)) + } if let Some(table_catalog) = throw!(transaction.table(cache.0, table_name.clone())).cloned() @@ -94,14 +98,18 @@ impl<'a, T: Transaction + 'a> WriteExecutor<'a, T> for Insert { while let CoroutineState::Yielded(tuple) = Pin::new(&mut coroutine).resume(()) { let Tuple { values, .. } = throw!(tuple); + let mut tuple_id = Vec::with_capacity(primary_keys.len()); let mut tuple_map = HashMap::new(); for (i, value) in values.into_iter().enumerate() { tuple_map.insert(schema[i].key(is_mapping_by_name), value); } - let tuple_id = throw!(tuple_map - .get(&pk_key) - .cloned() - .ok_or(DatabaseError::NotNull)); + + for primary_key in primary_keys.iter() { + tuple_id.push(throw!(tuple_map + .get(primary_key) + .cloned() + .ok_or(DatabaseError::NotNull))); + } let mut values = Vec::with_capacity(table_catalog.columns_len()); for col in table_catalog.columns() { @@ -120,7 +128,11 @@ impl<'a, T: Transaction + 'a> WriteExecutor<'a, T> for Insert { values.push(value) } tuples.push(Tuple { - id: Some(tuple_id), + id: Some(if primary_keys.len() == 1 { + tuple_id.pop().unwrap() + } else { + Arc::new(DataValue::Tuple(Some(tuple_id))) + }), values, }); } diff --git a/src/execution/dml/update.rs b/src/execution/dml/update.rs index 93d2f966..3f09fee3 100644 --- a/src/execution/dml/update.rs +++ b/src/execution/dml/update.rs @@ -9,10 +9,12 @@ use crate::types::index::Index; use crate::types::tuple::types; use crate::types::tuple::Tuple; use crate::types::tuple_builder::TupleBuilder; +use crate::types::value::DataValue; use std::collections::HashMap; use std::ops::Coroutine; use std::ops::CoroutineState; use std::pin::Pin; +use std::sync::Arc; pub struct Update { table_name: TableName, @@ -93,18 +95,28 @@ impl<'a, T: Transaction + 'a> WriteExecutor<'a, T> for Update { } for mut tuple in tuples { let mut is_overwrite = true; - + let mut primary_keys = Vec::new(); for (i, column) in input_schema.iter().enumerate() { if let Some(value) = value_map.get(&column.id()) { - if column.desc().is_primary { - let old_key = tuple.id.replace(value.clone()).unwrap(); - - throw!(transaction.remove_tuple(&table_name, &old_key)); - is_overwrite = false; + if column.desc().is_primary() { + primary_keys.push(value.clone()); } tuple.values[i] = value.clone(); } } + if !primary_keys.is_empty() { + let id = if primary_keys.len() == 1 { + primary_keys.pop().unwrap() + } else { + Arc::new(DataValue::Tuple(Some(primary_keys))) + }; + if &id != tuple.id.as_ref().unwrap() { + let old_key = tuple.id.replace(id).unwrap(); + + throw!(transaction.remove_tuple(&table_name, &old_key)); + is_overwrite = false; + } + } for (index_meta, exprs) in index_metas.iter() { let values = throw!(Projection::projection(&tuple, exprs, &input_schema)); diff --git a/src/execution/dql/aggregate/min_max.rs b/src/execution/dql/aggregate/min_max.rs index 39f64278..72b96e92 100644 --- a/src/execution/dql/aggregate/min_max.rs +++ b/src/execution/dql/aggregate/min_max.rs @@ -23,7 +23,7 @@ impl MinMaxAccumulator { Self { inner: None, op, - ty: *ty, + ty: ty.clone(), } } } diff --git a/src/execution/dql/aggregate/sum.rs b/src/execution/dql/aggregate/sum.rs index 3537ccf4..e3c6a304 100644 --- a/src/execution/dql/aggregate/sum.rs +++ b/src/execution/dql/aggregate/sum.rs @@ -19,7 +19,7 @@ impl SumAccumulator { Ok(Self { result: DataValue::none(ty), - evaluator: EvaluatorFactory::binary_create(*ty, BinaryOperator::Plus)?, + evaluator: EvaluatorFactory::binary_create(ty.clone(), BinaryOperator::Plus)?, }) } } diff --git a/src/execution/dql/describe.rs b/src/execution/dql/describe.rs index 1b8a0a17..a3840733 100644 --- a/src/execution/dql/describe.rs +++ b/src/execution/dql/describe.rs @@ -52,9 +52,9 @@ impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for Describe { let table = throw!(throw!(transaction.table(cache.0, self.table_name.clone())) .ok_or(DatabaseError::TableNotFound)); let key_fn = |column: &ColumnCatalog| { - if column.desc().is_primary { + if column.desc().is_primary() { PRIMARY_KEY_TYPE.clone() - } else if column.desc().is_unique { + } else if column.desc().is_unique() { UNIQUE_KEY_TYPE.clone() } else { EMPTY_KEY_TYPE.clone() diff --git a/src/expression/evaluator.rs b/src/expression/evaluator.rs index d09d69cb..dff88a90 100644 --- a/src/expression/evaluator.rs +++ b/src/expression/evaluator.rs @@ -351,7 +351,8 @@ impl ScalarExpression { let mut when_value = when_expr.eval(tuple, schema)?; let is_true = if let Some(operand_value) = &operand_value { let ty = operand_value.logical_type(); - let evaluator = EvaluatorFactory::binary_create(ty, BinaryOperator::Eq)?; + let evaluator = + EvaluatorFactory::binary_create(ty.clone(), BinaryOperator::Eq)?; if when_value.logical_type() != ty { when_value = Arc::new(DataValue::clone(&when_value).cast(&ty)?); diff --git a/src/expression/mod.rs b/src/expression/mod.rs index 1a6e051e..07fcbf4c 100644 --- a/src/expression/mod.rs +++ b/src/expression/mod.rs @@ -322,8 +322,8 @@ impl ScalarExpression { } } }; - fn_cast(left_expr, ty); - fn_cast(right_expr, ty); + fn_cast(left_expr, ty.clone()); + fn_cast(right_expr, ty.clone()); *evaluator = Some(EvaluatorFactory::binary_create(ty, *op)?); } @@ -567,7 +567,7 @@ impl ScalarExpression { pub fn return_type(&self) -> LogicalType { match self { ScalarExpression::Constant(v) => v.logical_type(), - ScalarExpression::ColumnRef(col) => *col.datatype(), + ScalarExpression::ColumnRef(col) => col.datatype().clone(), ScalarExpression::Binary { ty: return_type, .. } @@ -594,7 +594,7 @@ impl ScalarExpression { } | ScalarExpression::CaseWhen { ty: return_type, .. - } => *return_type, + } => return_type.clone(), ScalarExpression::IsNull { .. } | ScalarExpression::In { .. } | ScalarExpression::Between { .. } => LogicalType::Boolean, @@ -609,8 +609,14 @@ impl ScalarExpression { expr.return_type() } ScalarExpression::Empty | ScalarExpression::TableFunction(_) => unreachable!(), - ScalarExpression::Tuple(_) => LogicalType::Tuple, - ScalarExpression::ScalaFunction(ScalarFunction { inner, .. }) => *inner.return_type(), + ScalarExpression::Tuple(exprs) => { + let types = exprs.iter().map(|expr| expr.return_type()).collect_vec(); + + LogicalType::Tuple(types) + } + ScalarExpression::ScalaFunction(ScalarFunction { inner, .. }) => { + inner.return_type().clone() + } } } diff --git a/src/expression/simplify.rs b/src/expression/simplify.rs index 08786c15..50c5228b 100644 --- a/src/expression/simplify.rs +++ b/src/expression/simplify.rs @@ -173,7 +173,7 @@ impl ScalarExpression { let unary_value = if let Some(evaluator) = evaluator { evaluator.0.unary_eval(&value) } else { - EvaluatorFactory::unary_create(*ty, *op) + EvaluatorFactory::unary_create(ty.clone(), *op) .ok()? .0 .unary_eval(&value) @@ -190,16 +190,16 @@ impl ScalarExpression { } => { let mut left = left_expr.unpack_val()?; let mut right = right_expr.unpack_val()?; - if left.logical_type() != *ty { + if &left.logical_type() != ty { left = Arc::new(DataValue::clone(&left).cast(ty).ok()?); } - if right.logical_type() != *ty { + if &right.logical_type() != ty { right = Arc::new(DataValue::clone(&right).cast(ty).ok()?); } let binary_value = if let Some(evaluator) = evaluator { evaluator.0.binary_eval(&left, &right) } else { - EvaluatorFactory::binary_create(*ty, *op) + EvaluatorFactory::binary_create(ty.clone(), *op) .ok()? .0 .binary_eval(&left, &right) @@ -251,7 +251,7 @@ impl ScalarExpression { let value = if let Some(evaluator) = evaluator { evaluator.0.unary_eval(unary_val) } else { - EvaluatorFactory::unary_create(*ty, *op)? + EvaluatorFactory::unary_create(ty.clone(), *op)? .0 .unary_eval(unary_val) }; @@ -276,7 +276,7 @@ impl ScalarExpression { ScalarExpression::Constant(right_val), ) = (left_expr.as_mut(), right_expr.as_mut()) { - let evaluator = EvaluatorFactory::binary_create(ty, *op)?; + let evaluator = EvaluatorFactory::binary_create(ty.clone(), *op)?; if left_val.logical_type() != ty { *left_val = Arc::new(DataValue::clone(left_val).cast(&ty)?); @@ -421,7 +421,7 @@ impl ScalarExpression { column_expr: ScalarExpression::ColumnRef(col), val_expr: mem::replace(right_expr, ScalarExpression::Empty), op: *op, - ty: *ty, + ty: ty.clone(), is_column_left: true, })); } @@ -430,7 +430,7 @@ impl ScalarExpression { column_expr: ScalarExpression::ColumnRef(col), val_expr: mem::replace(left_expr, ScalarExpression::Empty), op: *op, - ty: *ty, + ty: ty.clone(), is_column_left: false, })); } @@ -445,7 +445,7 @@ impl ScalarExpression { column_expr: ScalarExpression::ColumnRef(col), val_expr: mem::replace(right_expr, ScalarExpression::Empty), op: *op, - ty: *ty, + ty: ty.clone(), is_column_left: true, })); } @@ -454,7 +454,7 @@ impl ScalarExpression { column_expr: ScalarExpression::ColumnRef(col), val_expr: mem::replace(left_expr, ScalarExpression::Empty), op: *op, - ty: *ty, + ty: ty.clone(), is_column_left: false, })); } @@ -492,7 +492,7 @@ impl ScalarExpression { let value = if let Some(evaluator) = evaluator { evaluator.0.unary_eval(&value) } else { - EvaluatorFactory::unary_create(*ty, *op)? + EvaluatorFactory::unary_create(ty.clone(), *op)? .0 .unary_eval(&value) }; @@ -502,7 +502,7 @@ impl ScalarExpression { replaces.push(Replace::Unary(ReplaceUnary { child_expr: expr.as_ref().clone(), op: *op, - ty: *ty, + ty: ty.clone(), })); } } diff --git a/src/optimizer/core/histogram.rs b/src/optimizer/core/histogram.rs index 5d2a08c9..480fe4f4 100644 --- a/src/optimizer/core/histogram.rs +++ b/src/optimizer/core/histogram.rs @@ -317,7 +317,7 @@ impl Histogram { | LogicalType::Float | LogicalType::Double | LogicalType::Decimal(_, _) => value.clone().cast(&LogicalType::Double)?.double(), - LogicalType::Tuple => match value { + LogicalType::Tuple(_) => match value { DataValue::Tuple(Some(values)) => { let mut float = 0.0; diff --git a/src/optimizer/rule/implementation/dql/table_scan.rs b/src/optimizer/rule/implementation/dql/table_scan.rs index a90dcbfd..92802171 100644 --- a/src/optimizer/rule/implementation/dql/table_scan.rs +++ b/src/optimizer/rule/implementation/dql/table_scan.rs @@ -37,11 +37,7 @@ impl ImplementationRule for SeqScanImplementation { let cost = scan_op .index_infos .iter() - .find(|index_info| { - let column_ids = &index_info.meta.column_ids; - - column_ids.len() == 1 && column_ids[0] == scan_op.primary_key - }) + .find(|index_info| index_info.meta.column_ids == scan_op.primary_keys) .map(|index_info| loader.load(&scan_op.table_name, index_info.meta.id)) .transpose()? .flatten() diff --git a/src/optimizer/rule/normalization/simplification.rs b/src/optimizer/rule/normalization/simplification.rs index d55c315f..d48f6b5a 100644 --- a/src/optimizer/rule/normalization/simplification.rs +++ b/src/optimizer/rule/normalization/simplification.rs @@ -260,12 +260,7 @@ mod test { }, }, false, - ColumnDesc { - column_datatype: LogicalType::Integer, - is_primary: true, - is_unique: false, - default: None, - }, + ColumnDesc::new(LogicalType::Integer, true, false, None)?, false, ); let c2_col = ColumnCatalog::direct_new( @@ -278,12 +273,7 @@ mod test { }, }, false, - ColumnDesc { - column_datatype: LogicalType::Integer, - is_primary: false, - is_unique: true, - default: None, - }, + ColumnDesc::new(LogicalType::Integer, false, true, None)?, false, ); diff --git a/src/planner/operator/table_scan.rs b/src/planner/operator/table_scan.rs index 3f50f972..af2fa59c 100644 --- a/src/planner/operator/table_scan.rs +++ b/src/planner/operator/table_scan.rs @@ -12,7 +12,7 @@ use std::fmt::Formatter; #[derive(Debug, PartialEq, Eq, Clone, Hash, ReferenceSerialization)] pub struct TableScanOperator { pub(crate) table_name: TableName, - pub(crate) primary_key: ColumnId, + pub(crate) primary_keys: Vec, pub(crate) columns: Vec<(usize, ColumnRef)>, // Support push down limit. pub(crate) limit: Bounds, @@ -24,18 +24,16 @@ pub struct TableScanOperator { impl TableScanOperator { pub fn build(table_name: TableName, table_catalog: &TableCatalog) -> LogicalPlan { - let mut primary_key_option = None; + let primary_keys = table_catalog + .primary_keys() + .iter() + .filter_map(|(_, column)| column.id()) + .collect_vec(); // Fill all Columns in TableCatalog by default let columns = table_catalog .columns() .enumerate() - .map(|(i, column)| { - if column.desc().is_primary { - primary_key_option = column.id(); - } - - (i, column.clone()) - }) + .map(|(i, column)| (i, column.clone())) .collect_vec(); let index_infos = table_catalog .indexes @@ -50,7 +48,7 @@ impl TableScanOperator { Operator::TableScan(TableScanOperator { index_infos, table_name, - primary_key: primary_key_option.unwrap(), + primary_keys, columns, limit: (None, None), }), diff --git a/src/serdes/column.rs b/src/serdes/column.rs index 7421503a..14f15e0d 100644 --- a/src/serdes/column.rs +++ b/src/serdes/column.rs @@ -190,12 +190,7 @@ pub(crate) mod test { }, }, false, - ColumnDesc { - column_datatype: LogicalType::Integer, - is_primary: false, - is_unique: false, - default: None, - }, + ColumnDesc::new(LogicalType::Integer, false, false, None)?, false, ))); @@ -228,14 +223,14 @@ pub(crate) mod test { relation: ColumnRelation::None, }, false, - ColumnDesc { - column_datatype: LogicalType::Integer, - is_primary: false, - is_unique: false, - default: Some(ScalarExpression::Constant(Arc::new(DataValue::UInt64( + ColumnDesc::new( + LogicalType::Integer, + false, + false, + Some(ScalarExpression::Constant(Arc::new(DataValue::UInt64( Some(42), )))), - }, + )?, false, ))); not_ref_column.encode(&mut cursor, false, &mut reference_tables)?; @@ -318,14 +313,14 @@ pub(crate) mod test { fn test_column_desc_serialization() -> Result<(), DatabaseError> { let mut cursor = Cursor::new(Vec::new()); let mut reference_tables = ReferenceTables::new(); - let desc = ColumnDesc { - column_datatype: LogicalType::Integer, - is_primary: false, - is_unique: false, - default: Some(ScalarExpression::Constant(Arc::new(DataValue::UInt64( + let desc = ColumnDesc::new( + LogicalType::Integer, + false, + false, + Some(ScalarExpression::Constant(Arc::new(DataValue::UInt64( Some(42), )))), - }; + )?; desc.encode(&mut cursor, false, &mut reference_tables)?; cursor.seek(SeekFrom::Start(0))?; diff --git a/src/serdes/data_value.rs b/src/serdes/data_value.rs index 8bf606d2..45733dca 100644 --- a/src/serdes/data_value.rs +++ b/src/serdes/data_value.rs @@ -1,33 +1,40 @@ use crate::errors::DatabaseError; use crate::serdes::{ReferenceSerialization, ReferenceTables}; use crate::storage::{TableCache, Transaction}; -use crate::types::value::{DataValue, ValueRef}; +use crate::types::value::DataValue; use crate::types::LogicalType; use std::io::{Read, Write}; +use std::sync::Arc; -impl ReferenceSerialization for DataValue { - fn encode( +impl DataValue { + // FIXME: redundant code + pub(crate) fn inner_encode( &self, writer: &mut W, - is_direct: bool, - reference_tables: &mut ReferenceTables, + ty: &LogicalType, ) -> Result<(), DatabaseError> { - let logical_type = self.logical_type(); - - logical_type.encode(writer, is_direct, reference_tables)?; - self.is_null().encode(writer, is_direct, reference_tables)?; + writer.write_all(&[if self.is_null() { 0u8 } else { 1u8 }])?; if self.is_null() { return Ok(()); } if let DataValue::Tuple(values) = self { - values.encode(writer, is_direct, reference_tables)?; + match values { + None => writer.write_all(&[0u8])?, + Some(values) => { + writer.write_all(&[1u8])?; + writer.write_all(&(values.len() as u32).to_le_bytes())?; + for value in values.iter() { + value.inner_encode(writer, &value.logical_type())? + } + } + } + return Ok(()); } - if logical_type.raw_len().is_none() { + if ty.raw_len().is_none() { let mut bytes = Vec::new(); - self.to_raw(&mut bytes)? - .encode(writer, is_direct, reference_tables)?; + writer.write_all(&(self.to_raw(&mut bytes)? as u32).to_le_bytes())?; writer.write_all(&bytes)?; } else { let _ = self.to_raw(writer)?; @@ -36,31 +43,72 @@ impl ReferenceSerialization for DataValue { Ok(()) } - fn decode( + pub(crate) fn inner_decode( reader: &mut R, - drive: Option<(&T, &TableCache)>, - reference_tables: &ReferenceTables, + ty: &LogicalType, ) -> Result { - let logical_type = LogicalType::decode(reader, drive, reference_tables)?; - - if bool::decode(reader, drive, reference_tables)? { - return Ok(DataValue::none(&logical_type)); + let mut bytes = [0u8; 1]; + reader.read_exact(&mut bytes)?; + if bytes[0] == 0 { + return Ok(DataValue::none(ty)); } - if matches!(logical_type, LogicalType::Tuple) { - return Ok(DataValue::Tuple(Option::>::decode( - reader, - drive, - reference_tables, - )?)); + if let LogicalType::Tuple(types) = ty { + let mut bytes = [0u8; 1]; + reader.read_exact(&mut bytes)?; + let values = match bytes[0] { + 0 => None, + 1 => { + let mut bytes = [0u8; 4]; + reader.read_exact(&mut bytes)?; + let len = u32::from_le_bytes(bytes) as usize; + let mut vec = Vec::with_capacity(len); + + for ty in types.iter() { + vec.push(Arc::new(Self::inner_decode(reader, ty)?)); + } + Some(vec) + } + _ => unreachable!(), + }; + + return Ok(DataValue::Tuple(values)); } - let value_len = match logical_type.raw_len() { - None => usize::decode(reader, drive, reference_tables)?, + let value_len = match ty.raw_len() { + None => { + let mut bytes = [0u8; 4]; + reader.read_exact(&mut bytes)?; + u32::from_le_bytes(bytes) as usize + } Some(len) => len, }; let mut buf = vec![0u8; value_len]; reader.read_exact(&mut buf)?; - Ok(DataValue::from_raw(&buf, &logical_type)) + Ok(DataValue::from_raw(&buf, ty)) + } +} + +impl ReferenceSerialization for DataValue { + fn encode( + &self, + writer: &mut W, + is_direct: bool, + reference_tables: &mut ReferenceTables, + ) -> Result<(), DatabaseError> { + let ty = self.logical_type(); + ty.encode(writer, is_direct, reference_tables)?; + + self.inner_encode(writer, &ty) + } + + fn decode( + reader: &mut R, + drive: Option<(&T, &TableCache)>, + reference_tables: &ReferenceTables, + ) -> Result { + let logical_type = LogicalType::decode(reader, drive, reference_tables)?; + + Self::inner_decode(reader, &logical_type) } } diff --git a/src/storage/mod.rs b/src/storage/mod.rs index ef15e692..e73c8967 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -60,8 +60,8 @@ pub trait Transaction: Sized { .ok_or(DatabaseError::TableNotFound)?; let table_types = table.types(); if columns.is_empty() { - let (i, column) = table.primary_key()?; - columns.push((i, column.clone())); + let (i, column) = &table.primary_keys()[0]; + columns.push((*i, column.clone())); } let mut tuple_columns = Vec::with_capacity(columns.len()); let mut projections = Vec::with_capacity(columns.len()); @@ -234,7 +234,7 @@ pub trait Transaction: Sized { let mut generator = Generator::new(); let col_id = table.add_column(column.clone(), &mut generator)?; - if column.desc().is_unique { + if column.desc().is_unique() { let meta_ref = table.add_index_meta( format!("uk_{}", column.name()), vec![col_id], @@ -313,9 +313,10 @@ pub trait Transaction: Sized { if_not_exists: bool, ) -> Result { let mut table_catalog = TableCatalog::new(table_name.clone(), columns)?; - let (_, column) = table_catalog.primary_key()?; - TableCodec::check_primary_key_type(column.datatype())?; + for (_, column) in table_catalog.primary_keys() { + TableCodec::check_primary_key_type(column.datatype())?; + } let (table_key, value) = TableCodec::encode_root_table(&TableMeta::empty(table_name.clone()))?; @@ -533,15 +534,15 @@ pub trait Transaction: Sized { let table_name = table.name.clone(); let index_column = table .columns() - .filter(|column| column.desc().is_primary || column.desc().is_unique) + .filter(|column| column.desc().is_primary() || column.desc().is_unique()) .map(|column| (column.id().unwrap(), column.clone())) .collect_vec(); for (col_id, col) in index_column { - let is_primary = col.desc().is_primary; + let is_primary = col.desc().is_primary(); let index_ty = if is_primary { IndexType::PrimaryKey - } else if col.desc().is_unique { + } else if col.desc().is_unique() { IndexType::Unique } else { continue; @@ -739,7 +740,7 @@ fn secondary_index_lookup( bytes: &Bytes, params: &IndexImplParams, ) -> Result { - let tuple_id = TableCodec::decode_index(bytes, ¶ms.index_meta.pk_ty); + let tuple_id = TableCodec::decode_index(bytes, ¶ms.index_meta.pk_ty)?; params .get_tuple_by_id(&tuple_id)? .ok_or_else(|| DatabaseError::NotFound("index's tuple_id", tuple_id.to_string())) @@ -765,7 +766,7 @@ impl IndexImpl for UniqueIndexImpl { .ok_or_else(|| { DatabaseError::NotFound("secondary index", format!("index_value -> {}", value)) })?; - let tuple_id = TableCodec::decode_index(&bytes, ¶ms.index_meta.pk_ty); + let tuple_id = TableCodec::decode_index(&bytes, ¶ms.index_meta.pk_ty)?; let tuple = params.get_tuple_by_id(&tuple_id)?.ok_or_else(|| { DatabaseError::NotFound("secondary index", format!("tuple_id -> {}", value)) })?; diff --git a/src/storage/table_codec.rs b/src/storage/table_codec.rs index 826882c9..25189376 100644 --- a/src/storage/table_codec.rs +++ b/src/storage/table_codec.rs @@ -37,6 +37,27 @@ enum CodecType { } impl TableCodec { + pub fn check_primary_key(value: &DataValue, indentation: usize) -> Result<(), DatabaseError> { + if indentation > 1 { + return Err(DatabaseError::PrimaryKeyTooManyLayers); + } + if value.is_null() { + return Err(DatabaseError::NotNull); + } + + if let DataValue::Tuple(Some(values)) = &value { + for value in values { + Self::check_primary_key(value, indentation + 1)? + } + + return Ok(()); + } else { + Self::check_primary_key_type(&value.logical_type())?; + } + + Ok(()) + } + pub fn check_primary_key_type(ty: &LogicalType) -> Result<(), DatabaseError> { if !matches!( ty, @@ -218,7 +239,7 @@ impl TableCodec { table_name: &str, tuple_id: &TupleId, ) -> Result, DatabaseError> { - Self::check_primary_key_type(&tuple_id.logical_type())?; + Self::check_primary_key(tuple_id, 0)?; let mut key_prefix = Self::key_prefix(CodecType::Tuple, table_name); key_prefix.push(BOUND_MIN_TAG); @@ -279,7 +300,8 @@ impl TableCodec { ) -> Result<(Bytes, Bytes), DatabaseError> { let key = TableCodec::encode_index_key(name, index, Some(tuple_id))?; let mut bytes = Vec::new(); - tuple_id.to_raw(&mut bytes)?; + + tuple_id.inner_encode(&mut bytes, &tuple_id.logical_type())?; Ok((Bytes::from(key), Bytes::from(bytes))) } @@ -327,8 +349,14 @@ impl TableCodec { Ok(key_prefix) } - pub fn decode_index(bytes: &[u8], primary_key_ty: &LogicalType) -> TupleId { - Arc::new(DataValue::from_raw(bytes, primary_key_ty)) + pub fn decode_index( + bytes: &[u8], + primary_key_ty: &LogicalType, + ) -> Result { + Ok(Arc::new(DataValue::inner_decode( + &mut Cursor::new(bytes), + primary_key_ty, + )?)) } /// Key: {TableName}{COLUMN_TAG}{BOUND_MIN_TAG}{ColumnId} @@ -578,7 +606,7 @@ mod tests { let (_, bytes) = TableCodec::encode_index(&table_catalog.name, &index, &tuple_id)?; debug_assert_eq!( - TableCodec::decode_index(&bytes, &tuple_id.logical_type()), + TableCodec::decode_index(&bytes, &tuple_id.logical_type())?, tuple_id ); @@ -672,12 +700,7 @@ mod tests { let mut col = ColumnCatalog::new( "".to_string(), false, - ColumnDesc { - column_datatype: LogicalType::Invalid, - is_primary: false, - is_unique: false, - default: None, - }, + ColumnDesc::new(LogicalType::SqlNull, false, false, None).unwrap(), ); col.summary_mut().relation = ColumnRelation::Table { diff --git a/src/types/evaluator/mod.rs b/src/types/evaluator/mod.rs index 58b64962..933ac317 100644 --- a/src/types/evaluator/mod.rs +++ b/src/types/evaluator/mod.rs @@ -229,7 +229,7 @@ impl EvaluatorFactory { }, LogicalType::SqlNull => Ok(BinaryEvaluatorBox(Arc::new(NullBinaryEvaluator))), LogicalType::Invalid => Err(DatabaseError::InvalidType), - LogicalType::Tuple => match op { + LogicalType::Tuple(_) => match op { BinaryOperator::Eq => Ok(BinaryEvaluatorBox(Arc::new(TupleEqBinaryEvaluator))), BinaryOperator::NotEq => { Ok(BinaryEvaluatorBox(Arc::new(TupleNotEqBinaryEvaluator))) diff --git a/src/types/mod.rs b/src/types/mod.rs index 08c42ffa..6be2cd5f 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -23,7 +23,6 @@ pub type ColumnId = Ulid; #[derive( Debug, Clone, - Copy, PartialEq, Eq, Hash, @@ -55,7 +54,7 @@ pub enum LogicalType { Time, // decimal (precision, scale) Decimal(Option, Option), - Tuple, + Tuple(Vec), } impl LogicalType { @@ -121,7 +120,7 @@ impl LogicalType { LogicalType::Date => Some(4), LogicalType::DateTime => Some(8), LogicalType::Time => Some(4), - LogicalType::Invalid | LogicalType::Tuple => unreachable!(), + LogicalType::Invalid | LogicalType::Tuple(_) => unreachable!(), } } @@ -185,12 +184,19 @@ impl LogicalType { right: &LogicalType, ) -> Result { if left == right { - return Ok(*left); + return Ok(left.clone()); } match (left, right) { // SqlNull type can be cast to anything - (LogicalType::SqlNull, _) => return Ok(*right), - (_, LogicalType::SqlNull) => return Ok(*left), + (LogicalType::SqlNull, _) => return Ok(right.clone()), + (_, LogicalType::SqlNull) => return Ok(left.clone()), + (LogicalType::Tuple(types_0), LogicalType::Tuple(types_1)) => { + if types_0.len() > types_1.len() { + return Ok(left.clone()); + } else { + return Ok(right.clone()); + } + } _ => {} } if left.is_numeric() && right.is_numeric() { @@ -223,7 +229,7 @@ impl LogicalType { { return Ok(LogicalType::Varchar(None, CharLengthUnits::Characters)); } - Err(DatabaseError::Incomparable(*left, *right)) + Err(DatabaseError::Incomparable(left.clone(), right.clone())) } fn combine_numeric_types( @@ -231,7 +237,7 @@ impl LogicalType { right: &LogicalType, ) -> Result { if left == right { - return Ok(*left); + return Ok(left.clone()); } if left.is_signed_numeric() && right.is_unsigned_numeric() { // this method is symmetric @@ -241,10 +247,10 @@ impl LogicalType { } if LogicalType::can_implicit_cast(left, right) { - return Ok(*right); + return Ok(right.clone()); } if LogicalType::can_implicit_cast(right, left) { - return Ok(*left); + return Ok(left.clone()); } // we can't cast implicitly either way and types are not equal // this happens when left is signed and right is unsigned @@ -255,7 +261,7 @@ impl LogicalType { (LogicalType::Integer, _) | (_, LogicalType::UInteger) => Ok(LogicalType::Bigint), (LogicalType::Smallint, _) | (_, LogicalType::USmallint) => Ok(LogicalType::Integer), (LogicalType::Tinyint, _) | (_, LogicalType::UTinyint) => Ok(LogicalType::Smallint), - _ => Err(DatabaseError::Incomparable(*left, *right)), + _ => Err(DatabaseError::Incomparable(left.clone(), right.clone())), } } @@ -333,7 +339,7 @@ impl LogicalType { LogicalType::Time => { matches!(to, LogicalType::Varchar(..) | LogicalType::Char(..)) } - LogicalType::Decimal(_, _) | LogicalType::Tuple => false, + LogicalType::Decimal(_, _) | LogicalType::Tuple(_) => false, } } } @@ -529,7 +535,11 @@ pub(crate) mod test { &mut reference_tables, LogicalType::Decimal(None, None), )?; - fn_assert(&mut cursor, &mut reference_tables, LogicalType::Tuple)?; + fn_assert( + &mut cursor, + &mut reference_tables, + LogicalType::Tuple(vec![LogicalType::Integer]), + )?; Ok(()) } diff --git a/src/types/tuple.rs b/src/types/tuple.rs index 54b1b427..eba5eff2 100644 --- a/src/types/tuple.rs +++ b/src/types/tuple.rs @@ -24,7 +24,10 @@ pub type Schema = Vec; pub type SchemaRef = Arc; pub fn types(schema: &Schema) -> Vec { - schema.iter().map(|column| *column.datatype()).collect_vec() + schema + .iter() + .map(|column| column.datatype().clone()) + .collect_vec() } #[derive(Clone, Debug, PartialEq)] @@ -50,7 +53,7 @@ impl Tuple { let values_len = schema.len(); let mut tuple_values = Vec::with_capacity(values_len); let bits_len = (values_len + BITS_MAX_INDEX) / BITS_MAX_INDEX; - let mut id_option = None; + let mut primary_keys = Vec::new(); let mut projection_i = 0; let mut pos = bits_len; @@ -62,7 +65,7 @@ impl Tuple { if is_none(bytes[i / BITS_MAX_INDEX], i % BITS_MAX_INDEX) { if projections[projection_i] == i { tuple_values.push(Arc::new(DataValue::none(logic_type))); - Self::values_push(schema, &tuple_values, &mut id_option, &mut projection_i); + Self::values_push(schema, &tuple_values, &mut primary_keys, &mut projection_i); } } else if let Some(len) = logic_type.raw_len() { /// fixed length (e.g.: int) @@ -71,7 +74,7 @@ impl Tuple { &bytes[pos..pos + len], logic_type, ))); - Self::values_push(schema, &tuple_values, &mut id_option, &mut projection_i); + Self::values_push(schema, &tuple_values, &mut primary_keys, &mut projection_i); } pos += len; } else { @@ -83,14 +86,21 @@ impl Tuple { &bytes[pos..pos + len], logic_type, ))); - Self::values_push(schema, &tuple_values, &mut id_option, &mut projection_i); + Self::values_push(schema, &tuple_values, &mut primary_keys, &mut projection_i); } pos += len; } } + let id = (!primary_keys.is_empty()).then(|| { + if primary_keys.len() == 1 { + primary_keys.pop().unwrap() + } else { + Arc::new(DataValue::Tuple(Some(primary_keys))) + } + }); Tuple { - id: id_option, + id, values: tuple_values, } } @@ -98,11 +108,11 @@ impl Tuple { fn values_push( tuple_columns: &Schema, tuple_values: &[ValueRef], - id_option: &mut Option>, + primary_keys: &mut Vec, projection_i: &mut usize, ) { - if tuple_columns[*projection_i].desc().is_primary { - let _ = id_option.replace(tuple_values[*projection_i].clone()); + if tuple_columns[*projection_i].desc().is_primary() { + primary_keys.push(tuple_values[*projection_i].clone()) } *projection_i += 1; } @@ -124,7 +134,7 @@ impl Tuple { if value.is_null() { bytes[i / BITS_MAX_INDEX] = flip_bit(bytes[i / BITS_MAX_INDEX], i % BITS_MAX_INDEX); } else { - let logical_type = types[i]; + let logical_type = &types[i]; let value_len = value.to_raw(&mut bytes)?; if logical_type.raw_len().is_none() { @@ -365,7 +375,7 @@ mod tests { ]; let types = columns .iter() - .map(|column| *column.datatype()) + .map(|column| column.datatype().clone()) .collect_vec(); let columns = Arc::new(columns); diff --git a/src/types/tuple_builder.rs b/src/types/tuple_builder.rs index 87bb1089..97f05165 100644 --- a/src/types/tuple_builder.rs +++ b/src/types/tuple_builder.rs @@ -28,7 +28,7 @@ impl<'a> TupleBuilder<'a> { row: impl IntoIterator, ) -> Result { let mut values = Vec::with_capacity(self.schema.len()); - let mut primary_key = None; + let mut primary_keys = Vec::new(); for (i, value) in row.into_iter().enumerate() { let data_value = Arc::new( @@ -40,18 +40,22 @@ impl<'a> TupleBuilder<'a> { .cast(self.schema[i].datatype())?, ); - if primary_key.is_none() && self.schema[i].desc().is_primary { - primary_key = Some(data_value.clone()); + if self.schema[i].desc().is_primary() { + primary_keys.push(data_value.clone()); } values.push(data_value); } if values.len() != self.schema.len() { return Err(DatabaseError::MisMatch("types", "values")); } + let id = (!primary_keys.is_empty()).then(|| { + if primary_keys.len() == 1 { + primary_keys.pop().unwrap() + } else { + Arc::new(DataValue::Tuple(Some(primary_keys))) + } + }); - Ok(Tuple { - id: primary_key, - values, - }) + Ok(Tuple { id, values }) } } diff --git a/src/types/value.rs b/src/types/value.rs index 01d727b2..b9ea9c21 100644 --- a/src/types/value.rs +++ b/src/types/value.rs @@ -1,7 +1,9 @@ +use super::LogicalType; use crate::errors::DatabaseError; use chrono::format::{DelayedFormat, StrftimeItems}; use chrono::{DateTime, Datelike, NaiveDate, NaiveDateTime, NaiveTime, Timelike}; use integer_encoding::{FixedInt, FixedIntWriter}; +use itertools::Itertools; use lazy_static::lazy_static; use ordered_float::OrderedFloat; use rust_decimal::prelude::{FromPrimitive, ToPrimitive}; @@ -15,8 +17,6 @@ use std::str::FromStr; use std::sync::Arc; use std::{cmp, fmt, mem}; -use super::LogicalType; - lazy_static! { pub static ref NULL_VALUE: ValueRef = Arc::new(DataValue::Null); static ref UNIX_DATETIME: NaiveDateTime = DateTime::from_timestamp(0, 0).unwrap().naive_utc(); @@ -443,7 +443,7 @@ impl DataValue { LogicalType::DateTime => DataValue::Date64(None), LogicalType::Time => DataValue::Time(None), LogicalType::Decimal(_, _) => DataValue::Decimal(None), - LogicalType::Tuple => DataValue::Tuple(None), + LogicalType::Tuple(_) => DataValue::Tuple(None), } } @@ -476,7 +476,14 @@ impl DataValue { LogicalType::DateTime => DataValue::Date64(Some(UNIX_DATETIME.and_utc().timestamp())), LogicalType::Time => DataValue::Time(Some(UNIX_TIME.num_seconds_from_midnight())), LogicalType::Decimal(_, _) => DataValue::Decimal(Some(Decimal::new(0, 0))), - LogicalType::Tuple => DataValue::Tuple(Some(vec![])), + LogicalType::Tuple(types) => { + let values = types + .iter() + .map(|ty| Arc::new(DataValue::init(ty))) + .collect_vec(); + + DataValue::Tuple(Some(values)) + } } } @@ -675,7 +682,7 @@ impl DataValue { (!bytes.is_empty()) .then(|| Decimal::deserialize(<[u8; 16]>::try_from(bytes).unwrap())), ), - LogicalType::Tuple => unreachable!(), + LogicalType::Tuple(_) => unreachable!(), } } @@ -707,7 +714,14 @@ impl DataValue { DataValue::Date64(_) => LogicalType::DateTime, DataValue::Time(_) => LogicalType::Time, DataValue::Decimal(_) => LogicalType::Decimal(None, None), - DataValue::Tuple(_) => LogicalType::Tuple, + DataValue::Tuple(values) => { + if let Some(values) = values { + let types = values.iter().map(|v| v.logical_type()).collect_vec(); + LogicalType::Tuple(types) + } else { + LogicalType::Tuple(vec![]) + } + } } } @@ -856,7 +870,7 @@ impl DataValue { LogicalType::DateTime => Ok(DataValue::Date64(None)), LogicalType::Time => Ok(DataValue::Time(None)), LogicalType::Decimal(_, _) => Ok(DataValue::Decimal(None)), - LogicalType::Tuple => Ok(DataValue::Tuple(None)), + LogicalType::Tuple(_) => Ok(DataValue::Tuple(None)), }, DataValue::Boolean(value) => match to { LogicalType::SqlNull => Ok(DataValue::Null), @@ -1364,7 +1378,16 @@ impl DataValue { _ => Err(DatabaseError::CastFail), }, DataValue::Tuple(values) => match to { - LogicalType::Tuple => Ok(DataValue::Tuple(values)), + LogicalType::Tuple(types) => Ok(if let Some(mut values) = values { + for (i, value) in values.iter_mut().enumerate() { + if types[i] != value.logical_type() { + *value = Arc::new(DataValue::clone(value).cast(&types[i])?); + } + } + DataValue::Tuple(Some(values)) + } else { + DataValue::Tuple(None) + }), _ => Err(DatabaseError::CastFail), }, }?; diff --git a/tests/slt/create.slt b/tests/slt/create.slt index 046c3031..1c532450 100644 --- a/tests/slt/create.slt +++ b/tests/slt/create.slt @@ -1,6 +1,9 @@ statement ok create table t(id int primary key, v1 int, v2 int, v3 int) +statement ok +create table t_m(ida int primary key, idb int primary key, v1 int, v2 int, v3 int) + statement error create table t(id int primary key, v1 int, v2 int, v3 int) diff --git a/tests/slt/delete.slt b/tests/slt/delete.slt index 2231b78f..69cb04ac 100644 --- a/tests/slt/delete.slt +++ b/tests/slt/delete.slt @@ -25,4 +25,4 @@ select * from t ---- statement ok -drop table t \ No newline at end of file +drop table t diff --git a/tests/slt/delete_multiple_primary_keys.slt b/tests/slt/delete_multiple_primary_keys.slt new file mode 100644 index 00000000..de549931 --- /dev/null +++ b/tests/slt/delete_multiple_primary_keys.slt @@ -0,0 +1,28 @@ +statement ok +create table t(id_0 int primary key, id_1 int primary key, v1 int, v2 int, v3 int) + +statement ok +insert into t values (0,0,1,10,100) + +statement ok +insert into t values (1,1,1,10,100), (2,2,2,20,200), (3,3,3,30,300), (4,4,4,40,400) + +statement ok +delete from t where v1 = 1 + +query III rowsort +select * from t; +---- +2 2 2 20 200 +3 3 3 30 300 +4 4 4 40 400 + +statement ok +delete from t + +query III rowsort +select * from t +---- + +statement ok +drop table t diff --git a/tests/slt/describe.slt b/tests/slt/describe.slt index d2cfb6a8..18bd3071 100644 --- a/tests/slt/describe.slt +++ b/tests/slt/describe.slt @@ -9,4 +9,17 @@ c2 INTEGER 4 true EMPTY 0 c3 VARCHAR null true UNIQUE null statement ok -drop table t9; \ No newline at end of file +drop table t9; + +statement ok +create table t9_m (c1 int primary key, c2 int primary key, c3 varchar unique); + +query TTTTI +describe t9_m; +---- +c1 INTEGER 4 false PRIMARY null +c2 INTEGER 4 false PRIMARY null +c3 VARCHAR null true UNIQUE null + +statement ok +drop table t9_m; diff --git a/tests/slt/insert.slt b/tests/slt/insert.slt index b38e252e..1c004e13 100644 --- a/tests/slt/insert.slt +++ b/tests/slt/insert.slt @@ -99,5 +99,3 @@ true statement ok drop table t2; - - diff --git a/tests/slt/insert_multiple_primary_keys.slt b/tests/slt/insert_multiple_primary_keys.slt new file mode 100644 index 00000000..d20662d3 --- /dev/null +++ b/tests/slt/insert_multiple_primary_keys.slt @@ -0,0 +1,101 @@ +statement ok +create table t(id_0 int primary key, id_1 int primary key, v1 bigint null, v2 varchar null, v3 decimal null) + +statement ok +insert into t values (0,0,1,10,100) + +statement ok +insert into t values (1,1,1,10,100), (2,2,2,20,200), (3,3,3,30,300), (4,4,4,40,400) + +statement ok +insert into t(id_0, id_1, v1, v2, v3) values (5,5,1,10,100) + +statement ok +insert into t(id_0, id_1, v1, v2) values (6,6,1,10) + +statement ok +insert into t(id_0, id_1, v2, v1) values (7,7,1,10) + +statement error +insert into t(id_0, id_1, v1, v2, v3) values (0, 0) + +statement error +insert into t(id_0, id_1, v1, v2, v3) values (0, 0, 0) + +statement error +insert into t(id_0, id_1, v1, v2, v3) values (0, 0, 0, 0) + +statement ok +insert into t values (8,8,NULL,NULL,NULL) + +statement ok +insert overwrite t values (1, 1, 9, 9, 9) + +query IIII rowsort +select * from t +---- +0 0 1 10 100 +1 1 9 9 9 +2 2 2 20 200 +3 3 3 30 300 +4 4 4 40 400 +5 5 1 10 100 +6 6 1 10 null +7 7 10 1 null +8 8 null null null + +statement ok +drop table t; + +statement ok +create table t1(id_0 int primary key, id_1 int primary key, v1 bigint default 233) + +statement ok +insert into t1 values (0, 0) + +statement ok +insert into t1 values (1, 1) + +statement ok +insert into t1 values (2, 2) + +statement ok +insert into t1 values (3, 3, DEFAULT) + +statement ok +insert into t1 values (4, 4, 0) + +statement ok +insert into t1 (v1, id_0, id_1) values (DEFAULT, 5, 5) + +query III rowsort +select * from t1 +---- +0 0 233 +1 1 233 +2 2 233 +3 3 233 +4 4 0 +5 5 233 + +statement ok +drop table t1; + +statement ok +create table t2(id_0 int primary key, id_1 int primary key, v1 char(10), v2 varchar); + +statement ok +insert into t2 (id_0, id_1, v1, v2) values (0, 0, 'foo', 'foo'); + +query ITT +select * from t2; +---- +0 0 foo foo + +query B +select v1 = v2 from t2; +---- +true + +statement ok +drop table t2; diff --git a/tests/slt/sql_2016/E141_03.slt b/tests/slt/sql_2016/E141_03.slt index 1b7f869d..325d90a7 100644 --- a/tests/slt/sql_2016/E141_03.slt +++ b/tests/slt/sql_2016/E141_03.slt @@ -1,18 +1,16 @@ # E141-03: PRIMARY KEY constraints -# TODO: Multiple primary keys - statement ok CREATE TABLE TABLE_E141_03_01_01 ( A INT NOT NULL, B INT NOT NULL, CONSTRAINT CONST_E141_03_01_01 PRIMARY KEY ( A ) ) -# statement ok -# CREATE TABLE TABLE_E141_03_01_02 ( A INT NOT NULL, B INT NOT NULL, CONSTRAINT CONST_E141_03_01_02 PRIMARY KEY ( A , B ) ) +statement ok +CREATE TABLE TABLE_E141_03_01_02 ( A INT NOT NULL, B INT NOT NULL, CONSTRAINT CONST_E141_03_01_02 PRIMARY KEY ( A , B ) ) statement ok CREATE TABLE TABLE_E141_03_01_03 ( A INT NOT NULL, B INT NOT NULL, PRIMARY KEY ( A ) ) -# statement ok -# CREATE TABLE TABLE_E141_03_01_04 ( A INT NOT NULL, B INT NOT NULL, PRIMARY KEY ( A , B ) ) +statement ok +CREATE TABLE TABLE_E141_03_01_04 ( A INT NOT NULL, B INT NOT NULL, PRIMARY KEY ( A , B ) ) statement ok CREATE TABLE TABLE_E141_03_02_01 ( A INT NOT NULL CONSTRAINT CONST_E141_03_02_01 PRIMARY KEY ) diff --git a/tests/slt/sql_2016/E141_08.slt b/tests/slt/sql_2016/E141_08.slt index 45c810c1..c7404f7d 100644 --- a/tests/slt/sql_2016/E141_08.slt +++ b/tests/slt/sql_2016/E141_08.slt @@ -21,16 +21,14 @@ CREATE TABLE TABLE_E141_08_02_02 ( ID INT PRIMARY KEY, A INT UNIQUE ) statement ok CREATE TABLE TABLE_E141_08_03_01 ( A INT, B INT, CONSTRAINT CONST_E141_08_03_01 PRIMARY KEY ( A ) ) -# TODO: Multiple primary keys -# statement ok -# CREATE TABLE TABLE_E141_08_03_02 ( A INT, B INT, CONSTRAINT CONST_E141_08_03_02 PRIMARY KEY ( A , B ) ) +statement ok +CREATE TABLE TABLE_E141_08_03_02 ( A INT, B INT, CONSTRAINT CONST_E141_08_03_02 PRIMARY KEY ( A , B ) ) statement ok CREATE TABLE TABLE_E141_08_03_03 ( A INT, B INT, PRIMARY KEY ( A ) ) -# TODO: Multiple primary keys -# statement ok -# CREATE TABLE TABLE_E141_08_03_04 ( A INT, B INT, PRIMARY KEY ( A , B ) ) +statement ok +CREATE TABLE TABLE_E141_08_03_04 ( A INT, B INT, PRIMARY KEY ( A , B ) ) statement ok CREATE TABLE TABLE_E141_08_04_01 ( A INT CONSTRAINT CONST_E141_08_04_01 PRIMARY KEY ) diff --git a/tests/slt/update.slt b/tests/slt/update.slt index efc0390e..590c223d 100644 --- a/tests/slt/update.slt +++ b/tests/slt/update.slt @@ -44,4 +44,4 @@ select * from t 4 4 9 233 statement ok -drop table t \ No newline at end of file +drop table t diff --git a/tests/slt/update_multiple_primary_keys.slt b/tests/slt/update_multiple_primary_keys.slt new file mode 100644 index 00000000..18c5f129 --- /dev/null +++ b/tests/slt/update_multiple_primary_keys.slt @@ -0,0 +1,47 @@ +statement ok +create table t(id_0 int primary key, id_1 int primary key, v1 int, v2 int, v3 int default 233) + +statement ok +insert into t values (0,0,1,10,100) + +statement ok +insert into t values (1,1,1,10,100), (2,2,2,20,200), (3,3,3,30,300), (4,4,4,40,400) + +statement ok +update t set v2 = 9 where v1 = 1 + +query IIII rowsort +select * from t; +---- +0 0 1 9 100 +1 1 1 9 100 +2 2 2 20 200 +3 3 3 30 300 +4 4 4 40 400 + +statement ok +update t set v2 = 9 + +query IIII rowsort +select * from t +---- +0 0 1 9 100 +1 1 1 9 100 +2 2 2 9 200 +3 3 3 9 300 +4 4 4 9 400 + +statement ok +update t set v3 = default + +query IIII rowsort +select * from t +---- +0 0 1 9 233 +1 1 1 9 233 +2 2 2 9 233 +3 3 3 9 233 +4 4 4 9 233 + +statement ok +drop table t