Skip to content

Commit

Permalink
Refactor PrimitiveGroupValueBuilder to use MaybeNullBufferBuilder (a…
Browse files Browse the repository at this point in the history
…pache#12623)

* Refactor PrimitiveGroupValueBuilder to use BooleanBuilder

* Refactor boolean buffer builder out

* tweaks

* tweak

* simplify

* Add specializations for null / non null
  • Loading branch information
alamb authored Sep 30, 2024
1 parent 29b8af2 commit ddb4fac
Show file tree
Hide file tree
Showing 4 changed files with 240 additions and 103 deletions.
79 changes: 34 additions & 45 deletions datafusion/physical-plan/src/aggregates/group_values/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
// under the License.

use crate::aggregates::group_values::group_column::{
ByteGroupValueBuilder, GroupColumn, PrimitiveGroupValueBuilder,
ByteGroupValueBuilder, GroupColumn, NonNullPrimitiveGroupValueBuilder,
PrimitiveGroupValueBuilder,
};
use crate::aggregates::group_values::GroupValues;
use ahash::RandomState;
Expand Down Expand Up @@ -123,6 +124,26 @@ impl GroupValuesColumn {
}
}

/// instantiates a [`PrimitiveGroupValueBuilder`] or
/// [`NonNullPrimitiveGroupValueBuilder`] and pushes it into $v
///
/// Arguments:
/// `$v`: the vector to push the new builder into
/// `$nullable`: whether the input can contains nulls
/// `$t`: the primitive type of the builder
///
macro_rules! instantiate_primitive {
($v:expr, $nullable:expr, $t:ty) => {
if $nullable {
let b = PrimitiveGroupValueBuilder::<$t>::new();
$v.push(Box::new(b) as _)
} else {
let b = NonNullPrimitiveGroupValueBuilder::<$t>::new();
$v.push(Box::new(b) as _)
}
};
}

