Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(over window): fix over window predicate pushdown #13662

Merged
merged 13 commits into from
Nov 29, 2023
1 change: 1 addition & 0 deletions e2e_test/over_window/generated/batch/main.slt.part
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ include ./rank_func/mod.slt.part
include ./expr_in_win_func/mod.slt.part
include ./agg_in_win_func/mod.slt.part
include ./opt_agg_then_join/mod.slt.part
include ./with_filter/mod.slt.part
27 changes: 27 additions & 0 deletions e2e_test/over_window/generated/batch/with_filter/mod.slt.part
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# This file is generated by `gen.py`. Do not edit it manually!

# Test window functions together with filters.
# https://github.com/risingwavelabs/risingwave/issues/13653

statement ok
create table t (id int, cat varchar, rule varchar, at timestamptz);

statement ok
insert into t values
(1, 'foo', 'A', '2023-11-23T12:00:42Z')
, (2, 'foo', 'B', '2023-11-23T12:01:15Z');

query TT
select rule, lag(rule) over (partition by cat order by at) from t where rule = 'B';
----
B NULL

query TT
select * from (select rule, lag(rule) over (partition by cat order by at) as prev_rule from t) where rule = 'B';
----
B A

query TT
select * from (select rule, at, row_number() over (partition by cat order by at) as rank from t) where at = '2023-11-23T12:01:15Z'::timestamptz;
----
B 2023-11-23 12:01:15+00:00 2
1 change: 1 addition & 0 deletions e2e_test/over_window/generated/streaming/main.slt.part
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ include ./rank_func/mod.slt.part
include ./expr_in_win_func/mod.slt.part
include ./agg_in_win_func/mod.slt.part
include ./opt_agg_then_join/mod.slt.part
include ./with_filter/mod.slt.part
27 changes: 27 additions & 0 deletions e2e_test/over_window/generated/streaming/with_filter/mod.slt.part
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# This file is generated by `gen.py`. Do not edit it manually!

# Test window functions together with filters.
# https://github.com/risingwavelabs/risingwave/issues/13653

statement ok
create table t (id int, cat varchar, rule varchar, at timestamptz);

statement ok
insert into t values
(1, 'foo', 'A', '2023-11-23T12:00:42Z')
, (2, 'foo', 'B', '2023-11-23T12:01:15Z');

query TT
select rule, lag(rule) over (partition by cat order by at) from t where rule = 'B';
----
B NULL

query TT
select * from (select rule, lag(rule) over (partition by cat order by at) as prev_rule from t) where rule = 'B';
----
B A

query TT
select * from (select rule, at, row_number() over (partition by cat order by at) as rank from t) where at = '2023-11-23T12:01:15Z'::timestamptz;
----
B 2023-11-23 12:01:15+00:00 2
1 change: 1 addition & 0 deletions e2e_test/over_window/templates/main.slt.part
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ include ./rank_func/mod.slt.part
include ./expr_in_win_func/mod.slt.part
include ./agg_in_win_func/mod.slt.part
include ./opt_agg_then_join/mod.slt.part
include ./with_filter/mod.slt.part
25 changes: 25 additions & 0 deletions e2e_test/over_window/templates/with_filter/mod.slt.part
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Test window functions together with filters.
# https://github.com/risingwavelabs/risingwave/issues/13653

statement ok
create table t (id int, cat varchar, rule varchar, at timestamptz);

statement ok
insert into t values
(1, 'foo', 'A', '2023-11-23T12:00:42Z')
, (2, 'foo', 'B', '2023-11-23T12:01:15Z');

query TT
select rule, lag(rule) over (partition by cat order by at) from t where rule = 'B';
----
B NULL

query TT
select * from (select rule, lag(rule) over (partition by cat order by at) as prev_rule from t) where rule = 'B';
----
B A

