diff --git a/src/expr/core/src/window_function/state/buffer.rs b/src/expr/core/src/window_function/state/buffer.rs index fc32d51ea70ab..3edb6d7adc164 100644 --- a/src/expr/core/src/window_function/state/buffer.rs +++ b/src/expr/core/src/window_function/state/buffer.rs @@ -16,8 +16,9 @@ use std::collections::VecDeque; use std::ops::Range; use risingwave_common::array::Op; -use smallvec::{smallvec, SmallVec}; +use super::range_utils::range_except; +use crate::window_function::state::range_utils::range_diff; use crate::window_function::{Frame, FrameBounds, FrameExclusion}; struct Entry { @@ -259,149 +260,13 @@ impl WindowBuffer { } } -/// Calculate range (A - B), the result might be the union of two ranges when B is totally included -/// in the A. -fn range_except(a: Range, b: Range) -> (Range, Range) { - #[allow(clippy::if_same_then_else)] // for better readability - if a.is_empty() { - (0..0, 0..0) - } else if b.is_empty() { - (a, 0..0) - } else if a.end <= b.start || b.end <= a.start { - // a: [ ) - // b: [ ) - // or - // a: [ ) - // b: [ ) - (a, 0..0) - } else if b.start <= a.start && a.end <= b.end { - // a: [ ) - // b: [ ) - (0..0, 0..0) - } else if a.start < b.start && b.end < a.end { - // a: [ ) - // b: [ ) - (a.start..b.start, b.end..a.end) - } else if a.end <= b.end { - // a: [ ) - // b: [ ) - (a.start..b.start, 0..0) - } else if b.start <= a.start { - // a: [ ) - // b: [ ) - (b.end..a.end, 0..0) - } else { - unreachable!() - } -} - -/// Calculate the difference of two ranges A and B, return (removed ranges, added ranges). -/// Note this is quite different from [`range_except`]. -#[allow(clippy::type_complexity)] // looks complex but it's not -fn range_diff( - a: Range, - b: Range, -) -> (SmallVec<[Range; 2]>, SmallVec<[Range; 2]>) { - if a.start == b.start { - match a.end.cmp(&b.end) { - std::cmp::Ordering::Equal => { - // a: [ ) - // b: [ ) - (smallvec![], smallvec![]) - } - std::cmp::Ordering::Less => { - // a: [ ) - // b: [ ) - (smallvec![], smallvec![a.end..b.end]) - } - std::cmp::Ordering::Greater => { - // a: [ ) - // b: [ ) - (smallvec![b.end..a.end], smallvec![]) - } - } - } else if a.end == b.end { - debug_assert!(a.start != b.start); - if a.start < b.start { - // a: [ ) - // b: [ ) - (smallvec![a.start..b.start], smallvec![]) - } else { - // a: [ ) - // b: [ ) - (smallvec![], smallvec![b.start..a.start]) - } - } else { - debug_assert!(a.start != b.start && a.end != b.end); - if a.end <= b.start || b.end <= a.start { - // a: [ ) - // b: [ [ ) - // or - // a: [ ) - // b: [ ) ) - (smallvec![a], smallvec![b]) - } else if b.start < a.start && a.end < b.end { - // a: [ ) - // b: [ ) - (smallvec![], smallvec![b.start..a.start, a.end..b.end]) - } else if a.start < b.start && b.end < a.end { - // a: [ ) - // b: [ ) - (smallvec![a.start..b.start, b.end..a.end], smallvec![]) - } else if a.end < b.end { - // a: [ ) - // b: [ ) - (smallvec![a.start..b.start], smallvec![a.end..b.end]) - } else { - // a: [ ) - // b: [ ) - (smallvec![b.end..a.end], smallvec![b.start..a.start]) - } - } -} - #[cfg(test)] mod tests { - use std::collections::HashSet; - use itertools::Itertools; use super::*; use crate::window_function::{Frame, FrameBound}; - #[test] - fn test_range_diff() { - fn test( - a: Range, - b: Range, - expected_removed: impl IntoIterator, - expected_added: impl IntoIterator, - ) { - let (removed, added) = range_diff(a, b); - let removed_set = removed.into_iter().flatten().collect::>(); - let added_set = added.into_iter().flatten().collect::>(); - let expected_removed_set = expected_removed.into_iter().collect::>(); - let expected_added_set = expected_added.into_iter().collect::>(); - assert_eq!(removed_set, expected_removed_set); - assert_eq!(added_set, expected_added_set); - } - - test(0..0, 0..0, [], []); - test(0..1, 0..1, [], []); - test(0..1, 0..2, [], [1]); - test(0..2, 0..1, [1], []); - test(0..2, 1..2, [0], []); - test(1..2, 0..2, [], [0]); - test(0..1, 1..2, [0], [1]); - test(0..1, 2..3, [0], [2]); - test(1..2, 0..1, [1], [0]); - test(2..3, 0..1, [2], [0]); - test(0..3, 1..2, [0, 2], []); - test(1..2, 0..3, [], [0, 2]); - test(0..3, 2..4, [0, 1], [3]); - test(2..4, 0..3, [3], [0, 1]); - } - #[test] fn test_rows_frame_unbounded_preceding_to_current_row() { let mut buffer = WindowBuffer::new( diff --git a/src/expr/core/src/window_function/state/mod.rs b/src/expr/core/src/window_function/state/mod.rs index 805688f22b3b1..37ee086ca7ba4 100644 --- a/src/expr/core/src/window_function/state/mod.rs +++ b/src/expr/core/src/window_function/state/mod.rs @@ -24,9 +24,9 @@ use smallvec::SmallVec; use super::{WindowFuncCall, WindowFuncKind}; use crate::{ExprError, Result}; -mod buffer; - mod aggregate; +mod buffer; +mod range_utils; mod rank; /// Unique and ordered identifier for a row in internal states. diff --git a/src/expr/core/src/window_function/state/range_utils.rs b/src/expr/core/src/window_function/state/range_utils.rs new file mode 100644 index 0000000000000..256eb1ce1227a --- /dev/null +++ b/src/expr/core/src/window_function/state/range_utils.rs @@ -0,0 +1,158 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::ops::Range; + +use smallvec::{smallvec, SmallVec}; + +/// Calculate range (A - B), the result might be the union of two ranges when B is totally included +/// in the A. +pub(super) fn range_except(a: Range, b: Range) -> (Range, Range) { + #[allow(clippy::if_same_then_else)] // for better readability + if a.is_empty() { + (0..0, 0..0) + } else if b.is_empty() { + (a, 0..0) + } else if a.end <= b.start || b.end <= a.start { + // a: [ ) + // b: [ ) + // or + // a: [ ) + // b: [ ) + (a, 0..0) + } else if b.start <= a.start && a.end <= b.end { + // a: [ ) + // b: [ ) + (0..0, 0..0) + } else if a.start < b.start && b.end < a.end { + // a: [ ) + // b: [ ) + (a.start..b.start, b.end..a.end) + } else if a.end <= b.end { + // a: [ ) + // b: [ ) + (a.start..b.start, 0..0) + } else if b.start <= a.start { + // a: [ ) + // b: [ ) + (b.end..a.end, 0..0) + } else { + unreachable!() + } +} + +/// Calculate the difference of two ranges A and B, return (removed ranges, added ranges). +/// Note this is quite different from [`range_except`]. +#[allow(clippy::type_complexity)] // looks complex but it's not +pub(super) fn range_diff( + a: Range, + b: Range, +) -> (SmallVec<[Range; 2]>, SmallVec<[Range; 2]>) { + if a.start == b.start { + match a.end.cmp(&b.end) { + std::cmp::Ordering::Equal => { + // a: [ ) + // b: [ ) + (smallvec![], smallvec![]) + } + std::cmp::Ordering::Less => { + // a: [ ) + // b: [ ) + (smallvec![], smallvec![a.end..b.end]) + } + std::cmp::Ordering::Greater => { + // a: [ ) + // b: [ ) + (smallvec![b.end..a.end], smallvec![]) + } + } + } else if a.end == b.end { + debug_assert!(a.start != b.start); + if a.start < b.start { + // a: [ ) + // b: [ ) + (smallvec![a.start..b.start], smallvec![]) + } else { + // a: [ ) + // b: [ ) + (smallvec![], smallvec![b.start..a.start]) + } + } else { + debug_assert!(a.start != b.start && a.end != b.end); + if a.end <= b.start || b.end <= a.start { + // a: [ ) + // b: [ [ ) + // or + // a: [ ) + // b: [ ) ) + (smallvec![a], smallvec![b]) + } else if b.start < a.start && a.end < b.end { + // a: [ ) + // b: [ ) + (smallvec![], smallvec![b.start..a.start, a.end..b.end]) + } else if a.start < b.start && b.end < a.end { + // a: [ ) + // b: [ ) + (smallvec![a.start..b.start, b.end..a.end], smallvec![]) + } else if a.end < b.end { + // a: [ ) + // b: [ ) + (smallvec![a.start..b.start], smallvec![a.end..b.end]) + } else { + // a: [ ) + // b: [ ) + (smallvec![b.end..a.end], smallvec![b.start..a.start]) + } + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashSet; + + use super::*; + + #[test] + fn test_range_diff() { + fn test( + a: Range, + b: Range, + expected_removed: impl IntoIterator, + expected_added: impl IntoIterator, + ) { + let (removed, added) = range_diff(a, b); + let removed_set = removed.into_iter().flatten().collect::>(); + let added_set = added.into_iter().flatten().collect::>(); + let expected_removed_set = expected_removed.into_iter().collect::>(); + let expected_added_set = expected_added.into_iter().collect::>(); + assert_eq!(removed_set, expected_removed_set); + assert_eq!(added_set, expected_added_set); + } + + test(0..0, 0..0, [], []); + test(0..1, 0..1, [], []); + test(0..1, 0..2, [], [1]); + test(0..2, 0..1, [1], []); + test(0..2, 1..2, [0], []); + test(1..2, 0..2, [], [0]); + test(0..1, 1..2, [0], [1]); + test(0..1, 2..3, [0], [2]); + test(1..2, 0..1, [1], [0]); + test(2..3, 0..1, [2], [0]); + test(0..3, 1..2, [0, 2], []); + test(1..2, 0..3, [], [0, 2]); + test(0..3, 2..4, [0, 1], [3]); + test(2..4, 0..3, [3], [0, 1]); + } +}