impl GroupValues for GroupValuesColumn {
fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec<usize>) -> Result<()> {
let n_rows = cols[0].len();
Expand All @@ -133,54 +154,22 @@ impl GroupValues for GroupValuesColumn {
for f in self.schema.fields().iter() {
let nullable = f.is_nullable();
match f.data_type() {
&DataType::Int8 => {
let b = PrimitiveGroupValueBuilder::<Int8Type>::new(nullable);
v.push(Box::new(b) as _)
}
&DataType::Int16 => {
let b = PrimitiveGroupValueBuilder::<Int16Type>::new(nullable);
v.push(Box::new(b) as _)
}
&DataType::Int32 => {
let b = PrimitiveGroupValueBuilder::<Int32Type>::new(nullable);
v.push(Box::new(b) as _)
}
&DataType::Int64 => {
let b = PrimitiveGroupValueBuilder::<Int64Type>::new(nullable);
v.push(Box::new(b) as _)
}
&DataType::UInt8 => {
let b = PrimitiveGroupValueBuilder::<UInt8Type>::new(nullable);
v.push(Box::new(b) as _)
}
&DataType::UInt16 => {
let b = PrimitiveGroupValueBuilder::<UInt16Type>::new(nullable);
v.push(Box::new(b) as _)
}
&DataType::UInt32 => {
let b = PrimitiveGroupValueBuilder::<UInt32Type>::new(nullable);
v.push(Box::new(b) as _)
}
&DataType::UInt64 => {
let b = PrimitiveGroupValueBuilder::<UInt64Type>::new(nullable);
v.push(Box::new(b) as _)
}
&DataType::Int8 => instantiate_primitive!(v, nullable, Int8Type),
&DataType::Int16 => instantiate_primitive!(v, nullable, Int16Type),
&DataType::Int32 => instantiate_primitive!(v, nullable, Int32Type),
&DataType::Int64 => instantiate_primitive!(v, nullable, Int64Type),
&DataType::UInt8 => instantiate_primitive!(v, nullable, UInt8Type),
&DataType::UInt16 => instantiate_primitive!(v, nullable, UInt16Type),
&DataType::UInt32 => instantiate_primitive!(v, nullable, UInt32Type),
&DataType::UInt64 => instantiate_primitive!(v, nullable, UInt64Type),
&DataType::Float32 => {
let b = PrimitiveGroupValueBuilder::<Float32Type>::new(nullable);
v.push(Box::new(b) as _)
instantiate_primitive!(v, nullable, Float32Type)
}
&DataType::Float64 => {
let b = PrimitiveGroupValueBuilder::<Float64Type>::new(nullable);
v.push(Box::new(b) as _)
}
&DataType::Date32 => {
let b = PrimitiveGroupValueBuilder::<Date32Type>::new(nullable);
v.push(Box::new(b) as _)
}
&DataType::Date64 => {
let b = PrimitiveGroupValueBuilder::<Date64Type>::new(nullable);
v.push(Box::new(b) as _)
instantiate_primitive!(v, nullable, Float64Type)
}
&DataType::Date32 => instantiate_primitive!(v, nullable, Date32Type),
&DataType::Date64 => instantiate_primitive!(v, nullable, Date64Type),
&DataType::Utf8 => {
let b = ByteGroupValueBuilder::<i32>::new(OutputType::Utf8);
v.push(Box::new(b) as _)
Expand Down
148 changes: 90 additions & 58 deletions datafusion/physical-plan/src/aggregates/group_values/group_column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ use arrow::datatypes::GenericBinaryType;
use arrow::datatypes::GenericStringType;
use datafusion_common::utils::proxy::VecAllocExt;

use crate::aggregates::group_values::null_builder::MaybeNullBufferBuilder;
use datafusion_physical_expr_common::binary_map::{OutputType, INITIAL_BUFFER_CAPACITY};
use std::sync::Arc;
use std::vec;

use datafusion_physical_expr_common::binary_map::{OutputType, INITIAL_BUFFER_CAPACITY};

/// Trait for storing a single column of group values in [`GroupValuesColumn`]
///
/// Implementations of this trait store an in-progress collection of group values
Expand All @@ -47,6 +47,8 @@ use datafusion_physical_expr_common::binary_map::{OutputType, INITIAL_BUFFER_CAP
pub trait GroupColumn: Send + Sync {
/// Returns equal if the row stored in this builder at `lhs_row` is equal to
/// the row in `array` at `rhs_row`
///
/// Note that this comparison returns true if both elements are NULL
fn equal_to(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool;
/// Appends the row at `row` in `array` to this builder
fn append_val(&mut self, array: &ArrayRef, row: usize);
Expand All @@ -61,61 +63,96 @@ pub trait GroupColumn: Send + Sync {
fn take_n(&mut self, n: usize) -> ArrayRef;
}

/// An implementation of [`GroupColumn`] for primitive types.
pub struct PrimitiveGroupValueBuilder<T: ArrowPrimitiveType> {
/// An implementation of [`GroupColumn`] for primitive values which are known to have no nulls
#[derive(Debug)]
pub struct NonNullPrimitiveGroupValueBuilder<T: ArrowPrimitiveType> {
group_values: Vec<T::Native>,
nulls: Vec<bool>,
/// whether the array contains at least one null, for fast non-null path
has_null: bool,
/// Can the input array contain nulls?
nullable: bool,
}

impl<T> PrimitiveGroupValueBuilder<T>
impl<T> NonNullPrimitiveGroupValueBuilder<T>
where
T: ArrowPrimitiveType,
{
pub fn new(nullable: bool) -> Self {
pub fn new() -> Self {
Self {
group_values: vec![],
nulls: vec![],
has_null: false,
nullable,
}
}
}

impl<T: ArrowPrimitiveType> GroupColumn for PrimitiveGroupValueBuilder<T> {
impl<T: ArrowPrimitiveType> GroupColumn for NonNullPrimitiveGroupValueBuilder<T> {
fn equal_to(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool {
// non-null fast path
// both non-null
if !self.nullable {
return self.group_values[lhs_row]
== array.as_primitive::<T>().value(rhs_row);
}
// know input has no nulls
self.group_values[lhs_row] == array.as_primitive::<T>().value(rhs_row)
}

// lhs is non-null
if self.nulls[lhs_row] {
if array.is_null(rhs_row) {
return false;
}
fn append_val(&mut self, array: &ArrayRef, row: usize) {
// input can't possibly have nulls, so don't worry about them
self.group_values.push(array.as_primitive::<T>().value(row))
}

fn len(&self) -> usize {
self.group_values.len()
}

fn size(&self) -> usize {
self.group_values.allocated_size()
}

fn build(self: Box<Self>) -> ArrayRef {
let Self { group_values } = *self;

return self.group_values[lhs_row]
== array.as_primitive::<T>().value(rhs_row);
let nulls = None;

Arc::new(PrimitiveArray::<T>::new(
ScalarBuffer::from(group_values),
nulls,
))
}

fn take_n(&mut self, n: usize) -> ArrayRef {
let first_n = self.group_values.drain(0..n).collect::<Vec<_>>();
let first_n_nulls = None;

Arc::new(PrimitiveArray::<T>::new(
ScalarBuffer::from(first_n),
first_n_nulls,
))
}
}

/// An implementation of [`GroupColumn`] for primitive values which may have nulls
#[derive(Debug)]
pub struct PrimitiveGroupValueBuilder<T: ArrowPrimitiveType> {
group_values: Vec<T::Native>,
nulls: MaybeNullBufferBuilder,
}

impl<T> PrimitiveGroupValueBuilder<T>
where
T: ArrowPrimitiveType,
{
pub fn new() -> Self {
Self {
group_values: vec![],
nulls: MaybeNullBufferBuilder::new(),
}
}
}

array.is_null(rhs_row)
impl<T: ArrowPrimitiveType> GroupColumn for PrimitiveGroupValueBuilder<T> {
fn equal_to(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool {
self.nulls.is_null(lhs_row) == array.is_null(rhs_row)
&& self.group_values[lhs_row] == array.as_primitive::<T>().value(rhs_row)
}

fn append_val(&mut self, array: &ArrayRef, row: usize) {
if self.nullable && array.is_null(row) {
if array.is_null(row) {
self.nulls.append(true);
self.group_values.push(T::default_value());
self.nulls.push(false);
self.has_null = true;
} else {
let elem = array.as_primitive::<T>().value(row);
self.group_values.push(elem);
self.nulls.push(true);
self.nulls.append(false);
self.group_values.push(array.as_primitive::<T>().value(row));
}
}

Expand All @@ -128,32 +165,27 @@ impl<T: ArrowPrimitiveType> GroupColumn for PrimitiveGroupValueBuilder<T> {
}

fn build(self: Box<Self>) -> ArrayRef {
if self.has_null {
Arc::new(PrimitiveArray::<T>::new(
ScalarBuffer::from(self.group_values),
Some(NullBuffer::from(self.nulls)),
))
} else {
Arc::new(PrimitiveArray::<T>::new(
ScalarBuffer::from(self.group_values),
None,
))
}
let Self {
group_values,
nulls,
} = *self;

let nulls = nulls.build();

Arc::new(PrimitiveArray::<T>::new(
ScalarBuffer::from(group_values),
nulls,
))
}

fn take_n(&mut self, n: usize) -> ArrayRef {
if self.has_null {
let first_n = self.group_values.drain(0..n).collect::<Vec<_>>();
let first_n_nulls = self.nulls.drain(0..n).collect::<Vec<_>>();
Arc::new(PrimitiveArray::<T>::new(
ScalarBuffer::from(first_n),
Some(NullBuffer::from(first_n_nulls)),
))
} else {
let first_n = self.group_values.drain(0..n).collect::<Vec<_>>();
self.nulls.truncate(self.nulls.len() - n);
Arc::new(PrimitiveArray::<T>::new(ScalarBuffer::from(first_n), None))
}
let first_n = self.group_values.drain(0..n).collect::<Vec<_>>();
let first_n_nulls = self.nulls.take_n(n);

Arc::new(PrimitiveArray::<T>::new(
ScalarBuffer::from(first_n),
first_n_nulls,
))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ use bytes::GroupValuesByes;
use datafusion_physical_expr::binary_map::OutputType;

mod group_column;
mod null_builder;

/// Stores the group values during hash aggregation.
///
Expand Down
Loading

0 comments on commit ddb4fac

Please sign in to comment.