query TT
select * from (select rule, at, row_number() over (partition by cat order by at) as rank from t) where at = '2023-11-23T12:01:15Z'::timestamptz;
----
B 2023-11-23 12:01:15+00:00 2
Original file line number Diff line number Diff line change
Expand Up @@ -492,3 +492,19 @@
- stream_plan
- optimized_logical_plan_for_batch
- batch_plan

# With filter
- sql: |
create table t (id int, cat varchar, rule varchar, at timestamptz);
select * from (select cat, rule, at, lag(rule) over (partition by cat order by at) as prev_rule from t) as with_prev
where rule = 'B' and cat is not null and at = '2023-11-23T12:00:42Z'::timestamptz;
expected_outputs:
- optimized_logical_plan_for_stream
- optimized_logical_plan_for_batch
- sql: |
create table t (id int, cat varchar, rule varchar, at timestamptz);
select cat, rule, at, lag(rule) over (partition by cat order by at) as prev_rule from t
where rule = 'B' and cat is not null and at = '2023-11-23T12:00:42Z'::timestamptz;
expected_outputs:
- optimized_logical_plan_for_stream
- optimized_logical_plan_for_batch
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,9 @@
optimized_logical_plan_for_batch: |-
LogicalProject { exprs: [row_number] }
└─LogicalOverWindow { window_functions: [row_number() OVER(PARTITION BY t.x ORDER BY t.y ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)] }
└─LogicalTopN { order: [t.y ASC], limit: 3, offset: 0, group_key: [t.x] }
└─LogicalScan { table: t, output_columns: [t.x, t.y], required_columns: [t.x, t.y, t.z], predicate: (t.z > 0:Int32) AND (t.y > 0:Int32) AND (t.x > 0:Int32) }
└─LogicalFilter { predicate: (t.y > 0:Int32) }
└─LogicalTopN { order: [t.y ASC], limit: 3, offset: 0, group_key: [t.x] }
└─LogicalScan { table: t, output_columns: [t.x, t.y], required_columns: [t.x, t.y, t.z], predicate: (t.z > 0:Int32) AND (t.x > 0:Int32) }
- name: mixed
sql: |
create table t (v1 bigint, v2 double precision, v3 int);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,24 +311,24 @@
BatchExchange { order: [], dist: Single }
└─BatchProject { exprs: [t.x, t.y, t.z, Sqrt(((sum::Decimal - (($expr4 * $expr4) / $expr5)) / $expr5)) as $expr6, Case((count <= 1:Int64), null:Decimal, Sqrt(((sum::Decimal - ((sum::Decimal * sum::Decimal) / count::Decimal)) / (count - 1:Int64)::Decimal))) as $expr7] }
└─BatchProject { exprs: [t.x, t.y, t.z, $expr2, $expr1, $expr3, sum, sum, count, sum, sum, count, sum::Decimal as $expr4, count::Decimal as $expr5] }
└─BatchFilter { predicate: (Sqrt(((sum::Decimal - ((sum::Decimal * sum::Decimal) / count::Decimal)) / count::Decimal)) <= 3.0:Decimal) AND (Case((count <= 1:Int64), null:Decimal, Sqrt(((sum::Decimal - ((sum::Decimal * sum::Decimal) / count::Decimal)) / (count - 1:Int64)::Decimal))) > 1.0:Decimal) }
└─BatchFilter { predicate: (t.x > 0:Int32) AND (Sqrt(((sum::Decimal - ((sum::Decimal * sum::Decimal) / count::Decimal)) / count::Decimal)) <= 3.0:Decimal) AND (Case((count <= 1:Int64), null:Decimal, Sqrt(((sum::Decimal - ((sum::Decimal * sum::Decimal) / count::Decimal)) / (count - 1:Int64)::Decimal))) > 1.0:Decimal) }
└─BatchOverWindow { window_functions: [sum($expr2) OVER(PARTITION BY t.z ORDER BY t.x ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), sum($expr1) OVER(PARTITION BY t.z ORDER BY t.x ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), count($expr1) OVER(PARTITION BY t.z ORDER BY t.x ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), sum($expr3) OVER(PARTITION BY t.z ORDER BY t.x ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), sum(t.x) OVER(PARTITION BY t.z ORDER BY t.x ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), count(t.x) OVER(PARTITION BY t.z ORDER BY t.x ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)] }
└─BatchExchange { order: [t.z ASC, t.x ASC], dist: HashShard(t.z) }
└─BatchSort { order: [t.z ASC, t.x ASC] }
└─BatchProject { exprs: [t.x, t.y, t.z, ($expr1 * $expr1) as $expr2, $expr1, (t.x * t.x) as $expr3] }
└─BatchProject { exprs: [t.x, t.y, t.z, (t.x - t.y) as $expr1] }
└─BatchFilter { predicate: (t.z > 0:Int32) AND (t.y > 0:Int32) AND (t.x > 0:Int32) }
└─BatchFilter { predicate: (t.z > 0:Int32) AND (t.y > 0:Int32) }
└─BatchScan { table: t, columns: [t.x, t.y, t.z], distribution: SomeShard }
stream_plan: |-
StreamMaterialize { columns: [x, y, z, res0, res1, t._row_id(hidden)], stream_key: [t._row_id, z], pk_columns: [t._row_id, z], pk_conflict: NoCheck }
└─StreamProject { exprs: [t.x, t.y, t.z, Sqrt(((sum::Decimal - (($expr4 * $expr4) / $expr5)) / $expr5)) as $expr6, Case((count <= 1:Int64), null:Decimal, Sqrt(((sum::Decimal - ((sum::Decimal * sum::Decimal) / count::Decimal)) / (count - 1:Int64)::Decimal))) as $expr7, t._row_id] }
└─StreamProject { exprs: [t.x, t.y, t.z, $expr2, $expr1, $expr3, sum, sum, count, sum, sum, count, sum::Decimal as $expr4, count::Decimal as $expr5, t._row_id] }
└─StreamFilter { predicate: (Sqrt(((sum::Decimal - ((sum::Decimal * sum::Decimal) / count::Decimal)) / count::Decimal)) <= 3.0:Decimal) AND (Case((count <= 1:Int64), null:Decimal, Sqrt(((sum::Decimal - ((sum::Decimal * sum::Decimal) / count::Decimal)) / (count - 1:Int64)::Decimal))) > 1.0:Decimal) }
└─StreamFilter { predicate: (t.x > 0:Int32) AND (Sqrt(((sum::Decimal - ((sum::Decimal * sum::Decimal) / count::Decimal)) / count::Decimal)) <= 3.0:Decimal) AND (Case((count <= 1:Int64), null:Decimal, Sqrt(((sum::Decimal - ((sum::Decimal * sum::Decimal) / count::Decimal)) / (count - 1:Int64)::Decimal))) > 1.0:Decimal) }
└─StreamOverWindow { window_functions: [sum($expr2) OVER(PARTITION BY t.z ORDER BY t.x ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), sum($expr1) OVER(PARTITION BY t.z ORDER BY t.x ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), count($expr1) OVER(PARTITION BY t.z ORDER BY t.x ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), sum($expr3) OVER(PARTITION BY t.z ORDER BY t.x ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), sum(t.x) OVER(PARTITION BY t.z ORDER BY t.x ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), count(t.x) OVER(PARTITION BY t.z ORDER BY t.x ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)] }
└─StreamExchange { dist: HashShard(t.z) }
└─StreamProject { exprs: [t.x, t.y, t.z, ($expr1 * $expr1) as $expr2, $expr1, (t.x * t.x) as $expr3, t._row_id] }
└─StreamProject { exprs: [t.x, t.y, t.z, (t.x - t.y) as $expr1, t._row_id] }
└─StreamFilter { predicate: (t.z > 0:Int32) AND (t.y > 0:Int32) AND (t.x > 0:Int32) }
└─StreamFilter { predicate: (t.z > 0:Int32) AND (t.y > 0:Int32) }
└─StreamTableScan { table: t, columns: [t.x, t.y, t.z, t._row_id], pk: [t._row_id], dist: UpstreamHashShard(t._row_id) }
- id: aggregate with expression in func arguments and over clause
sql: |
Expand Down Expand Up @@ -488,19 +488,20 @@
└─LogicalProject { exprs: [t.x, t.y, t._row_id] }
└─LogicalScan { table: t, columns: [t.x, t.y, t._row_id] }
optimized_logical_plan_for_batch: |-
LogicalTopN { order: [t.x ASC], limit: 2, offset: 0, group_key: [t.y] }
└─LogicalScan { table: t, columns: [t.x, t.y], predicate: (t.x > t.y) }
LogicalFilter { predicate: (t.x > t.y) }
└─LogicalTopN { order: [t.x ASC], limit: 2, offset: 0, group_key: [t.y] }
└─LogicalScan { table: t, columns: [t.x, t.y] }
batch_plan: |-
BatchExchange { order: [], dist: Single }
└─BatchGroupTopN { order: [t.x ASC], limit: 2, offset: 0, group_key: [t.y] }
└─BatchExchange { order: [], dist: HashShard(t.y) }
└─BatchFilter { predicate: (t.x > t.y) }
└─BatchFilter { predicate: (t.x > t.y) }
└─BatchGroupTopN { order: [t.x ASC], limit: 2, offset: 0, group_key: [t.y] }
└─BatchExchange { order: [], dist: HashShard(t.y) }
└─BatchScan { table: t, columns: [t.x, t.y], distribution: SomeShard }
stream_plan: |-
StreamMaterialize { columns: [x, y, t._row_id(hidden)], stream_key: [y, t._row_id], pk_columns: [y, t._row_id], pk_conflict: NoCheck }
└─StreamGroupTopN { order: [t.x ASC], limit: 2, offset: 0, group_key: [t.y] }
└─StreamExchange { dist: HashShard(t.y) }
└─StreamFilter { predicate: (t.x > t.y) }
└─StreamFilter { predicate: (t.x > t.y) }
└─StreamGroupTopN { order: [t.x ASC], limit: 2, offset: 0, group_key: [t.y] }
└─StreamExchange { dist: HashShard(t.y) }
└─StreamTableScan { table: t, columns: [t.x, t.y, t._row_id], pk: [t._row_id], dist: UpstreamHashShard(t._row_id) }
- id: TopN by rank without rank output
sql: |
Expand Down Expand Up @@ -1024,3 +1025,25 @@
└─StreamOverWindow { window_functions: [rank() OVER(PARTITION BY t.x ORDER BY t.y ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)] }
└─StreamExchange { dist: HashShard(t.x) }
└─StreamTableScan { table: t, columns: [t.x, t.y, t.z, t._row_id], pk: [t._row_id], dist: UpstreamHashShard(t._row_id) }
- sql: |
create table t (id int, cat varchar, rule varchar, at timestamptz);
select * from (select cat, rule, at, lag(rule) over (partition by cat order by at) as prev_rule from t) as with_prev
where rule = 'B' and cat is not null and at = '2023-11-23T12:00:42Z'::timestamptz;
optimized_logical_plan_for_batch: |-
LogicalFilter { predicate: (t.rule = 'B':Varchar) AND (t.at = '2023-11-23 12:00:42+00:00':Timestamptz) }
└─LogicalOverWindow { window_functions: [first_value(t.rule) OVER(PARTITION BY t.cat ORDER BY t.at ASC ROWS BETWEEN 1 PRECEDING AND 1 PRECEDING)] }
└─LogicalScan { table: t, columns: [t.cat, t.rule, t.at], predicate: IsNotNull(t.cat) }
optimized_logical_plan_for_stream: |-
LogicalFilter { predicate: (t.rule = 'B':Varchar) AND (t.at = '2023-11-23 12:00:42+00:00':Timestamptz) }
└─LogicalOverWindow { window_functions: [first_value(t.rule) OVER(PARTITION BY t.cat ORDER BY t.at ASC ROWS BETWEEN 1 PRECEDING AND 1 PRECEDING)] }
└─LogicalScan { table: t, columns: [t.cat, t.rule, t.at], predicate: IsNotNull(t.cat) }
- sql: |
create table t (id int, cat varchar, rule varchar, at timestamptz);
select cat, rule, at, lag(rule) over (partition by cat order by at) as prev_rule from t
where rule = 'B' and cat is not null and at = '2023-11-23T12:00:42Z'::timestamptz;
optimized_logical_plan_for_batch: |-
LogicalOverWindow { window_functions: [first_value(t.rule) OVER(PARTITION BY t.cat ORDER BY t.at ASC ROWS BETWEEN 1 PRECEDING AND 1 PRECEDING)] }
└─LogicalScan { table: t, columns: [t.cat, t.rule, t.at], predicate: (t.rule = 'B':Varchar) AND IsNotNull(t.cat) AND (t.at = '2023-11-23 12:00:42+00:00':Timestamptz) }
optimized_logical_plan_for_stream: |-
LogicalOverWindow { window_functions: [first_value(t.rule) OVER(PARTITION BY t.cat ORDER BY t.at ASC ROWS BETWEEN 1 PRECEDING AND 1 PRECEDING)] }
└─LogicalScan { table: t, columns: [t.cat, t.rule, t.at], predicate: (t.rule = 'B':Varchar) AND IsNotNull(t.cat) AND (t.at = '2023-11-23 12:00:42+00:00':Timestamptz) }
18 changes: 15 additions & 3 deletions src/frontend/src/optimizer/plan_node/logical_over_window.rs
Original file line number Diff line number Diff line change
Expand Up @@ -698,9 +698,21 @@ impl PredicatePushdown for LogicalOverWindow {
predicate: Condition,
ctx: &mut PredicatePushdownContext,
) -> PlanRef {
let mut window_col = FixedBitSet::with_capacity(self.schema().len());
window_col.insert_range(self.core.input.schema().len()..self.schema().len());
let (window_pred, other_pred) = predicate.split_disjoint(&window_col);
let in_schema_len = self.core.input.schema().len();
let out_schema_len = self.schema().len();

let window_func_input_refs = self.window_functions().iter().flat_map(|func| {
func.args
.iter()
.map(|arg| arg.index)
.chain(func.order_by.iter().map(|o| o.column_index))
});
let mut over_window_related_cols: FixedBitSet = window_func_input_refs
.chain(in_schema_len..out_schema_len)
.collect();
over_window_related_cols.grow(out_schema_len);

let (window_pred, other_pred) = predicate.split_disjoint(&over_window_related_cols);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Considering this case, even if we use a id which should be not over_window_related_cols column, we should not push this filter through the over window.

 select * from (select id, rule, at, row_number() over (partition by cat order by at) as rank from t) x where id = 2;

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then it seems that we cannot push any predicate through over window.

Copy link
Contributor

@chenzl25 chenzl25 Nov 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, partition only related filter could be pushed down. For example

select * from (select id, cat, rule, at, count(*) over (partition by cat order by at) as cnt from t) x where cat = 'foo';

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed.

gen_filter_and_pushdown(self, window_pred, other_pred, ctx)
}
}
Expand Down
4 changes: 4 additions & 0 deletions src/frontend/src/optimizer/rule/over_window_to_topn_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ impl Rule for OverWindowToTopNRule {

/// Returns `None` if the conditions are too complex or invalid. `Some((limit, offset))` otherwise.
fn handle_rank_preds(rank_preds: &[ExprImpl], window_func_pos: usize) -> Option<(u64, u64)> {
if rank_preds.is_empty() {
return None;
}

// rank >= lb
let mut lb: Option<i64> = None;
// rank <= ub
Expand Down
Loading