diff --git a/dozer-sql/expression/src/aggregate.rs b/dozer-sql/expression/src/aggregate.rs index d136358979..2eef718c7e 100644 --- a/dozer-sql/expression/src/aggregate.rs +++ b/dozer-sql/expression/src/aggregate.rs @@ -1,6 +1,9 @@ use std::fmt::{Display, Formatter}; -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] +use dozer_types::serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash, Serialize, Deserialize)] +#[serde(crate = "dozer_types::serde")] pub enum AggregateFunctionType { Avg, Count, diff --git a/dozer-sql/src/aggregation/aggregator.rs b/dozer-sql/src/aggregation/aggregator.rs index 14fef33f84..31125dd01f 100644 --- a/dozer-sql/src/aggregation/aggregator.rs +++ b/dozer-sql/src/aggregation/aggregator.rs @@ -5,7 +5,11 @@ use crate::aggregation::count::CountAggregator; use crate::aggregation::max::MaxAggregator; use crate::aggregation::min::MinAggregator; use crate::aggregation::sum::SumAggregator; +use crate::calculate_err; use crate::errors::PipelineError; +use dozer_types::chrono::{DateTime, FixedOffset, NaiveDate}; +use dozer_types::ordered_float::OrderedFloat; +use dozer_types::rust_decimal::Decimal; use dozer_types::serde::de::DeserializeOwned; use dozer_types::serde::{Deserialize, Serialize}; use enum_dispatch::enum_dispatch; @@ -18,7 +22,7 @@ use crate::aggregation::max_value::MaxValueAggregator; use crate::aggregation::min_value::MinValueAggregator; use crate::errors::PipelineError::{InvalidFunctionArgument, InvalidValue}; use dozer_sql_expression::aggregate::AggregateFunctionType::MaxValue; -use dozer_types::types::{Field, FieldType, Schema}; +use dozer_types::types::{DozerDuration, Field, FieldType, Schema}; use std::fmt::{Debug, Display, Formatter}; #[enum_dispatch] @@ -53,6 +57,185 @@ pub enum AggregatorType { Sum, } +#[derive(Debug, Serialize, Deserialize)] +#[serde(crate = "dozer_types::serde")] +pub(crate) struct OrderedAggregatorState { + function_type: AggregateFunctionType, + inner: OrderedAggregatorStateInner, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(crate = "dozer_types::serde")] +enum OrderedAggregatorStateInner { + UInt(BTreeMap), + U128(BTreeMap), + Int(BTreeMap), + I128(BTreeMap), + Float(BTreeMap, u64>), + Decimal(BTreeMap), + Timestamp(BTreeMap, u64>), + Date(BTreeMap), + Duration(BTreeMap), +} + +impl OrderedAggregatorState { + fn max_in_map(map: &BTreeMap) -> Option { + let (value, _count) = map.last_key_value()?; + Some(value.clone()) + } + fn min_in_map(map: &BTreeMap) -> Option { + Some(map.first_key_value()?.0.clone()) + } + + fn update_for_map( + map: &mut BTreeMap, + for_value: T, + incr: bool, + ) { + let amount = map.entry(for_value.clone()).or_insert(0); + if incr { + *amount += 1; + } else { + *amount -= 1; + } + if *amount == 0 { + map.remove(&for_value); + } + } + + fn update(&mut self, for_value: &Field, incr: bool) -> Result<(), PipelineError> { + if for_value == &Field::Null { + return Ok(()); + } + match &mut self.inner { + OrderedAggregatorStateInner::UInt(map) => Self::update_for_map( + map, + calculate_err!(for_value.as_uint(), self.function_type), + incr, + ), + OrderedAggregatorStateInner::U128(map) => Self::update_for_map( + map, + calculate_err!(for_value.as_u128(), self.function_type), + incr, + ), + + OrderedAggregatorStateInner::Int(map) => Self::update_for_map( + map, + calculate_err!(for_value.as_int(), self.function_type), + incr, + ), + + OrderedAggregatorStateInner::I128(map) => Self::update_for_map( + map, + calculate_err!(for_value.as_i128(), self.function_type), + incr, + ), + + OrderedAggregatorStateInner::Float(map) => Self::update_for_map( + map, + OrderedFloat(calculate_err!(for_value.as_float(), self.function_type)), + incr, + ), + + OrderedAggregatorStateInner::Decimal(map) => Self::update_for_map( + map, + calculate_err!(for_value.as_decimal(), self.function_type), + incr, + ), + + OrderedAggregatorStateInner::Timestamp(map) => Self::update_for_map( + map, + calculate_err!(for_value.as_timestamp(), self.function_type), + incr, + ), + + OrderedAggregatorStateInner::Date(map) => Self::update_for_map( + map, + calculate_err!(for_value.as_date(), self.function_type), + incr, + ), + + OrderedAggregatorStateInner::Duration(map) => Self::update_for_map( + map, + calculate_err!(for_value.as_duration(), self.function_type), + incr, + ), + } + Ok(()) + } + + #[inline] + pub(crate) fn incr(&mut self, value: &Field) -> Result<(), PipelineError> { + self.update(value, true)?; + Ok(()) + } + + #[inline] + pub(crate) fn decr(&mut self, value: &Field) -> Result<(), PipelineError> { + self.update(value, false)?; + Ok(()) + } + + pub(crate) fn new(function_type: AggregateFunctionType, field_type: FieldType) -> Option { + let inner = match field_type { + FieldType::UInt => OrderedAggregatorStateInner::UInt(Default::default()), + FieldType::U128 => OrderedAggregatorStateInner::U128(Default::default()), + FieldType::Int => OrderedAggregatorStateInner::Int(Default::default()), + FieldType::I128 => OrderedAggregatorStateInner::I128(Default::default()), + FieldType::Float => OrderedAggregatorStateInner::Float(Default::default()), + FieldType::Decimal => OrderedAggregatorStateInner::Decimal(Default::default()), + FieldType::Timestamp => OrderedAggregatorStateInner::Timestamp(Default::default()), + FieldType::Date => OrderedAggregatorStateInner::Date(Default::default()), + FieldType::Duration => OrderedAggregatorStateInner::Duration(Default::default()), + _ => return None, + }; + Some(Self { + function_type, + inner, + }) + } + + fn get_min_opt(&self) -> Option { + let field = match &self.inner { + OrderedAggregatorStateInner::UInt(map) => Field::UInt(Self::min_in_map(map)?), + OrderedAggregatorStateInner::U128(map) => Field::U128(Self::min_in_map(map)?), + OrderedAggregatorStateInner::Int(map) => Field::Int(Self::min_in_map(map)?), + OrderedAggregatorStateInner::I128(map) => Field::I128(Self::min_in_map(map)?), + OrderedAggregatorStateInner::Float(map) => Field::Float(Self::min_in_map(map)?), + OrderedAggregatorStateInner::Decimal(map) => Field::Decimal(Self::min_in_map(map)?), + OrderedAggregatorStateInner::Timestamp(map) => Field::Timestamp(Self::min_in_map(map)?), + OrderedAggregatorStateInner::Date(map) => Field::Date(Self::min_in_map(map)?), + OrderedAggregatorStateInner::Duration(map) => Field::Duration(Self::min_in_map(map)?), + }; + Some(field) + } + + #[inline] + pub(crate) fn get_min(&self) -> Field { + self.get_min_opt().unwrap_or(Field::Null) + } + + fn get_max_opt(&self) -> Option { + let field = match &self.inner { + OrderedAggregatorStateInner::UInt(map) => Field::UInt(Self::max_in_map(map)?), + OrderedAggregatorStateInner::U128(map) => Field::U128(Self::max_in_map(map)?), + OrderedAggregatorStateInner::Int(map) => Field::Int(Self::max_in_map(map)?), + OrderedAggregatorStateInner::I128(map) => Field::I128(Self::max_in_map(map)?), + OrderedAggregatorStateInner::Float(map) => Field::Float(Self::max_in_map(map)?), + OrderedAggregatorStateInner::Decimal(map) => Field::Decimal(Self::max_in_map(map)?), + OrderedAggregatorStateInner::Timestamp(map) => Field::Timestamp(Self::max_in_map(map)?), + OrderedAggregatorStateInner::Date(map) => Field::Date(Self::max_in_map(map)?), + OrderedAggregatorStateInner::Duration(map) => Field::Duration(Self::max_in_map(map)?), + }; + Some(field) + } + + #[inline] + pub(crate) fn get_max(&self) -> Field { + self.get_max_opt().unwrap_or(Field::Null) + } +} + impl Display for AggregatorType { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { @@ -192,40 +375,6 @@ pub fn get_aggregator_type_from_aggregation_expression( } } -pub fn update_map( - fields: &[Field], - val_delta: u64, - decr: bool, - field_map: &mut BTreeMap, -) { - for field in fields { - if field == &Field::Null { - continue; - } - - let get_prev_count = field_map.get(field); - let prev_count = match get_prev_count { - Some(v) => *v, - None => 0_u64, - }; - let mut new_count = prev_count; - if decr { - new_count = new_count.wrapping_sub(val_delta); - } else { - new_count = new_count.wrapping_add(val_delta); - } - if new_count < 1 { - field_map.remove(field); - } else if field_map.contains_key(field) { - if let Some(val) = field_map.get_mut(field) { - *val = new_count; - } - } else { - field_map.insert(field.clone(), new_count); - } - } -} - pub fn update_val_map( fields: &[Field], val_delta: u64, diff --git a/dozer-sql/src/aggregation/max.rs b/dozer-sql/src/aggregation/max.rs index a3aade5dd4..d0d951ee4c 100644 --- a/dozer-sql/src/aggregation/max.rs +++ b/dozer-sql/src/aggregation/max.rs @@ -1,23 +1,22 @@ -use crate::aggregation::aggregator::{update_map, Aggregator}; +use crate::aggregation::aggregator::Aggregator; use crate::errors::PipelineError; -use crate::{calculate_err, calculate_err_field}; use dozer_sql_expression::aggregate::AggregateFunctionType::Max; -use dozer_types::ordered_float::OrderedFloat; use dozer_types::serde::{Deserialize, Serialize}; use dozer_types::types::{Field, FieldType}; -use std::collections::BTreeMap; + +use super::aggregator::OrderedAggregatorState; #[derive(Debug, Serialize, Deserialize)] #[serde(crate = "dozer_types::serde")] pub struct MaxAggregator { - current_state: BTreeMap, + current_state: Option, return_type: Option, } impl MaxAggregator { pub fn new() -> Self { Self { - current_state: BTreeMap::new(), + current_state: None, return_type: None, } } @@ -26,6 +25,7 @@ impl MaxAggregator { impl Aggregator for MaxAggregator { fn init(&mut self, return_type: FieldType) { self.return_type = Some(return_type); + self.current_state = OrderedAggregatorState::new(Max, return_type); } fn update(&mut self, old: &[Field], new: &[Field]) -> Result { @@ -34,63 +34,39 @@ impl Aggregator for MaxAggregator { } fn delete(&mut self, old: &[Field]) -> Result { - update_map(old, 1_u64, true, &mut self.current_state); - get_max(&self.current_state, self.return_type) + let state = self.get_state()?; + for field in old { + state.decr(field)?; + } + Ok(state.get_max()) } fn insert(&mut self, new: &[Field]) -> Result { - update_map(new, 1_u64, false, &mut self.current_state); - get_max(&self.current_state, self.return_type) + let state = self.get_state()?; + for field in new { + state.incr(field)?; + } + Ok(state.get_max()) } } -fn get_max( - field_map: &BTreeMap, - return_type: Option, -) -> Result { - if field_map.is_empty() { - Ok(Field::Null) - } else { - let val = calculate_err!(field_map.keys().max(), Max).clone(); - match return_type { - Some(typ) => match typ { - FieldType::UInt => Ok(Field::UInt(calculate_err_field!(val.to_uint(), Max, val))), - FieldType::U128 => Ok(Field::U128(calculate_err_field!(val.to_u128(), Max, val))), - FieldType::Int => Ok(Field::Int(calculate_err_field!(val.to_int(), Max, val))), - FieldType::I128 => Ok(Field::I128(calculate_err_field!(val.to_i128(), Max, val))), - FieldType::Float => Ok(Field::Float(OrderedFloat::from(calculate_err_field!( - val.to_float(), - Max, - val - )))), - FieldType::Decimal => Ok(Field::Decimal(calculate_err_field!( - val.to_decimal(), - Max, - val - ))), - FieldType::Timestamp => Ok(Field::Timestamp(calculate_err_field!( - val.to_timestamp(), - Max, - val - ))), - FieldType::Date => Ok(Field::Date(calculate_err_field!(val.to_date(), Max, val))), - FieldType::Duration => Ok(Field::Duration(calculate_err_field!( - val.to_duration(), - Max, - val - ))), - FieldType::Boolean +impl MaxAggregator { + fn get_state(&mut self) -> Result<&mut OrderedAggregatorState, PipelineError> { + self.current_state.as_mut().ok_or_else(|| { + match self + .return_type + .expect("MaxAggregator processor not initialized") + { + typ @ (FieldType::Boolean | FieldType::String | FieldType::Text | FieldType::Binary | FieldType::Json - | FieldType::Point => Err(PipelineError::InvalidReturnType(format!( + | FieldType::Point) => PipelineError::InvalidReturnType(format!( "Not supported return type {typ} for {Max}" - ))), - }, - None => Err(PipelineError::InvalidReturnType(format!( - "Not supported None return type for {Max}" - ))), - } + )), + _ => panic!("MaxAggregator processor not correctly initialized"), + } + }) } } diff --git a/dozer-sql/src/aggregation/min.rs b/dozer-sql/src/aggregation/min.rs index 294486614e..7a535c73bf 100644 --- a/dozer-sql/src/aggregation/min.rs +++ b/dozer-sql/src/aggregation/min.rs @@ -1,23 +1,22 @@ -use crate::aggregation::aggregator::{update_map, Aggregator}; +use crate::aggregation::aggregator::Aggregator; use crate::errors::PipelineError; -use crate::{calculate_err, calculate_err_field}; -use dozer_sql_expression::aggregate::AggregateFunctionType::Min; -use dozer_types::ordered_float::OrderedFloat; +use dozer_sql_expression::aggregate::AggregateFunctionType::{self, Min}; use dozer_types::serde::{Deserialize, Serialize}; use dozer_types::types::{Field, FieldType}; -use std::collections::BTreeMap; + +use super::aggregator::OrderedAggregatorState; #[derive(Debug, Serialize, Deserialize)] #[serde(crate = "dozer_types::serde")] pub struct MinAggregator { - current_state: BTreeMap, + current_state: Option, return_type: Option, } impl MinAggregator { pub fn new() -> Self { Self { - current_state: BTreeMap::new(), + current_state: None, return_type: None, } } @@ -25,6 +24,7 @@ impl MinAggregator { impl Aggregator for MinAggregator { fn init(&mut self, return_type: FieldType) { + self.current_state = OrderedAggregatorState::new(AggregateFunctionType::Min, return_type); self.return_type = Some(return_type); } @@ -34,63 +34,39 @@ impl Aggregator for MinAggregator { } fn delete(&mut self, old: &[Field]) -> Result { - update_map(old, 1_u64, true, &mut self.current_state); - get_min(&self.current_state, self.return_type) + let state = self.get_state()?; + for field in old { + state.decr(field)?; + } + Ok(state.get_min()) } fn insert(&mut self, new: &[Field]) -> Result { - update_map(new, 1_u64, false, &mut self.current_state); - get_min(&self.current_state, self.return_type) + let state = self.get_state()?; + for field in new { + state.incr(field)?; + } + Ok(state.get_min()) } } -fn get_min( - field_map: &BTreeMap, - return_type: Option, -) -> Result { - if field_map.is_empty() { - Ok(Field::Null) - } else { - let val = calculate_err!(field_map.keys().min(), Min).clone(); - match return_type { - Some(typ) => match typ { - FieldType::UInt => Ok(Field::UInt(calculate_err_field!(val.to_uint(), Min, val))), - FieldType::U128 => Ok(Field::U128(calculate_err_field!(val.to_u128(), Min, val))), - FieldType::Int => Ok(Field::Int(calculate_err_field!(val.to_int(), Min, val))), - FieldType::I128 => Ok(Field::I128(calculate_err_field!(val.to_i128(), Min, val))), - FieldType::Float => Ok(Field::Float(OrderedFloat::from(calculate_err_field!( - val.to_float(), - Min, - val - )))), - FieldType::Decimal => Ok(Field::Decimal(calculate_err_field!( - val.to_decimal(), - Min, - val - ))), - FieldType::Timestamp => Ok(Field::Timestamp(calculate_err_field!( - val.to_timestamp(), - Min, - val - ))), - FieldType::Date => Ok(Field::Date(calculate_err_field!(val.to_date(), Min, val))), - FieldType::Duration => Ok(Field::Duration(calculate_err_field!( - val.to_duration(), - Min, - val - ))), - FieldType::Boolean +impl MinAggregator { + fn get_state(&mut self) -> Result<&mut OrderedAggregatorState, PipelineError> { + self.current_state.as_mut().ok_or_else(|| { + match self + .return_type + .expect("MinAggregator processor not initialized") + { + typ @ (FieldType::Boolean | FieldType::String | FieldType::Text | FieldType::Binary | FieldType::Json - | FieldType::Point => Err(PipelineError::InvalidReturnType(format!( + | FieldType::Point) => PipelineError::InvalidReturnType(format!( "Not supported return type {typ} for {Min}" - ))), - }, - None => Err(PipelineError::InvalidReturnType(format!( - "Not supported None return type for {Min}" - ))), - } + )), + _ => panic!("MinAggregator processor not correctly initialized"), + } + }) } } diff --git a/dozer-sql/src/aggregation/tests/aggregation_min_tests.rs b/dozer-sql/src/aggregation/tests/aggregation_min_tests.rs index 25121ae07f..d1ff5d2a16 100644 --- a/dozer-sql/src/aggregation/tests/aggregation_min_tests.rs +++ b/dozer-sql/src/aggregation/tests/aggregation_min_tests.rs @@ -197,7 +197,7 @@ fn test_min_aggregation_int() { ------------- MIN = 50.0 */ - inp = update_field(ITALY, ITALY, FIELD_100_INT, FIELD_50_INT); + inp = update_field(ITALY, ITALY, FIELD_100_INT, FIELD_200_INT); out = output!(processor, inp); exp = vec![update_exp(ITALY, ITALY, FIELD_50_INT, FIELD_50_INT)]; assert_eq!(out, exp); @@ -222,7 +222,7 @@ fn test_min_aggregation_int() { */ inp = delete_field(ITALY, FIELD_50_INT); out = output!(processor, inp); - exp = vec![update_exp(ITALY, ITALY, FIELD_50_INT, FIELD_50_INT)]; + exp = vec![update_exp(ITALY, ITALY, FIELD_50_INT, FIELD_100_INT)]; assert_eq!(out, exp); // Delete last record @@ -232,7 +232,7 @@ fn test_min_aggregation_int() { */ inp = delete_field(ITALY, FIELD_100_INT); out = output!(processor, inp); - exp = vec![delete_exp(ITALY, FIELD_50_INT)]; + exp = vec![delete_exp(ITALY, FIELD_100_INT)]; assert_eq!(out, exp); } @@ -310,7 +310,7 @@ fn test_min_aggregation_uint() { ------------- MIN = 50.0 */ - inp = update_field(ITALY, ITALY, FIELD_100_UINT, FIELD_50_UINT); + inp = update_field(ITALY, ITALY, FIELD_100_UINT, FIELD_200_UINT); out = output!(processor, inp); exp = vec![update_exp(ITALY, ITALY, FIELD_50_UINT, FIELD_50_UINT)]; assert_eq!(out, exp); @@ -335,7 +335,7 @@ fn test_min_aggregation_uint() { */ inp = delete_field(ITALY, FIELD_50_UINT); out = output!(processor, inp); - exp = vec![update_exp(ITALY, ITALY, FIELD_50_UINT, FIELD_50_UINT)]; + exp = vec![update_exp(ITALY, ITALY, FIELD_50_UINT, FIELD_100_UINT)]; assert_eq!(out, exp); // Delete last record @@ -345,7 +345,7 @@ fn test_min_aggregation_uint() { */ inp = delete_field(ITALY, FIELD_100_UINT); out = output!(processor, inp); - exp = vec![delete_exp(ITALY, FIELD_50_UINT)]; + exp = vec![delete_exp(ITALY, FIELD_100_UINT)]; assert_eq!(out, exp); }