From 6a70a2efd4716aeaec5055571f61a6879aa16b88 Mon Sep 17 00:00:00 2001 From: Richard Chien Date: Fri, 5 Jan 2024 16:06:50 +0800 Subject: [PATCH] introduce `RowsFrameBounds` type Signed-off-by: Richard Chien --- src/expr/core/src/window_function/call.rs | 44 +++++++++++++------ .../core/src/window_function/state/buffer.rs | 14 +++--- src/frontend/src/binder/expr/function.rs | 4 +- .../executor/over_window/over_partition.rs | 16 +++---- 4 files changed, 47 insertions(+), 31 deletions(-) diff --git a/src/expr/core/src/window_function/call.rs b/src/expr/core/src/window_function/call.rs index 797c27b963c95..1bb4dfa85f2bb 100644 --- a/src/expr/core/src/window_function/call.rs +++ b/src/expr/core/src/window_function/call.rs @@ -63,7 +63,7 @@ impl Display for Frame { impl Frame { pub fn rows(start: FrameBound, end: FrameBound) -> Self { Self { - bounds: FrameBounds::Rows(start, end), + bounds: FrameBounds::Rows(RowsFrameBounds { start, end }), exclusion: FrameExclusion::default(), } } @@ -74,7 +74,7 @@ impl Frame { exclusion: FrameExclusion, ) -> Self { Self { - bounds: FrameBounds::Rows(start, end), + bounds: FrameBounds::Rows(RowsFrameBounds { start, end }), exclusion, } } @@ -88,7 +88,7 @@ impl Frame { PbType::Rows => { let start = FrameBound::from_protobuf(frame.get_start()?)?; let end = FrameBound::from_protobuf(frame.get_end()?)?; - FrameBounds::Rows(start, end) + FrameBounds::Rows(RowsFrameBounds { start, end }) } }; let exclusion = FrameExclusion::from_protobuf(frame.get_exclusion()?)?; @@ -99,7 +99,7 @@ impl Frame { use risingwave_pb::expr::window_frame::PbType; let exclusion = self.exclusion.to_protobuf() as _; match &self.bounds { - FrameBounds::Rows(start, end) => PbWindowFrame { + FrameBounds::Rows(RowsFrameBounds { start, end }) => PbWindowFrame { r#type: PbType::Rows as _, start: Some(start.to_protobuf()), end: Some(end.to_protobuf()), @@ -112,19 +112,19 @@ impl Frame { impl FrameBounds { pub fn validate(&self) -> Result<()> { match self { - Self::Rows(start, end) => FrameBound::validate_bounds(start, end), + Self::Rows(bounds) => bounds.validate(), } } pub fn start_is_unbounded(&self) -> bool { match self { - Self::Rows(start, _) => matches!(start, FrameBound::UnboundedPreceding), + Self::Rows(RowsFrameBounds { start, .. }) => start.is_unbounded_preceding(), } } pub fn end_is_unbounded(&self) -> bool { match self { - Self::Rows(_, end) => matches!(end, FrameBound::UnboundedFollowing), + Self::Rows(RowsFrameBounds { end, .. }) => end.is_unbounded_following(), } } @@ -136,22 +136,38 @@ impl FrameBounds { impl Display for FrameBounds { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Self::Rows(start, end) => { - write!(f, "ROWS BETWEEN {} AND {}", start, end)?; - } + Self::Rows(bounds) => bounds.fmt(f), } - Ok(()) } } #[derive(Debug, Clone, Eq, PartialEq, Hash)] pub enum FrameBounds { - Rows(FrameBound, FrameBound), - // Groups(FrameBound, FrameBound), - // Range(FrameBound, FrameBound), + Rows(RowsFrameBounds), + // Groups(GroupsFrameBounds), + // Range(RangeFrameBounds), } #[derive(Debug, Clone, Eq, PartialEq, Hash)] +pub struct RowsFrameBounds { + pub start: FrameBound, + pub end: FrameBound, +} + +impl Display for RowsFrameBounds { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "ROWS BETWEEN {} AND {}", self.start, self.end)?; + Ok(()) + } +} + +impl RowsFrameBounds { + fn validate(&self) -> Result<()> { + FrameBound::validate_bounds(&self.start, &self.end) + } +} + +#[derive(Debug, Clone, Eq, PartialEq, Hash, EnumAsInner)] pub enum FrameBound { UnboundedPreceding, Preceding(T), diff --git a/src/expr/core/src/window_function/state/buffer.rs b/src/expr/core/src/window_function/state/buffer.rs index 0a3e2fb27ab7c..2fbe5ed1d3790 100644 --- a/src/expr/core/src/window_function/state/buffer.rs +++ b/src/expr/core/src/window_function/state/buffer.rs @@ -67,8 +67,8 @@ impl WindowBuffer { fn preceding_saturated(&self) -> bool { self.curr_key().is_some() && match &self.frame.bounds { - FrameBounds::Rows(start, _) => { - let start_off = start.to_offset(); + FrameBounds::Rows(bounds) => { + let start_off = bounds.start.to_offset(); if let Some(start_off) = start_off { if start_off >= 0 { true // pure following frame, always preceding-saturated @@ -91,8 +91,8 @@ impl WindowBuffer { fn following_saturated(&self) -> bool { self.curr_key().is_some() && match &self.frame.bounds { - FrameBounds::Rows(_, end) => { - let end_off = end.to_offset(); + FrameBounds::Rows(bounds) => { + let end_off = bounds.end.to_offset(); if let Some(end_off) = end_off { if end_off <= 0 { true // pure preceding frame, always following-saturated @@ -177,9 +177,9 @@ impl WindowBuffer { } match &self.frame.bounds { - FrameBounds::Rows(start, end) => { - let start_off = start.to_offset(); - let end_off = end.to_offset(); + FrameBounds::Rows(bounds) => { + let start_off = bounds.start.to_offset(); + let end_off = bounds.end.to_offset(); if let Some(start_off) = start_off { let logical_left_idx = self.curr_idx as isize + start_off; if logical_left_idx >= 0 { diff --git a/src/frontend/src/binder/expr/function.rs b/src/frontend/src/binder/expr/function.rs index a0545b81b17d6..49cff0a3ded74 100644 --- a/src/frontend/src/binder/expr/function.rs +++ b/src/frontend/src/binder/expr/function.rs @@ -27,7 +27,7 @@ use risingwave_common::types::{DataType, ScalarImpl, Timestamptz}; use risingwave_common::{bail_not_implemented, current_cluster_version, no_function}; use risingwave_expr::aggregate::{agg_kinds, AggKind}; use risingwave_expr::window_function::{ - Frame, FrameBound, FrameBounds, FrameExclusion, WindowFuncKind, + Frame, FrameBound, FrameBounds, FrameExclusion, RowsFrameBounds, WindowFuncKind, }; use risingwave_sqlparser::ast::{ self, Expr as AstExpr, Function, FunctionArg, FunctionArgExpr, Ident, SelectItem, SetExpr, @@ -670,7 +670,7 @@ impl Binder { } else { FrameBound::CurrentRow }; - FrameBounds::Rows(start, end) + FrameBounds::Rows(RowsFrameBounds { start, end }) } WindowFrameUnits::Range | WindowFrameUnits::Groups => { bail_not_implemented!( diff --git a/src/stream/src/executor/over_window/over_partition.rs b/src/stream/src/executor/over_window/over_partition.rs index 3a1b91380f78c..f0c230103408f 100644 --- a/src/stream/src/executor/over_window/over_partition.rs +++ b/src/stream/src/executor/over_window/over_partition.rs @@ -830,10 +830,10 @@ fn find_affected_ranges<'cache>( calls .iter() .map(|call| match &call.frame.bounds { - FrameBounds::Rows(_start, end) => { + FrameBounds::Rows(bounds) => { let mut cursor = part_with_delta .lower_bound(Bound::Included(delta.first_key_value().unwrap().0)); - for _ in 0..end.n_following_rows().unwrap() { + for _ in 0..bounds.end.n_following_rows().unwrap() { // Note that we have to move before check, to handle situation where the // cursor is at ghost position at first. cursor.move_prev(); @@ -856,9 +856,9 @@ fn find_affected_ranges<'cache>( calls .iter() .map(|call| match &call.frame.bounds { - FrameBounds::Rows(start, _end) => { + FrameBounds::Rows(bounds) => { let mut cursor = part_with_delta.find(first_curr_key).unwrap(); - for _ in 0..start.n_preceding_rows().unwrap() { + for _ in 0..bounds.start.n_preceding_rows().unwrap() { cursor.move_prev(); if cursor.position().is_ghost() { break; @@ -877,10 +877,10 @@ fn find_affected_ranges<'cache>( calls .iter() .map(|call| match &call.frame.bounds { - FrameBounds::Rows(start, _end) => { + FrameBounds::Rows(bounds) => { let mut cursor = part_with_delta .upper_bound(Bound::Included(delta.last_key_value().unwrap().0)); - for _ in 0..start.n_preceding_rows().unwrap() { + for _ in 0..bounds.start.n_preceding_rows().unwrap() { cursor.move_next(); if cursor.position().is_ghost() { break; @@ -899,9 +899,9 @@ fn find_affected_ranges<'cache>( calls .iter() .map(|call| match &call.frame.bounds { - FrameBounds::Rows(_start, end) => { + FrameBounds::Rows(bounds) => { let mut cursor = part_with_delta.find(last_curr_key).unwrap(); - for _ in 0..end.n_following_rows().unwrap() { + for _ in 0..bounds.end.n_following_rows().unwrap() { cursor.move_next(); if cursor.position().is_ghost() { break;