diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index 9251060da2f1d..3cdabd03011c7 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -23,6 +23,7 @@ use std::{fmt::Debug, sync::Arc}; use arrow::{ array::{ArrayRef, AsArray}, + compute, datatypes::{ DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Field, Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, @@ -34,7 +35,7 @@ use arrow::{ }; use arrow::{ - array::{Array, BooleanArray, Int64Array, Int64Builder, PrimitiveArray}, + array::{Array, BooleanArray, Int64Array, PrimitiveArray}, buffer::BooleanBuffer, }; use datafusion_common::{ @@ -445,35 +446,52 @@ impl GroupsAccumulator for CountGroupsAccumulator { let values = &values[0]; let state_array = match (values.logical_nulls(), opt_filter) { - (Some(nulls), None) => { - let mut builder = Int64Builder::with_capacity(values.len()); - nulls - .into_iter() - .for_each(|is_valid| builder.append_value(is_valid as i64)); - builder.finish() + (None, None) => { + // In case there is no nulls in input and no filter, returning array of 1 + Arc::new(Int64Array::from_value(1, values.len())) } - (Some(nulls), Some(filter)) => { - let mut builder = Int64Builder::with_capacity(values.len()); - nulls.into_iter().zip(filter.iter()).for_each( - |(is_valid, filter_value)| { - builder.append_value( - (is_valid && filter_value.is_some_and(|val| val)) as i64, - ) - }, - ); - builder.finish() + (Some(nulls), None) => { + // If there are any nulls in input values -- casting `nulls` (true for values, false for nulls) + // of input array to Int64 + let nulls = BooleanArray::new(nulls.into_inner(), None); + compute::cast(&nulls, &DataType::Int64)? } (None, Some(filter)) => { - let mut builder = Int64Builder::with_capacity(values.len()); - filter.into_iter().for_each(|filter_value| { - builder.append_value(filter_value.is_some_and(|val| val) as i64) - }); - builder.finish() + // If there is only filter + // - applying filter null mask to filter values by bitand filter values and nulls buffers + // (using buffers guarantees absence of nulls in result) + // - casting result of bitand to Int64 array + let (filter_values, filter_nulls) = filter.clone().into_parts(); + + let state_buf = match filter_nulls { + Some(filter_nulls) => &filter_values & filter_nulls.inner(), + None => filter_values, + }; + + let boolean_state = BooleanArray::new(state_buf, None); + compute::cast(&boolean_state, &DataType::Int64)? + } + (Some(nulls), Some(filter)) => { + // For both input nulls and filter + // - applying filter null mask to filter values by bitand filter values and nulls buffers + // (using buffers guarantees absence of nulls in result) + // - applying values null mask to filter buffer by another bitand on filter result and + // nulls from input values + // - casting result to Int64 array + let (filter_values, filter_nulls) = filter.clone().into_parts(); + + let filter_buf = match filter_nulls { + Some(filter_nulls) => &filter_values & filter_nulls.inner(), + None => filter_values, + }; + let state_buf = &filter_buf & nulls.inner(); + + let boolean_state = BooleanArray::new(state_buf, None); + compute::cast(&boolean_state, &DataType::Int64)? } - (None, None) => Int64Array::from_value(1, values.len()), }; - Ok(vec![Arc::new(state_array)]) + Ok(vec![state_array]) } fn convert_to_state_supported(&self) -> bool {