From dfcf5c6a6781fc28186efb864ada00ba6a6e1314 Mon Sep 17 00:00:00 2001 From: Kould <2435992353@qq.com> Date: Mon, 25 Sep 2023 03:35:41 +0800 Subject: [PATCH] fix: Optimize Decimal's check length logic and fix interference with other types --- src/binder/insert.rs | 2 +- src/binder/update.rs | 2 +- src/types/value.rs | 91 +++++++++----------------------------------- 3 files changed, 21 insertions(+), 74 deletions(-) diff --git a/src/binder/insert.rs b/src/binder/insert.rs index 0a5e0b94..e5e21c93 100644 --- a/src/binder/insert.rs +++ b/src/binder/insert.rs @@ -50,7 +50,7 @@ impl Binder { match &self.bind_expr(expr).await? { ScalarExpression::Constant(value) => { // Check if the value length is too long - value.check_length(columns[i].datatype())?; + value.check_len(columns[i].datatype())?; let cast_value = DataValue::clone(value) .cast(columns[i].datatype())?; row.push(Arc::new(cast_value)) diff --git a/src/binder/update.rs b/src/binder/update.rs index 92de8986..ea99b4cc 100644 --- a/src/binder/update.rs +++ b/src/binder/update.rs @@ -44,7 +44,7 @@ impl Binder { bind_table_name.as_ref() ).await? { ScalarExpression::ColumnRef(catalog) => { - value.check_length(catalog.datatype())?; + value.check_len(catalog.datatype())?; columns.push(catalog); row.push(value.clone()); }, diff --git a/src/types/value.rs b/src/types/value.rs index 41249b07..96ab9183 100644 --- a/src/types/value.rs +++ b/src/types/value.rs @@ -11,7 +11,7 @@ use lazy_static::lazy_static; use rust_decimal::Decimal; use ordered_float::OrderedFloat; -use rust_decimal::prelude::{FromPrimitive, Signed}; +use rust_decimal::prelude::FromPrimitive; use crate::types::errors::TypeError; use super::LogicalType; @@ -201,81 +201,28 @@ macro_rules! varchar_cast { }; } -macro_rules! check_decimal_length { - ($data_value:expr, $logic_type:expr) => { - if let LogicalType::Decimal(precision, scale) = $logic_type { - let data_value_str = $data_value.to_string(); - let data_value_precision = data_value_str.chars().filter(|c| *c >= '0' && *c <= '9').count(); - if data_value_precision > precision.unwrap() as usize { - return Err(TypeError::TooLong); - } - if $data_value.scale() > scale.unwrap() as u32 { - return Err(TypeError::TooLong); - } - }else{ - return Ok(()) - } - }; -} - impl DataValue { - pub(crate) fn check_length(&self, logic_type: &LogicalType) -> Result<(), TypeError> { - match self { - DataValue::Boolean(_) => return Ok(()), - DataValue::Float32(v) => { - // check literal to decimal - check_decimal_length!(Decimal::from_f32(v.unwrap()).unwrap(), logic_type) - } - DataValue::Float64(v) =>{ - // check literal to decimal - check_decimal_length!(Decimal::from_f64(v.unwrap()).unwrap(), logic_type) - }, - DataValue::Int8(v) => { - check_decimal_length!(Decimal::from(v.unwrap()), logic_type) - } - DataValue::Int16(v) => { - check_decimal_length!(Decimal::from(v.unwrap()), logic_type) - } - DataValue::Int32(v) => { - check_decimal_length!(Decimal::from(v.unwrap()), logic_type) - } - DataValue::Int64(v) => { - check_decimal_length!(Decimal::from(v.unwrap()), logic_type) - } - DataValue::UInt8(v) => { - check_decimal_length!(Decimal::from(v.unwrap()), logic_type) - } - DataValue::UInt16(v) => { - check_decimal_length!(Decimal::from(v.unwrap()), logic_type) - } - DataValue::UInt32(v) => { - check_decimal_length!(Decimal::from(v.unwrap()), logic_type) - } - DataValue::UInt64(v) => { - check_decimal_length!(Decimal::from(v.unwrap()), logic_type) - } - DataValue::Date32(_) => return Ok(()), - DataValue::Date64(_) => return Ok(()), - DataValue::Utf8(value) => { - if let LogicalType::Varchar(len) = logic_type { - if let Some(len) = len { - if value.as_ref().map(|v| v.len() > *len as usize).unwrap_or(false) { - return Err(TypeError::TooLong); - } - } - } - } - DataValue::Decimal(value) => { - if let LogicalType::Decimal(_, scale) = logic_type { - if let Some(value) = value { - if value.scale() as u8 > scale.ok_or(TypeError::InvalidType)? { - return Err(TypeError::TooLong); - } - } + pub(crate) fn check_len(&self, logic_type: &LogicalType) -> Result<(), TypeError> { + let is_over_len = match (logic_type, self) { + (LogicalType::Varchar(Some(len)), DataValue::Utf8(Some(val))) => { + val.len() > *len as usize + } + (LogicalType::Decimal(full_len, scale_len), DataValue::Decimal(Some(val))) => { + if let Some(len) = full_len { + val.mantissa().ilog10() + 1 > *len as u32 + } else if let Some(len) = scale_len { + val.scale() > *len as u32 + } else { + false } } - _ => { return Err(TypeError::InvalidType); } + _ => false + }; + + if is_over_len { + return Err(TypeError::TooLong) } + Ok(()) }