Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(over window): move window state impls to expr_impl crate #17047

Merged
merged 3 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,9 @@ use risingwave_common::util::memcmp_encoding::MemcmpEncoded;
use risingwave_common_estimate_size::EstimateSize;
use smallvec::SmallVec;

use super::{WindowFuncCall, WindowFuncKind};
use super::WindowFuncCall;
use crate::{ExprError, Result};

mod aggregate;
mod buffer;
mod range_utils;
mod rank;

/// Unique and ordered identifier for a row in internal states.
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, EstimateSize)]
pub struct StateKey {
Expand Down Expand Up @@ -110,22 +105,21 @@ pub trait WindowState: EstimateSize {

pub type BoxedWindowState = Box<dyn WindowState + Send + Sync>;

pub fn create_window_state(call: &WindowFuncCall) -> Result<BoxedWindowState> {
assert!(call.frame.bounds.validate().is_ok());
#[linkme::distributed_slice]
pub static WINDOW_STATE_BUILDERS: [fn(&WindowFuncCall) -> Result<BoxedWindowState>];

use WindowFuncKind::*;
Ok(match call.kind {
RowNumber => Box::new(rank::RankState::<rank::RowNumber>::new(call)),
Rank => Box::new(rank::RankState::<rank::Rank>::new(call)),
DenseRank => Box::new(rank::RankState::<rank::DenseRank>::new(call)),
Aggregate(_) => aggregate::new(call)?,
kind => {
return Err(ExprError::UnsupportedFunction(format!(
pub fn create_window_state(call: &WindowFuncCall) -> Result<BoxedWindowState> {
// we expect only one builder function in `expr_impl/window_function/mod.rs`
let builder = WINDOW_STATE_BUILDERS.iter().next();
builder.map_or_else(
|| {
Err(ExprError::UnsupportedFunction(format!(
"{}({}) -> {}",
kind,
call.kind,
call.args.arg_types().iter().format(", "),
&call.return_type,
)));
}
})
)))
},
|f| f(call),
)
}
2 changes: 2 additions & 0 deletions src/expr/impl/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ chrono = { version = "0.4", default-features = false, features = [
"std",
] }
chrono-tz = { version = "0.9", features = ["case-insensitive"] }
educe = "0.5"
fancy-regex = "0.13"
futures-async-stream = { workspace = true }
futures-util = "0.3"
Expand All @@ -65,6 +66,7 @@ serde = { version = "1", features = ["derive"] }
serde_json = "1"
sha1 = "0.10"
sha2 = "0.10"
smallvec = "1"
sql-json-path = { version = "0.1", features = ["jsonbb"] }
thiserror = "1"
thiserror-ext = { workspace = true }
Expand Down
1 change: 1 addition & 0 deletions src/expr/impl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ mod aggregate;
mod scalar;
mod table_function;
mod udf;
mod window_function;

/// Enable functions in this crate.
#[macro_export]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,18 @@ use risingwave_common::types::{DataType, Datum};
use risingwave_common::util::iter_util::ZipEqFast;
use risingwave_common::{bail, must_match};
use risingwave_common_estimate_size::{EstimateSize, KvSize};
use risingwave_expr::aggregate::{
AggCall, AggregateFunction, AggregateState as AggImplState, BoxedAggregateFunction,
};
use risingwave_expr::sig::FUNCTION_REGISTRY;
use risingwave_expr::window_function::{
BoxedWindowState, FrameBounds, StateEvictHint, StateKey, StatePos, WindowFuncCall,
WindowFuncKind, WindowState,
};
use risingwave_expr::Result;
use smallvec::SmallVec;

use super::buffer::{RangeWindow, RowsWindow, WindowBuffer, WindowImpl};
use super::{BoxedWindowState, StateEvictHint, StateKey, StatePos, WindowState};
use crate::aggregate::{
AggCall, AggregateFunction, AggregateState as AggImplState, BoxedAggregateFunction,
};
use crate::sig::FUNCTION_REGISTRY;
use crate::window_function::{FrameBounds, WindowFuncCall, WindowFuncKind};
use crate::Result;

type StateValue = SmallVec<[Datum; 2]>;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ use educe::Educe;
use risingwave_common::array::Op;
use risingwave_common::types::Sentinelled;
use risingwave_common::util::memcmp_encoding;
use risingwave_expr::window_function::{
FrameExclusion, RangeFrameBounds, RowsFrameBounds, StateKey,
};

use super::range_utils::range_except;
use super::StateKey;
use crate::window_function::state::range_utils::range_diff;
use crate::window_function::{FrameExclusion, RangeFrameBounds, RowsFrameBounds};
use super::range_utils::{range_diff, range_except};

/// A common sliding window buffer.
pub(super) struct WindowBuffer<W: WindowImpl> {
Expand Down Expand Up @@ -439,12 +439,12 @@ impl<V: Clone> WindowImpl for RangeWindow<V> {
#[cfg(test)]
mod tests {
use itertools::Itertools;

use super::*;
use crate::window_function::FrameBound::{
use risingwave_expr::window_function::FrameBound::{
CurrentRow, Following, Preceding, UnboundedFollowing, UnboundedPreceding,
};

use super::*;

#[test]
fn test_rows_frame_unbounded_preceding_to_current_row() {
let mut buffer = WindowBuffer::<RowsWindow<_, _>>::new(
Expand Down
45 changes: 45 additions & 0 deletions src/expr/impl/src/window_function/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Copyright 2024 RisingWave Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use itertools::Itertools;
use risingwave_expr::window_function::{
BoxedWindowState, WindowFuncCall, WindowFuncKind, WINDOW_STATE_BUILDERS,
};
use risingwave_expr::{ExprError, Result};

mod aggregate;
mod buffer;
mod range_utils;
mod rank;

#[linkme::distributed_slice(WINDOW_STATE_BUILDERS)]
fn create_window_state_impl(call: &WindowFuncCall) -> Result<BoxedWindowState> {
assert!(call.frame.bounds.validate().is_ok());

use WindowFuncKind::*;
Ok(match call.kind {
RowNumber => Box::new(rank::RankState::<rank::RowNumber>::new(call)),
Rank => Box::new(rank::RankState::<rank::Rank>::new(call)),
DenseRank => Box::new(rank::RankState::<rank::DenseRank>::new(call)),
Aggregate(_) => aggregate::new(call)?,
kind => {
return Err(ExprError::UnsupportedFunction(format!(
"{}({}) -> {}",
kind,
call.args.arg_types().iter().format(", "),
&call.return_type,
)));
}
})
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@ use risingwave_common::types::Datum;
use risingwave_common::util::memcmp_encoding::MemcmpEncoded;
use risingwave_common_estimate_size::collections::EstimatedVecDeque;
use risingwave_common_estimate_size::EstimateSize;
use risingwave_expr::window_function::{
StateEvictHint, StateKey, StatePos, WindowFuncCall, WindowState,
};
use risingwave_expr::Result;
use smallvec::SmallVec;

use self::private::RankFuncCount;
use super::{StateEvictHint, StateKey, StatePos, WindowState};
use crate::window_function::WindowFuncCall;
use crate::Result;

mod private {
use super::*;
Expand All @@ -34,7 +35,7 @@ mod private {
}

#[derive(Default, EstimateSize)]
pub struct RowNumber {
pub(super) struct RowNumber {
prev_rank: i64,
}

Expand All @@ -47,7 +48,7 @@ impl RankFuncCount for RowNumber {
}

#[derive(EstimateSize)]
pub struct Rank {
pub(super) struct Rank {
prev_order_key: Option<MemcmpEncoded>,
prev_rank: i64,
prev_pos_in_peer_group: i64,
Expand Down Expand Up @@ -83,7 +84,7 @@ impl RankFuncCount for Rank {
}

#[derive(Default, EstimateSize)]
pub struct DenseRank {
pub(super) struct DenseRank {
prev_order_key: Option<MemcmpEncoded>,
prev_rank: i64,
}
Expand All @@ -107,7 +108,7 @@ impl RankFuncCount for DenseRank {

/// Generic state for rank window functions including `row_number`, `rank` and `dense_rank`.
#[derive(EstimateSize)]
pub struct RankState<RF: RankFuncCount> {
pub(super) struct RankState<RF: RankFuncCount> {
/// First state key of the partition.
first_key: Option<StateKey>,
/// State keys that are waiting to be outputted.
Expand Down Expand Up @@ -176,10 +177,10 @@ mod tests {
use risingwave_common::types::{DataType, ScalarImpl};
use risingwave_common::util::memcmp_encoding;
use risingwave_common::util::sort_util::OrderType;
use risingwave_expr::aggregate::AggArgs;
use risingwave_expr::window_function::{Frame, FrameBound, WindowFuncKind};

use super::*;
use crate::aggregate::AggArgs;
use crate::window_function::{Frame, FrameBound, WindowFuncKind};

fn create_state_key(order: i64, pk: i64) -> StateKey {
StateKey {
Expand Down
Loading