Skip to content

Commit

Permalink
feat(type): add support for Decimal type in database (#66)
Browse files Browse the repository at this point in the history
Co-authored-by: Kould <[email protected]>
  • Loading branch information
loloxwg and KKould authored Sep 25, 2023
1 parent a348db1 commit a696e2a
Show file tree
Hide file tree
Showing 10 changed files with 183 additions and 33 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ comfy-table = "7.0.1"
bytes = "*"
kip_db = "0.1.2-alpha.15"
async-recursion = "1.0.5"
rust_decimal = "1"

[dev-dependencies]
tokio-test = "0.4.2"
Expand Down
2 changes: 1 addition & 1 deletion src/binder/insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ impl<S: Storage> Binder<S> {
match &self.bind_expr(expr).await? {
ScalarExpression::Constant(value) => {
// Check if the value length is too long
value.check_length(columns[i].datatype())?;
value.check_len(columns[i].datatype())?;
let cast_value = DataValue::clone(value)
.cast(columns[i].datatype())?;
row.push(Arc::new(cast_value))
Expand Down
2 changes: 1 addition & 1 deletion src/binder/update.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ impl<S: Storage> Binder<S> {
bind_table_name.as_ref()
).await? {
ScalarExpression::ColumnRef(catalog) => {
value.check_length(catalog.datatype())?;
value.check_len(catalog.datatype())?;
columns.push(catalog);
row.push(value.clone());
},
Expand Down
7 changes: 7 additions & 0 deletions src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,9 @@ mod test {
let _ = kipsql.run("create table t2 (c int primary key, d int unsigned null, e datetime)").await?;
let _ = kipsql.run("insert into t1 (a, b, k) values (-99, 1, 1), (-1, 2, 2), (5, 2, 2)").await?;
let _ = kipsql.run("insert into t2 (d, c, e) values (2, 1, '2021-05-20 21:00:00'), (3, 4, '2023-09-10 00:00:00')").await?;
let _ = kipsql.run("create table t3 (a int primary key, b decimal(4,2))").await?;
let _ = kipsql.run("insert into t3 (a, b) values (1, 1111), (2, 2.01), (3, 3.00)").await?;
let _ = kipsql.run("insert into t3 (a, b) values (4, 4444), (5, 5222), (6, 1.00)").await?;

println!("full t1:");
let tuples_full_fields_t1 = kipsql.run("select * from t1").await?;
Expand Down Expand Up @@ -305,6 +308,10 @@ mod test {
let tuples_show_tables = kipsql.run("show tables").await?;
println!("{}", create_table(&tuples_show_tables));

println!("decimal:");
let tuples_decimal = kipsql.run("select * from t3").await?;
println!("{}", create_table(&tuples_decimal));

Ok(())
}
}
2 changes: 1 addition & 1 deletion src/expression/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::fmt;
use std::fmt::Formatter;
use std::fmt::{Debug, Formatter};
use std::sync::Arc;
use itertools::Itertools;

Expand Down
9 changes: 8 additions & 1 deletion src/storage/table_codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ mod tests {
use std::ops::Bound;
use std::sync::Arc;
use itertools::Itertools;
use rust_decimal::Decimal;
use crate::catalog::{ColumnCatalog, ColumnDesc, TableCatalog};
use crate::storage::table_codec::{COLUMNS_ID_LEN, TableCodec};
use crate::types::errors::TypeError;
Expand All @@ -159,7 +160,12 @@ mod tests {
"c1".into(),
false,
ColumnDesc::new(LogicalType::Integer, true)
)
),
ColumnCatalog::new(
"c2".into(),
false,
ColumnDesc::new(LogicalType::Decimal(None,None), false)
),
];
let table_catalog = TableCatalog::new(Arc::new("t1".to_string()), columns).unwrap();
let codec = TableCodec { table: table_catalog.clone() };
Expand All @@ -175,6 +181,7 @@ mod tests {
columns: table_catalog.all_columns(),
values: vec![
Arc::new(DataValue::Int32(Some(0))),
Arc::new(DataValue::Decimal(Some(Decimal::new(1, 0)))),
]
};

Expand Down
6 changes: 6 additions & 0 deletions src/types/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,10 @@ pub enum TypeError {
#[from]
ParseError,
),
#[error("try from decimal")]
TryFromDecimal(
#[source]
#[from]
rust_decimal::Error,
),
}
18 changes: 15 additions & 3 deletions src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use std::sync::atomic::Ordering::{Acquire, Release};
use serde::{Deserialize, Serialize};

use integer_encoding::FixedInt;
use sqlparser::ast::ExactNumberInfo;
use strum_macros::AsRefStr;

use crate::types::errors::TypeError;
Expand Down Expand Up @@ -57,6 +58,8 @@ pub enum LogicalType {
Varchar(Option<u32>),
Date,
DateTime,
// decimal (precision, scale)
Decimal(Option<u8>, Option<u8>),
}

impl LogicalType {
Expand All @@ -75,8 +78,9 @@ impl LogicalType {
LogicalType::UBigint => Some(8),
LogicalType::Float => Some(4),
LogicalType::Double => Some(8),
/// Note: The non-fixed length type's raw_len is None
LogicalType::Varchar(_)=>None,
/// Note: The non-fixed length type's raw_len is None e.g. Varchar and Decimal
LogicalType::Varchar(_) => None,
LogicalType::Decimal(_, _) => Some(16),
LogicalType::Date => Some(4),
LogicalType::DateTime => Some(8),
}
Expand Down Expand Up @@ -269,6 +273,7 @@ impl LogicalType {
LogicalType::Varchar(_) => false,
LogicalType::Date => matches!(to, LogicalType::DateTime | LogicalType::Varchar(_)),
LogicalType::DateTime => matches!(to, LogicalType::Date | LogicalType::Varchar(_)),
LogicalType::Decimal(_, _) => false,
}
}
}
Expand Down Expand Up @@ -296,6 +301,13 @@ impl TryFrom<sqlparser::ast::DataType> for LogicalType {
sqlparser::ast::DataType::UnsignedBigInt(_) => Ok(LogicalType::UBigint),
sqlparser::ast::DataType::Boolean => Ok(LogicalType::Boolean),
sqlparser::ast::DataType::Datetime(_) => Ok(LogicalType::DateTime),
sqlparser::ast::DataType::Decimal(info) => match info {
ExactNumberInfo::None => Ok(Self::Decimal(None, None)),
ExactNumberInfo::Precision(p) => Ok(Self::Decimal(Some(p as u8), None)),
ExactNumberInfo::PrecisionAndScale(p, s) => {
Ok(Self::Decimal(Some(p as u8), Some(s as u8)))
}
},
other => Err(TypeError::NotImplementedSqlparserDataType(
other.to_string(),
)),
Expand All @@ -313,7 +325,7 @@ impl std::fmt::Display for LogicalType {
mod test {
use std::sync::atomic::Ordering::Release;

use crate::types::{IdGenerator, ID_BUF, LogicalType};
use crate::types::{IdGenerator, ID_BUF};

/// Tips: 由于IdGenerator为static全局性质生成的id,因此需要单独测试避免其他测试方法干扰
#[test]
Expand Down
Loading

0 comments on commit a696e2a

Please sign in to comment.