diff --git a/Cargo.toml b/Cargo.toml index 1182c3da92bf..93af308ff041 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -92,7 +92,7 @@ arrow-ipc = { version = "53.3.0", default-features = false, features = [ arrow-ord = { version = "53.3.0", default-features = false } arrow-schema = { version = "53.3.0", default-features = false } async-trait = "0.1.73" -bigdecimal = "=0.4.1" +bigdecimal = "0.4.6" bytes = "1.4" chrono = { version = "0.4.38", default-features = false } ctor = "0.2.0" diff --git a/datafusion/sqllogictest/src/engines/conversion.rs b/datafusion/sqllogictest/src/engines/conversion.rs index 909539b3131b..8d2fd1e6d0f2 100644 --- a/datafusion/sqllogictest/src/engines/conversion.rs +++ b/datafusion/sqllogictest/src/engines/conversion.rs @@ -101,5 +101,70 @@ pub(crate) fn decimal_to_str(value: Decimal) -> String { } pub(crate) fn big_decimal_to_str(value: BigDecimal) -> String { - value.round(12).normalized().to_string() + // Round the value to limit the number of decimal places + let value = value.round(12).normalized(); + // Format the value to a string + format_big_decimal(value) +} + +fn format_big_decimal(value: BigDecimal) -> String { + let (integer, scale) = value.into_bigint_and_exponent(); + let mut str = integer.to_str_radix(10); + if scale <= 0 { + // Append zeros to the right of the integer part + str.extend(std::iter::repeat('0').take(scale.unsigned_abs() as usize)); + str + } else { + let (sign, unsigned_len, unsigned_str) = if integer.is_negative() { + ("-", str.len() - 1, &str[1..]) + } else { + ("", str.len(), &str[..]) + }; + let scale = scale as usize; + if unsigned_len <= scale { + format!("{}0.{:0>scale$}", sign, unsigned_str) + } else { + str.insert(str.len() - scale, '.'); + str + } + } +} + +#[cfg(test)] +mod tests { + use super::big_decimal_to_str; + use bigdecimal::{num_bigint::BigInt, BigDecimal}; + + macro_rules! assert_decimal_str_eq { + ($integer:expr, $scale:expr, $expected:expr) => { + assert_eq!( + big_decimal_to_str(BigDecimal::from_bigint( + BigInt::from($integer), + $scale + )), + $expected + ); + }; + } + + #[test] + fn test_big_decimal_to_str() { + assert_decimal_str_eq!(11, 3, "0.011"); + assert_decimal_str_eq!(11, 2, "0.11"); + assert_decimal_str_eq!(11, 1, "1.1"); + assert_decimal_str_eq!(11, 0, "11"); + assert_decimal_str_eq!(11, -1, "110"); + assert_decimal_str_eq!(0, 0, "0"); + + // Negative cases + assert_decimal_str_eq!(-11, 3, "-0.011"); + assert_decimal_str_eq!(-11, 2, "-0.11"); + assert_decimal_str_eq!(-11, 1, "-1.1"); + assert_decimal_str_eq!(-11, 0, "-11"); + assert_decimal_str_eq!(-11, -1, "-110"); + + // Round to 12 decimal places + // 1.0000000000011 -> 1.000000000001 + assert_decimal_str_eq!(10_i128.pow(13) + 11, 13, "1.000000000001"); + } }