diff --git a/e2e_test/batch/subquery/subquery.slt.part b/e2e_test/batch/subquery/subquery.slt.part index 59a832c41126a..2d08d4c329096 100644 --- a/e2e_test/batch/subquery/subquery.slt.part +++ b/e2e_test/batch/subquery/subquery.slt.part @@ -132,6 +132,13 @@ select a, (select count(*) from t1 where t1.a <> t.b) from t1 as t order by 1; 2 2 NULL 0 +query II +select a, (select array_agg(t1.a) filter (where t1.a is distinct from 1) from t1 where t1.a <> t.b) from t1 as t order by 1; +---- +1 NULL +2 {2} +NULL NULL + statement ok drop table t1; diff --git a/src/frontend/planner_test/tests/testdata/input/subquery_expr_correlated.yaml b/src/frontend/planner_test/tests/testdata/input/subquery_expr_correlated.yaml index ca35df39ff7a6..5ba9930be99b9 100644 --- a/src/frontend/planner_test/tests/testdata/input/subquery_expr_correlated.yaml +++ b/src/frontend/planner_test/tests/testdata/input/subquery_expr_correlated.yaml @@ -46,6 +46,13 @@ expected_outputs: - optimized_logical_plan_for_batch - logical_plan +- name: 'Like `count(*)`, SimpleAgg also need to rewrite `array_agg` for the extra null row due to outer join #14735' + sql: | + create table t1(a int, b int); + select a, (select array_agg(t1.a) filter (where t1.a is distinct from 1) from t1 where t1.a <> t.b) from t1 as t order by 1; + expected_outputs: + - logical_plan + - optimized_logical_plan_for_batch - sql: | create table t1(x int, y int); create table t2(x int, y int); @@ -500,4 +507,4 @@ WHERE T.A > (SELECT avg(c) FROM T2 WHERE B = D); expected_outputs: - batch_plan - - stream_plan \ No newline at end of file + - stream_plan diff --git a/src/frontend/planner_test/tests/testdata/output/subquery_expr_correlated.yaml b/src/frontend/planner_test/tests/testdata/output/subquery_expr_correlated.yaml index 883e786eac7f6..f15fc5c6fe433 100644 --- a/src/frontend/planner_test/tests/testdata/output/subquery_expr_correlated.yaml +++ b/src/frontend/planner_test/tests/testdata/output/subquery_expr_correlated.yaml @@ -146,6 +146,31 @@ │ └─LogicalScan { table: t1, columns: [t1.y] } └─LogicalProject { exprs: [t2.y, 1:Int32] } └─LogicalScan { table: t2, columns: [t2.y], predicate: IsNotNull(t2.y) } +- name: 'Like `count(*)`, SimpleAgg also need to rewrite `array_agg` for the extra null row due to outer join #14735' + sql: | + create table t1(a int, b int); + select a, (select array_agg(t1.a) filter (where t1.a is distinct from 1) from t1 where t1.a <> t.b) from t1 as t order by 1; + logical_plan: |- + LogicalProject { exprs: [t1.a, array_agg(t1.a) filter(IsDistinctFrom(t1.a, 1:Int32))] } + └─LogicalApply { type: LeftOuter, on: true, correlated_id: 1, max_one_row: true } + ├─LogicalScan { table: t1, columns: [t1.a, t1.b, t1._row_id] } + └─LogicalProject { exprs: [array_agg(t1.a) filter(IsDistinctFrom(t1.a, 1:Int32))] } + └─LogicalAgg { aggs: [array_agg(t1.a) filter(IsDistinctFrom(t1.a, 1:Int32))] } + └─LogicalProject { exprs: [t1.a] } + └─LogicalFilter { predicate: (t1.a <> CorrelatedInputRef { index: 1, correlated_id: 1 }) } + └─LogicalScan { table: t1, columns: [t1.a, t1.b, t1._row_id] } + optimized_logical_plan_for_batch: |- + LogicalJoin { type: LeftOuter, on: IsNotDistinctFrom(t1.b, t1.b), output: [t1.a, array_agg(t1.a) filter(IsDistinctFrom(t1.a, 1:Int32) AND IsNotNull(1:Int32))] } + ├─LogicalScan { table: t1, columns: [t1.a, t1.b] } + └─LogicalAgg { group_key: [t1.b], aggs: [array_agg(t1.a) filter(IsDistinctFrom(t1.a, 1:Int32) AND IsNotNull(1:Int32))] } + └─LogicalJoin { type: LeftOuter, on: IsNotDistinctFrom(t1.b, t1.b), output: [t1.b, t1.a, 1:Int32] } + ├─LogicalAgg { group_key: [t1.b], aggs: [] } + │ └─LogicalScan { table: t1, columns: [t1.b] } + └─LogicalProject { exprs: [t1.b, t1.a, 1:Int32] } + └─LogicalJoin { type: Inner, on: (t1.a <> t1.b), output: all } + ├─LogicalAgg { group_key: [t1.b], aggs: [] } + │ └─LogicalScan { table: t1, columns: [t1.b] } + └─LogicalScan { table: t1, columns: [t1.a] } - sql: | create table t1(x int, y int); create table t2(x int, y int); @@ -1028,14 +1053,14 @@ └─BatchHashJoin { type: LeftOuter, predicate: array_types.x IS NOT DISTINCT FROM array_types.x, output: [$expr1] } ├─BatchExchange { order: [], dist: HashShard(array_types.x) } │ └─BatchScan { table: array_types, columns: [array_types.x], distribution: SomeShard } - └─BatchProject { exprs: [array_types.x, Coalesce(array_agg(array_types.x), ARRAY[]:List(List(Int64))) as $expr1] } - └─BatchHashAgg { group_key: [array_types.x], aggs: [array_agg(array_types.x)] } - └─BatchHashJoin { type: LeftOuter, predicate: array_types.x IS NOT DISTINCT FROM array_types.x, output: [array_types.x, array_types.x] } + └─BatchProject { exprs: [array_types.x, Coalesce(array_agg(array_types.x) filter(IsNotNull(1:Int32)), ARRAY[]:List(List(Int64))) as $expr1] } + └─BatchHashAgg { group_key: [array_types.x], aggs: [array_agg(array_types.x) filter(IsNotNull(1:Int32))] } + └─BatchHashJoin { type: LeftOuter, predicate: array_types.x IS NOT DISTINCT FROM array_types.x, output: [array_types.x, array_types.x, 1:Int32] } ├─BatchHashAgg { group_key: [array_types.x], aggs: [] } │ └─BatchExchange { order: [], dist: HashShard(array_types.x) } │ └─BatchScan { table: array_types, columns: [array_types.x], distribution: SomeShard } └─BatchExchange { order: [], dist: HashShard(array_types.x) } - └─BatchProject { exprs: [array_types.x, array_types.x] } + └─BatchProject { exprs: [array_types.x, array_types.x, 1:Int32] } └─BatchHashAgg { group_key: [array_types.x], aggs: [] } └─BatchExchange { order: [], dist: HashShard(array_types.x) } └─BatchScan { table: array_types, columns: [array_types.x], distribution: SomeShard } diff --git a/src/frontend/src/optimizer/rule/apply_agg_transpose_rule.rs b/src/frontend/src/optimizer/rule/apply_agg_transpose_rule.rs index 0e58634bda7a8..c53357234e350 100644 --- a/src/frontend/src/optimizer/rule/apply_agg_transpose_rule.rs +++ b/src/frontend/src/optimizer/rule/apply_agg_transpose_rule.rs @@ -140,9 +140,49 @@ impl Rule for ApplyAggTransposeRule { // convert count(*) to count(1). let pos_of_constant_column = node.schema().len() - 1; agg_calls.iter_mut().for_each(|agg_call| { - if agg_call.agg_kind == AggKind::Count && agg_call.inputs.is_empty() { - let input_ref = InputRef::new(pos_of_constant_column, DataType::Int32); - agg_call.inputs.push(input_ref); + match agg_call.agg_kind { + AggKind::Count if agg_call.inputs.is_empty() => { + let input_ref = InputRef::new(pos_of_constant_column, DataType::Int32); + agg_call.inputs.push(input_ref); + } + AggKind::ArrayAgg + | AggKind::JsonbAgg + | AggKind::JsonbObjectAgg => { + let input_ref = InputRef::new(pos_of_constant_column, DataType::Int32); + let cond = FunctionCall::new(ExprType::IsNotNull, vec![input_ref.into()]).unwrap(); + agg_call.filter.conjunctions.push(cond.into()); + } + AggKind::Count + | AggKind::Sum + | AggKind::Sum0 + | AggKind::Avg + | AggKind::Min + | AggKind::Max + | AggKind::BitAnd + | AggKind::BitOr + | AggKind::BitXor + | AggKind::BoolAnd + | AggKind::BoolOr + | AggKind::StringAgg + // not in PostgreSQL + | AggKind::ApproxCountDistinct + | AggKind::FirstValue + | AggKind::LastValue + | AggKind::InternalLastSeenValue + // All statistical aggregates only consider non-null inputs. + | AggKind::VarPop + | AggKind::VarSamp + | AggKind::StddevPop + | AggKind::StddevSamp + // All ordered-set aggregates ignore null values in their aggregated input. + | AggKind::PercentileCont + | AggKind::PercentileDisc + | AggKind::Mode + // `grouping` has no *aggregate* input and unreachable when `is_scalar_agg`. + | AggKind::Grouping + => { + // no-op when `agg(0 rows) == agg(1 row of nulls)` + } } }); }