Skip to content

Commit

Permalink
fix: Optimize Decimal's check length logic and fix interference with …
Browse files Browse the repository at this point in the history
…other types
  • Loading branch information
KKould committed Sep 24, 2023
1 parent 045c4ee commit dfcf5c6
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 74 deletions.
2 changes: 1 addition & 1 deletion src/binder/insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ impl<S: Storage> Binder<S> {
match &self.bind_expr(expr).await? {
ScalarExpression::Constant(value) => {
// Check if the value length is too long
value.check_length(columns[i].datatype())?;
value.check_len(columns[i].datatype())?;
let cast_value = DataValue::clone(value)
.cast(columns[i].datatype())?;
row.push(Arc::new(cast_value))
Expand Down
2 changes: 1 addition & 1 deletion src/binder/update.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ impl<S: Storage> Binder<S> {
bind_table_name.as_ref()
).await? {
ScalarExpression::ColumnRef(catalog) => {
value.check_length(catalog.datatype())?;
value.check_len(catalog.datatype())?;
columns.push(catalog);
row.push(value.clone());
},
Expand Down
91 changes: 19 additions & 72 deletions src/types/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use lazy_static::lazy_static;
use rust_decimal::Decimal;

use ordered_float::OrderedFloat;
use rust_decimal::prelude::{FromPrimitive, Signed};
use rust_decimal::prelude::FromPrimitive;
use crate::types::errors::TypeError;

use super::LogicalType;
Expand Down Expand Up @@ -201,81 +201,28 @@ macro_rules! varchar_cast {
};
}

macro_rules! check_decimal_length {
($data_value:expr, $logic_type:expr) => {
if let LogicalType::Decimal(precision, scale) = $logic_type {
let data_value_str = $data_value.to_string();
let data_value_precision = data_value_str.chars().filter(|c| *c >= '0' && *c <= '9').count();
if data_value_precision > precision.unwrap() as usize {
return Err(TypeError::TooLong);
}
if $data_value.scale() > scale.unwrap() as u32 {
return Err(TypeError::TooLong);
}
}else{
return Ok(())
}
};
}

impl DataValue {
pub(crate) fn check_length(&self, logic_type: &LogicalType) -> Result<(), TypeError> {
match self {
DataValue::Boolean(_) => return Ok(()),
DataValue::Float32(v) => {
// check literal to decimal
check_decimal_length!(Decimal::from_f32(v.unwrap()).unwrap(), logic_type)
}
DataValue::Float64(v) =>{
// check literal to decimal
check_decimal_length!(Decimal::from_f64(v.unwrap()).unwrap(), logic_type)
},
DataValue::Int8(v) => {
check_decimal_length!(Decimal::from(v.unwrap()), logic_type)
}
DataValue::Int16(v) => {
check_decimal_length!(Decimal::from(v.unwrap()), logic_type)
}
DataValue::Int32(v) => {
check_decimal_length!(Decimal::from(v.unwrap()), logic_type)
}
DataValue::Int64(v) => {
check_decimal_length!(Decimal::from(v.unwrap()), logic_type)
}
DataValue::UInt8(v) => {
check_decimal_length!(Decimal::from(v.unwrap()), logic_type)
}
DataValue::UInt16(v) => {
check_decimal_length!(Decimal::from(v.unwrap()), logic_type)
}
DataValue::UInt32(v) => {
check_decimal_length!(Decimal::from(v.unwrap()), logic_type)
}
DataValue::UInt64(v) => {
check_decimal_length!(Decimal::from(v.unwrap()), logic_type)
}
DataValue::Date32(_) => return Ok(()),
DataValue::Date64(_) => return Ok(()),
DataValue::Utf8(value) => {
if let LogicalType::Varchar(len) = logic_type {
if let Some(len) = len {
if value.as_ref().map(|v| v.len() > *len as usize).unwrap_or(false) {
return Err(TypeError::TooLong);
}
}
}
}
DataValue::Decimal(value) => {
if let LogicalType::Decimal(_, scale) = logic_type {
if let Some(value) = value {
if value.scale() as u8 > scale.ok_or(TypeError::InvalidType)? {
return Err(TypeError::TooLong);
}
}
pub(crate) fn check_len(&self, logic_type: &LogicalType) -> Result<(), TypeError> {
let is_over_len = match (logic_type, self) {
(LogicalType::Varchar(Some(len)), DataValue::Utf8(Some(val))) => {
val.len() > *len as usize
}
(LogicalType::Decimal(full_len, scale_len), DataValue::Decimal(Some(val))) => {
if let Some(len) = full_len {
val.mantissa().ilog10() + 1 > *len as u32
} else if let Some(len) = scale_len {
val.scale() > *len as u32
} else {
false
}
}
_ => { return Err(TypeError::InvalidType); }
_ => false
};

if is_over_len {
return Err(TypeError::TooLong)
}

Ok(())
}

Expand Down

0 comments on commit dfcf5c6

Please sign in to comment.