Skip to content

Commit

Permalink
refactor(optimizer, agg): reuse agg call rewriting logic between agg …
Browse files Browse the repository at this point in the history
…and over window (#16690)

Signed-off-by: Richard Chien <[email protected]>
  • Loading branch information
stdrc authored May 12, 2024
1 parent 21b2502 commit 7da5ea8
Show file tree
Hide file tree
Showing 5 changed files with 275 additions and 430 deletions.
12 changes: 6 additions & 6 deletions src/frontend/planner_test/tests/testdata/output/agg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1278,22 +1278,22 @@
logical_plan: |-
LogicalProject { exprs: [Case((count(t.v1) <= 1:Int64), null:Decimal, Sqrt(((sum($expr1)::Decimal - ((sum(t.v1)::Decimal * sum(t.v1)::Decimal) / count(t.v1)::Decimal)) / (count(t.v1) - 1:Int64)::Decimal))) as $expr2, Sqrt(((sum($expr1)::Decimal - ((sum(t.v1)::Decimal * sum(t.v1)::Decimal) / count(t.v1)::Decimal)) / count(t.v1)::Decimal)) as $expr3] }
└─LogicalAgg { aggs: [sum($expr1), sum(t.v1), count(t.v1)] }
└─LogicalProject { exprs: [t.v1, (t.v1 * t.v1) as $expr1] }
└─LogicalProject { exprs: [(t.v1 * t.v1) as $expr1, t.v1] }
└─LogicalScan { table: t, columns: [t.v1, t._row_id] }
batch_plan: |-
BatchProject { exprs: [Case((sum0(count(t.v1)) <= 1:Int64), null:Decimal, Sqrt(((sum(sum($expr1))::Decimal - (($expr2 * $expr2) / $expr3)) / (sum0(count(t.v1)) - 1:Int64)::Decimal))) as $expr4, Sqrt(((sum(sum($expr1))::Decimal - (($expr2 * $expr2) / $expr3)) / $expr3)) as $expr5] }
└─BatchProject { exprs: [sum(sum($expr1)), sum(sum(t.v1)), sum0(count(t.v1)), sum(sum(t.v1))::Decimal as $expr2, sum0(count(t.v1))::Decimal as $expr3] }
└─BatchSimpleAgg { aggs: [sum(sum($expr1)), sum(sum(t.v1)), sum0(count(t.v1))] }
└─BatchExchange { order: [], dist: Single }
└─BatchSimpleAgg { aggs: [sum($expr1), sum(t.v1), count(t.v1)] }
└─BatchProject { exprs: [t.v1, (t.v1 * t.v1) as $expr1] }
└─BatchProject { exprs: [(t.v1 * t.v1) as $expr1, t.v1] }
└─BatchScan { table: t, columns: [t.v1], distribution: SomeShard }
batch_local_plan: |-
BatchProject { exprs: [Case((count(t.v1) <= 1:Int64), null:Decimal, Sqrt(((sum($expr1)::Decimal - (($expr2 * $expr2) / $expr3)) / (count(t.v1) - 1:Int64)::Decimal))) as $expr4, Sqrt(((sum($expr1)::Decimal - (($expr2 * $expr2) / $expr3)) / $expr3)) as $expr5] }
└─BatchProject { exprs: [sum($expr1), sum(t.v1), count(t.v1), sum(t.v1)::Decimal as $expr2, count(t.v1)::Decimal as $expr3] }
└─BatchSimpleAgg { aggs: [sum($expr1), sum(t.v1), count(t.v1)] }
└─BatchExchange { order: [], dist: Single }
└─BatchProject { exprs: [t.v1, (t.v1 * t.v1) as $expr1] }
└─BatchProject { exprs: [(t.v1 * t.v1) as $expr1, t.v1] }
└─BatchScan { table: t, columns: [t.v1], distribution: SomeShard }
stream_plan: |-
StreamMaterialize { columns: [stddev_samp, stddev_pop], stream_key: [], pk_columns: [], pk_conflict: NoCheck }
Expand All @@ -1302,15 +1302,15 @@
└─StreamSimpleAgg { aggs: [sum(sum($expr1)), sum(sum(t.v1)), sum0(count(t.v1)), count] }
└─StreamExchange { dist: Single }
└─StreamStatelessSimpleAgg { aggs: [sum($expr1), sum(t.v1), count(t.v1)] }
└─StreamProject { exprs: [t.v1, (t.v1 * t.v1) as $expr1, t._row_id] }
└─StreamProject { exprs: [(t.v1 * t.v1) as $expr1, t.v1, t._row_id] }
└─StreamTableScan { table: t, columns: [t.v1, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) }
- name: stddev_samp with other columns
sql: |
select count(''), stddev_samp(1);
logical_plan: |-
LogicalProject { exprs: [count('':Varchar), Case((count(1:Int32) <= 1:Int64), null:Decimal, Sqrt(((sum($expr1)::Decimal - ((sum(1:Int32)::Decimal * sum(1:Int32)::Decimal) / count(1:Int32)::Decimal)) / (count(1:Int32) - 1:Int64)::Decimal))) as $expr2] }
└─LogicalAgg { aggs: [count('':Varchar), sum($expr1), sum(1:Int32), count(1:Int32)] }
└─LogicalProject { exprs: ['':Varchar, 1:Int32, (1:Int32 * 1:Int32) as $expr1] }
└─LogicalProject { exprs: ['':Varchar, (1:Int32 * 1:Int32) as $expr1, 1:Int32] }
└─LogicalValues { rows: [[]], schema: Schema { fields: [] } }
- name: stddev_samp with group
sql: |
Expand All @@ -1319,7 +1319,7 @@
logical_plan: |-
LogicalProject { exprs: [Case((count(t.v) <= 1:Int64), null:Decimal, Sqrt(((sum($expr1)::Decimal - ((sum(t.v)::Decimal * sum(t.v)::Decimal) / count(t.v)::Decimal)) / (count(t.v) - 1:Int64)::Decimal))) as $expr2] }
└─LogicalAgg { group_key: [t.w], aggs: [sum($expr1), sum(t.v), count(t.v)] }
└─LogicalProject { exprs: [t.w, t.v, (t.v * t.v) as $expr1] }
└─LogicalProject { exprs: [t.w, (t.v * t.v) as $expr1, t.v] }
└─LogicalScan { table: t, columns: [t.v, t.w, t._row_id] }
- name: force two phase aggregation should succeed with UpstreamHashShard and SomeShard (batch only).
sql: |
Expand Down
4 changes: 2 additions & 2 deletions src/frontend/planner_test/tests/testdata/output/cse_expr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
└─BatchSimpleAgg { aggs: [sum(sum($expr1)), sum(sum(t.v)), sum0(count(t.v))] }
└─BatchExchange { order: [], dist: Single }
└─BatchSimpleAgg { aggs: [sum($expr1), sum(t.v), count(t.v)] }
└─BatchProject { exprs: [t.v, (t.v * t.v) as $expr1] }
└─BatchProject { exprs: [(t.v * t.v) as $expr1, t.v] }
└─BatchScan { table: t, columns: [t.v], distribution: SomeShard }
stream_plan: |-
StreamMaterialize { columns: [stddev_pop, stddev_samp, var_pop, var_samp], stream_key: [], pk_columns: [], pk_conflict: NoCheck }
Expand All @@ -78,7 +78,7 @@
└─StreamSimpleAgg { aggs: [sum(sum($expr1)), sum(sum(t.v)), sum0(count(t.v)), count] }
└─StreamExchange { dist: Single }
└─StreamStatelessSimpleAgg { aggs: [sum($expr1), sum(t.v), count(t.v)] }
└─StreamProject { exprs: [t.v, (t.v * t.v) as $expr1, t._row_id] }
└─StreamProject { exprs: [(t.v * t.v) as $expr1, t.v, t._row_id] }
└─StreamTableScan { table: t, columns: [t.v, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) }
- name: Common sub expression shouldn't extract partial expression of `some`/`all`. See 11766
sql: |
Expand Down
14 changes: 7 additions & 7 deletions src/frontend/src/expr/agg_call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ use crate::utils::Condition;

#[derive(Clone, Eq, PartialEq, Hash)]
pub struct AggCall {
agg_kind: AggKind,
return_type: DataType,
args: Vec<ExprImpl>,
distinct: bool,
order_by: OrderBy,
filter: Condition,
direct_args: Vec<Literal>,
pub agg_kind: AggKind,
pub return_type: DataType,
pub args: Vec<ExprImpl>,
pub distinct: bool,
pub order_by: OrderBy,
pub filter: Condition,
pub direct_args: Vec<Literal>,
}

impl std::fmt::Debug for AggCall {
Expand Down
Loading

0 comments on commit 7da5ea8

Please sign in to comment.