Skip to content

Commit

Permalink
fix: DataValue::cast fixes Decimal conversion problem
Browse files Browse the repository at this point in the history
  • Loading branch information
KKould committed Sep 24, 2023
1 parent dfcf5c6 commit f63c8da
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 16 deletions.
4 changes: 2 additions & 2 deletions src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ impl LogicalType {
LogicalType::Float => Some(4),
LogicalType::Double => Some(8),
/// Note: The non-fixed length type's raw_len is None e.g. Varchar and Decimal
LogicalType::Varchar(_)=>None,
LogicalType::Decimal(_, _) =>None,
LogicalType::Varchar(_) => None,
LogicalType::Decimal(_, _) => Some(16),
LogicalType::Date => Some(4),
LogicalType::DateTime => Some(8),
}
Expand Down
94 changes: 80 additions & 14 deletions src/types/value.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::cmp::Ordering;
use std::fmt;
use std::{fmt, mem};
use std::fmt::Formatter;
use std::hash::Hash;
use std::str::FromStr;
Expand Down Expand Up @@ -241,7 +241,6 @@ impl DataValue {
pub fn is_variable(&self) -> bool {
match self {
DataValue::Utf8(_) => true,
DataValue::Decimal(_) => true,
_ => false
}
}
Expand Down Expand Up @@ -445,8 +444,13 @@ impl DataValue {
LogicalType::Float => Ok(DataValue::Float32(value)),
LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))),
LogicalType::Varchar(len) => varchar_cast!(value, len),
LogicalType::Decimal(_,s) =>{
Ok(DataValue::Decimal(value.map(|v| Decimal::from_f32(v).unwrap().round_dp_with_strategy( s.clone().unwrap() as u32, rust_decimal::RoundingStrategy::MidpointAwayFromZero))))
LogicalType::Decimal(_, option) =>{
Ok(DataValue::Decimal(value.map(|v| {
let mut decimal = Decimal::from_f32(v).ok_or(TypeError::CastFail)?;
Self::decimal_round_f(option, &mut decimal);

Ok::<Decimal, TypeError>(decimal)
}).transpose()?))
}
_ => Err(TypeError::CastFail),
}
Expand All @@ -456,8 +460,13 @@ impl DataValue {
LogicalType::SqlNull => Ok(DataValue::Null),
LogicalType::Double => Ok(DataValue::Float64(value)),
LogicalType::Varchar(len) => varchar_cast!(value, len),
LogicalType::Decimal(_,s) => {
Ok(DataValue::Decimal(value.map(|v| Decimal::from_f64(v).unwrap().round_dp_with_strategy( s.clone().unwrap() as u32, rust_decimal::RoundingStrategy::MidpointAwayFromZero))))
LogicalType::Decimal(_, option) => {
Ok(DataValue::Decimal(value.map(|v| {
let mut decimal = Decimal::from_f64(v).ok_or(TypeError::CastFail)?;
Self::decimal_round_f(option, &mut decimal);

Ok::<Decimal, TypeError>(decimal)
}).transpose()?))
}
_ => Err(TypeError::CastFail),
}
Expand All @@ -476,7 +485,12 @@ impl DataValue {
LogicalType::Float => Ok(DataValue::Float32(value.map(|v| v.into()))),
LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))),
LogicalType::Varchar(len) => varchar_cast!(value, len),
LogicalType::Decimal(_,s) => Ok(DataValue::Decimal(value.map(|v| Decimal::from(v).trunc_with_scale(s.unwrap() as u32)))),
LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| {
let mut decimal = Decimal::from(v);
Self::decimal_round_i(option, &mut decimal);

decimal
}))),
_ => Err(TypeError::CastFail),
}
}
Expand All @@ -493,7 +507,12 @@ impl DataValue {
LogicalType::Float => Ok(DataValue::Float32(value.map(|v| v.into()))),
LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))),
LogicalType::Varchar(len) => varchar_cast!(value, len),
LogicalType::Decimal(_,s) => Ok(DataValue::Decimal(value.map(|v| Decimal::from(v).trunc_with_scale(s.unwrap() as u32)))),
LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| {
let mut decimal = Decimal::from(v);
Self::decimal_round_i(option, &mut decimal);

decimal
}))),
_ => Err(TypeError::CastFail),
}
}
Expand All @@ -508,7 +527,12 @@ impl DataValue {
LogicalType::Bigint => Ok(DataValue::Int64(value.map(|v| v.into()))),
LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))),
LogicalType::Varchar(len) => varchar_cast!(value, len),
LogicalType::Decimal(_,s) => Ok(DataValue::Decimal(value.map(|v| Decimal::from(v).trunc_with_scale(s.unwrap() as u32)))),
LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| {
let mut decimal = Decimal::from(v);
Self::decimal_round_i(option, &mut decimal);

decimal
}))),
_ => Err(TypeError::CastFail),
}
}
Expand All @@ -521,7 +545,12 @@ impl DataValue {
LogicalType::UBigint => Ok(DataValue::UInt64(value.map(|v| u64::try_from(v)).transpose()?)),
LogicalType::Bigint => Ok(DataValue::Int64(value.map(|v| v.into()))),
LogicalType::Varchar(len) => varchar_cast!(value, len),
LogicalType::Decimal(_,s) => Ok(DataValue::Decimal(value.map(|v| Decimal::from(v).trunc_with_scale(s.unwrap() as u32)))),
LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| {
let mut decimal = Decimal::from(v);
Self::decimal_round_i(option, &mut decimal);

decimal
}))),
_ => Err(TypeError::CastFail),
}
}
Expand All @@ -538,7 +567,12 @@ impl DataValue {
LogicalType::Float => Ok(DataValue::Float32(value.map(|v| v.into()))),
LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))),
LogicalType::Varchar(len) => varchar_cast!(value, len),
LogicalType::Decimal(_,s) => Ok(DataValue::Decimal(value.map(|v| Decimal::from(v).trunc_with_scale(s.unwrap() as u32)))),
LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| {
let mut decimal = Decimal::from(v);
Self::decimal_round_i(option, &mut decimal);

decimal
}))),
_ => Err(TypeError::CastFail),
}
}
Expand All @@ -553,7 +587,12 @@ impl DataValue {
LogicalType::Float => Ok(DataValue::Float32(value.map(|v| v.into()))),
LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))),
LogicalType::Varchar(len) => varchar_cast!(value, len),
LogicalType::Decimal(_,s) => Ok(DataValue::Decimal(value.map(|v| Decimal::from(v).trunc_with_scale(s.unwrap() as u32)))),
LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| {
let mut decimal = Decimal::from(v);
Self::decimal_round_i(option, &mut decimal);

decimal
}))),
_ => Err(TypeError::CastFail),
}
}
Expand All @@ -565,7 +604,12 @@ impl DataValue {
LogicalType::UBigint => Ok(DataValue::UInt64(value.map(|v| v.into()))),
LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))),
LogicalType::Varchar(len) => varchar_cast!(value, len),
LogicalType::Decimal(_,s) => Ok(DataValue::Decimal(value.map(|v| Decimal::from(v).trunc_with_scale(s.unwrap() as u32)))),
LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| {
let mut decimal = Decimal::from(v);
Self::decimal_round_i(option, &mut decimal);

decimal
}))),
_ => Err(TypeError::CastFail),
}
}
Expand All @@ -574,7 +618,12 @@ impl DataValue {
LogicalType::SqlNull => Ok(DataValue::Null),
LogicalType::UBigint => Ok(DataValue::UInt64(value.map(|v| v.into()))),
LogicalType::Varchar(len) => varchar_cast!(value, len),
LogicalType::Decimal(_,s) => Ok(DataValue::Decimal(value.map(|v| Decimal::from(v).trunc_with_scale(s.unwrap() as u32)))),
LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| {
let mut decimal = Decimal::from(v);
Self::decimal_round_i(option, &mut decimal);

decimal
}))),
_ => Err(TypeError::CastFail),
}
}
Expand Down Expand Up @@ -663,6 +712,23 @@ impl DataValue {
}
}

fn decimal_round_i(option: &Option<u8>, decimal: &mut Decimal) {
if let Some(scale) = option {
let new_decimal = decimal.trunc_with_scale(*scale as u32);
let _ = mem::replace(decimal, new_decimal);
}
}

fn decimal_round_f(option: &Option<u8>, decimal: &mut Decimal) {
if let Some(scale) = option {
let new_decimal = decimal.round_dp_with_strategy(
*scale as u32,
rust_decimal::RoundingStrategy::MidpointAwayFromZero
);
let _ = mem::replace(decimal, new_decimal);
}
}

fn date_format<'a>(v: i32) -> Option<DelayedFormat<StrftimeItems<'a>>> {
NaiveDate::from_num_days_from_ce_opt(v)
.map(|date| date.format(DATE_FMT))
Expand Down

0 comments on commit f63c8da

Please sign in to comment.