From 6ec0b95392ec27d096e35783650bafb7c992b403 Mon Sep 17 00:00:00 2001 From: jinser Date: Tue, 10 Oct 2023 12:02:26 +0800 Subject: [PATCH] =?UTF-8?q?fix(over=20window):=20fix=20error=20in=20using?= =?UTF-8?q?=20aggregate=20function=20result=20as=20win=E2=80=A6=20(#12551)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../batch/aggregate/with_over_window.slt.part | 87 +++++++++++++++++++ .../generated/batch/create.slt.part | 12 ++- .../over_window/generated/batch/drop.slt.part | 3 + .../over_window/generated/batch/mod.slt.part | 28 ++++++ .../generated/streaming/create.slt.part | 12 ++- .../generated/streaming/drop.slt.part | 3 + .../generated/streaming/mod.slt.part | 28 ++++++ .../over_window/templates/create.slt.part | 12 ++- e2e_test/over_window/templates/drop.slt.part | 3 + e2e_test/over_window/templates/mod.slt.part | 28 ++++++ .../tests/testdata/input/agg.yaml | 29 +++++++ .../tests/testdata/output/agg.yaml | 83 ++++++++++++++++++ .../src/optimizer/plan_node/logical_agg.rs | 28 ++++++ 13 files changed, 353 insertions(+), 3 deletions(-) create mode 100644 e2e_test/batch/aggregate/with_over_window.slt.part diff --git a/e2e_test/batch/aggregate/with_over_window.slt.part b/e2e_test/batch/aggregate/with_over_window.slt.part new file mode 100644 index 0000000000000..a450b13837791 --- /dev/null +++ b/e2e_test/batch/aggregate/with_over_window.slt.part @@ -0,0 +1,87 @@ +statement ok +create table t (a int, b int, c int, d int, e int); + +statement ok +insert into t values + (1, 23, 84, 11, 87), + (2, 34, 29, 22, 98), + (3, 45, 43, 33, 10), + (4, 56, 83, 44, 26), + (5, 68, 20, 55, 12), + (5, 68, 90, 66, 34), + (5, 68, 11, 77, 32); + +query II +select + a, + sum((sum(b))) over (partition by a order by a) +from t +group by a +order by a; +---- +1 23 +2 34 +3 45 +4 56 +5 204 + +query II +select + a, + row_number() over (partition by a order by a) +from t +group by a +order by a; +---- +1 1 +2 1 +3 1 +4 1 +5 1 + +query II +select + a, + row_number() over (partition by a order by a desc) +from t +group by a +order by a; +---- +1 1 +2 1 +3 1 +4 1 +5 1 + +query III +select + a, + b, + sum(sum(c)) over (partition by a order by b) +from t +group by a, b +order by a, b; +---- +1 23 84 +2 34 29 +3 45 43 +4 56 83 +5 68 121 + +query III +select + a, + b, + sum(sum(c)) over (partition by a, avg(d) order by max(e), b) +from t +group by a, b +order by a, b; +---- +1 23 84 +2 34 29 +3 45 43 +4 56 83 +5 68 121 + +statement ok +drop table t; diff --git a/e2e_test/over_window/generated/batch/create.slt.part b/e2e_test/over_window/generated/batch/create.slt.part index 8e489c3dde0bd..5f4b5e1152804 100644 --- a/e2e_test/over_window/generated/batch/create.slt.part +++ b/e2e_test/over_window/generated/batch/create.slt.part @@ -49,6 +49,16 @@ select , row_number() over (partition by p1 order by p2 desc, id) as out11 from t; +# over + agg +statement ok +create view v_e as +select + p1, p2 + , row_number() over (partition by p1 order by p2) as out12 + , sum(sum(v2)) over (partition by p1, avg(time) order by max(v1), p2) as out13 +from t +group by p1, p2; + statement ok create view v_a_b as select @@ -103,4 +113,4 @@ select , first_value(v1) over (partition by p1, p2 order by time, id rows 3 preceding) as out3 , lag(v1 + 2, 0 + 1) over (partition by p1 - 1 order by id) as out4 , min(v1 * 2) over (partition by p1, p2 order by time + 1, id rows between current row and unbounded following) as out5 -from t; \ No newline at end of file +from t; diff --git a/e2e_test/over_window/generated/batch/drop.slt.part b/e2e_test/over_window/generated/batch/drop.slt.part index 8eaca578e1f4e..435ffd46433e7 100644 --- a/e2e_test/over_window/generated/batch/drop.slt.part +++ b/e2e_test/over_window/generated/batch/drop.slt.part @@ -12,6 +12,9 @@ drop view v_c; statement ok drop view v_d; +statement ok +drop view v_e; + statement ok drop view v_a_b; diff --git a/e2e_test/over_window/generated/batch/mod.slt.part b/e2e_test/over_window/generated/batch/mod.slt.part index ff46877de7ddf..2c7778fd46aff 100644 --- a/e2e_test/over_window/generated/batch/mod.slt.part +++ b/e2e_test/over_window/generated/batch/mod.slt.part @@ -41,6 +41,13 @@ select * from v_d order by id; 100003 100 208 2 723 807 3 1 100004 103 200 2 702 808 1 1 +query iiii +select * from v_e order by p1; +---- +100 200 1 1611 +100 208 2 807 +103 200 1 808 + include ./cross_check.slt.part statement ok @@ -88,6 +95,14 @@ select * from v_d order by id; 100005 100 200 3 717 810 4 4 100006 105 204 5 703 828 1 1 +query iiii +select * from v_e order by p1, p2; +---- +100 200 1 2421 +100 208 2 3228 +103 200 1 808 +105 204 1 828 + include ./cross_check.slt.part statement ok @@ -139,6 +154,13 @@ select * from v_d order by id; 100005 100 200 1 717 810 2 4 100006 105 204 5 703 828 1 1 +query iiiiiii +select * from v_e order by p1; +---- +100 200 1 3228 +103 200 1 808 +105 204 1 828 + query iiiiiiiiii select * from v_expr order by id; ---- @@ -182,6 +204,12 @@ select * from v_d order by id; 100005 100 200 1 717 810 2 2 100006 105 204 5 703 828 1 1 +query iiii +select * from v_e order by p1; +---- +100 200 1 1615 +105 204 1 828 + query iiiiiiiiii select * from v_expr order by id; ---- diff --git a/e2e_test/over_window/generated/streaming/create.slt.part b/e2e_test/over_window/generated/streaming/create.slt.part index 23a496ff0f315..4334fb1cdd30e 100644 --- a/e2e_test/over_window/generated/streaming/create.slt.part +++ b/e2e_test/over_window/generated/streaming/create.slt.part @@ -49,6 +49,16 @@ select , row_number() over (partition by p1 order by p2 desc, id) as out11 from t; +# over + agg +statement ok +create materialized view v_e as +select + p1, p2 + , row_number() over (partition by p1 order by p2) as out12 + , sum(sum(v2)) over (partition by p1, avg(time) order by max(v1), p2) as out13 +from t +group by p1, p2; + statement ok create materialized view v_a_b as select @@ -103,4 +113,4 @@ select , first_value(v1) over (partition by p1, p2 order by time, id rows 3 preceding) as out3 , lag(v1 + 2, 0 + 1) over (partition by p1 - 1 order by id) as out4 , min(v1 * 2) over (partition by p1, p2 order by time + 1, id rows between current row and unbounded following) as out5 -from t; \ No newline at end of file +from t; diff --git a/e2e_test/over_window/generated/streaming/drop.slt.part b/e2e_test/over_window/generated/streaming/drop.slt.part index d469282f41247..e6c4fcfaad244 100644 --- a/e2e_test/over_window/generated/streaming/drop.slt.part +++ b/e2e_test/over_window/generated/streaming/drop.slt.part @@ -12,6 +12,9 @@ drop materialized view v_c; statement ok drop materialized view v_d; +statement ok +drop materialized view v_e; + statement ok drop materialized view v_a_b; diff --git a/e2e_test/over_window/generated/streaming/mod.slt.part b/e2e_test/over_window/generated/streaming/mod.slt.part index ff46877de7ddf..2c7778fd46aff 100644 --- a/e2e_test/over_window/generated/streaming/mod.slt.part +++ b/e2e_test/over_window/generated/streaming/mod.slt.part @@ -41,6 +41,13 @@ select * from v_d order by id; 100003 100 208 2 723 807 3 1 100004 103 200 2 702 808 1 1 +query iiii +select * from v_e order by p1; +---- +100 200 1 1611 +100 208 2 807 +103 200 1 808 + include ./cross_check.slt.part statement ok @@ -88,6 +95,14 @@ select * from v_d order by id; 100005 100 200 3 717 810 4 4 100006 105 204 5 703 828 1 1 +query iiii +select * from v_e order by p1, p2; +---- +100 200 1 2421 +100 208 2 3228 +103 200 1 808 +105 204 1 828 + include ./cross_check.slt.part statement ok @@ -139,6 +154,13 @@ select * from v_d order by id; 100005 100 200 1 717 810 2 4 100006 105 204 5 703 828 1 1 +query iiiiiii +select * from v_e order by p1; +---- +100 200 1 3228 +103 200 1 808 +105 204 1 828 + query iiiiiiiiii select * from v_expr order by id; ---- @@ -182,6 +204,12 @@ select * from v_d order by id; 100005 100 200 1 717 810 2 2 100006 105 204 5 703 828 1 1 +query iiii +select * from v_e order by p1; +---- +100 200 1 1615 +105 204 1 828 + query iiiiiiiiii select * from v_expr order by id; ---- diff --git a/e2e_test/over_window/templates/create.slt.part b/e2e_test/over_window/templates/create.slt.part index 0d16b52fcdc86..7ac749e459b02 100644 --- a/e2e_test/over_window/templates/create.slt.part +++ b/e2e_test/over_window/templates/create.slt.part @@ -47,6 +47,16 @@ select , row_number() over (partition by p1 order by p2 desc, id) as out11 from t; +# over + agg +statement ok +create $view_type v_e as +select + p1, p2 + , row_number() over (partition by p1 order by p2) as out12 + , sum(sum(v2)) over (partition by p1, avg(time) order by max(v1), p2) as out13 +from t +group by p1, p2; + statement ok create $view_type v_a_b as select @@ -101,4 +111,4 @@ select , first_value(v1) over (partition by p1, p2 order by time, id rows 3 preceding) as out3 , lag(v1 + 2, 0 + 1) over (partition by p1 - 1 order by id) as out4 , min(v1 * 2) over (partition by p1, p2 order by time + 1, id rows between current row and unbounded following) as out5 -from t; \ No newline at end of file +from t; diff --git a/e2e_test/over_window/templates/drop.slt.part b/e2e_test/over_window/templates/drop.slt.part index 926305ee42699..def8e92379878 100644 --- a/e2e_test/over_window/templates/drop.slt.part +++ b/e2e_test/over_window/templates/drop.slt.part @@ -10,6 +10,9 @@ drop $view_type v_c; statement ok drop $view_type v_d; +statement ok +drop $view_type v_e; + statement ok drop $view_type v_a_b; diff --git a/e2e_test/over_window/templates/mod.slt.part b/e2e_test/over_window/templates/mod.slt.part index 3e48a52358701..1b1b86a0d40d3 100644 --- a/e2e_test/over_window/templates/mod.slt.part +++ b/e2e_test/over_window/templates/mod.slt.part @@ -39,6 +39,13 @@ select * from v_d order by id; 100003 100 208 2 723 807 3 1 100004 103 200 2 702 808 1 1 +query iiii +select * from v_e order by p1; +---- +100 200 1 1611 +100 208 2 807 +103 200 1 808 + include ./cross_check.slt.part statement ok @@ -86,6 +93,14 @@ select * from v_d order by id; 100005 100 200 3 717 810 4 4 100006 105 204 5 703 828 1 1 +query iiii +select * from v_e order by p1, p2; +---- +100 200 1 2421 +100 208 2 3228 +103 200 1 808 +105 204 1 828 + include ./cross_check.slt.part statement ok @@ -137,6 +152,13 @@ select * from v_d order by id; 100005 100 200 1 717 810 2 4 100006 105 204 5 703 828 1 1 +query iiiiiii +select * from v_e order by p1; +---- +100 200 1 3228 +103 200 1 808 +105 204 1 828 + query iiiiiiiiii select * from v_expr order by id; ---- @@ -180,6 +202,12 @@ select * from v_d order by id; 100005 100 200 1 717 810 2 2 100006 105 204 5 703 828 1 1 +query iiii +select * from v_e order by p1; +---- +100 200 1 1615 +105 204 1 828 + query iiiiiiiiii select * from v_expr order by id; ---- diff --git a/src/frontend/planner_test/tests/testdata/input/agg.yaml b/src/frontend/planner_test/tests/testdata/input/agg.yaml index b5b8e182703f8..99aa94ff773b9 100644 --- a/src/frontend/planner_test/tests/testdata/input/agg.yaml +++ b/src/frontend/planner_test/tests/testdata/input/agg.yaml @@ -929,3 +929,32 @@ expected_outputs: - batch_plan - stream_plan + +- sql: | + CREATE TABLE t (a int, b int); + SELECT a, sum((sum(b))) OVER (PARTITION BY a ORDER BY a) FROM t GROUP BY a; + expected_outputs: + - batch_plan + - stream_plan +- sql: | + CREATE TABLE t (a int, b int); + SELECT a, row_number() OVER (PARTITION BY a ORDER BY a DESC) FROM t GROUP BY a; + expected_outputs: + - batch_plan + - stream_plan +- sql: | + CREATE TABLE t (a int, b int, c int); + SELECT a, b, sum(sum(c)) OVER (PARTITION BY a ORDER BY b) + FROM t + GROUP BY a, b; + expected_outputs: + - batch_plan + - stream_plan +- sql: | + CREATE TABLE t (a int, b int, c int, d int, e int); + SELECT a, b, sum(sum(c)) OVER (PARTITION BY a, avg(d) ORDER BY max(e), b) + FROM t + GROUP BY a, b; + expected_outputs: + - batch_plan + - stream_plan diff --git a/src/frontend/planner_test/tests/testdata/output/agg.yaml b/src/frontend/planner_test/tests/testdata/output/agg.yaml index d62ce89d0ed3b..aefb4df98ef4e 100644 --- a/src/frontend/planner_test/tests/testdata/output/agg.yaml +++ b/src/frontend/planner_test/tests/testdata/output/agg.yaml @@ -1649,3 +1649,86 @@ └─StreamProject { exprs: [count, t.id] } └─StreamHashAgg { group_key: [t.id], aggs: [count] } └─StreamTableScan { table: t, columns: [t.id], pk: [t.id], dist: UpstreamHashShard(t.id) } +- sql: | + CREATE TABLE t (a int, b int); + SELECT a, sum((sum(b))) OVER (PARTITION BY a ORDER BY a) FROM t GROUP BY a; + batch_plan: |- + BatchExchange { order: [], dist: Single } + └─BatchProject { exprs: [t.a, sum] } + └─BatchOverWindow { window_functions: [sum(sum(t.b)) OVER(PARTITION BY t.a ORDER BY t.a ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)] } + └─BatchSort { order: [t.a ASC, t.a ASC] } + └─BatchHashAgg { group_key: [t.a], aggs: [sum(t.b)] } + └─BatchExchange { order: [], dist: HashShard(t.a) } + └─BatchScan { table: t, columns: [t.a, t.b], distribution: SomeShard } + stream_plan: |- + StreamMaterialize { columns: [a, sum], stream_key: [a], pk_columns: [a], pk_conflict: NoCheck } + └─StreamProject { exprs: [t.a, sum] } + └─StreamOverWindow { window_functions: [sum(sum(t.b)) OVER(PARTITION BY t.a ORDER BY t.a ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)] } + └─StreamProject { exprs: [t.a, sum(t.b)] } + └─StreamHashAgg { group_key: [t.a], aggs: [sum(t.b), count] } + └─StreamExchange { dist: HashShard(t.a) } + └─StreamTableScan { table: t, columns: [t.a, t.b, t._row_id], pk: [t._row_id], dist: UpstreamHashShard(t._row_id) } +- sql: | + CREATE TABLE t (a int, b int); + SELECT a, row_number() OVER (PARTITION BY a ORDER BY a DESC) FROM t GROUP BY a; + batch_plan: |- + BatchExchange { order: [], dist: Single } + └─BatchOverWindow { window_functions: [row_number() OVER(PARTITION BY t.a ORDER BY t.a DESC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)] } + └─BatchSort { order: [t.a ASC, t.a DESC] } + └─BatchHashAgg { group_key: [t.a], aggs: [] } + └─BatchExchange { order: [], dist: HashShard(t.a) } + └─BatchScan { table: t, columns: [t.a], distribution: SomeShard } + stream_plan: |- + StreamMaterialize { columns: [a, row_number], stream_key: [a], pk_columns: [a], pk_conflict: NoCheck } + └─StreamOverWindow { window_functions: [row_number() OVER(PARTITION BY t.a ORDER BY t.a DESC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)] } + └─StreamProject { exprs: [t.a] } + └─StreamHashAgg { group_key: [t.a], aggs: [count] } + └─StreamExchange { dist: HashShard(t.a) } + └─StreamTableScan { table: t, columns: [t.a, t._row_id], pk: [t._row_id], dist: UpstreamHashShard(t._row_id) } +- sql: | + CREATE TABLE t (a int, b int, c int); + SELECT a, b, sum(sum(c)) OVER (PARTITION BY a ORDER BY b) + FROM t + GROUP BY a, b; + batch_plan: |- + BatchExchange { order: [], dist: Single } + └─BatchProject { exprs: [t.a, t.b, sum] } + └─BatchOverWindow { window_functions: [sum(sum(t.c)) OVER(PARTITION BY t.a ORDER BY t.b ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)] } + └─BatchExchange { order: [t.a ASC, t.b ASC], dist: HashShard(t.a) } + └─BatchSort { order: [t.a ASC, t.b ASC] } + └─BatchHashAgg { group_key: [t.a, t.b], aggs: [sum(t.c)] } + └─BatchExchange { order: [], dist: HashShard(t.a, t.b) } + └─BatchScan { table: t, columns: [t.a, t.b, t.c], distribution: SomeShard } + stream_plan: |- + StreamMaterialize { columns: [a, b, sum], stream_key: [a, b], pk_columns: [a, b], pk_conflict: NoCheck } + └─StreamProject { exprs: [t.a, t.b, sum] } + └─StreamOverWindow { window_functions: [sum(sum(t.c)) OVER(PARTITION BY t.a ORDER BY t.b ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)] } + └─StreamExchange { dist: HashShard(t.a) } + └─StreamProject { exprs: [t.a, t.b, sum(t.c)] } + └─StreamHashAgg { group_key: [t.a, t.b], aggs: [sum(t.c), count] } + └─StreamExchange { dist: HashShard(t.a, t.b) } + └─StreamTableScan { table: t, columns: [t.a, t.b, t.c, t._row_id], pk: [t._row_id], dist: UpstreamHashShard(t._row_id) } +- sql: | + CREATE TABLE t (a int, b int, c int, d int, e int); + SELECT a, b, sum(sum(c)) OVER (PARTITION BY a, avg(d) ORDER BY max(e), b) + FROM t + GROUP BY a, b; + batch_plan: |- + BatchExchange { order: [], dist: Single } + └─BatchProject { exprs: [t.a, t.b, sum] } + └─BatchOverWindow { window_functions: [sum(sum(t.c)) OVER(PARTITION BY t.a, $expr1 ORDER BY max(t.e) ASC, t.b ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)] } + └─BatchExchange { order: [t.a ASC, $expr1 ASC, max(t.e) ASC, t.b ASC], dist: HashShard(t.a, $expr1) } + └─BatchSort { order: [t.a ASC, $expr1 ASC, max(t.e) ASC, t.b ASC] } + └─BatchProject { exprs: [t.a, t.b, sum(t.c), max(t.e), (sum(t.d)::Decimal / count(t.d)::Decimal) as $expr1] } + └─BatchHashAgg { group_key: [t.a, t.b], aggs: [sum(t.c), sum(t.d), count(t.d), max(t.e)] } + └─BatchExchange { order: [], dist: HashShard(t.a, t.b) } + └─BatchScan { table: t, columns: [t.a, t.b, t.c, t.d, t.e], distribution: SomeShard } + stream_plan: |- + StreamMaterialize { columns: [a, b, sum, $expr1(hidden)], stream_key: [a, b, $expr1], pk_columns: [a, b, $expr1], pk_conflict: NoCheck } + └─StreamProject { exprs: [t.a, t.b, sum, $expr1] } + └─StreamOverWindow { window_functions: [sum(sum(t.c)) OVER(PARTITION BY t.a, $expr1 ORDER BY max(t.e) ASC, t.b ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)] } + └─StreamExchange { dist: HashShard(t.a, $expr1) } + └─StreamProject { exprs: [t.a, t.b, sum(t.c), max(t.e), (sum(t.d)::Decimal / count(t.d)::Decimal) as $expr1] } + └─StreamHashAgg { group_key: [t.a, t.b], aggs: [sum(t.c), sum(t.d), count(t.d), max(t.e), count] } + └─StreamExchange { dist: HashShard(t.a, t.b) } + └─StreamTableScan { table: t, columns: [t.a, t.b, t.c, t.d, t.e, t._row_id], pk: [t._row_id], dist: UpstreamHashShard(t._row_id) } diff --git a/src/frontend/src/optimizer/plan_node/logical_agg.rs b/src/frontend/src/optimizer/plan_node/logical_agg.rs index d3bf8d896f9ff..7f86551acb397 100644 --- a/src/frontend/src/optimizer/plan_node/logical_agg.rs +++ b/src/frontend/src/optimizer/plan_node/logical_agg.rs @@ -28,6 +28,7 @@ use super::{ }; use crate::expr::{ AggCall, Expr, ExprImpl, ExprRewriter, ExprType, FunctionCall, InputRef, Literal, OrderBy, + WindowFunction, }; use crate::optimizer::plan_node::generic::GenericPlanNode; use crate::optimizer::plan_node::stream::StreamPlanRef; @@ -731,6 +732,33 @@ impl ExprRewriter for LogicalAggBuilder { } } + /// When there is an `WindowFunction` (outside of agg call), it must refers to a group column. + /// Or all `InputRef`s appears in it must refer to a group column. + fn rewrite_window_function(&mut self, window_func: WindowFunction) -> ExprImpl { + let WindowFunction { + args, + partition_by, + order_by, + .. + } = window_func; + let args = args + .into_iter() + .map(|expr| self.rewrite_expr(expr)) + .collect(); + let partition_by = partition_by + .into_iter() + .map(|expr| self.rewrite_expr(expr)) + .collect(); + let order_by = order_by.rewrite_expr(self); + WindowFunction { + args, + partition_by, + order_by, + ..window_func + } + .into() + } + /// When there is an `InputRef` (outside of agg call), it must refers to a group column. fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl { let expr = input_ref.into();