Skip to content

Commit

Permalink
fix(optimizer): decorrelate SimpleAgg with array_agg/jsonb_agg/`j…
Browse files Browse the repository at this point in the history
…sonb_object_agg` (#15590)
  • Loading branch information
xiangjinwu authored Mar 11, 2024
1 parent b0325c5 commit ca671e6
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 8 deletions.
7 changes: 7 additions & 0 deletions e2e_test/batch/subquery/subquery.slt.part
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,13 @@ select a, (select count(*) from t1 where t1.a <> t.b) from t1 as t order by 1;
2 2
NULL 0

query II
select a, (select array_agg(t1.a) filter (where t1.a is distinct from 1) from t1 where t1.a <> t.b) from t1 as t order by 1;
----
1 NULL
2 {2}
NULL NULL

statement ok
drop table t1;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@
expected_outputs:
- optimized_logical_plan_for_batch
- logical_plan
- name: 'Like `count(*)`, SimpleAgg also need to rewrite `array_agg` for the extra null row due to outer join #14735'
sql: |
create table t1(a int, b int);
select a, (select array_agg(t1.a) filter (where t1.a is distinct from 1) from t1 where t1.a <> t.b) from t1 as t order by 1;
expected_outputs:
- logical_plan
- optimized_logical_plan_for_batch
- sql: |
create table t1(x int, y int);
create table t2(x int, y int);
Expand Down Expand Up @@ -500,4 +507,4 @@
WHERE T.A > (SELECT avg(c) FROM T2 WHERE B = D);
expected_outputs:
- batch_plan
- stream_plan
- stream_plan
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,31 @@
│ └─LogicalScan { table: t1, columns: [t1.y] }
└─LogicalProject { exprs: [t2.y, 1:Int32] }
└─LogicalScan { table: t2, columns: [t2.y], predicate: IsNotNull(t2.y) }
- name: 'Like `count(*)`, SimpleAgg also need to rewrite `array_agg` for the extra null row due to outer join #14735'
sql: |
create table t1(a int, b int);
select a, (select array_agg(t1.a) filter (where t1.a is distinct from 1) from t1 where t1.a <> t.b) from t1 as t order by 1;
logical_plan: |-
LogicalProject { exprs: [t1.a, array_agg(t1.a) filter(IsDistinctFrom(t1.a, 1:Int32))] }
└─LogicalApply { type: LeftOuter, on: true, correlated_id: 1, max_one_row: true }
├─LogicalScan { table: t1, columns: [t1.a, t1.b, t1._row_id] }
└─LogicalProject { exprs: [array_agg(t1.a) filter(IsDistinctFrom(t1.a, 1:Int32))] }
└─LogicalAgg { aggs: [array_agg(t1.a) filter(IsDistinctFrom(t1.a, 1:Int32))] }
└─LogicalProject { exprs: [t1.a] }
└─LogicalFilter { predicate: (t1.a <> CorrelatedInputRef { index: 1, correlated_id: 1 }) }
└─LogicalScan { table: t1, columns: [t1.a, t1.b, t1._row_id] }
optimized_logical_plan_for_batch: |-
LogicalJoin { type: LeftOuter, on: IsNotDistinctFrom(t1.b, t1.b), output: [t1.a, array_agg(t1.a) filter(IsDistinctFrom(t1.a, 1:Int32) AND IsNotNull(1:Int32))] }
├─LogicalScan { table: t1, columns: [t1.a, t1.b] }
└─LogicalAgg { group_key: [t1.b], aggs: [array_agg(t1.a) filter(IsDistinctFrom(t1.a, 1:Int32) AND IsNotNull(1:Int32))] }
└─LogicalJoin { type: LeftOuter, on: IsNotDistinctFrom(t1.b, t1.b), output: [t1.b, t1.a, 1:Int32] }
├─LogicalAgg { group_key: [t1.b], aggs: [] }
│ └─LogicalScan { table: t1, columns: [t1.b] }
└─LogicalProject { exprs: [t1.b, t1.a, 1:Int32] }
└─LogicalJoin { type: Inner, on: (t1.a <> t1.b), output: all }
├─LogicalAgg { group_key: [t1.b], aggs: [] }
│ └─LogicalScan { table: t1, columns: [t1.b] }
└─LogicalScan { table: t1, columns: [t1.a] }
- sql: |
create table t1(x int, y int);
create table t2(x int, y int);
Expand Down Expand Up @@ -1028,14 +1053,14 @@
└─BatchHashJoin { type: LeftOuter, predicate: array_types.x IS NOT DISTINCT FROM array_types.x, output: [$expr1] }
├─BatchExchange { order: [], dist: HashShard(array_types.x) }
│ └─BatchScan { table: array_types, columns: [array_types.x], distribution: SomeShard }
└─BatchProject { exprs: [array_types.x, Coalesce(array_agg(array_types.x), ARRAY[]:List(List(Int64))) as $expr1] }
└─BatchHashAgg { group_key: [array_types.x], aggs: [array_agg(array_types.x)] }
└─BatchHashJoin { type: LeftOuter, predicate: array_types.x IS NOT DISTINCT FROM array_types.x, output: [array_types.x, array_types.x] }
└─BatchProject { exprs: [array_types.x, Coalesce(array_agg(array_types.x) filter(IsNotNull(1:Int32)), ARRAY[]:List(List(Int64))) as $expr1] }
└─BatchHashAgg { group_key: [array_types.x], aggs: [array_agg(array_types.x) filter(IsNotNull(1:Int32))] }
└─BatchHashJoin { type: LeftOuter, predicate: array_types.x IS NOT DISTINCT FROM array_types.x, output: [array_types.x, array_types.x, 1:Int32] }
├─BatchHashAgg { group_key: [array_types.x], aggs: [] }
│ └─BatchExchange { order: [], dist: HashShard(array_types.x) }
│ └─BatchScan { table: array_types, columns: [array_types.x], distribution: SomeShard }
└─BatchExchange { order: [], dist: HashShard(array_types.x) }
└─BatchProject { exprs: [array_types.x, array_types.x] }
└─BatchProject { exprs: [array_types.x, array_types.x, 1:Int32] }
└─BatchHashAgg { group_key: [array_types.x], aggs: [] }
└─BatchExchange { order: [], dist: HashShard(array_types.x) }
└─BatchScan { table: array_types, columns: [array_types.x], distribution: SomeShard }
Expand Down
46 changes: 43 additions & 3 deletions src/frontend/src/optimizer/rule/apply_agg_transpose_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,49 @@ impl Rule for ApplyAggTransposeRule {
// convert count(*) to count(1).
let pos_of_constant_column = node.schema().len() - 1;
agg_calls.iter_mut().for_each(|agg_call| {
if agg_call.agg_kind == AggKind::Count && agg_call.inputs.is_empty() {
let input_ref = InputRef::new(pos_of_constant_column, DataType::Int32);
agg_call.inputs.push(input_ref);
match agg_call.agg_kind {
AggKind::Count if agg_call.inputs.is_empty() => {
let input_ref = InputRef::new(pos_of_constant_column, DataType::Int32);
agg_call.inputs.push(input_ref);
}
AggKind::ArrayAgg
| AggKind::JsonbAgg
| AggKind::JsonbObjectAgg => {
let input_ref = InputRef::new(pos_of_constant_column, DataType::Int32);
let cond = FunctionCall::new(ExprType::IsNotNull, vec![input_ref.into()]).unwrap();
agg_call.filter.conjunctions.push(cond.into());
}
AggKind::Count
| AggKind::Sum
| AggKind::Sum0
| AggKind::Avg
| AggKind::Min
| AggKind::Max
| AggKind::BitAnd
| AggKind::BitOr
| AggKind::BitXor
| AggKind::BoolAnd
| AggKind::BoolOr
| AggKind::StringAgg
// not in PostgreSQL
| AggKind::ApproxCountDistinct
| AggKind::FirstValue
| AggKind::LastValue
| AggKind::InternalLastSeenValue
// All statistical aggregates only consider non-null inputs.
| AggKind::VarPop
| AggKind::VarSamp
| AggKind::StddevPop
| AggKind::StddevSamp
// All ordered-set aggregates ignore null values in their aggregated input.
| AggKind::PercentileCont
| AggKind::PercentileDisc
| AggKind::Mode
// `grouping` has no *aggregate* input and unreachable when `is_scalar_agg`.
| AggKind::Grouping
=> {
// no-op when `agg(0 rows) == agg(1 row of nulls)`
}
}
});
}
Expand Down

0 comments on commit ca671e6

Please sign in to comment.