diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index b7537dc02e9d..a024824c7a5a 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -2294,13 +2294,25 @@ impl Aggregate { aggr_expr: Vec, ) -> Result { let group_expr = enumerate_grouping_sets(group_expr)?; + + let is_grouping_set = matches!(group_expr.as_slice(), [Expr::GroupingSet(_)]); + let grouping_expr: Vec = grouping_set_to_exprlist(group_expr.as_slice())?; - let all_expr = grouping_expr.iter().chain(aggr_expr.iter()); - let schema = DFSchema::new_with_metadata( - exprlist_to_fields(all_expr, &input)?, - input.schema().metadata().clone(), - )?; + let mut fields = exprlist_to_fields(grouping_expr.iter(), &input)?; + + // Even columns that cannot be null will become nullable when used in a grouping set. + if is_grouping_set { + fields = fields + .into_iter() + .map(|field| field.with_nullable(true)) + .collect::>(); + } + + fields.extend(exprlist_to_fields(aggr_expr.iter(), &input)?); + + let schema = + DFSchema::new_with_metadata(fields, input.schema().metadata().clone())?; Self::try_new_with_schema(input, group_expr, aggr_expr, Arc::new(schema)) } @@ -2539,7 +2551,7 @@ pub struct Unnest { mod tests { use super::*; use crate::logical_plan::table_scan; - use crate::{col, exists, in_subquery, lit, placeholder}; + use crate::{col, count, exists, in_subquery, lit, placeholder, GroupingSet}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::tree_node::TreeNodeVisitor; use datafusion_common::{not_impl_err, DFSchema, TableReference}; @@ -3006,4 +3018,36 @@ digraph { plan.replace_params_with_values(&[42i32.into()]) .expect_err("unexpectedly succeeded to replace an invalid placeholder"); } + + #[test] + fn test_nullable_schema_after_grouping_set() { + let schema = Schema::new(vec![ + Field::new("foo", DataType::Int32, false), + Field::new("bar", DataType::Int32, false), + ]); + + let plan = table_scan(TableReference::none(), &schema, None) + .unwrap() + .aggregate( + vec![Expr::GroupingSet(GroupingSet::GroupingSets(vec![ + vec![col("foo")], + vec![col("bar")], + ]))], + vec![count(lit(true))], + ) + .unwrap() + .build() + .unwrap(); + + let output_schema = plan.schema(); + + assert!(output_schema + .field_with_name(None, "foo") + .unwrap() + .is_nullable(),); + assert!(output_schema + .field_with_name(None, "bar") + .unwrap() + .is_nullable()); + } } diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index be76c069f0b7..ac18e596b7bd 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -322,7 +322,7 @@ mod tests { .build()?; // Should not be optimized - let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32, b:UInt32, COUNT(DISTINCT test.c):Int64;N]\ + let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, COUNT(DISTINCT test.c):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(&plan, expected) @@ -340,7 +340,7 @@ mod tests { .build()?; // Should not be optimized - let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32, b:UInt32, COUNT(DISTINCT test.c):Int64;N]\ + let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, COUNT(DISTINCT test.c):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(&plan, expected) @@ -359,7 +359,7 @@ mod tests { .build()?; // Should not be optimized - let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32, b:UInt32, COUNT(DISTINCT test.c):Int64;N]\ + let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, COUNT(DISTINCT test.c):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(&plan, expected) diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 0a495dd2b0c9..faad6feb3f33 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -2672,9 +2672,10 @@ query TT EXPLAIN SELECT c2, c3 FROM aggregate_test_100 group by rollup(c2, c3) limit 3; ---- logical_plan -Limit: skip=0, fetch=3 ---Aggregate: groupBy=[[ROLLUP (aggregate_test_100.c2, aggregate_test_100.c3)]], aggr=[[]] -----TableScan: aggregate_test_100 projection=[c2, c3] +Projection: aggregate_test_100.c2, aggregate_test_100.c3 +--Limit: skip=0, fetch=3 +----Aggregate: groupBy=[[ROLLUP (aggregate_test_100.c2, aggregate_test_100.c3)]], aggr=[[]] +------TableScan: aggregate_test_100 projection=[c2, c3] physical_plan GlobalLimitExec: skip=0, fetch=3 --AggregateExec: mode=Final, gby=[c2@0 as c2, c3@1 as c3], aggr=[], lim=[3]