Skip to content

Commit

Permalink
Minor: Make schema of grouping set columns nullable (apache#8248)
Browse files Browse the repository at this point in the history
* Make output schema of aggregation grouping sets nullable

* Improve

* Fix tests
  • Loading branch information
markusa380 authored Nov 18, 2023
1 parent 76ced31 commit 8f48053
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 12 deletions.
56 changes: 50 additions & 6 deletions datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2294,13 +2294,25 @@ impl Aggregate {
aggr_expr: Vec<Expr>,
) -> Result<Self> {
let group_expr = enumerate_grouping_sets(group_expr)?;

let is_grouping_set = matches!(group_expr.as_slice(), [Expr::GroupingSet(_)]);

let grouping_expr: Vec<Expr> = grouping_set_to_exprlist(group_expr.as_slice())?;
let all_expr = grouping_expr.iter().chain(aggr_expr.iter());

let schema = DFSchema::new_with_metadata(
exprlist_to_fields(all_expr, &input)?,
input.schema().metadata().clone(),
)?;
let mut fields = exprlist_to_fields(grouping_expr.iter(), &input)?;

// Even columns that cannot be null will become nullable when used in a grouping set.
if is_grouping_set {
fields = fields
.into_iter()
.map(|field| field.with_nullable(true))
.collect::<Vec<_>>();
}

fields.extend(exprlist_to_fields(aggr_expr.iter(), &input)?);

let schema =
DFSchema::new_with_metadata(fields, input.schema().metadata().clone())?;

Self::try_new_with_schema(input, group_expr, aggr_expr, Arc::new(schema))
}
Expand Down Expand Up @@ -2539,7 +2551,7 @@ pub struct Unnest {
mod tests {
use super::*;
use crate::logical_plan::table_scan;
use crate::{col, exists, in_subquery, lit, placeholder};
use crate::{col, count, exists, in_subquery, lit, placeholder, GroupingSet};
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::tree_node::TreeNodeVisitor;
use datafusion_common::{not_impl_err, DFSchema, TableReference};
Expand Down Expand Up @@ -3006,4 +3018,36 @@ digraph {
plan.replace_params_with_values(&[42i32.into()])
.expect_err("unexpectedly succeeded to replace an invalid placeholder");
}

#[test]
fn test_nullable_schema_after_grouping_set() {
let schema = Schema::new(vec![
Field::new("foo", DataType::Int32, false),
Field::new("bar", DataType::Int32, false),
]);

let plan = table_scan(TableReference::none(), &schema, None)
.unwrap()
.aggregate(
vec![Expr::GroupingSet(GroupingSet::GroupingSets(vec![
vec![col("foo")],
vec![col("bar")],
]))],
vec![count(lit(true))],
)
.unwrap()
.build()
.unwrap();

let output_schema = plan.schema();

assert!(output_schema
.field_with_name(None, "foo")
.unwrap()
.is_nullable(),);
assert!(output_schema
.field_with_name(None, "bar")
.unwrap()
.is_nullable());
}
}
6 changes: 3 additions & 3 deletions datafusion/optimizer/src/single_distinct_to_groupby.rs
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ mod tests {
.build()?;

// Should not be optimized
let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32, b:UInt32, COUNT(DISTINCT test.c):Int64;N]\
let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, COUNT(DISTINCT test.c):Int64;N]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(&plan, expected)
Expand All @@ -340,7 +340,7 @@ mod tests {
.build()?;

// Should not be optimized
let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32, b:UInt32, COUNT(DISTINCT test.c):Int64;N]\
let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, COUNT(DISTINCT test.c):Int64;N]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(&plan, expected)
Expand All @@ -359,7 +359,7 @@ mod tests {
.build()?;

// Should not be optimized
let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32, b:UInt32, COUNT(DISTINCT test.c):Int64;N]\
let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, COUNT(DISTINCT test.c):Int64;N]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(&plan, expected)
Expand Down
7 changes: 4 additions & 3 deletions datafusion/sqllogictest/test_files/aggregate.slt
Original file line number Diff line number Diff line change
Expand Up @@ -2672,9 +2672,10 @@ query TT
EXPLAIN SELECT c2, c3 FROM aggregate_test_100 group by rollup(c2, c3) limit 3;
----
logical_plan
Limit: skip=0, fetch=3
--Aggregate: groupBy=[[ROLLUP (aggregate_test_100.c2, aggregate_test_100.c3)]], aggr=[[]]
----TableScan: aggregate_test_100 projection=[c2, c3]
Projection: aggregate_test_100.c2, aggregate_test_100.c3
--Limit: skip=0, fetch=3
----Aggregate: groupBy=[[ROLLUP (aggregate_test_100.c2, aggregate_test_100.c3)]], aggr=[[]]
------TableScan: aggregate_test_100 projection=[c2, c3]
physical_plan
GlobalLimitExec: skip=0, fetch=3
--AggregateExec: mode=Final, gby=[c2@0 as c2, c3@1 as c3], aggr=[], lim=[3]
Expand Down

0 comments on commit 8f48053

Please sign in to comment.