From a696e2a5782a0dfeaf02fa4a65ea0e4af9c39e1a Mon Sep 17 00:00:00 2001 From: Xwg Date: Mon, 25 Sep 2023 10:31:17 +0800 Subject: [PATCH] feat(type): add support for Decimal type in database (#66) Co-authored-by: Kould <2435992353@qq.com> --- Cargo.toml | 1 + src/binder/insert.rs | 2 +- src/binder/update.rs | 2 +- src/db.rs | 7 ++ src/expression/mod.rs | 2 +- src/storage/table_codec.rs | 9 +- src/types/errors.rs | 6 ++ src/types/mod.rs | 18 +++- src/types/value.rs | 163 +++++++++++++++++++++++++++++++------ tests/slt/decimal | 6 ++ 10 files changed, 183 insertions(+), 33 deletions(-) create mode 100644 tests/slt/decimal diff --git a/Cargo.toml b/Cargo.toml index 35152592..a1ba0656 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,6 +32,7 @@ comfy-table = "7.0.1" bytes = "*" kip_db = "0.1.2-alpha.15" async-recursion = "1.0.5" +rust_decimal = "1" [dev-dependencies] tokio-test = "0.4.2" diff --git a/src/binder/insert.rs b/src/binder/insert.rs index 0a5e0b94..e5e21c93 100644 --- a/src/binder/insert.rs +++ b/src/binder/insert.rs @@ -50,7 +50,7 @@ impl Binder { match &self.bind_expr(expr).await? { ScalarExpression::Constant(value) => { // Check if the value length is too long - value.check_length(columns[i].datatype())?; + value.check_len(columns[i].datatype())?; let cast_value = DataValue::clone(value) .cast(columns[i].datatype())?; row.push(Arc::new(cast_value)) diff --git a/src/binder/update.rs b/src/binder/update.rs index 92de8986..ea99b4cc 100644 --- a/src/binder/update.rs +++ b/src/binder/update.rs @@ -44,7 +44,7 @@ impl Binder { bind_table_name.as_ref() ).await? { ScalarExpression::ColumnRef(catalog) => { - value.check_length(catalog.datatype())?; + value.check_len(catalog.datatype())?; columns.push(catalog); row.push(value.clone()); }, diff --git a/src/db.rs b/src/db.rs index bc768d16..a8accc6e 100644 --- a/src/db.rs +++ b/src/db.rs @@ -186,6 +186,9 @@ mod test { let _ = kipsql.run("create table t2 (c int primary key, d int unsigned null, e datetime)").await?; let _ = kipsql.run("insert into t1 (a, b, k) values (-99, 1, 1), (-1, 2, 2), (5, 2, 2)").await?; let _ = kipsql.run("insert into t2 (d, c, e) values (2, 1, '2021-05-20 21:00:00'), (3, 4, '2023-09-10 00:00:00')").await?; + let _ = kipsql.run("create table t3 (a int primary key, b decimal(4,2))").await?; + let _ = kipsql.run("insert into t3 (a, b) values (1, 1111), (2, 2.01), (3, 3.00)").await?; + let _ = kipsql.run("insert into t3 (a, b) values (4, 4444), (5, 5222), (6, 1.00)").await?; println!("full t1:"); let tuples_full_fields_t1 = kipsql.run("select * from t1").await?; @@ -305,6 +308,10 @@ mod test { let tuples_show_tables = kipsql.run("show tables").await?; println!("{}", create_table(&tuples_show_tables)); + println!("decimal:"); + let tuples_decimal = kipsql.run("select * from t3").await?; + println!("{}", create_table(&tuples_decimal)); + Ok(()) } } diff --git a/src/expression/mod.rs b/src/expression/mod.rs index 456d26a5..30678993 100644 --- a/src/expression/mod.rs +++ b/src/expression/mod.rs @@ -1,5 +1,5 @@ use std::fmt; -use std::fmt::Formatter; +use std::fmt::{Debug, Formatter}; use std::sync::Arc; use itertools::Itertools; diff --git a/src/storage/table_codec.rs b/src/storage/table_codec.rs index 18de475a..aab27fef 100644 --- a/src/storage/table_codec.rs +++ b/src/storage/table_codec.rs @@ -146,6 +146,7 @@ mod tests { use std::ops::Bound; use std::sync::Arc; use itertools::Itertools; + use rust_decimal::Decimal; use crate::catalog::{ColumnCatalog, ColumnDesc, TableCatalog}; use crate::storage::table_codec::{COLUMNS_ID_LEN, TableCodec}; use crate::types::errors::TypeError; @@ -159,7 +160,12 @@ mod tests { "c1".into(), false, ColumnDesc::new(LogicalType::Integer, true) - ) + ), + ColumnCatalog::new( + "c2".into(), + false, + ColumnDesc::new(LogicalType::Decimal(None,None), false) + ), ]; let table_catalog = TableCatalog::new(Arc::new("t1".to_string()), columns).unwrap(); let codec = TableCodec { table: table_catalog.clone() }; @@ -175,6 +181,7 @@ mod tests { columns: table_catalog.all_columns(), values: vec![ Arc::new(DataValue::Int32(Some(0))), + Arc::new(DataValue::Decimal(Some(Decimal::new(1, 0)))), ] }; diff --git a/src/types/errors.rs b/src/types/errors.rs index cd5aa59b..c9476f62 100644 --- a/src/types/errors.rs +++ b/src/types/errors.rs @@ -46,4 +46,10 @@ pub enum TypeError { #[from] ParseError, ), + #[error("try from decimal")] + TryFromDecimal( + #[source] + #[from] + rust_decimal::Error, + ), } diff --git a/src/types/mod.rs b/src/types/mod.rs index cb4f2fce..de67aed9 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -7,6 +7,7 @@ use std::sync::atomic::Ordering::{Acquire, Release}; use serde::{Deserialize, Serialize}; use integer_encoding::FixedInt; +use sqlparser::ast::ExactNumberInfo; use strum_macros::AsRefStr; use crate::types::errors::TypeError; @@ -57,6 +58,8 @@ pub enum LogicalType { Varchar(Option), Date, DateTime, + // decimal (precision, scale) + Decimal(Option, Option), } impl LogicalType { @@ -75,8 +78,9 @@ impl LogicalType { LogicalType::UBigint => Some(8), LogicalType::Float => Some(4), LogicalType::Double => Some(8), - /// Note: The non-fixed length type's raw_len is None - LogicalType::Varchar(_)=>None, + /// Note: The non-fixed length type's raw_len is None e.g. Varchar and Decimal + LogicalType::Varchar(_) => None, + LogicalType::Decimal(_, _) => Some(16), LogicalType::Date => Some(4), LogicalType::DateTime => Some(8), } @@ -269,6 +273,7 @@ impl LogicalType { LogicalType::Varchar(_) => false, LogicalType::Date => matches!(to, LogicalType::DateTime | LogicalType::Varchar(_)), LogicalType::DateTime => matches!(to, LogicalType::Date | LogicalType::Varchar(_)), + LogicalType::Decimal(_, _) => false, } } } @@ -296,6 +301,13 @@ impl TryFrom for LogicalType { sqlparser::ast::DataType::UnsignedBigInt(_) => Ok(LogicalType::UBigint), sqlparser::ast::DataType::Boolean => Ok(LogicalType::Boolean), sqlparser::ast::DataType::Datetime(_) => Ok(LogicalType::DateTime), + sqlparser::ast::DataType::Decimal(info) => match info { + ExactNumberInfo::None => Ok(Self::Decimal(None, None)), + ExactNumberInfo::Precision(p) => Ok(Self::Decimal(Some(p as u8), None)), + ExactNumberInfo::PrecisionAndScale(p, s) => { + Ok(Self::Decimal(Some(p as u8), Some(s as u8))) + } + }, other => Err(TypeError::NotImplementedSqlparserDataType( other.to_string(), )), @@ -313,7 +325,7 @@ impl std::fmt::Display for LogicalType { mod test { use std::sync::atomic::Ordering::Release; - use crate::types::{IdGenerator, ID_BUF, LogicalType}; + use crate::types::{IdGenerator, ID_BUF}; /// Tips: 由于IdGenerator为static全局性质生成的id,因此需要单独测试避免其他测试方法干扰 #[test] diff --git a/src/types/value.rs b/src/types/value.rs index 3e02ee19..94b856f3 100644 --- a/src/types/value.rs +++ b/src/types/value.rs @@ -1,5 +1,5 @@ use std::cmp::Ordering; -use std::fmt; +use std::{fmt, mem}; use std::fmt::Formatter; use std::hash::Hash; use std::str::FromStr; @@ -8,8 +8,10 @@ use chrono::{NaiveDateTime, Datelike, NaiveDate}; use chrono::format::{DelayedFormat, StrftimeItems}; use integer_encoding::FixedInt; use lazy_static::lazy_static; +use rust_decimal::Decimal; use ordered_float::OrderedFloat; +use rust_decimal::prelude::FromPrimitive; use crate::types::errors::TypeError; use super::LogicalType; @@ -44,6 +46,7 @@ pub enum DataValue { Date32(Option), /// Date stored as a signed 64bit int timestamp since UNIX epoch 1970-01-01 Date64(Option), + Decimal(Option), } impl PartialEq for DataValue { @@ -88,6 +91,8 @@ impl PartialEq for DataValue { (Date32(_), _) => false, (Date64(v1), Date64(v2)) => v1.eq(v2), (Date64(_), _) => false, + (Decimal(v1), Decimal(v2)) => v1.eq(v2), + (Decimal(_), _) => false, } } } @@ -134,6 +139,8 @@ impl PartialOrd for DataValue { (Date32(_), _) => None, (Date64(v1), Date64(v2)) => v1.partial_cmp(v2), (Date64(_), _) => None, + (Decimal(v1), Decimal(v2)) => v1.partial_cmp(v2), + (Decimal(_), _) => None, } } } @@ -175,6 +182,7 @@ impl Hash for DataValue { Null => 1.hash(state), Date32(v) => v.hash(state), Date64(v) => v.hash(state), + Decimal(v) => v.hash(state), } } } @@ -194,32 +202,27 @@ macro_rules! varchar_cast { } impl DataValue { - pub(crate) fn check_length(&self, logic_type: &LogicalType) -> Result<(), TypeError> { - match self { - DataValue::Boolean(_) => return Ok(()), - DataValue::Float32(_) => return Ok(()), - DataValue::Float64(_) => return Ok(()), - DataValue::Int8(_) => return Ok(()), - DataValue::Int16(_) => return Ok(()), - DataValue::Int32(_) => return Ok(()), - DataValue::Int64(_) => return Ok(()), - DataValue::UInt8(_) => return Ok(()), - DataValue::UInt16(_) => return Ok(()), - DataValue::UInt32(_) => return Ok(()), - DataValue::UInt64(_) => return Ok(()), - DataValue::Date32(_) => return Ok(()), - DataValue::Date64(_) => return Ok(()), - DataValue::Utf8(value) => { - if let LogicalType::Varchar(len) = logic_type { - if let Some(len) = len { - if value.as_ref().map(|v| v.len() > *len as usize).unwrap_or(false) { - return Err(TypeError::TooLong); - } - } + pub(crate) fn check_len(&self, logic_type: &LogicalType) -> Result<(), TypeError> { + let is_over_len = match (logic_type, self) { + (LogicalType::Varchar(Some(len)), DataValue::Utf8(Some(val))) => { + val.len() > *len as usize + } + (LogicalType::Decimal(full_len, scale_len), DataValue::Decimal(Some(val))) => { + if let Some(len) = full_len { + val.mantissa().ilog10() + 1 > *len as u32 + } else if let Some(len) = scale_len { + val.scale() > *len as u32 + } else { + false } } - _ => { return Err(TypeError::InvalidType); } + _ => false + }; + + if is_over_len { + return Err(TypeError::TooLong) } + Ok(()) } @@ -259,6 +262,7 @@ impl DataValue { DataValue::Utf8(value) => value.is_none(), DataValue::Date32(value) => value.is_none(), DataValue::Date64(value) => value.is_none(), + DataValue::Decimal(value) => value.is_none(), } } @@ -279,7 +283,8 @@ impl DataValue { LogicalType::Double => DataValue::Float64(None), LogicalType::Varchar(_) => DataValue::Utf8(None), LogicalType::Date => DataValue::Date32(None), - LogicalType::DateTime => DataValue::Date64(None) + LogicalType::DateTime => DataValue::Date64(None), + LogicalType::Decimal(_, _) => DataValue::Decimal(None), } } @@ -300,7 +305,8 @@ impl DataValue { LogicalType::Double => DataValue::Float64(Some(0.0)), LogicalType::Varchar(_) => DataValue::Utf8(Some("".to_string())), LogicalType::Date => DataValue::Date32(Some(UNIX_DATETIME.num_days_from_ce())), - LogicalType::DateTime => DataValue::Date64(Some(UNIX_DATETIME.timestamp())) + LogicalType::DateTime => DataValue::Date64(Some(UNIX_DATETIME.timestamp())), + LogicalType::Decimal(_, _) => DataValue::Decimal(Some(Decimal::new(0, 0))), } } @@ -321,6 +327,7 @@ impl DataValue { DataValue::Utf8(v) => v.clone().map(|v| v.into_bytes()), DataValue::Date32(v) => v.map(|v| v.encode_fixed_vec()), DataValue::Date64(v) => v.map(|v| v.encode_fixed_vec()), + DataValue::Decimal(v) => v.clone().map(|v| v.serialize().to_vec()), }.unwrap_or(vec![]) } @@ -350,6 +357,7 @@ impl DataValue { LogicalType::Varchar(_) => DataValue::Utf8((!bytes.is_empty()).then(|| String::from_utf8(bytes.to_owned()).unwrap())), LogicalType::Date => DataValue::Date32((!bytes.is_empty()).then(|| i32::decode_fixed(bytes))), LogicalType::DateTime => DataValue::Date64((!bytes.is_empty()).then(|| i64::decode_fixed(bytes))), + LogicalType::Decimal(_, _) => DataValue::Decimal((!bytes.is_empty()).then(|| Decimal::deserialize(<[u8; 16]>::try_from(bytes).unwrap()))), } } @@ -370,6 +378,7 @@ impl DataValue { DataValue::Utf8(_) => LogicalType::Varchar(None), DataValue::Date32(_) => LogicalType::Date, DataValue::Date64(_) => LogicalType::DateTime, + DataValue::Decimal(_) => LogicalType::Decimal(None, None), } } @@ -408,6 +417,7 @@ impl DataValue { LogicalType::Varchar(_) => Ok(DataValue::Utf8(None)), LogicalType::Date => Ok(DataValue::Date32(None)), LogicalType::DateTime => Ok(DataValue::Date64(None)), + LogicalType::Decimal(_, _) => Ok(DataValue::Decimal(None)), } } DataValue::Boolean(value) => { @@ -434,6 +444,14 @@ impl DataValue { LogicalType::Float => Ok(DataValue::Float32(value)), LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))), LogicalType::Varchar(len) => varchar_cast!(value, len), + LogicalType::Decimal(_, option) =>{ + Ok(DataValue::Decimal(value.map(|v| { + let mut decimal = Decimal::from_f32(v).ok_or(TypeError::CastFail)?; + Self::decimal_round_f(option, &mut decimal); + + Ok::(decimal) + }).transpose()?)) + } _ => Err(TypeError::CastFail), } } @@ -442,6 +460,14 @@ impl DataValue { LogicalType::SqlNull => Ok(DataValue::Null), LogicalType::Double => Ok(DataValue::Float64(value)), LogicalType::Varchar(len) => varchar_cast!(value, len), + LogicalType::Decimal(_, option) => { + Ok(DataValue::Decimal(value.map(|v| { + let mut decimal = Decimal::from_f64(v).ok_or(TypeError::CastFail)?; + Self::decimal_round_f(option, &mut decimal); + + Ok::(decimal) + }).transpose()?)) + } _ => Err(TypeError::CastFail), } } @@ -459,6 +485,12 @@ impl DataValue { LogicalType::Float => Ok(DataValue::Float32(value.map(|v| v.into()))), LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))), LogicalType::Varchar(len) => varchar_cast!(value, len), + LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| { + let mut decimal = Decimal::from(v); + Self::decimal_round_i(option, &mut decimal); + + decimal + }))), _ => Err(TypeError::CastFail), } } @@ -475,6 +507,12 @@ impl DataValue { LogicalType::Float => Ok(DataValue::Float32(value.map(|v| v.into()))), LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))), LogicalType::Varchar(len) => varchar_cast!(value, len), + LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| { + let mut decimal = Decimal::from(v); + Self::decimal_round_i(option, &mut decimal); + + decimal + }))), _ => Err(TypeError::CastFail), } } @@ -489,6 +527,12 @@ impl DataValue { LogicalType::Bigint => Ok(DataValue::Int64(value.map(|v| v.into()))), LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))), LogicalType::Varchar(len) => varchar_cast!(value, len), + LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| { + let mut decimal = Decimal::from(v); + Self::decimal_round_i(option, &mut decimal); + + decimal + }))), _ => Err(TypeError::CastFail), } } @@ -501,6 +545,12 @@ impl DataValue { LogicalType::UBigint => Ok(DataValue::UInt64(value.map(|v| u64::try_from(v)).transpose()?)), LogicalType::Bigint => Ok(DataValue::Int64(value.map(|v| v.into()))), LogicalType::Varchar(len) => varchar_cast!(value, len), + LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| { + let mut decimal = Decimal::from(v); + Self::decimal_round_i(option, &mut decimal); + + decimal + }))), _ => Err(TypeError::CastFail), } } @@ -517,6 +567,12 @@ impl DataValue { LogicalType::Float => Ok(DataValue::Float32(value.map(|v| v.into()))), LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))), LogicalType::Varchar(len) => varchar_cast!(value, len), + LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| { + let mut decimal = Decimal::from(v); + Self::decimal_round_i(option, &mut decimal); + + decimal + }))), _ => Err(TypeError::CastFail), } } @@ -531,6 +587,12 @@ impl DataValue { LogicalType::Float => Ok(DataValue::Float32(value.map(|v| v.into()))), LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))), LogicalType::Varchar(len) => varchar_cast!(value, len), + LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| { + let mut decimal = Decimal::from(v); + Self::decimal_round_i(option, &mut decimal); + + decimal + }))), _ => Err(TypeError::CastFail), } } @@ -542,6 +604,12 @@ impl DataValue { LogicalType::UBigint => Ok(DataValue::UInt64(value.map(|v| v.into()))), LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))), LogicalType::Varchar(len) => varchar_cast!(value, len), + LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| { + let mut decimal = Decimal::from(v); + Self::decimal_round_i(option, &mut decimal); + + decimal + }))), _ => Err(TypeError::CastFail), } } @@ -550,6 +618,12 @@ impl DataValue { LogicalType::SqlNull => Ok(DataValue::Null), LogicalType::UBigint => Ok(DataValue::UInt64(value.map(|v| v.into()))), LogicalType::Varchar(len) => varchar_cast!(value, len), + LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| { + let mut decimal = Decimal::from(v); + Self::decimal_round_i(option, &mut decimal); + + decimal + }))), _ => Err(TypeError::CastFail), } } @@ -588,6 +662,9 @@ impl DataValue { }).transpose()?; Ok(DataValue::Date64(option)) + }, + LogicalType::Decimal(_, _) => { + Ok(DataValue::Decimal(value.map(|v| Decimal::from_str(&v)).transpose()?)) } } } @@ -624,6 +701,31 @@ impl DataValue { _ => Err(TypeError::CastFail), } } + DataValue::Decimal(value) => { + match to { + LogicalType::SqlNull => Ok(DataValue::Null), + LogicalType::Decimal(_, _) => Ok(DataValue::Decimal(value)), + LogicalType::Varchar(len) => varchar_cast!(value, len), + _ => Err(TypeError::CastFail), + } + } + } + } + + fn decimal_round_i(option: &Option, decimal: &mut Decimal) { + if let Some(scale) = option { + let new_decimal = decimal.trunc_with_scale(*scale as u32); + let _ = mem::replace(decimal, new_decimal); + } + } + + fn decimal_round_f(option: &Option, decimal: &mut Decimal) { + if let Some(scale) = option { + let new_decimal = decimal.round_dp_with_strategy( + *scale as u32, + rust_decimal::RoundingStrategy::MidpointAwayFromZero + ); + let _ = mem::replace(decimal, new_decimal); } } @@ -636,6 +738,11 @@ impl DataValue { NaiveDateTime::from_timestamp_opt(v, 0) .map(|date_time| date_time.format(DATE_TIME_FMT)) } + + fn decimal_format(v: &Decimal) -> String { + v.to_string() + + } } macro_rules! impl_scalar { @@ -724,6 +831,9 @@ impl fmt::Display for DataValue { DataValue::Date64(e) => { format_option!(f, e.and_then(|s| DataValue::date_time_format(s)))? } + DataValue::Decimal(e) => { + format_option!(f, e.as_ref().map(|s| DataValue::decimal_format(s)))? + } }; Ok(()) } @@ -748,6 +858,7 @@ impl fmt::Debug for DataValue { DataValue::Null => write!(f, "null"), DataValue::Date32(_) => write!(f, "Date32({})", self), DataValue::Date64(_) => write!(f, "Date64({})", self), + DataValue::Decimal(_) => write!(f, "Decimal({})", self), } } } diff --git a/tests/slt/decimal b/tests/slt/decimal new file mode 100644 index 00000000..4d566b0a --- /dev/null +++ b/tests/slt/decimal @@ -0,0 +1,6 @@ + +statement ok +CREATE TABLE mytable ( title varchar(256) primary key, cost decimal(4,2)); + +statement ok +INSERT INTO mytable (title, cost) VALUES ('A', 1.00); \ No newline at end of file