Skip to content

Commit

Permalink
quick & dirty
Browse files Browse the repository at this point in the history
  • Loading branch information
korowa committed Dec 19, 2023
1 parent fc6cc48 commit 602e097
Showing 1 changed file with 141 additions and 3 deletions.
144 changes: 141 additions & 3 deletions datafusion/physical-expr/src/aggregate/count_distinct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,26 @@
// under the License.

use arrow::datatypes::{DataType, Field};
use arrow_array::{ArrowNumericType, UInt64Array, PrimitiveArray, Int64Array, StructArray, ArrayAccessor, ListArray};
use arrow_array::types::{ArrowPrimitiveType, Int64Type, Int32Type, UInt64Type};
use arrow_buffer::ArrowNativeType;
use hashbrown::HashMap;
use itertools::Itertools;

use std::any::Any;
use std::fmt::Debug;
use std::sync::Arc;

use ahash::RandomState;
use arrow::array::{Array, ArrayRef};
use arrow::array::{Array, ArrayRef, AsArray};
use std::collections::HashSet;

use crate::aggregate::utils::down_cast_any_ref;
use crate::expressions::format_state_name;
use crate::{AggregateExpr, PhysicalExpr};
use datafusion_common::Result;
use crate::{AggregateExpr, PhysicalExpr, GroupsAccumulator};
use datafusion_common::{DataFusionError, Result, not_impl_err};
use datafusion_common::ScalarValue;
use datafusion_common::hash_utils::create_hashes;
use datafusion_expr::Accumulator;

type DistinctScalarValues = ScalarValue;
Expand Down Expand Up @@ -71,6 +77,13 @@ impl AggregateExpr for DistinctCount {
}

fn state_fields(&self) -> Result<Vec<Field>> {
if self.groups_accumulator_supported() {
return Ok(vec![Field::new_list(
format_state_name(&self.name, "count distinct"),
Field::new("item", DataType::UInt64, true),
false,
)])
}
Ok(vec![Field::new_list(
format_state_name(&self.name, "count distinct"),
Field::new("item", self.state_data_type.clone(), true),
Expand All @@ -82,13 +95,44 @@ impl AggregateExpr for DistinctCount {
vec![self.expr.clone()]
}

fn groups_accumulator_supported(&self) -> bool {
use DataType::*;
matches!(
self.state_data_type,
Int8 | Int16
| Int32
| Int64
// | UInt8
// | UInt16
// | UInt32
// | UInt64
// | Float32
// | Float64
// | Decimal128(_, _)
// | Decimal256(_, _)
// | Date32
// | Date64
// | Time32(_)
// | Time64(_)
// | Timestamp(_, _)
)
}

fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(DistinctCountAccumulator {
values: HashSet::default(),
state_data_type: self.state_data_type.clone(),
}))
}

fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
Ok(Box::new(DictinctCountGroupsAccumulator::<Int64Type>{
unique_values: HashMap::new(),
unique_counts: vec![],
dummy: Int64Type{},
}))
}

fn name(&self) -> &str {
&self.name
}
Expand Down Expand Up @@ -192,6 +236,100 @@ impl Accumulator for DistinctCountAccumulator {
}
}

#[derive(PartialEq)]
struct AccumulatorKey<T> {
group_idx: usize,
value: T
}

struct DictinctCountGroupsAccumulator<T> {
unique_values: HashMap<usize, HashSet<usize>>,
unique_counts: Vec<i64>,
dummy: T,
}

impl<T> GroupsAccumulator for DictinctCountGroupsAccumulator<T>
where T: Send + ArrowNumericType {
fn update_batch(
&mut self,
values: &[ArrayRef],
group_indices: &[usize],
opt_filter: Option<&arrow_array::BooleanArray>,
total_num_groups: usize,
) -> Result<()> {
let values = values[0].as_primitive::<T>();

self.unique_counts.resize(total_num_groups, 0);

group_indices.iter().zip(values.iter()).for_each(|(group_idx, value)| {
if let Some(value) = value {
let inserted = if let Some(set) = self.unique_values.get_mut(group_idx) {
set.insert(value.as_usize())
} else {
let mut set = HashSet::new();
let inserted = set.insert(value.as_usize());
self.unique_values.insert(*group_idx, set);
inserted
};
if inserted {
self.unique_counts[*group_idx] = self.unique_counts[*group_idx] + 1;
};
}
});

Ok(())
}

fn merge_batch(
&mut self,
values: &[ArrayRef],
group_indices: &[usize],
opt_filter: Option<&arrow_array::BooleanArray>,
total_num_groups: usize,
) -> Result<()> {
//assert_eq!(values.len(), 1, "one argument to merge_batch");
let values = values[0].as_list::<i32>();

self.unique_counts.resize(total_num_groups, 0);

group_indices.iter().zip(values.iter()).for_each(|(group_idx, maybe_set)| {
let incoming_set = HashSet::from_iter(maybe_set.unwrap().as_primitive::<UInt64Type>().iter().map(|entry| entry.unwrap() as usize));
if let Some(current_set) = self.unique_values.get_mut(group_idx) {
current_set.extend(incoming_set.iter());
self.unique_counts[*group_idx] = current_set.len() as i64;
} else {
self.unique_counts[*group_idx] = incoming_set.len() as i64;
self.unique_values.insert(*group_idx, incoming_set);
}
});

Ok(())
}

fn evaluate(&mut self, emit_to: crate::EmitTo) -> Result<ArrayRef> {
return Ok(Arc::new(Int64Array::from(self.unique_counts.clone())))
}

fn state(&mut self, emit_to: crate::EmitTo) -> Result<Vec<ArrayRef>> {
let state_array = ListArray::from_iter_primitive::<UInt64Type, _, _>(
(0..self.unique_counts.len()).map(|group_idx| {
if let Some(val) = self.unique_values.get(&group_idx) {
Some(val.iter().map(|entry| Some(*entry as u64)).collect_vec())
} else {
None
}
}));

Ok(vec![
Arc::new(state_array) as ArrayRef,
])
}

fn size(&self) -> usize {
return 0
}
}

#[cfg(test)]
mod tests {
use crate::expressions::NoOp;
Expand Down

0 comments on commit 602e097

Please sign in to comment.