Skip to content

Commit

Permalink
Remove element's nullability of array_agg function (apache#11447)
Browse files Browse the repository at this point in the history
* rm null

Signed-off-by: jayzhan211 <[email protected]>

* fmt

Signed-off-by: jayzhan211 <[email protected]>

* fix test

Signed-off-by: jayzhan211 <[email protected]>

---------

Signed-off-by: jayzhan211 <[email protected]>
  • Loading branch information
jayzhan211 authored Jul 17, 2024
1 parent a979f3e commit d67b0fb
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 75 deletions.
2 changes: 1 addition & 1 deletion datafusion/core/tests/sql/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ async fn csv_query_array_agg_distinct() -> Result<()> {
*actual[0].schema(),
Schema::new(vec![Field::new_list(
"ARRAY_AGG(DISTINCT aggregate_test_100.c2)",
Field::new("item", DataType::UInt32, false),
Field::new("item", DataType::UInt32, true),
true
),])
);
Expand Down
23 changes: 6 additions & 17 deletions datafusion/physical-expr/src/aggregate/array_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use arrow::array::ArrayRef;
use arrow::datatypes::{DataType, Field};
use arrow_array::Array;
use datafusion_common::cast::as_list_array;
use datafusion_common::utils::array_into_list_array;
use datafusion_common::utils::array_into_list_array_nullable;
use datafusion_common::Result;
use datafusion_common::ScalarValue;
use datafusion_expr::Accumulator;
Expand All @@ -40,8 +40,6 @@ pub struct ArrayAgg {
input_data_type: DataType,
/// The input expression
expr: Arc<dyn PhysicalExpr>,
/// If the input expression can have NULLs
nullable: bool,
}

impl ArrayAgg {
Expand All @@ -50,13 +48,11 @@ impl ArrayAgg {
expr: Arc<dyn PhysicalExpr>,
name: impl Into<String>,
data_type: DataType,
nullable: bool,
) -> Self {
Self {
name: name.into(),
input_data_type: data_type,
expr,
nullable,
}
}
}
Expand All @@ -70,22 +66,21 @@ impl AggregateExpr for ArrayAgg {
Ok(Field::new_list(
&self.name,
// This should be the same as return type of AggregateFunction::ArrayAgg
Field::new("item", self.input_data_type.clone(), self.nullable),
Field::new("item", self.input_data_type.clone(), true),
true,
))
}

fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(ArrayAggAccumulator::try_new(
&self.input_data_type,
self.nullable,
)?))
}

fn state_fields(&self) -> Result<Vec<Field>> {
Ok(vec![Field::new_list(
format_state_name(&self.name, "array_agg"),
Field::new("item", self.input_data_type.clone(), self.nullable),
Field::new("item", self.input_data_type.clone(), true),
true,
)])
}
Expand Down Expand Up @@ -116,16 +111,14 @@ impl PartialEq<dyn Any> for ArrayAgg {
pub(crate) struct ArrayAggAccumulator {
values: Vec<ArrayRef>,
datatype: DataType,
nullable: bool,
}

impl ArrayAggAccumulator {
/// new array_agg accumulator based on given item data type
pub fn try_new(datatype: &DataType, nullable: bool) -> Result<Self> {
pub fn try_new(datatype: &DataType) -> Result<Self> {
Ok(Self {
values: vec![],
datatype: datatype.clone(),
nullable,
})
}
}
Expand Down Expand Up @@ -169,15 +162,11 @@ impl Accumulator for ArrayAggAccumulator {
self.values.iter().map(|a| a.as_ref()).collect();

if element_arrays.is_empty() {
return Ok(ScalarValue::new_null_list(
self.datatype.clone(),
self.nullable,
1,
));
return Ok(ScalarValue::new_null_list(self.datatype.clone(), true, 1));
}

let concated_array = arrow::compute::concat(&element_arrays)?;
let list_array = array_into_list_array(concated_array, self.nullable);
let list_array = array_into_list_array_nullable(concated_array);

Ok(ScalarValue::List(Arc::new(list_array)))
}
Expand Down
23 changes: 5 additions & 18 deletions datafusion/physical-expr/src/aggregate/array_agg_distinct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@ pub struct DistinctArrayAgg {
input_data_type: DataType,
/// The input expression
expr: Arc<dyn PhysicalExpr>,
/// If the input expression can have NULLs
nullable: bool,
}

impl DistinctArrayAgg {
Expand All @@ -52,14 +50,12 @@ impl DistinctArrayAgg {
expr: Arc<dyn PhysicalExpr>,
name: impl Into<String>,
input_data_type: DataType,
nullable: bool,
) -> Self {
let name = name.into();
Self {
name,
input_data_type,
expr,
nullable,
}
}
}
Expand All @@ -74,22 +70,21 @@ impl AggregateExpr for DistinctArrayAgg {
Ok(Field::new_list(
&self.name,
// This should be the same as return type of AggregateFunction::ArrayAgg
Field::new("item", self.input_data_type.clone(), self.nullable),
Field::new("item", self.input_data_type.clone(), true),
true,
))
}

fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(DistinctArrayAggAccumulator::try_new(
&self.input_data_type,
self.nullable,
)?))
}

