diff --git a/Cargo.toml b/Cargo.toml index 35152592..a1ba0656 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,6 +32,7 @@ comfy-table = "7.0.1" bytes = "*" kip_db = "0.1.2-alpha.15" async-recursion = "1.0.5" +rust_decimal = "1" [dev-dependencies] tokio-test = "0.4.2" diff --git a/src/binder/insert.rs b/src/binder/insert.rs index 31f80df3..4b82755c 100644 --- a/src/binder/insert.rs +++ b/src/binder/insert.rs @@ -50,7 +50,7 @@ impl Binder { match &self.bind_expr(expr).await? { ScalarExpression::Constant(value) => { // Check if the value length is too long - value.check_length(columns[i].datatype())?; + value.check_len(columns[i].datatype())?; let cast_value = DataValue::clone(value) .cast(columns[i].datatype())?; row.push(Arc::new(cast_value)) diff --git a/src/binder/update.rs b/src/binder/update.rs index 92de8986..ea99b4cc 100644 --- a/src/binder/update.rs +++ b/src/binder/update.rs @@ -44,7 +44,7 @@ impl Binder { bind_table_name.as_ref() ).await? { ScalarExpression::ColumnRef(catalog) => { - value.check_length(catalog.datatype())?; + value.check_len(catalog.datatype())?; columns.push(catalog); row.push(value.clone()); }, diff --git a/src/db.rs b/src/db.rs index d49e528e..060144c5 100644 --- a/src/db.rs +++ b/src/db.rs @@ -200,6 +200,9 @@ mod test { let _ = kipsql.run("create table t2 (c int primary key, d int unsigned null, e datetime)").await?; let _ = kipsql.run("insert into t1 (a, b, k) values (-99, 1, 1), (-1, 2, 2), (5, 3, 2)").await?; let _ = kipsql.run("insert into t2 (d, c, e) values (2, 1, '2021-05-20 21:00:00'), (3, 4, '2023-09-10 00:00:00')").await?; + let _ = kipsql.run("create table t3 (a int primary key, b decimal(4,2))").await?; + let _ = kipsql.run("insert into t3 (a, b) values (1, 1111), (2, 2.01), (3, 3.00)").await?; + let _ = kipsql.run("insert into t3 (a, b) values (4, 4444), (5, 5222), (6, 1.00)").await?; println!("show tables:"); let tuples_show_tables = kipsql.run("show tables").await?; @@ -321,6 +324,10 @@ mod test { println!("drop t1:"); let _ = kipsql.run("drop table t1").await?; + println!("decimal:"); + let tuples_decimal = kipsql.run("select * from t3").await?; + println!("{}", create_table(&tuples_decimal)); + Ok(()) } } diff --git a/src/expression/mod.rs b/src/expression/mod.rs index 28125c7b..35f55e5f 100644 --- a/src/expression/mod.rs +++ b/src/expression/mod.rs @@ -1,5 +1,5 @@ use std::fmt; -use std::fmt::Formatter; +use std::fmt::{Debug, Formatter}; use std::sync::Arc; use itertools::Itertools; diff --git a/src/storage/table_codec.rs b/src/storage/table_codec.rs index 58a5a9c5..584c5376 100644 --- a/src/storage/table_codec.rs +++ b/src/storage/table_codec.rs @@ -234,6 +234,7 @@ mod tests { use std::ops::Bound; use std::sync::Arc; use itertools::Itertools; + use rust_decimal::Decimal; use crate::catalog::{ColumnCatalog, ColumnDesc, TableCatalog}; use crate::storage::table_codec::{COLUMNS_ID_LEN, TableCodec}; use crate::types::errors::TypeError; @@ -249,6 +250,13 @@ mod tests { false, ColumnDesc::new(LogicalType::Integer, true, false) ) + ColumnDesc::new(LogicalType::Integer, true, false) + ), + ColumnCatalog::new( + "c2".into(), + false, + ColumnDesc::new(LogicalType::Decimal(None,None), false) + ), ]; let table_catalog = TableCatalog::new(Arc::new("t1".to_string()), columns, vec![]).unwrap(); let codec = TableCodec { table: table_catalog.clone() }; @@ -264,6 +272,7 @@ mod tests { columns: table_catalog.all_columns(), values: vec![ Arc::new(DataValue::Int32(Some(0))), + Arc::new(DataValue::Decimal(Some(Decimal::new(1, 0)))), ] }; diff --git a/src/types/errors.rs b/src/types/errors.rs index 2317a14c..dff88e9d 100644 --- a/src/types/errors.rs +++ b/src/types/errors.rs @@ -52,4 +52,10 @@ pub enum TypeError { #[from] Box ) + #[error("try from decimal")] + TryFromDecimal( + #[source] + #[from] + rust_decimal::Error, + ), } diff --git a/src/types/mod.rs b/src/types/mod.rs index b0443b5e..c6fbac8d 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -4,6 +4,9 @@ pub mod tuple; pub mod index; use serde::{Deserialize, Serialize}; + +use integer_encoding::FixedInt; +use sqlparser::ast::ExactNumberInfo; use strum_macros::AsRefStr; use crate::types::errors::TypeError; @@ -30,6 +33,8 @@ pub enum LogicalType { Varchar(Option), Date, DateTime, + // decimal (precision, scale) + Decimal(Option, Option), } impl LogicalType { @@ -48,8 +53,9 @@ impl LogicalType { LogicalType::UBigint => Some(8), LogicalType::Float => Some(4), LogicalType::Double => Some(8), - /// Note: The non-fixed length type's raw_len is None - LogicalType::Varchar(_)=>None, + /// Note: The non-fixed length type's raw_len is None e.g. Varchar and Decimal + LogicalType::Varchar(_) => None, + LogicalType::Decimal(_, _) => Some(16), LogicalType::Date => Some(4), LogicalType::DateTime => Some(8), } @@ -242,6 +248,7 @@ impl LogicalType { LogicalType::Varchar(_) => false, LogicalType::Date => matches!(to, LogicalType::DateTime | LogicalType::Varchar(_)), LogicalType::DateTime => matches!(to, LogicalType::Date | LogicalType::Varchar(_)), + LogicalType::Decimal(_, _) => false, } } } @@ -269,6 +276,13 @@ impl TryFrom for LogicalType { sqlparser::ast::DataType::UnsignedBigInt(_) => Ok(LogicalType::UBigint), sqlparser::ast::DataType::Boolean => Ok(LogicalType::Boolean), sqlparser::ast::DataType::Datetime(_) => Ok(LogicalType::DateTime), + sqlparser::ast::DataType::Decimal(info) => match info { + ExactNumberInfo::None => Ok(Self::Decimal(None, None)), + ExactNumberInfo::Precision(p) => Ok(Self::Decimal(Some(p as u8), None)), + ExactNumberInfo::PrecisionAndScale(p, s) => { + Ok(Self::Decimal(Some(p as u8), Some(s as u8))) + } + }, other => Err(TypeError::NotImplementedSqlparserDataType( other.to_string(), )), diff --git a/src/types/value.rs b/src/types/value.rs index 9ff5f8e2..3229b97e 100644 --- a/src/types/value.rs +++ b/src/types/value.rs @@ -8,9 +8,11 @@ use chrono::{NaiveDateTime, Datelike, NaiveDate}; use chrono::format::{DelayedFormat, StrftimeItems}; use integer_encoding::FixedInt; use lazy_static::lazy_static; +use rust_decimal::Decimal; use ordered_float::OrderedFloat; use serde::{Deserialize, Serialize}; +use rust_decimal::prelude::FromPrimitive; use crate::types::errors::TypeError; use super::LogicalType; @@ -45,6 +47,7 @@ pub enum DataValue { Date32(Option), /// Date stored as a signed 64bit int timestamp since UNIX epoch 1970-01-01 Date64(Option), + Decimal(Option), } impl PartialEq for DataValue { @@ -89,6 +92,8 @@ impl PartialEq for DataValue { (Date32(_), _) => false, (Date64(v1), Date64(v2)) => v1.eq(v2), (Date64(_), _) => false, + (Decimal(v1), Decimal(v2)) => v1.eq(v2), + (Decimal(_), _) => false, } } } @@ -135,6 +140,8 @@ impl PartialOrd for DataValue { (Date32(_), _) => None, (Date64(v1), Date64(v2)) => v1.partial_cmp(v2), (Date64(_), _) => None, + (Decimal(v1), Decimal(v2)) => v1.partial_cmp(v2), + (Decimal(_), _) => None, } } } @@ -176,6 +183,7 @@ impl Hash for DataValue { Null => 1.hash(state), Date32(v) => v.hash(state), Date64(v) => v.hash(state), + Decimal(v) => v.hash(state), } } } @@ -195,32 +203,27 @@ macro_rules! varchar_cast { } impl DataValue { - pub(crate) fn check_length(&self, logic_type: &LogicalType) -> Result<(), TypeError> { - match self { - DataValue::Boolean(_) => return Ok(()), - DataValue::Float32(_) => return Ok(()), - DataValue::Float64(_) => return Ok(()), - DataValue::Int8(_) => return Ok(()), - DataValue::Int16(_) => return Ok(()), - DataValue::Int32(_) => return Ok(()), - DataValue::Int64(_) => return Ok(()), - DataValue::UInt8(_) => return Ok(()), - DataValue::UInt16(_) => return Ok(()), - DataValue::UInt32(_) => return Ok(()), - DataValue::UInt64(_) => return Ok(()), - DataValue::Date32(_) => return Ok(()), - DataValue::Date64(_) => return Ok(()), - DataValue::Utf8(value) => { - if let LogicalType::Varchar(len) = logic_type { - if let Some(len) = len { - if value.as_ref().map(|v| v.len() > *len as usize).unwrap_or(false) { - return Err(TypeError::TooLong); - } - } + pub(crate) fn check_len(&self, logic_type: &LogicalType) -> Result<(), TypeError> { + let is_over_len = match (logic_type, self) { + (LogicalType::Varchar(Some(len)), DataValue::Utf8(Some(val))) => { + val.len() > *len as usize + } + (LogicalType::Decimal(full_len, scale_len), DataValue::Decimal(Some(val))) => { + if let Some(len) = full_len { + val.mantissa().ilog10() + 1 > *len as u32 + } else if let Some(len) = scale_len { + val.scale() > *len as u32 + } else { + false } } - _ => { return Err(TypeError::InvalidType); } + _ => false + }; + + if is_over_len { + return Err(TypeError::TooLong) } + Ok(()) } @@ -260,6 +263,7 @@ impl DataValue { DataValue::Utf8(value) => value.is_none(), DataValue::Date32(value) => value.is_none(), DataValue::Date64(value) => value.is_none(), + DataValue::Decimal(value) => value.is_none(), } } @@ -280,7 +284,8 @@ impl DataValue { LogicalType::Double => DataValue::Float64(None), LogicalType::Varchar(_) => DataValue::Utf8(None), LogicalType::Date => DataValue::Date32(None), - LogicalType::DateTime => DataValue::Date64(None) + LogicalType::DateTime => DataValue::Date64(None), + LogicalType::Decimal(_, _) => DataValue::Decimal(None), } } @@ -301,7 +306,8 @@ impl DataValue { LogicalType::Double => DataValue::Float64(Some(0.0)), LogicalType::Varchar(_) => DataValue::Utf8(Some("".to_string())), LogicalType::Date => DataValue::Date32(Some(UNIX_DATETIME.num_days_from_ce())), - LogicalType::DateTime => DataValue::Date64(Some(UNIX_DATETIME.timestamp())) + LogicalType::DateTime => DataValue::Date64(Some(UNIX_DATETIME.timestamp())), + LogicalType::Decimal(_, _) => DataValue::Decimal(Some(Decimal::new(0, 0))), } } @@ -322,6 +328,7 @@ impl DataValue { DataValue::Utf8(v) => v.clone().map(|v| v.into_bytes()), DataValue::Date32(v) => v.map(|v| v.encode_fixed_vec()), DataValue::Date64(v) => v.map(|v| v.encode_fixed_vec()), + DataValue::Decimal(v) => v.clone().map(|v| v.serialize().to_vec()), }.unwrap_or(vec![]) } @@ -351,6 +358,7 @@ impl DataValue { LogicalType::Varchar(_) => DataValue::Utf8((!bytes.is_empty()).then(|| String::from_utf8(bytes.to_owned()).unwrap())), LogicalType::Date => DataValue::Date32((!bytes.is_empty()).then(|| i32::decode_fixed(bytes))), LogicalType::DateTime => DataValue::Date64((!bytes.is_empty()).then(|| i64::decode_fixed(bytes))), + LogicalType::Decimal(_, _) => DataValue::Decimal((!bytes.is_empty()).then(|| Decimal::deserialize(<[u8; 16]>::try_from(bytes).unwrap()))), } } @@ -371,6 +379,7 @@ impl DataValue { DataValue::Utf8(_) => LogicalType::Varchar(None), DataValue::Date32(_) => LogicalType::Date, DataValue::Date64(_) => LogicalType::DateTime, + DataValue::Decimal(_) => LogicalType::Decimal(None, None), } } @@ -429,6 +438,7 @@ impl DataValue { LogicalType::Varchar(_) => Ok(DataValue::Utf8(None)), LogicalType::Date => Ok(DataValue::Date32(None)), LogicalType::DateTime => Ok(DataValue::Date64(None)), + LogicalType::Decimal(_, _) => Ok(DataValue::Decimal(None)), } } DataValue::Boolean(value) => { @@ -455,6 +465,14 @@ impl DataValue { LogicalType::Float => Ok(DataValue::Float32(value)), LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))), LogicalType::Varchar(len) => varchar_cast!(value, len), + LogicalType::Decimal(_, option) =>{ + Ok(DataValue::Decimal(value.map(|v| { + let mut decimal = Decimal::from_f32(v).ok_or(TypeError::CastFail)?; + Self::decimal_round_f(option, &mut decimal); + + Ok::(decimal) + }).transpose()?)) + } _ => Err(TypeError::CastFail), } } @@ -463,6 +481,14 @@ impl DataValue { LogicalType::SqlNull => Ok(DataValue::Null), LogicalType::Double => Ok(DataValue::Float64(value)), LogicalType::Varchar(len) => varchar_cast!(value, len), + LogicalType::Decimal(_, option) => { + Ok(DataValue::Decimal(value.map(|v| { + let mut decimal = Decimal::from_f64(v).ok_or(TypeError::CastFail)?; + Self::decimal_round_f(option, &mut decimal); + + Ok::(decimal) + }).transpose()?)) + } _ => Err(TypeError::CastFail), } } @@ -480,6 +506,12 @@ impl DataValue { LogicalType::Float => Ok(DataValue::Float32(value.map(|v| v.into()))), LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))), LogicalType::Varchar(len) => varchar_cast!(value, len), + LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| { + let mut decimal = Decimal::from(v); + Self::decimal_round_i(option, &mut decimal); + + decimal + }))), _ => Err(TypeError::CastFail), } } @@ -496,6 +528,12 @@ impl DataValue { LogicalType::Float => Ok(DataValue::Float32(value.map(|v| v.into()))), LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))), LogicalType::Varchar(len) => varchar_cast!(value, len), + LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| { + let mut decimal = Decimal::from(v); + Self::decimal_round_i(option, &mut decimal); + + decimal + }))), _ => Err(TypeError::CastFail), } } @@ -510,6 +548,12 @@ impl DataValue { LogicalType::Bigint => Ok(DataValue::Int64(value.map(|v| v.into()))), LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))), LogicalType::Varchar(len) => varchar_cast!(value, len), + LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| { + let mut decimal = Decimal::from(v); + Self::decimal_round_i(option, &mut decimal); + + decimal + }))), _ => Err(TypeError::CastFail), } } @@ -522,6 +566,12 @@ impl DataValue { LogicalType::UBigint => Ok(DataValue::UInt64(value.map(|v| u64::try_from(v)).transpose()?)), LogicalType::Bigint => Ok(DataValue::Int64(value.map(|v| v.into()))), LogicalType::Varchar(len) => varchar_cast!(value, len), + LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| { + let mut decimal = Decimal::from(v); + Self::decimal_round_i(option, &mut decimal); + + decimal + }))), _ => Err(TypeError::CastFail), } } @@ -538,6 +588,12 @@ impl DataValue { LogicalType::Float => Ok(DataValue::Float32(value.map(|v| v.into()))), LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))), LogicalType::Varchar(len) => varchar_cast!(value, len), + LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| { + let mut decimal = Decimal::from(v); + Self::decimal_round_i(option, &mut decimal); + + decimal + }))), _ => Err(TypeError::CastFail), } } @@ -552,6 +608,12 @@ impl DataValue { LogicalType::Float => Ok(DataValue::Float32(value.map(|v| v.into()))), LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))), LogicalType::Varchar(len) => varchar_cast!(value, len), + LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| { + let mut decimal = Decimal::from(v); + Self::decimal_round_i(option, &mut decimal); + + decimal + }))), _ => Err(TypeError::CastFail), } } @@ -563,6 +625,12 @@ impl DataValue { LogicalType::UBigint => Ok(DataValue::UInt64(value.map(|v| v.into()))), LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))), LogicalType::Varchar(len) => varchar_cast!(value, len), + LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| { + let mut decimal = Decimal::from(v); + Self::decimal_round_i(option, &mut decimal); + + decimal + }))), _ => Err(TypeError::CastFail), } } @@ -571,6 +639,12 @@ impl DataValue { LogicalType::SqlNull => Ok(DataValue::Null), LogicalType::UBigint => Ok(DataValue::UInt64(value.map(|v| v.into()))), LogicalType::Varchar(len) => varchar_cast!(value, len), + LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| { + let mut decimal = Decimal::from(v); + Self::decimal_round_i(option, &mut decimal); + + decimal + }))), _ => Err(TypeError::CastFail), } } @@ -609,6 +683,9 @@ impl DataValue { }).transpose()?; Ok(DataValue::Date64(option)) + }, + LogicalType::Decimal(_, _) => { + Ok(DataValue::Decimal(value.map(|v| Decimal::from_str(&v)).transpose()?)) } } } @@ -645,6 +722,31 @@ impl DataValue { _ => Err(TypeError::CastFail), } } + DataValue::Decimal(value) => { + match to { + LogicalType::SqlNull => Ok(DataValue::Null), + LogicalType::Decimal(_, _) => Ok(DataValue::Decimal(value)), + LogicalType::Varchar(len) => varchar_cast!(value, len), + _ => Err(TypeError::CastFail), + } + } + } + } + + fn decimal_round_i(option: &Option, decimal: &mut Decimal) { + if let Some(scale) = option { + let new_decimal = decimal.trunc_with_scale(*scale as u32); + let _ = mem::replace(decimal, new_decimal); + } + } + + fn decimal_round_f(option: &Option, decimal: &mut Decimal) { + if let Some(scale) = option { + let new_decimal = decimal.round_dp_with_strategy( + *scale as u32, + rust_decimal::RoundingStrategy::MidpointAwayFromZero + ); + let _ = mem::replace(decimal, new_decimal); } } @@ -657,6 +759,11 @@ impl DataValue { NaiveDateTime::from_timestamp_opt(v, 0) .map(|date_time| date_time.format(DATE_TIME_FMT)) } + + fn decimal_format(v: &Decimal) -> String { + v.to_string() + + } } macro_rules! impl_scalar { @@ -745,6 +852,9 @@ impl fmt::Display for DataValue { DataValue::Date64(e) => { format_option!(f, e.and_then(|s| DataValue::date_time_format(s)))? } + DataValue::Decimal(e) => { + format_option!(f, e.as_ref().map(|s| DataValue::decimal_format(s)))? + } }; Ok(()) } @@ -769,6 +879,7 @@ impl fmt::Debug for DataValue { DataValue::Null => write!(f, "null"), DataValue::Date32(_) => write!(f, "Date32({})", self), DataValue::Date64(_) => write!(f, "Date64({})", self), + DataValue::Decimal(_) => write!(f, "Decimal({})", self), } } } diff --git a/tests/slt/decimal b/tests/slt/decimal new file mode 100644 index 00000000..4d566b0a --- /dev/null +++ b/tests/slt/decimal @@ -0,0 +1,6 @@ + +statement ok +CREATE TABLE mytable ( title varchar(256) primary key, cost decimal(4,2)); + +statement ok +INSERT INTO mytable (title, cost) VALUES ('A', 1.00); \ No newline at end of file