From f63c8da1926ad15164a00260fd52c58ccafcbddf Mon Sep 17 00:00:00 2001 From: Kould <2435992353@qq.com> Date: Mon, 25 Sep 2023 04:03:25 +0800 Subject: [PATCH] fix: `DataValue::cast` fixes Decimal conversion problem --- src/types/mod.rs | 4 +- src/types/value.rs | 94 +++++++++++++++++++++++++++++++++++++++------- 2 files changed, 82 insertions(+), 16 deletions(-) diff --git a/src/types/mod.rs b/src/types/mod.rs index c0ba1470..de67aed9 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -79,8 +79,8 @@ impl LogicalType { LogicalType::Float => Some(4), LogicalType::Double => Some(8), /// Note: The non-fixed length type's raw_len is None e.g. Varchar and Decimal - LogicalType::Varchar(_)=>None, - LogicalType::Decimal(_, _) =>None, + LogicalType::Varchar(_) => None, + LogicalType::Decimal(_, _) => Some(16), LogicalType::Date => Some(4), LogicalType::DateTime => Some(8), } diff --git a/src/types/value.rs b/src/types/value.rs index 96ab9183..94b856f3 100644 --- a/src/types/value.rs +++ b/src/types/value.rs @@ -1,5 +1,5 @@ use std::cmp::Ordering; -use std::fmt; +use std::{fmt, mem}; use std::fmt::Formatter; use std::hash::Hash; use std::str::FromStr; @@ -241,7 +241,6 @@ impl DataValue { pub fn is_variable(&self) -> bool { match self { DataValue::Utf8(_) => true, - DataValue::Decimal(_) => true, _ => false } } @@ -445,8 +444,13 @@ impl DataValue { LogicalType::Float => Ok(DataValue::Float32(value)), LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))), LogicalType::Varchar(len) => varchar_cast!(value, len), - LogicalType::Decimal(_,s) =>{ - Ok(DataValue::Decimal(value.map(|v| Decimal::from_f32(v).unwrap().round_dp_with_strategy( s.clone().unwrap() as u32, rust_decimal::RoundingStrategy::MidpointAwayFromZero)))) + LogicalType::Decimal(_, option) =>{ + Ok(DataValue::Decimal(value.map(|v| { + let mut decimal = Decimal::from_f32(v).ok_or(TypeError::CastFail)?; + Self::decimal_round_f(option, &mut decimal); + + Ok::(decimal) + }).transpose()?)) } _ => Err(TypeError::CastFail), } @@ -456,8 +460,13 @@ impl DataValue { LogicalType::SqlNull => Ok(DataValue::Null), LogicalType::Double => Ok(DataValue::Float64(value)), LogicalType::Varchar(len) => varchar_cast!(value, len), - LogicalType::Decimal(_,s) => { - Ok(DataValue::Decimal(value.map(|v| Decimal::from_f64(v).unwrap().round_dp_with_strategy( s.clone().unwrap() as u32, rust_decimal::RoundingStrategy::MidpointAwayFromZero)))) + LogicalType::Decimal(_, option) => { + Ok(DataValue::Decimal(value.map(|v| { + let mut decimal = Decimal::from_f64(v).ok_or(TypeError::CastFail)?; + Self::decimal_round_f(option, &mut decimal); + + Ok::(decimal) + }).transpose()?)) } _ => Err(TypeError::CastFail), } @@ -476,7 +485,12 @@ impl DataValue { LogicalType::Float => Ok(DataValue::Float32(value.map(|v| v.into()))), LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))), LogicalType::Varchar(len) => varchar_cast!(value, len), - LogicalType::Decimal(_,s) => Ok(DataValue::Decimal(value.map(|v| Decimal::from(v).trunc_with_scale(s.unwrap() as u32)))), + LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| { + let mut decimal = Decimal::from(v); + Self::decimal_round_i(option, &mut decimal); + + decimal + }))), _ => Err(TypeError::CastFail), } } @@ -493,7 +507,12 @@ impl DataValue { LogicalType::Float => Ok(DataValue::Float32(value.map(|v| v.into()))), LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))), LogicalType::Varchar(len) => varchar_cast!(value, len), - LogicalType::Decimal(_,s) => Ok(DataValue::Decimal(value.map(|v| Decimal::from(v).trunc_with_scale(s.unwrap() as u32)))), + LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| { + let mut decimal = Decimal::from(v); + Self::decimal_round_i(option, &mut decimal); + + decimal + }))), _ => Err(TypeError::CastFail), } } @@ -508,7 +527,12 @@ impl DataValue { LogicalType::Bigint => Ok(DataValue::Int64(value.map(|v| v.into()))), LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))), LogicalType::Varchar(len) => varchar_cast!(value, len), - LogicalType::Decimal(_,s) => Ok(DataValue::Decimal(value.map(|v| Decimal::from(v).trunc_with_scale(s.unwrap() as u32)))), + LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| { + let mut decimal = Decimal::from(v); + Self::decimal_round_i(option, &mut decimal); + + decimal + }))), _ => Err(TypeError::CastFail), } } @@ -521,7 +545,12 @@ impl DataValue { LogicalType::UBigint => Ok(DataValue::UInt64(value.map(|v| u64::try_from(v)).transpose()?)), LogicalType::Bigint => Ok(DataValue::Int64(value.map(|v| v.into()))), LogicalType::Varchar(len) => varchar_cast!(value, len), - LogicalType::Decimal(_,s) => Ok(DataValue::Decimal(value.map(|v| Decimal::from(v).trunc_with_scale(s.unwrap() as u32)))), + LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| { + let mut decimal = Decimal::from(v); + Self::decimal_round_i(option, &mut decimal); + + decimal + }))), _ => Err(TypeError::CastFail), } } @@ -538,7 +567,12 @@ impl DataValue { LogicalType::Float => Ok(DataValue::Float32(value.map(|v| v.into()))), LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))), LogicalType::Varchar(len) => varchar_cast!(value, len), - LogicalType::Decimal(_,s) => Ok(DataValue::Decimal(value.map(|v| Decimal::from(v).trunc_with_scale(s.unwrap() as u32)))), + LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| { + let mut decimal = Decimal::from(v); + Self::decimal_round_i(option, &mut decimal); + + decimal + }))), _ => Err(TypeError::CastFail), } } @@ -553,7 +587,12 @@ impl DataValue { LogicalType::Float => Ok(DataValue::Float32(value.map(|v| v.into()))), LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))), LogicalType::Varchar(len) => varchar_cast!(value, len), - LogicalType::Decimal(_,s) => Ok(DataValue::Decimal(value.map(|v| Decimal::from(v).trunc_with_scale(s.unwrap() as u32)))), + LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| { + let mut decimal = Decimal::from(v); + Self::decimal_round_i(option, &mut decimal); + + decimal + }))), _ => Err(TypeError::CastFail), } } @@ -565,7 +604,12 @@ impl DataValue { LogicalType::UBigint => Ok(DataValue::UInt64(value.map(|v| v.into()))), LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))), LogicalType::Varchar(len) => varchar_cast!(value, len), - LogicalType::Decimal(_,s) => Ok(DataValue::Decimal(value.map(|v| Decimal::from(v).trunc_with_scale(s.unwrap() as u32)))), + LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| { + let mut decimal = Decimal::from(v); + Self::decimal_round_i(option, &mut decimal); + + decimal + }))), _ => Err(TypeError::CastFail), } } @@ -574,7 +618,12 @@ impl DataValue { LogicalType::SqlNull => Ok(DataValue::Null), LogicalType::UBigint => Ok(DataValue::UInt64(value.map(|v| v.into()))), LogicalType::Varchar(len) => varchar_cast!(value, len), - LogicalType::Decimal(_,s) => Ok(DataValue::Decimal(value.map(|v| Decimal::from(v).trunc_with_scale(s.unwrap() as u32)))), + LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| { + let mut decimal = Decimal::from(v); + Self::decimal_round_i(option, &mut decimal); + + decimal + }))), _ => Err(TypeError::CastFail), } } @@ -663,6 +712,23 @@ impl DataValue { } } + fn decimal_round_i(option: &Option, decimal: &mut Decimal) { + if let Some(scale) = option { + let new_decimal = decimal.trunc_with_scale(*scale as u32); + let _ = mem::replace(decimal, new_decimal); + } + } + + fn decimal_round_f(option: &Option, decimal: &mut Decimal) { + if let Some(scale) = option { + let new_decimal = decimal.round_dp_with_strategy( + *scale as u32, + rust_decimal::RoundingStrategy::MidpointAwayFromZero + ); + let _ = mem::replace(decimal, new_decimal); + } + } + fn date_format<'a>(v: i32) -> Option>> { NaiveDate::from_num_days_from_ce_opt(v) .map(|date| date.format(DATE_FMT))