fn state_fields(&self) -> Result<Vec<Field>> {
Ok(vec![Field::new_list(
format_state_name(&self.name, "distinct_array_agg"),
Field::new("item", self.input_data_type.clone(), self.nullable),
Field::new("item", self.input_data_type.clone(), true),
true,
)])
}
Expand Down Expand Up @@ -120,15 +115,13 @@ impl PartialEq<dyn Any> for DistinctArrayAgg {
struct DistinctArrayAggAccumulator {
values: HashSet<ScalarValue>,
datatype: DataType,
nullable: bool,
}

impl DistinctArrayAggAccumulator {
pub fn try_new(datatype: &DataType, nullable: bool) -> Result<Self> {
pub fn try_new(datatype: &DataType) -> Result<Self> {
Ok(Self {
values: HashSet::new(),
datatype: datatype.clone(),
nullable,
})
}
}
Expand Down Expand Up @@ -166,13 +159,9 @@ impl Accumulator for DistinctArrayAggAccumulator {
fn evaluate(&mut self) -> Result<ScalarValue> {
let values: Vec<ScalarValue> = self.values.iter().cloned().collect();
if values.is_empty() {
return Ok(ScalarValue::new_null_list(
self.datatype.clone(),
self.nullable,
1,
));
return Ok(ScalarValue::new_null_list(self.datatype.clone(), true, 1));
}
let arr = ScalarValue::new_list(&values, &self.datatype, self.nullable);
let arr = ScalarValue::new_list(&values, &self.datatype, true);
Ok(ScalarValue::List(arr))
}

Expand Down Expand Up @@ -255,7 +244,6 @@ mod tests {
col("a", &schema)?,
"bla".to_string(),
datatype,
true,
));
let actual = aggregate(&batch, agg)?;
compare_list_contents(expected, actual)
Expand All @@ -272,7 +260,6 @@ mod tests {
col("a", &schema)?,
"bla".to_string(),
datatype,
true,
));

let mut accum1 = agg.create_accumulator()?;
Expand Down
37 changes: 9 additions & 28 deletions datafusion/physical-expr/src/aggregate/array_agg_ordered.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use arrow::datatypes::{DataType, Field};
use arrow_array::cast::AsArray;
use arrow_array::{new_empty_array, Array, ArrayRef, StructArray};
use arrow_schema::Fields;
use datafusion_common::utils::{array_into_list_array, get_row_at_idx};
use datafusion_common::utils::{array_into_list_array_nullable, get_row_at_idx};
use datafusion_common::{exec_err, Result, ScalarValue};
use datafusion_expr::utils::AggregateOrderSensitivity;
use datafusion_expr::Accumulator;
Expand All @@ -50,8 +50,6 @@ pub struct OrderSensitiveArrayAgg {
input_data_type: DataType,
/// The input expression
expr: Arc<dyn PhysicalExpr>,
/// If the input expression can have `NULL`s
nullable: bool,
/// Ordering data types
order_by_data_types: Vec<DataType>,
/// Ordering requirement
Expand All @@ -66,15 +64,13 @@ impl OrderSensitiveArrayAgg {
expr: Arc<dyn PhysicalExpr>,
name: impl Into<String>,
input_data_type: DataType,
nullable: bool,
order_by_data_types: Vec<DataType>,
ordering_req: LexOrdering,
) -> Self {
Self {
name: name.into(),
input_data_type,
expr,
nullable,
order_by_data_types,
ordering_req,
reverse: false,
Expand All @@ -90,8 +86,8 @@ impl AggregateExpr for OrderSensitiveArrayAgg {
fn field(&self) -> Result<Field> {
Ok(Field::new_list(
&self.name,
// This should be the same as return type of AggregateFunction::ArrayAgg
Field::new("item", self.input_data_type.clone(), self.nullable),
// This should be the same as return type of AggregateFunction::OrderSensitiveArrayAgg
Field::new("item", self.input_data_type.clone(), true),
true,
))
}
Expand All @@ -102,25 +98,20 @@ impl AggregateExpr for OrderSensitiveArrayAgg {
&self.order_by_data_types,
self.ordering_req.clone(),
self.reverse,
self.nullable,
)
.map(|acc| Box::new(acc) as _)
}

fn state_fields(&self) -> Result<Vec<Field>> {
let mut fields = vec![Field::new_list(
format_state_name(&self.name, "array_agg"),
Field::new("item", self.input_data_type.clone(), self.nullable),
Field::new("item", self.input_data_type.clone(), true),
true, // This should be the same as field()
)];
let orderings = ordering_fields(&self.ordering_req, &self.order_by_data_types);
fields.push(Field::new_list(
format_state_name(&self.name, "array_agg_orderings"),
Field::new(
"item",
DataType::Struct(Fields::from(orderings)),
self.nullable,
),
Field::new("item", DataType::Struct(Fields::from(orderings)), true),
false,
));
Ok(fields)
Expand All @@ -147,7 +138,6 @@ impl AggregateExpr for OrderSensitiveArrayAgg {
name: self.name.to_string(),
input_data_type: self.input_data_type.clone(),
expr: Arc::clone(&self.expr),
nullable: self.nullable,
order_by_data_types: self.order_by_data_types.clone(),
// Reverse requirement:
ordering_req: reverse_order_bys(&self.ordering_req),
Expand Down Expand Up @@ -186,8 +176,6 @@ pub(crate) struct OrderSensitiveArrayAggAccumulator {
ordering_req: LexOrdering,
/// Whether the aggregation is running in reverse.
reverse: bool,
/// Whether the input expr is nullable
nullable: bool,
}

impl OrderSensitiveArrayAggAccumulator {
Expand All @@ -198,7 +186,6 @@ impl OrderSensitiveArrayAggAccumulator {
ordering_dtypes: &[DataType],
ordering_req: LexOrdering,
reverse: bool,
nullable: bool,
) -> Result<Self> {
let mut datatypes = vec![datatype.clone()];
datatypes.extend(ordering_dtypes.iter().cloned());
Expand All @@ -208,7 +195,6 @@ impl OrderSensitiveArrayAggAccumulator {
datatypes,
ordering_req,
reverse,
nullable,
})
}
}
Expand Down Expand Up @@ -312,7 +298,7 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator {
if self.values.is_empty() {
return Ok(ScalarValue::new_null_list(
self.datatypes[0].clone(),
self.nullable,
true,
1,
));
}
Expand All @@ -322,14 +308,10 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator {
ScalarValue::new_list_from_iter(
values.into_iter().rev(),
&self.datatypes[0],
self.nullable,
true,
)
} else {
ScalarValue::new_list_from_iter(
values.into_iter(),
&self.datatypes[0],
self.nullable,
)
ScalarValue::new_list_from_iter(values.into_iter(), &self.datatypes[0], true)
};
Ok(ScalarValue::List(array))
}
Expand Down Expand Up @@ -385,9 +367,8 @@ impl OrderSensitiveArrayAggAccumulator {
column_wise_ordering_values,
None,
)?;
Ok(ScalarValue::List(Arc::new(array_into_list_array(
Ok(ScalarValue::List(Arc::new(array_into_list_array_nullable(
Arc::new(ordering_array),
self.nullable,
))))
}
}
Expand Down
12 changes: 2 additions & 10 deletions datafusion/physical-expr/src/aggregate/build_in.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,14 @@ pub fn create_aggregate_expr(
Ok(match (fun, distinct) {
(AggregateFunction::ArrayAgg, false) => {
let expr = Arc::clone(&input_phy_exprs[0]);
let nullable = expr.nullable(input_schema)?;

if ordering_req.is_empty() {
Arc::new(expressions::ArrayAgg::new(expr, name, data_type, nullable))
Arc::new(expressions::ArrayAgg::new(expr, name, data_type))
} else {
Arc::new(expressions::OrderSensitiveArrayAgg::new(
expr,
name,
data_type,
nullable,
ordering_types,
ordering_req.to_vec(),
))
Expand All @@ -84,13 +82,7 @@ pub fn create_aggregate_expr(
);
}
let expr = Arc::clone(&input_phy_exprs[0]);
let is_expr_nullable = expr.nullable(input_schema)?;
Arc::new(expressions::DistinctArrayAgg::new(
expr,
name,
data_type,
is_expr_nullable,
))
Arc::new(expressions::DistinctArrayAgg::new(expr, name, data_type))
}
(AggregateFunction::Min, _) => Arc::new(expressions::Min::new(
Arc::clone(&input_phy_exprs[0]),
Expand Down
1 change: 0 additions & 1 deletion datafusion/physical-plan/src/aggregates/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2231,7 +2231,6 @@ mod tests {
Arc::clone(col_a),
"array_agg",
DataType::Int32,
false,
vec![],
order_by_expr.unwrap_or_default(),
)) as _
Expand Down

0 comments on commit d67b0fb

Please sign in to comment.