diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index ce86fa3301686..5a54fa48597e8 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -22,6 +22,7 @@ use std::{fmt::Debug, sync::Arc}; use arrow::{ array::{ArrayRef, AsArray}, + compute, datatypes::{ DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Field, Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, @@ -441,12 +442,17 @@ impl GroupsAccumulator for CountGroupsAccumulator { let values = &values[0]; let state_array = match (values.logical_nulls(), opt_filter) { + (None, None) => Arc::new(Int64Array::from_value(1, values.len())), (Some(nulls), None) => { + let nulls = BooleanArray::new(nulls.into_inner(), None); + compute::cast(&nulls, &DataType::Int64)? + } + (None, Some(filter)) => { let mut builder = Int64Builder::with_capacity(values.len()); - nulls - .into_iter() - .for_each(|is_valid| builder.append_value(is_valid as i64)); - builder.finish() + filter.into_iter().for_each(|filter_value| { + builder.append_value(filter_value.is_some_and(|val| val) as i64) + }); + Arc::new(builder.finish()) } (Some(nulls), Some(filter)) => { let mut builder = Int64Builder::with_capacity(values.len()); @@ -457,19 +463,11 @@ impl GroupsAccumulator for CountGroupsAccumulator { ) }, ); - builder.finish() - } - (None, Some(filter)) => { - let mut builder = Int64Builder::with_capacity(values.len()); - filter.into_iter().for_each(|filter_value| { - builder.append_value(filter_value.is_some_and(|val| val) as i64) - }); - builder.finish() + Arc::new(builder.finish()) } - (None, None) => Int64Array::from_value(1, values.len()), }; - Ok(vec![Arc::new(state_array)]) + Ok(vec![state_array]) } fn convert_to_state_supported(&self) -> bool {