From 262eb168fa85409ad45b1d3019e835d298eb80af Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Mon, 28 Nov 2022 16:35:54 +0800 Subject: [PATCH] perf(memcomparable): optimize ser/de for decimal (#6586) * optimize ser/de for decimal in memcomparable Signed-off-by: Runji Wang * remove date/time from memcomparable crate because they are trivial Signed-off-by: Runji Wang * fix decimal test Signed-off-by: Runji Wang * Update src/utils/memcomparable/src/decimal.rs Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fix clippy and refactor read_bytes_len Signed-off-by: Runji Wang Signed-off-by: Runji Wang Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> --- Cargo.lock | 3 +- src/common/Cargo.toml | 2 +- src/common/src/types/decimal.rs | 48 ++-- src/common/src/types/mod.rs | 40 ++- src/common/src/util/ordered/serde.rs | 12 +- src/utils/memcomparable/Cargo.toml | 16 +- src/utils/memcomparable/benches/serde.rs | 65 +++++ src/utils/memcomparable/src/de.rs | 328 +++++------------------ src/utils/memcomparable/src/decimal.rs | 52 ++++ src/utils/memcomparable/src/error.rs | 2 + src/utils/memcomparable/src/lib.rs | 6 +- src/utils/memcomparable/src/ser.rs | 275 +++++++------------ src/workspace-hack/Cargo.toml | 4 +- 13 files changed, 360 insertions(+), 493 deletions(-) create mode 100644 src/utils/memcomparable/benches/serde.rs create mode 100644 src/utils/memcomparable/src/decimal.rs diff --git a/Cargo.lock b/Cargo.lock index e70acb716a330..e92b392161a7b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3316,11 +3316,11 @@ name = "memcomparable" version = "0.2.0-alpha" dependencies = [ "bytes", + "criterion", "rand 0.8.5", "rust_decimal", "serde", "thiserror", - "workspace-hack", ] [[package]] @@ -6077,6 +6077,7 @@ dependencies = [ "bytes", "num-traits", "postgres", + "rand 0.8.5", "serde", "tokio-postgres", ] diff --git a/src/common/Cargo.toml b/src/common/Cargo.toml index 59daf02904d61..d778ff9b00603 100644 --- a/src/common/Cargo.toml +++ b/src/common/Cargo.toml @@ -26,7 +26,7 @@ futures-async-stream = "0.2" humantime = "2.1" itertools = "0.10" lru = { git = "https://github.com/risingwavelabs/lru-rs.git", branch = "evict_by_timestamp" } -memcomparable = { path = "../utils/memcomparable" } +memcomparable = { path = "../utils/memcomparable", features = ["decimal"] } more-asserts = "0.3" num-traits = "0.2" parking_lot = "0.12" diff --git a/src/common/src/types/decimal.rs b/src/common/src/types/decimal.rs index 7740e3f09893f..fac87f632573f 100644 --- a/src/common/src/types/decimal.rs +++ b/src/common/src/types/decimal.rs @@ -519,34 +519,14 @@ impl Decimal { } } - /// TODO: 1. test whether the decimal in rust, any crate, has the same behavior as PG. - /// 2. support memcomparable encoding for dynamic decimal. - pub fn mantissa_scale_for_serialization(&self) -> (i128, u8) { - // Since the largest scale supported by `rust_decimal` is 28, - // and we first compare scale, we use 29 and 30 to denote +Inf and NaN. - match self { - Self::NegativeInf => (0, 29), - Self::Normalized(d) => { - // We remark that we do not dynamic numeric, i.e. the scale of all the numeric in - // the system is fixed. So we don't need to do any rescale, just use - // the `scale` of `rust_decimal`. However, it is possible that scale - // may overflow during calculation as `rust_decimal`'s max scale is - // 28. - (d.mantissa(), d.scale() as u8) - } - Self::PositiveInf => (0, 30), - Self::NaN => (0, 31), - } - } - pub fn unordered_serialize(&self) -> [u8; 16] { // according to https://docs.rs/rust_decimal/1.18.0/src/rust_decimal/decimal.rs.html#665-684 // the lower 15 bits is not used, so we can use first byte to distinguish nan and inf match self { Self::Normalized(d) => d.serialize(), - Self::NaN => [vec![1u8], vec![0u8; 15]].concat().try_into().unwrap(), - Self::PositiveInf => [vec![2u8], vec![0u8; 15]].concat().try_into().unwrap(), - Self::NegativeInf => [vec![3u8], vec![0u8; 15]].concat().try_into().unwrap(), + Self::NaN => [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + Self::PositiveInf => [2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + Self::NegativeInf => [3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], } } @@ -576,6 +556,28 @@ impl Decimal { } } +impl From for memcomparable::Decimal { + fn from(d: Decimal) -> Self { + match d { + Decimal::Normalized(d) => Self::Normalized(d), + Decimal::PositiveInf => Self::Inf, + Decimal::NegativeInf => Self::NegInf, + Decimal::NaN => Self::NaN, + } + } +} + +impl From for Decimal { + fn from(d: memcomparable::Decimal) -> Self { + match d { + memcomparable::Decimal::Normalized(d) => Self::Normalized(d), + memcomparable::Decimal::Inf => Self::PositiveInf, + memcomparable::Decimal::NegInf => Self::NegativeInf, + memcomparable::Decimal::NaN => Self::NaN, + } + } +} + impl Default for Decimal { fn default() -> Self { Self::Normalized(RustDecimal::default()) diff --git a/src/common/src/types/mod.rs b/src/common/src/types/mod.rs index 9ebafe7729f50..9fc898599167d 100644 --- a/src/common/src/types/mod.rs +++ b/src/common/src/types/mod.rs @@ -817,17 +817,16 @@ impl ScalarRefImpl<'_> { Self::Float64(v) => v.serialize(ser)?, Self::Utf8(v) => v.serialize(ser)?, Self::Bool(v) => v.serialize(ser)?, - Self::Decimal(v) => { - let (mantissa, scale) = v.mantissa_scale_for_serialization(); - ser.serialize_decimal(mantissa, scale)?; - } + Self::Decimal(v) => ser.serialize_decimal((*v).into())?, Self::Interval(v) => v.serialize(ser)?, - Self::NaiveDate(v) => ser.serialize_naivedate(v.0.num_days_from_ce())?, + Self::NaiveDate(v) => v.0.num_days_from_ce().serialize(ser)?, Self::NaiveDateTime(v) => { - ser.serialize_naivedatetime(v.0.timestamp(), v.0.timestamp_subsec_nanos())? + v.0.timestamp().serialize(&mut *ser)?; + v.0.timestamp_subsec_nanos().serialize(ser)?; } Self::NaiveTime(v) => { - ser.serialize_naivetime(v.0.num_seconds_from_midnight(), v.0.nanosecond())? + v.0.num_seconds_from_midnight().serialize(&mut *ser)?; + v.0.nanosecond().serialize(ser)?; } Self::Struct(v) => v.serialize(ser)?, Self::List(v) => v.serialize(ser)?, @@ -859,27 +858,21 @@ impl ScalarImpl { Ty::Float64 => Self::Float64(f64::deserialize(de)?.into()), Ty::Varchar => Self::Utf8(String::deserialize(de)?), Ty::Boolean => Self::Bool(bool::deserialize(de)?), - Ty::Decimal => Self::Decimal({ - let (mantissa, scale) = de.deserialize_decimal()?; - match scale { - 29 => Decimal::NegativeInf, - 30 => Decimal::PositiveInf, - 31 => Decimal::NaN, - _ => Decimal::from_i128_with_scale(mantissa, scale as u32), - } - }), + Ty::Decimal => Self::Decimal(de.deserialize_decimal()?.into()), Ty::Interval => Self::Interval(IntervalUnit::deserialize(de)?), Ty::Time => Self::NaiveTime({ - let (secs, nano) = de.deserialize_naivetime()?; + let secs = u32::deserialize(&mut *de)?; + let nano = u32::deserialize(de)?; NaiveTimeWrapper::with_secs_nano(secs, nano)? }), Ty::Timestamp => Self::NaiveDateTime({ - let (secs, nsecs) = de.deserialize_naivedatetime()?; + let secs = i64::deserialize(&mut *de)?; + let nsecs = u32::deserialize(de)?; NaiveDateTimeWrapper::with_secs_nsecs(secs, nsecs)? }), Ty::Timestampz => Self::Int64(i64::deserialize(de)?), Ty::Date => Self::NaiveDate({ - let days = de.deserialize_naivedate()?; + let days = i32::deserialize(de)?; NaiveDateWrapper::with_days(days)? }), Ty::Struct(t) => StructValue::deserialize(&t.fields, de)?.to_scalar_value(), @@ -915,16 +908,19 @@ impl ScalarImpl { DataType::Boolean => size_of::(), // IntervalUnit is serialized as (i32, i32, i64) DataType::Interval => size_of::<(i32, i32, i64)>(), - DataType::Decimal => deserializer.read_decimal_len()?, + DataType::Decimal => { + deserializer.deserialize_decimal()?; + 0 // the len is not used since decimal is not a fixed length type + } // these two types is var-length and should only be determine at runtime. // TODO: need some test for this case (e.g. e2e test) - DataType::List { .. } => deserializer.read_bytes_len()?, + DataType::List { .. } => deserializer.skip_bytes()?, DataType::Struct(t) => t .fields .iter() .map(|field| Self::encoding_data_size(field, deserializer)) .try_fold(0, |a, b| b.map(|b| a + b))?, - DataType::Varchar => deserializer.read_bytes_len()?, + DataType::Varchar => deserializer.skip_bytes()?, }; // consume offset of fixed_type diff --git a/src/common/src/util/ordered/serde.rs b/src/common/src/util/ordered/serde.rs index b81798dcd25dd..e658c83bcca56 100644 --- a/src/common/src/util/ordered/serde.rs +++ b/src/common/src/util/ordered/serde.rs @@ -349,8 +349,8 @@ mod tests { let encoding_data_size = ScalarImpl::encoding_data_size(&DataType::Decimal, &mut deserializer) .unwrap(); - // [nulltag, flag, decimal_chunk, 0] - assert_eq!(18, encoding_data_size); + // [nulltag, flag, decimal_chunk] + assert_eq!(17, encoding_data_size); } { @@ -362,8 +362,8 @@ mod tests { let encoding_data_size = ScalarImpl::encoding_data_size(&DataType::Decimal, &mut deserializer) .unwrap(); - // [nulltag, flag, decimal_chunk, 0] - assert_eq!(4, encoding_data_size); + // [nulltag, flag, decimal_chunk] + assert_eq!(3, encoding_data_size); } { @@ -376,7 +376,7 @@ mod tests { ScalarImpl::encoding_data_size(&DataType::Decimal, &mut deserializer) .unwrap(); - assert_eq!(3, encoding_data_size); // [1, 35, 0] + assert_eq!(2, encoding_data_size); // [1, 35] } { @@ -388,7 +388,7 @@ mod tests { let encoding_data_size = ScalarImpl::encoding_data_size(&DataType::Decimal, &mut deserializer) .unwrap(); - assert_eq!(3, encoding_data_size); // [1, 6, 0] + assert_eq!(2, encoding_data_size); // [1, 6] } { diff --git a/src/utils/memcomparable/Cargo.toml b/src/utils/memcomparable/Cargo.toml index 6beec269ec211..1f629ffe5508e 100644 --- a/src/utils/memcomparable/Cargo.toml +++ b/src/utils/memcomparable/Cargo.toml @@ -9,15 +9,21 @@ repository = { workspace = true } description = "Memcomparable format." # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[features] +decimal = ["rust_decimal"] + [dependencies] bytes = "1" -serde = { version = "1", features = ["derive"] } +rust_decimal = { version = "1", optional = true } +serde = "1" thiserror = "1" -[target.'cfg(not(madsim))'.dependencies] -workspace-hack = { path = "../../workspace-hack" } - [dev-dependencies] +criterion = "0.4" rand = "0.8" -rust_decimal = "1" +rust_decimal = { version = "1", features = ["rand"] } serde = { version = "1", features = ["derive"] } + +[[bench]] +name = "serde" +harness = false diff --git a/src/utils/memcomparable/benches/serde.rs b/src/utils/memcomparable/benches/serde.rs new file mode 100644 index 0000000000000..995a1c1c34d83 --- /dev/null +++ b/src/utils/memcomparable/benches/serde.rs @@ -0,0 +1,65 @@ +// Copyright 2022 Singularity Data +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use criterion::{criterion_group, criterion_main, Criterion}; + +criterion_group!(benches, decimal); +criterion_main!(benches); + +#[cfg(not(feature = "decimal"))] +fn decimal(_c: &mut Criterion) {} + +#[cfg(feature = "decimal")] +fn decimal(c: &mut Criterion) { + use memcomparable::{Decimal, Deserializer, Serializer}; + + // generate decimals + let mut decimals = vec![]; + for _ in 0..10 { + decimals.push(Decimal::Normalized(rand::random())); + } + + c.bench_function("serialize_decimal", |b| { + let mut i = 0; + b.iter(|| { + let mut ser = Serializer::new(vec![]); + ser.serialize_decimal(decimals[i]).unwrap(); + i += 1; + if i == decimals.len() { + i = 0; + } + }) + }); + + c.bench_function("deserialize_decimal", |b| { + let encodings = decimals + .iter() + .map(|d| { + let mut ser = Serializer::new(vec![]); + ser.serialize_decimal(*d).unwrap(); + ser.into_inner() + }) + .collect::>(); + let mut i = 0; + b.iter(|| { + Deserializer::new(encodings[i].as_slice()) + .deserialize_decimal() + .unwrap(); + i += 1; + if i == decimals.len() { + i = 0; + } + }) + }); +} diff --git a/src/utils/memcomparable/src/de.rs b/src/utils/memcomparable/src/de.rs index 20583d0057fdf..07215b475665b 100644 --- a/src/utils/memcomparable/src/de.rs +++ b/src/utils/memcomparable/src/de.rs @@ -17,10 +17,10 @@ use serde::de::{ self, DeserializeSeed, EnumAccess, IntoDeserializer, SeqAccess, VariantAccess, Visitor, }; +#[cfg(feature = "decimal")] +use crate::decimal::Decimal; use crate::error::{Error, Result}; -const DECIMAL_FLAG_LOW_BOUND: u8 = 0x6; -const DECIMAL_FLAG_UP_BOUND: u8 = 0x23; const BYTES_CHUNK_SIZE: usize = 8; const BYTES_CHUNK_UNIT_SIZE: usize = BYTES_CHUNK_SIZE + 1; @@ -107,10 +107,6 @@ impl MaybeFlip { def_method!(get_u64, u64); - def_method!(get_i32, i32); - - def_method!(get_i64, i64); - fn copy_to_slice(&mut self, dst: &mut [u8]) { self.input.copy_to_slice(dst); if self.flip { @@ -145,92 +141,23 @@ impl Deserializer { } } - fn read_decimal(&mut self) -> Result> { - let flag = self.input.get_u8(); - if !(DECIMAL_FLAG_LOW_BOUND..=DECIMAL_FLAG_UP_BOUND).contains(&flag) { - return Err(Error::InvalidBytesEncoding(flag)); - } - let mut byte_array = vec![flag]; - loop { - let byte = self.input.get_u8(); - if byte == 0 { - break; - } - byte_array.push(byte); - } - Ok(byte_array) - } - - /// Read bytes_len without copy, it will consume offset - pub fn read_bytes_len(&mut self) -> Result { - use core::cmp; - let mut result: usize = 0; - + /// Skip the next bytes. Return the length of bytes. + pub fn skip_bytes(&mut self) -> Result { match self.input.get_u8() { 0 => return Ok(0), // empty slice 1 => {} // non-empty slice v => return Err(Error::InvalidBytesEncoding(v)), } - + let mut total_len = 0; loop { - { - // calc advance - let mut offset = 0; - while offset < BYTES_CHUNK_SIZE { - let src = self.input.input.chunk(); - let cnt = cmp::min(src.len(), BYTES_CHUNK_SIZE - offset); - offset += cnt; - self.advance(cnt); - } - } - - let chunk_len = if self.input.flip { - !self.input.input.chunk()[0] - } else { - self.input.input.chunk()[0] - }; - self.advance(1); - - match chunk_len { - len @ 1..=8 => { - result += len as usize; - // self.advance(len as usize); - return Ok(result); - } - 9 => { - result += 8; - } + self.advance(BYTES_CHUNK_SIZE); + match self.input.get_u8() { + len @ 1..=8 => return Ok(total_len + len as usize), + 9 => total_len += 8, v => return Err(Error::InvalidBytesEncoding(v)), } } } - - /// Read decimal_len without copy, it will consume offset - pub fn read_decimal_len(&mut self) -> Result { - let mut len: usize = 0; - - let flag = self.input.get_u8(); - if !(DECIMAL_FLAG_LOW_BOUND..=DECIMAL_FLAG_UP_BOUND).contains(&flag) { - return Err(Error::InvalidBytesEncoding(flag)); - } - loop { - let byte = self.input.get_u8(); - if byte == 0 { - break; - } - - len += 1; - } - - Ok(len) - } - - /// Read struct_and_list without copy, it will consume offset - pub fn read_struct_and_list_len(&mut self) -> Result { - let len = self.input.get_u32() as usize; - self.advance(len); - Ok(len) - } } // Format Reference: @@ -585,69 +512,43 @@ impl<'de, 'a, B: Buf + 'de> VariantAccess<'de> for &'a mut Deserializer { } impl Deserializer { - /// Deserialize a decimal value. Returns `(mantissa, scale)`. - pub fn deserialize_decimal(&mut self) -> Result<(i128, u8)> { - let mut byte_array = self.read_decimal()?; - - // indicate the beginning position of mantissa in `byte_array`. - let mut begin: usize = 2; - // whether the decimal is negative or not. - let mut neg: bool = false; - let exponent = match byte_array[0] { - DECIMAL_FLAG_LOW_BOUND => { - // NaN - return Ok((0, 31)); - } - 0x07 => { - // Negative INF - return Ok((0, 29)); - } - 0x08 => { - neg = true; - !byte_array[1] as i8 - } - 0x09..=0x13 => { - begin -= 1; - neg = true; - (0x13 - byte_array[0]) as i8 - } - 0x14 => { - neg = true; - -(byte_array[1] as i8) - } - 0x15 => { - return Ok((0, 0)); - } - 0x16 => -!(byte_array[1] as i8), - 0x17..=0x21 => { - begin -= 1; - (byte_array[0] - 0x17) as i8 - } - 0x22 => byte_array[1] as i8, - DECIMAL_FLAG_UP_BOUND => { - // Positive INF - return Ok((0, 30)); - } - invalid_byte => { - return Err(Error::InvalidBytesEncoding(invalid_byte)); - } + /// Deserialize a decimal value. + #[cfg(feature = "decimal")] + pub fn deserialize_decimal(&mut self) -> Result { + // decode exponent + let flag = self.input.get_u8(); + let exponent = match flag { + 0x06 => return Ok(Decimal::NaN), + 0x07 => return Ok(Decimal::NegInf), + 0x08 => !self.input.get_u8() as i8, + 0x09..=0x13 => (0x13 - flag) as i8, + 0x14 => -(self.input.get_u8() as i8), + 0x15 => return Ok(Decimal::ZERO), + 0x16 => -!(self.input.get_u8() as i8), + 0x17..=0x21 => (flag - 0x17) as i8, + 0x22 => self.input.get_u8() as i8, + 0x23 => return Ok(Decimal::Inf), + b => return Err(Error::InvalidDecimalEncoding(b)), }; - if neg { - byte_array = byte_array.into_iter().map(|item| !item).collect(); - } - - // decode mantissa. + // decode mantissa + let neg = (0x07..0x15).contains(&flag); let mut mantissa: i128 = 0; - let bytes_len = byte_array.len() - begin; - let mut exp = bytes_len; - for item in byte_array.iter().skip(begin) { - exp -= 1; - mantissa += ((item - 1) / 2) as i128 * 100i128.pow(exp as u32); + let mut mlen = 0i8; + loop { + let mut b = self.input.get_u8(); + if neg { + b = !b; + } + let x = b / 2; + mantissa = mantissa * 100 + x as i128; + mlen += 1; + if b & 1 == 0 { + break; + } } - mantissa += 1; // get scale - let mut scale = (bytes_len as i8 - exponent) * 2; + let mut scale = (mlen - exponent) * 2; if scale <= 0 { // e.g. 1(mantissa) + 2(exponent) (which is 100). for _i in 0..-scale { @@ -664,44 +565,12 @@ impl Deserializer { if neg { mantissa = -mantissa; } - Ok((mantissa, scale as u8)) - } - - /// Deserialize a NaiveDateWrapper value. Returns `days`. - pub fn deserialize_naivedate(&mut self) -> Result { - let days = self.input.get_i32() ^ (1 << 31); - Ok(days) - } - - /// Deserialize a NaiveTimeWrapper value. Returns `(secs, nano)`. - pub fn deserialize_naivetime(&mut self) -> Result<(u32, u32)> { - let secs = self.input.get_u32(); - let nano = self.input.get_u32(); - Ok((secs, nano)) - } - - /// Deserialize a NaiveDateTimeWrapper value. Returns `(secs, nsecs)`. - pub fn deserialize_naivedatetime(&mut self) -> Result<(i64, u32)> { - let secs = self.input.get_i64() ^ (1 << 63); - let nsecs = self.input.get_u32(); - Ok((secs, nsecs)) - } - - /// Deserialize struct and list value. Returns `bytes`. - pub fn deserialize_struct_or_list(&mut self) -> Result> { - let len = self.input.get_u32(); - let mut bytes = vec![0; len as usize]; - self.input.copy_to_slice(&mut bytes); - Ok(bytes) + Ok(rust_decimal::Decimal::from_i128_with_scale(mantissa, scale as u32).into()) } } #[cfg(test)] mod tests { - use std::iter::zip; - use std::str::FromStr; - - use rust_decimal::Decimal; use serde::Deserialize; use super::*; @@ -844,102 +713,47 @@ mod tests { } #[test] + #[cfg(feature = "decimal")] fn test_decimal() { // Notice: decimals like 100.00 will be decoding as 100. - // Test: -1234_5678_9012_3456_7890_1234, -12_3456_7890.1234, -0.001, 0.001, 100, 0.01111, - // 12345, 1234_5678_9012_3456_7890_1234, -233.3, 50 - let mantissas: Vec = vec![ - -1234_5678_9012_3456_7890_1234, - -12_3456_7890_1234, - -1, - 1, - 100, - 1111, - 12345, - 1234_5678_9012_3456_7890_1234, - -2333, - 50, + let decimals = [ + "nan", + "-inf", + "-123456789012345678901234", + "-1234567890.1234", + "-233.3", + "-0.001", + "0", + "0.001", + "0.01111", + "50", + "100", + "12345", + "41721.900909090909090909090909", + "123456789012345678901234", + "inf", ]; - let scales: Vec = vec![0, 4, 3, 3, 0, 5, 0, 0, 1, 0]; - for (mantissa, scale) in zip(mantissas, scales) { - assert_eq!( - (mantissa, scale), - deserialize_decimal(&serialize_decimal(mantissa, scale)) - ); + let mut last_encoding = vec![]; + for s in decimals { + let decimal: Decimal = s.parse().unwrap(); + let encoding = serialize_decimal(decimal); + assert_eq!(deserialize_decimal(&encoding), decimal); + assert!(encoding > last_encoding); + last_encoding = encoding; } } - #[test] - fn test_decimal_2() { - let d = Decimal::from_str("41721.900909090909090909090909").unwrap(); - let (mantissa, scale) = (d.mantissa(), d.scale() as u8); - let (mantissa0, scale0) = deserialize_decimal(&serialize_decimal(mantissa, scale)); - assert_eq!((mantissa, scale), (mantissa0, scale0)); - } - - fn serialize_decimal(mantissa: i128, scale: u8) -> Vec { + #[cfg(feature = "decimal")] + fn serialize_decimal(decimal: impl Into) -> Vec { let mut serializer = crate::Serializer::new(vec![]); - serializer.serialize_decimal(mantissa, scale).unwrap(); + serializer.serialize_decimal(decimal.into()).unwrap(); serializer.into_inner() } - fn deserialize_decimal(bytes: &[u8]) -> (i128, u8) { + #[cfg(feature = "decimal")] + fn deserialize_decimal(bytes: &[u8]) -> Decimal { let mut deserializer = Deserializer::new(bytes); deserializer.deserialize_decimal().unwrap() } - - #[test] - fn test_naivedate() { - let days = 12_3456; - let days0 = deserialize_naivedate(&serialize_naivedate(days)); - assert_eq!(days, days0); - } - - fn serialize_naivedate(days: i32) -> Vec { - let mut serializer = crate::Serializer::new(vec![]); - serializer.serialize_naivedate(days).unwrap(); - serializer.into_inner() - } - - fn deserialize_naivedate(bytes: &[u8]) -> i32 { - let mut deserializer = Deserializer::new(bytes); - deserializer.deserialize_naivedate().unwrap() - } - - #[test] - fn test_naivetime() { - let (secs, nano) = (23 * 3600 + 59 * 60 + 59, 1234_5678); - let (secs0, nano0) = deserialize_naivetime(&serialize_naivetime(secs, nano)); - assert_eq!((secs, nano), (secs0, nano0)); - } - - fn serialize_naivetime(secs: u32, nano: u32) -> Vec { - let mut serializer = crate::Serializer::new(vec![]); - serializer.serialize_naivetime(secs, nano).unwrap(); - serializer.into_inner() - } - - fn deserialize_naivetime(bytes: &[u8]) -> (u32, u32) { - let mut deserializer = Deserializer::new(bytes); - deserializer.deserialize_naivetime().unwrap() - } - - #[test] - fn test_naivedatetime() { - let (secs, nsecs) = (12_3456_7890_1234, 1234_5678); - let (secs0, nsecs0) = deserialize_naivedatetime(&serialize_naivedatetime(secs, nsecs)); - assert_eq!((secs, nsecs), (secs0, nsecs0)); - } - - fn serialize_naivedatetime(secs: i64, nsecs: u32) -> Vec { - let mut serializer = crate::Serializer::new(vec![]); - serializer.serialize_naivedatetime(secs, nsecs).unwrap(); - serializer.into_inner() - } - - fn deserialize_naivedatetime(bytes: &[u8]) -> (i64, u32) { - let mut deserializer = Deserializer::new(bytes); - deserializer.deserialize_naivedatetime().unwrap() - } } diff --git a/src/utils/memcomparable/src/decimal.rs b/src/utils/memcomparable/src/decimal.rs new file mode 100644 index 0000000000000..3e1fe257370c9 --- /dev/null +++ b/src/utils/memcomparable/src/decimal.rs @@ -0,0 +1,52 @@ +// Copyright 2022 Singularity Data +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::str::FromStr; + +/// An extended decimal number with `NaN`, `-Inf` and `Inf`. +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum Decimal { + /// Not a Number. + NaN, + /// Negative infinity. + NegInf, + /// Normalized value. + Normalized(rust_decimal::Decimal), + /// Infinity. + Inf, +} + +impl Decimal { + /// A constant representing 0. + pub const ZERO: Self = Decimal::Normalized(rust_decimal::Decimal::ZERO); +} + +impl From for Decimal { + fn from(decimal: rust_decimal::Decimal) -> Self { + Decimal::Normalized(decimal) + } +} + +impl FromStr for Decimal { + type Err = rust_decimal::Error; + + fn from_str(s: &str) -> Result { + match s { + "nan" => Ok(Decimal::NaN), + "-inf" => Ok(Decimal::NegInf), + "inf" => Ok(Decimal::Inf), + _ => Ok(Decimal::Normalized(s.parse()?)), + } + } +} diff --git a/src/utils/memcomparable/src/error.rs b/src/utils/memcomparable/src/error.rs index ca662309984bf..f3585e3194bc4 100644 --- a/src/utils/memcomparable/src/error.rs +++ b/src/utils/memcomparable/src/error.rs @@ -42,6 +42,8 @@ pub enum Error { InvalidUtf8(#[from] std::string::FromUtf8Error), #[error("invalid bytes encoding: {0}")] InvalidBytesEncoding(u8), + #[error("invalid decimal encoding: {0}")] + InvalidDecimalEncoding(u8), #[error("trailing characters")] TrailingCharacters, #[error("invalid NaiveDate scope: days: {0}")] diff --git a/src/utils/memcomparable/src/lib.rs b/src/utils/memcomparable/src/lib.rs index 789e16e773de4..39a47cea711d0 100644 --- a/src/utils/memcomparable/src/lib.rs +++ b/src/utils/memcomparable/src/lib.rs @@ -16,13 +16,15 @@ //! compared with memcmp. #![deny(missing_docs)] -#![feature(lint_reasons)] -#![expect(clippy::doc_markdown, reason = "FIXME: later")] mod de; +#[cfg(feature = "decimal")] +mod decimal; mod error; mod ser; pub use de::{from_slice, Deserializer}; +#[cfg(feature = "decimal")] +pub use decimal::Decimal; pub use error::{Error, Result}; pub use ser::{to_vec, Serializer}; diff --git a/src/utils/memcomparable/src/ser.rs b/src/utils/memcomparable/src/ser.rs index 3dd681494b7b0..73449faba1354 100644 --- a/src/utils/memcomparable/src/ser.rs +++ b/src/utils/memcomparable/src/ser.rs @@ -15,6 +15,8 @@ use bytes::BufMut; use serde::{ser, Serialize}; +#[cfg(feature = "decimal")] +use crate::decimal::Decimal; use crate::error::{Error, Result}; /// A structure for serializing Rust values into a memcomparable bytes. @@ -74,10 +76,6 @@ impl MaybeFlip { def_method!(put_u64, u64); - def_method!(put_i32, i32); - - def_method!(put_i64, i64); - fn put_slice(&mut self, src: &[u8]) { for &val in src { let val = if self.flip { !val } else { val }; @@ -439,101 +437,114 @@ impl<'a, B: BufMut> ser::SerializeStructVariant for &'a mut Serializer { impl Serializer { /// Serialize a decimal value. /// - /// - `mantissa`: From `rust_decimal::Decimal::mantissa()`. A 96-bits signed integer. - /// - `scale`: From `rust_decimal::Decimal::scale()`. A power of 10 ranging from 0 to 28. - /// - /// The decimal will be encoded to 13 bytes. - pub fn serialize_decimal(&mut self, mantissa: i128, scale: u8) -> Result<()> { - // https://github.com/pingcap/tidb/blob/fec2938c1379270bf9939822c1abfe3d7244c174/types/mydecimal.go#L1133 - // https://sqlite.org/src4/doc/trunk/www/key_encoding.wiki - let (exponent, significand) = Serializer::::decimal_e_m(mantissa, scale); - let mut encoded_decimal = vec![]; - match mantissa { - 1.. => { - match exponent { - 11.. => { - encoded_decimal.push(0x22); - encoded_decimal.push(exponent as u8); - } - 0..=10 => { - encoded_decimal.push(0x17 + exponent as u8); - } - _ => { - encoded_decimal.push(0x16); - encoded_decimal.push(!(-exponent) as u8); - } - } - encoded_decimal.extend(significand.iter()); + /// The encoding format follows `SQLite`: + #[cfg(feature = "decimal")] + pub fn serialize_decimal(&mut self, decimal: Decimal) -> Result<()> { + let decimal = match decimal { + Decimal::NaN => { + self.output.put_u8(0x06); + return Ok(()); + } + Decimal::NegInf => { + self.output.put_u8(0x07); + return Ok(()); + } + Decimal::Inf => { + self.output.put_u8(0x23); + return Ok(()); + } + Decimal::Normalized(d) if d.is_zero() => { + self.output.put_u8(0x15); + return Ok(()); } - 0 => { - match scale { - 29 => { - // Negative INF - encoded_decimal.push(0x07); - } - 30 => { - // Positive INF - encoded_decimal.push(0x23); - } - 31 => { - // NaN - encoded_decimal.push(0x06); - } - _ => { - // 0 - // Maybe need to change. - encoded_decimal.push(0x15); - } + Decimal::Normalized(d) => d, + }; + let (exponent, significand) = Self::decimal_e_m(decimal); + if decimal.is_sign_positive() { + match exponent { + 11.. => { + self.output.put_u8(0x22); + self.output.put_u8(exponent as u8); + } + 0..=10 => { + self.output.put_u8(0x17 + exponent as u8); + } + _ => { + self.output.put_u8(0x16); + self.output.put_u8(!(-exponent) as u8); } } - _ => { - match exponent { - 11.. => { - encoded_decimal.push(0x8); - encoded_decimal.push(!exponent as u8); - } - 0..=10 => { - encoded_decimal.push(0x13 - exponent as u8); - } - _ => { - encoded_decimal.push(0x14); - encoded_decimal.push(-exponent as u8); - } + self.output.put_slice(&significand); + } else { + match exponent { + 11.. => { + self.output.put_u8(0x8); + self.output.put_u8(!exponent as u8); } - encoded_decimal.extend(significand.into_iter().map(|m| !m)); + 0..=10 => { + self.output.put_u8(0x13 - exponent as u8); + } + _ => { + self.output.put_u8(0x14); + self.output.put_u8(-exponent as u8); + } + } + for b in significand { + self.output.put_u8(!b); } } - // use 0x00 as the end marker. - encoded_decimal.push(0); - self.output.put_slice(&encoded_decimal); Ok(()) } - /// Get the exponent and byte_array form of mantissa. - pub fn decimal_e_m(mantissa: i128, scale: u8) -> (i8, Vec) { - if mantissa == 0 { + /// Get the exponent and significand mantissa from a decimal. + #[cfg(feature = "decimal")] + fn decimal_e_m(decimal: rust_decimal::Decimal) -> (i8, Vec) { + if decimal.is_zero() { return (0, vec![]); } - let prec = { - let mut abs_man = mantissa.abs(); - let mut cnt = 0; - while abs_man > 0 { - cnt += 1; - abs_man /= 10; - } - cnt - }; - let scale = scale as i32; + const POW10: [u128; 30] = [ + 1, + 10, + 100, + 1000, + 10000, + 100000, + 1000000, + 10000000, + 100000000, + 1000000000, + 10000000000, + 100000000000, + 1000000000000, + 10000000000000, + 100000000000000, + 1000000000000000, + 10000000000000000, + 100000000000000000, + 1000000000000000000, + 10000000000000000000, + 100000000000000000000, + 1000000000000000000000, + 10000000000000000000000, + 100000000000000000000000, + 1000000000000000000000000, + 10000000000000000000000000, + 100000000000000000000000000, + 1000000000000000000000000000, + 10000000000000000000000000000, + 100000000000000000000000000000, + ]; + let mut mantissa = decimal.mantissa().unsigned_abs(); + let prec = POW10.as_slice().partition_point(|&p| p <= mantissa); - let e10 = prec - scale; + let e10 = prec as i32 - decimal.scale() as i32; let e100 = if e10 >= 0 { (e10 + 1) / 2 } else { e10 / 2 }; // Maybe need to add a zero at the beginning. // e.g. 111.11 -> 2(exponent which is 100 based) + 0.011111(mantissa). // So, the `digit_num` of 111.11 will be 6. let mut digit_num = if e10 == 2 * e100 { prec } else { prec + 1 }; - let mut byte_array: Vec = vec![]; - let mut mantissa = mantissa.abs(); + let mut byte_array = Vec::with_capacity(16); // Remove trailing zero. while mantissa % 10 == 0 && mantissa != 0 { mantissa /= 10; @@ -545,6 +556,13 @@ impl Serializer { mantissa *= 10; // digit_num += 1; } + while mantissa >> 64 != 0 { + let byte = (mantissa % 100) as u8 * 2 + 1; + byte_array.push(byte); + mantissa /= 100; + } + // optimize for division + let mut mantissa = mantissa as u64; while mantissa != 0 { let byte = (mantissa % 100) as u8 * 2 + 1; byte_array.push(byte); @@ -555,41 +573,6 @@ impl Serializer { (e100 as i8, byte_array) } - - /// Serialize a NaiveDateWrapper value. - /// - /// - `days`: From `chrono::Datelike::num_days_from_ce()`. - pub fn serialize_naivedate(&mut self, days: i32) -> Result<()> { - self.output.put_i32(days ^ (1 << 31)); - Ok(()) - } - - /// Serialize a NaiveTimeWrapper value. - /// - /// - `secs`: From `chrono::Timelike::num_seconds_from_midnight()`. - /// - `nano`: From `chrono::Timelike::nanosecond()`. - pub fn serialize_naivetime(&mut self, secs: u32, nano: u32) -> Result<()> { - self.output.put_u32(secs); - self.output.put_u32(nano); - Ok(()) - } - - /// Serialize a NaiveDateTimeWrapper value. - /// - /// - `secs`: From `chrono::naive::NaiveDateTime::timestamp()`. - /// - `nsecs`: From `chrono::naive::NaiveDateTime::timestamp_subsec_nanos()`. - pub fn serialize_naivedatetime(&mut self, secs: i64, nsecs: u32) -> Result<()> { - self.output.put_i64(secs ^ (1 << 63)); - self.output.put_u32(nsecs); - Ok(()) - } - - /// Serialize bytes of ListValue or StructValue. - pub fn serialize_struct_or_list(&mut self, bytes: Vec) -> Result<()> { - self.output.put_u32(bytes.len() as u32); - self.output.put_slice(bytes.as_slice()); - Ok(()) - } } #[cfg(test)] @@ -774,20 +757,7 @@ mod tests { } #[test] - fn test_decimal() { - let a = serialize_decimal(12_3456_7890_1234, 4); - let b = serialize_decimal(0, 4); - let c = serialize_decimal(-12_3456_7890_1234, 4); - assert!(a > b && b > c); - } - - fn serialize_decimal(mantissa: i128, scale: u8) -> Vec { - let mut serializer = Serializer::new(vec![]); - serializer.serialize_decimal(mantissa, scale).unwrap(); - serializer.into_inner() - } - - #[test] + #[cfg(feature = "decimal")] fn test_decimal_e_m() { // from: https://sqlite.org/src4/doc/trunk/www/key_encoding.wiki let cases = vec![ @@ -827,7 +797,7 @@ mod tests { for (decimal, exponents, significand) in cases { let d = decimal.parse::().unwrap(); - let (exp, sig) = Serializer::>::decimal_e_m(d.mantissa(), d.scale() as u8); + let (exp, sig) = Serializer::>::decimal_e_m(d); assert_eq!(exp, exponents, "wrong exponents for decimal: {decimal}"); assert_eq!( sig.iter() @@ -840,49 +810,6 @@ mod tests { } } - #[test] - fn test_naivedate() { - let a = serialize_naivedate(12_3456); - let b = serialize_naivedate(0); - let c = serialize_naivedate(-12_3456); - assert!(a > b && b > c); - } - - fn serialize_naivedate(days: i32) -> Vec { - let mut serializer = Serializer::new(vec![]); - serializer.serialize_naivedate(days).unwrap(); - serializer.into_inner() - } - - #[test] - fn test_naivetime() { - let a = serialize_naivetime(23 * 3600 + 59 * 60 + 59, 1234_5678); - let b = serialize_naivetime(12 * 3600, 1); - let c = serialize_naivetime(12 * 3600, 0); - let d = serialize_naivetime(0, 0); - assert!(a > b && b > c && c > d); - } - - fn serialize_naivetime(secs: u32, nano: u32) -> Vec { - let mut serializer = Serializer::new(vec![]); - serializer.serialize_naivetime(secs, nano).unwrap(); - serializer.into_inner() - } - - #[test] - fn test_naivedatetime() { - let a = serialize_naivedatetime(12_3456_7890_1234, 1234_5678); - let b = serialize_naivedatetime(0, 0); - let c = serialize_naivedatetime(-12_3456_7890_1234, 1234_5678); - assert!(a > b && b > c); - } - - fn serialize_naivedatetime(secs: i64, nsecs: u32) -> Vec { - let mut serializer = Serializer::new(vec![]); - serializer.serialize_naivedatetime(secs, nsecs).unwrap(); - serializer.into_inner() - } - #[test] fn test_reverse_order() { // Order: (ASC, DESC) diff --git a/src/workspace-hack/Cargo.toml b/src/workspace-hack/Cargo.toml index f49b14c535d7d..50fd4a57551c8 100644 --- a/src/workspace-hack/Cargo.toml +++ b/src/workspace-hack/Cargo.toml @@ -79,7 +79,7 @@ regex = { version = "1", features = ["aho-corasick", "memchr", "perf", "perf-cac regex-syntax = { version = "0.6", features = ["unicode", "unicode-age", "unicode-bool", "unicode-case", "unicode-gencat", "unicode-perl", "unicode-script", "unicode-segment"] } reqwest = { version = "0.11", features = ["__tls", "default-tls", "hyper-tls", "json", "native-tls-crate", "serde_json", "tokio-native-tls"] } ring = { version = "0.16", features = ["alloc", "dev_urandom_fallback", "once_cell", "std"] } -rust_decimal = { version = "1", features = ["byteorder", "bytes", "db-tokio-postgres", "postgres", "serde", "std", "tokio-postgres"] } +rust_decimal = { version = "1", features = ["byteorder", "bytes", "db-tokio-postgres", "postgres", "rand", "serde", "std", "tokio-postgres"] } scopeguard = { version = "1", features = ["use_std"] } serde = { version = "1", features = ["alloc", "derive", "rc", "serde_derive", "std"] } smallvec = { version = "1", default-features = false, features = ["serde", "union", "write"] } @@ -166,7 +166,7 @@ regex = { version = "1", features = ["aho-corasick", "memchr", "perf", "perf-cac regex-syntax = { version = "0.6", features = ["unicode", "unicode-age", "unicode-bool", "unicode-case", "unicode-gencat", "unicode-perl", "unicode-script", "unicode-segment"] } reqwest = { version = "0.11", features = ["__tls", "default-tls", "hyper-tls", "json", "native-tls-crate", "serde_json", "tokio-native-tls"] } ring = { version = "0.16", features = ["alloc", "dev_urandom_fallback", "once_cell", "std"] } -rust_decimal = { version = "1", features = ["byteorder", "bytes", "db-tokio-postgres", "postgres", "serde", "std", "tokio-postgres"] } +rust_decimal = { version = "1", features = ["byteorder", "bytes", "db-tokio-postgres", "postgres", "rand", "serde", "std", "tokio-postgres"] } scopeguard = { version = "1", features = ["use_std"] } serde = { version = "1", features = ["alloc", "derive", "rc", "serde_derive", "std"] } smallvec = { version = "1", default-features = false, features = ["serde", "union", "write"] }