diff --git a/Cargo.lock b/Cargo.lock index cf2c39426316..2b55b6db8d05 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -857,6 +857,19 @@ dependencies = [ "num-traits", ] +[[package]] +name = "bigdecimal" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c06619be423ea5bb86c95f087d5707942791a08a85530df0db2209a3ecfb8bc9" +dependencies = [ + "autocfg", + "libm", + "num-bigint", + "num-integer", + "num-traits", +] + [[package]] name = "bincode" version = "1.3.3" @@ -2739,6 +2752,20 @@ dependencies = [ "uuid", ] +[[package]] +name = "decimal" +version = "0.4.2" +dependencies = [ + "arrow", + "bigdecimal 0.4.2", + "common-error", + "common-macro", + "rust_decimal", + "serde", + "serde_json", + "snafu", +] + [[package]] name = "der" version = "0.5.1" @@ -4974,7 +5001,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57349d5a326b437989b6ee4dc8f2f34b0cc131202748414712a8e7d98952fc8c" dependencies = [ "base64 0.21.5", - "bigdecimal", + "bigdecimal 0.3.1", "bindgen", "bitflags 2.4.1", "bitvec", diff --git a/Cargo.toml b/Cargo.toml index 9d4b42901e5d..489b4f5de1a5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,7 @@ members = [ "src/common/telemetry", "src/common/test-util", "src/common/time", + "src/common/decimal", "src/common/version", "src/datanode", "src/datatypes", @@ -68,6 +69,7 @@ arrow-flight = "47.0" arrow-schema = { version = "47.0", features = ["serde"] } async-stream = "0.3" async-trait = "0.1" +bigdecimal = "0.4.2" chrono = { version = "0.4", features = ["serde"] } datafusion = { git = "https://github.com/apache/arrow-datafusion.git", rev = "26e43acac3a96cec8dd4c8365f22dfb1a84306e9" } datafusion-common = { git = "https://github.com/apache/arrow-datafusion.git", rev = "26e43acac3a96cec8dd4c8365f22dfb1a84306e9" } @@ -104,6 +106,7 @@ reqwest = { version = "0.11", default-features = false, features = [ "rustls-tls-native-roots", "stream", ] } +rust_decimal = "1.32.0" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" smallvec = "1" diff --git a/src/common/decimal/Cargo.toml b/src/common/decimal/Cargo.toml new file mode 100644 index 000000000000..0c162e11f0c4 --- /dev/null +++ b/src/common/decimal/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "decimal" +version.workspace = true +edition.workspace = true +license.workspace = true + +[dependencies] +arrow.workspace = true +bigdecimal = { workspace = true } +common-error = { workspace = true } +common-macro = { workspace = true } +rust_decimal = { workspace = true } +serde.workspace = true +serde_json = "1.0" +snafu.workspace = true diff --git a/src/common/decimal/src/decimal128.rs b/src/common/decimal/src/decimal128.rs new file mode 100644 index 000000000000..98becbe9c8a7 --- /dev/null +++ b/src/common/decimal/src/decimal128.rs @@ -0,0 +1,394 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::fmt::Display; +use std::hash::Hash; +use std::str::FromStr; + +use bigdecimal::{BigDecimal, ToPrimitive}; +use rust_decimal::Decimal as RustDecimal; +use serde::{Deserialize, Serialize}; +use snafu::ResultExt; + +use crate::error::{ + self, BigDecimalOutOfRangeSnafu, Error, InvalidPrecisionOrScaleSnafu, ParseBigDecimalStrSnafu, + ParseRustDecimalStrSnafu, +}; + +/// The maximum precision for [Decimal128] values +pub const DECIMAL128_MAX_PRECISION: u8 = 38; + +/// The maximum scale for [Decimal128] values +pub const DECIMAL128_MAX_SCALE: i8 = 38; + +/// The default scale for [Decimal128] values +pub const DECIMAL128_DEFAULT_SCALE: i8 = 10; + +/// The maximum bytes length that an accurate RustDecimal can represent +const BYTES_TO_OVERFLOW_RUST_DECIMAL: usize = 28; + +/// 128bit decimal, using the i128 to represent the decimal. +/// +/// **precision**: the total number of digits in the number, it's range is \[1, 38\]. +/// +/// **scale**: the number of digits to the right of the decimal point, it's range is \[0, precision\]. +#[derive(Debug, Default, Eq, Copy, Clone, Serialize, Deserialize)] +pub struct Decimal128 { + value: i128, + precision: u8, + scale: i8, +} + +impl Decimal128 { + /// Create a new Decimal128 from i128, precision and scale. + pub fn new_unchecked(value: i128, precision: u8, scale: i8) -> Self { + Self { + value, + precision, + scale, + } + } + + pub fn try_new(value: i128, precision: u8, scale: i8) -> error::Result { + // make sure the precision and scale is valid. + valid_precision_and_scale(precision, scale)?; + Ok(Self { + value, + precision, + scale, + }) + } + + pub fn val(&self) -> i128 { + self.value + } + + /// Returns the precision of this decimal. + pub fn precision(&self) -> u8 { + self.precision + } + + /// Returns the scale of this decimal. + pub fn scale(&self) -> i8 { + self.scale + } + + /// Convert to ScalarValue + pub fn to_scalar_value(&self) -> (Option, u8, i8) { + (Some(self.value), self.precision, self.scale) + } +} + +impl PartialEq for Decimal128 { + fn eq(&self, other: &Self) -> bool { + self.precision.eq(&other.precision) + && self.scale.eq(&other.scale) + && self.value.eq(&other.value) + } +} + +// Two decimal values can be compared if they have the same precision and scale. +impl PartialOrd for Decimal128 { + fn partial_cmp(&self, other: &Self) -> Option { + if self.precision == other.precision && self.scale == other.scale { + return self.value.partial_cmp(&other.value); + } + None + } +} + +/// Convert from string to Decimal128 +/// If the string length is less than 28, the result of rust_decimal will underflow, +/// In this case, use BigDecimal to get accurate result. +impl FromStr for Decimal128 { + type Err = Error; + + fn from_str(s: &str) -> Result { + let len = s.as_bytes().len(); + if len <= BYTES_TO_OVERFLOW_RUST_DECIMAL { + let rd = RustDecimal::from_str_exact(s).context(ParseRustDecimalStrSnafu { raw: s })?; + Ok(Self::from(rd)) + } else { + let bd = BigDecimal::from_str(s).context(ParseBigDecimalStrSnafu { raw: s })?; + Self::try_from(bd) + } + } +} + +impl Display for Decimal128 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + format_decimal_str(&self.value.to_string(), self.precision as usize, self.scale) + ) + } +} + +impl Hash for Decimal128 { + fn hash(&self, state: &mut H) { + state.write_i128(self.value); + state.write_u8(self.precision); + state.write_i8(self.scale); + } +} + +impl From for serde_json::Value { + fn from(decimal: Decimal128) -> Self { + serde_json::Value::String(decimal.to_string()) + } +} + +impl From for i128 { + fn from(decimal: Decimal128) -> Self { + decimal.val() + } +} + +impl From for Decimal128 { + fn from(value: i128) -> Self { + Self { + value, + precision: DECIMAL128_MAX_PRECISION, + scale: DECIMAL128_DEFAULT_SCALE, + } + } +} + +/// Convert from RustDecimal to Decimal128 +/// RustDecimal can represent the range is smaller than Decimal128, +/// it is safe to convert RustDecimal to Decimal128 +impl From for Decimal128 { + fn from(rd: RustDecimal) -> Self { + let s = rd.to_string(); + let precision = (s.len() - s.matches(&['.', '-'][..]).count()) as u8; + Self { + value: rd.mantissa(), + precision, + scale: rd.scale() as i8, + } + } +} + +/// Try from BigDecimal to Decimal128 +/// The range that BigDecimal can represent is larger than Decimal128, +/// so it is not safe to convert BigDecimal to Decimal128, +/// If the BigDecimal is out of range, return error. +impl TryFrom for Decimal128 { + type Error = Error; + + fn try_from(value: BigDecimal) -> Result { + let precision = value.digits(); + let (big_int, scale) = value.as_bigint_and_exponent(); + // convert big_int to i128, if convert failed, return error + big_int + .to_i128() + .map(|val| Self::try_new(val, precision as u8, scale as i8)) + .unwrap_or_else(|| BigDecimalOutOfRangeSnafu { value }.fail()) + } +} + +/// Port from arrow-rs, +/// see https://github.com/Apache/arrow-rs/blob/master/arrow-array/src/types.rs#L1323-L1344 +fn format_decimal_str(value_str: &str, precision: usize, scale: i8) -> String { + let (sign, rest) = match value_str.strip_prefix('-') { + Some(stripped) => ("-", stripped), + None => ("", value_str), + }; + + let bound = precision.min(rest.len()) + sign.len(); + let value_str = &value_str[0..bound]; + + if scale == 0 { + value_str.to_string() + } else if scale < 0 { + let padding = value_str.len() + scale.unsigned_abs() as usize; + format!("{value_str:0 scale as usize { + // Decimal separator is in the middle of the string + let (whole, decimal) = value_str.split_at(value_str.len() - scale as usize); + format!("{whole}.{decimal}") + } else { + // String has to be padded + format!("{}0.{:0>width$}", sign, rest, width = scale as usize) + } +} + +/// check whether precision and scale is valid +fn valid_precision_and_scale(precision: u8, scale: i8) -> error::Result<()> { + if precision == 0 { + return InvalidPrecisionOrScaleSnafu { + reason: format!( + "precision cannot be 0, has to be between [1, {}]", + DECIMAL128_MAX_PRECISION + ), + } + .fail(); + } + if precision > DECIMAL128_MAX_PRECISION { + return InvalidPrecisionOrScaleSnafu { + reason: format!( + "precision {} is greater than max {}", + precision, DECIMAL128_MAX_PRECISION + ), + } + .fail(); + } + if scale > DECIMAL128_MAX_SCALE { + return InvalidPrecisionOrScaleSnafu { + reason: format!( + "scale {} is greater than max {}", + scale, DECIMAL128_MAX_SCALE + ), + } + .fail(); + } + if scale > 0 && scale > precision as i8 { + return InvalidPrecisionOrScaleSnafu { + reason: format!("scale {} is greater than precision {}", scale, precision), + } + .fail(); + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_common_decimal128() { + let decimal = Decimal128::new_unchecked(123456789, 9, 3); + assert_eq!(decimal.to_string(), "123456.789"); + + let decimal = Decimal128::try_new(123456789, 9, 0); + assert_eq!(decimal.unwrap().to_string(), "123456789"); + + let decimal = Decimal128::try_new(123456789, 9, 2); + assert_eq!(decimal.unwrap().to_string(), "1234567.89"); + + let decimal = Decimal128::try_new(123, 3, -2); + assert_eq!(decimal.unwrap().to_string(), "12300"); + + // invalid precision or scale + + // precision is 0 + let decimal = Decimal128::try_new(123, 0, 0); + assert!(decimal.is_err()); + + // precision is greater than 38 + let decimal = Decimal128::try_new(123, 39, 0); + assert!(decimal.is_err()); + + // scale is greater than 38 + let decimal = Decimal128::try_new(123, 38, 39); + assert!(decimal.is_err()); + + // scale is greater than precision + let decimal = Decimal128::try_new(123, 3, 4); + assert!(decimal.is_err()); + } + + #[test] + fn test_decimal128_from_str() { + // 0 < precision <= 28 + let decimal = Decimal128::from_str("1234567890.123456789").unwrap(); + assert_eq!(decimal.to_string(), "1234567890.123456789"); + assert_eq!(decimal.precision(), 19); + assert_eq!(decimal.scale(), 9); + + let decimal = Decimal128::from_str("1234567890.123456789012345678").unwrap(); + assert_eq!(decimal.to_string(), "1234567890.123456789012345678"); + assert_eq!(decimal.precision(), 28); + assert_eq!(decimal.scale(), 18); + + // 28 < precision <= 38 + let decimal = Decimal128::from_str("1234567890.1234567890123456789012").unwrap(); + assert_eq!(decimal.to_string(), "1234567890.1234567890123456789012"); + assert_eq!(decimal.precision(), 32); + assert_eq!(decimal.scale(), 22); + + let decimal = Decimal128::from_str("1234567890.1234567890123456789012345678").unwrap(); + assert_eq!( + decimal.to_string(), + "1234567890.1234567890123456789012345678" + ); + assert_eq!(decimal.precision(), 38); + assert_eq!(decimal.scale(), 28); + + // precision > 38 + let decimal = Decimal128::from_str("1234567890.12345678901234567890123456789"); + assert!(decimal.is_err()); + } + + #[test] + #[ignore] + fn test_parse_decimal128_speed() { + // RustDecimal::from_str: 1.124855167s + for _ in 0..1500000 { + let _ = RustDecimal::from_str("1234567890.123456789012345678999").unwrap(); + } + + // BigDecimal::try_from: 6.799290042s + for _ in 0..1500000 { + let _ = BigDecimal::from_str("1234567890.123456789012345678999").unwrap(); + } + } + + #[test] + fn test_decimal128_precision_and_scale() { + // precision and scale from Deicmal(1,1) to Decimal(38,38) + for precision in 1..=38 { + for scale in 1..=precision { + let decimal_str = format!("0.{}", "1".repeat(scale as usize)); + let decimal = Decimal128::from_str(&decimal_str).unwrap(); + assert_eq!(decimal_str, decimal.to_string()); + } + } + } + + #[test] + fn test_decimal128_compare() { + // the same precision and scale + let decimal1 = Decimal128::from_str("1234567890.123456789012345678999").unwrap(); + let decimal2 = Decimal128::from_str("1234567890.123456789012345678999").unwrap(); + assert!(decimal1 == decimal2); + + let decimal1 = Decimal128::from_str("1234567890.123456789012345678999").unwrap(); + let decimal2 = Decimal128::from_str("1234567890.123456789012345678998").unwrap(); + assert!(decimal1 > decimal2); + + let decimal1 = Decimal128::from_str("1234567890.123456789012345678999").unwrap(); + let decimal2 = Decimal128::from_str("1234567890.123456789012345678998").unwrap(); + assert!(decimal2 < decimal1); + + let decimal1 = Decimal128::from_str("1234567890.123456789012345678999").unwrap(); + let decimal2 = Decimal128::from_str("1234567890.123456789012345678998").unwrap(); + assert!(decimal1 >= decimal2); + + let decimal1 = Decimal128::from_str("1234567890.123456789012345678999").unwrap(); + let decimal2 = Decimal128::from_str("1234567890.123456789012345678998").unwrap(); + assert!(decimal2 <= decimal1); + + let decimal1 = Decimal128::from_str("1234567890.123456789012345678999").unwrap(); + let decimal2 = Decimal128::from_str("1234567890.123456789012345678998").unwrap(); + assert!(decimal1 != decimal2); + + // different precision and scale cmp is None + let decimal1 = Decimal128::from_str("1234567890.123456789012345678999").unwrap(); + let decimal2 = Decimal128::from_str("1234567890.123").unwrap(); + assert_eq!(decimal1.partial_cmp(&decimal2), None); + } +} diff --git a/src/common/decimal/src/error.rs b/src/common/decimal/src/error.rs new file mode 100644 index 000000000000..8bfa3e9fe8f6 --- /dev/null +++ b/src/common/decimal/src/error.rs @@ -0,0 +1,72 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use bigdecimal::BigDecimal; +use common_error::ext::ErrorExt; +use common_error::status_code::StatusCode; +use common_macro::stack_trace_debug; +use snafu::{Location, Snafu}; + +#[derive(Snafu)] +#[snafu(visibility(pub))] +#[stack_trace_debug] +pub enum Error { + #[snafu(display("Decimal out of range, decimal value: {}", value))] + BigDecimalOutOfRange { + value: BigDecimal, + location: Location, + }, + + #[snafu(display("Failed to parse string to rust decimal, raw: {}", raw))] + ParseRustDecimalStr { + raw: String, + #[snafu(source)] + error: rust_decimal::Error, + }, + + #[snafu(display("Failed to parse string to big decimal, raw: {}", raw))] + ParseBigDecimalStr { + raw: String, + #[snafu(source)] + error: bigdecimal::ParseBigDecimalError, + }, + + #[snafu(display("Invalid precision or scale, resion: {}", reason))] + InvalidPrecisionOrScale { reason: String, location: Location }, +} + +impl ErrorExt for Error { + fn status_code(&self) -> StatusCode { + match self { + Error::BigDecimalOutOfRange { .. } => StatusCode::Internal, + Error::ParseRustDecimalStr { .. } + | Error::InvalidPrecisionOrScale { .. } + | Error::ParseBigDecimalStr { .. } => StatusCode::InvalidArguments, + } + } + + fn location_opt(&self) -> Option { + match self { + Error::BigDecimalOutOfRange { location, .. } => Some(*location), + Error::InvalidPrecisionOrScale { location, .. } => Some(*location), + Error::ParseRustDecimalStr { .. } | Error::ParseBigDecimalStr { .. } => None, + } + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } +} + +pub type Result = std::result::Result; diff --git a/src/common/decimal/src/lib.rs b/src/common/decimal/src/lib.rs new file mode 100644 index 000000000000..815c79fa0fad --- /dev/null +++ b/src/common/decimal/src/lib.rs @@ -0,0 +1,16 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +pub mod decimal128; +pub mod error;