Skip to content

Commit

Permalink
extract range_except and range_diff out of buffer.rs
Browse files Browse the repository at this point in the history
Signed-off-by: Richard Chien <[email protected]>
  • Loading branch information
stdrc committed Jan 10, 2024
1 parent 6a4c33b commit d00a704
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 139 deletions.
139 changes: 2 additions & 137 deletions src/expr/core/src/window_function/state/buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ use std::collections::VecDeque;
use std::ops::Range;

use risingwave_common::array::Op;
use smallvec::{smallvec, SmallVec};

use super::range_utils::range_except;
use crate::window_function::state::range_utils::range_diff;
use crate::window_function::{Frame, FrameBounds, FrameExclusion};

struct Entry<K: Ord, V> {
Expand Down Expand Up @@ -259,149 +260,13 @@ impl<K: Ord, V: Clone> WindowBuffer<K, V> {
}
}

/// Calculate range (A - B), the result might be the union of two ranges when B is totally included
/// in the A.
fn range_except(a: Range<usize>, b: Range<usize>) -> (Range<usize>, Range<usize>) {
#[allow(clippy::if_same_then_else)] // for better readability
if a.is_empty() {
(0..0, 0..0)
} else if b.is_empty() {
(a, 0..0)
} else if a.end <= b.start || b.end <= a.start {
// a: [ )
// b: [ )
// or
// a: [ )
// b: [ )
(a, 0..0)
} else if b.start <= a.start && a.end <= b.end {
// a: [ )
// b: [ )
(0..0, 0..0)
} else if a.start < b.start && b.end < a.end {
// a: [ )
// b: [ )
(a.start..b.start, b.end..a.end)
} else if a.end <= b.end {
// a: [ )
// b: [ )
(a.start..b.start, 0..0)
} else if b.start <= a.start {
// a: [ )
// b: [ )
(b.end..a.end, 0..0)
} else {
unreachable!()
}
}

/// Calculate the difference of two ranges A and B, return (removed ranges, added ranges).
/// Note this is quite different from [`range_except`].
#[allow(clippy::type_complexity)] // looks complex but it's not
fn range_diff(
a: Range<usize>,
b: Range<usize>,
) -> (SmallVec<[Range<usize>; 2]>, SmallVec<[Range<usize>; 2]>) {
if a.start == b.start {
match a.end.cmp(&b.end) {
std::cmp::Ordering::Equal => {
// a: [ )
// b: [ )
(smallvec![], smallvec![])
}
std::cmp::Ordering::Less => {
// a: [ )
// b: [ )
(smallvec![], smallvec![a.end..b.end])
}
std::cmp::Ordering::Greater => {
// a: [ )
// b: [ )
(smallvec![b.end..a.end], smallvec![])
}
}
} else if a.end == b.end {
debug_assert!(a.start != b.start);
if a.start < b.start {
// a: [ )
// b: [ )
(smallvec![a.start..b.start], smallvec![])
} else {
// a: [ )
// b: [ )
(smallvec![], smallvec![b.start..a.start])
}
} else {
debug_assert!(a.start != b.start && a.end != b.end);
if a.end <= b.start || b.end <= a.start {
// a: [ )
// b: [ [ )
// or
// a: [ )
// b: [ ) )
(smallvec![a], smallvec![b])
} else if b.start < a.start && a.end < b.end {
// a: [ )
// b: [ )
(smallvec![], smallvec![b.start..a.start, a.end..b.end])
} else if a.start < b.start && b.end < a.end {
// a: [ )
// b: [ )
(smallvec![a.start..b.start, b.end..a.end], smallvec![])
} else if a.end < b.end {
// a: [ )
// b: [ )
(smallvec![a.start..b.start], smallvec![a.end..b.end])
} else {
// a: [ )
// b: [ )
(smallvec![b.end..a.end], smallvec![b.start..a.start])
}
}
}

#[cfg(test)]
mod tests {
use std::collections::HashSet;

use itertools::Itertools;

use super::*;
use crate::window_function::{Frame, FrameBound};

#[test]
fn test_range_diff() {
fn test(
a: Range<usize>,
b: Range<usize>,
expected_removed: impl IntoIterator<Item = usize>,
expected_added: impl IntoIterator<Item = usize>,
) {
let (removed, added) = range_diff(a, b);
let removed_set = removed.into_iter().flatten().collect::<HashSet<_>>();
let added_set = added.into_iter().flatten().collect::<HashSet<_>>();
let expected_removed_set = expected_removed.into_iter().collect::<HashSet<_>>();
let expected_added_set = expected_added.into_iter().collect::<HashSet<_>>();
assert_eq!(removed_set, expected_removed_set);
assert_eq!(added_set, expected_added_set);
}

test(0..0, 0..0, [], []);
test(0..1, 0..1, [], []);
test(0..1, 0..2, [], [1]);
test(0..2, 0..1, [1], []);
test(0..2, 1..2, [0], []);
test(1..2, 0..2, [], [0]);
test(0..1, 1..2, [0], [1]);
test(0..1, 2..3, [0], [2]);
test(1..2, 0..1, [1], [0]);
test(2..3, 0..1, [2], [0]);
test(0..3, 1..2, [0, 2], []);
test(1..2, 0..3, [], [0, 2]);
test(0..3, 2..4, [0, 1], [3]);
test(2..4, 0..3, [3], [0, 1]);
}

#[test]
fn test_rows_frame_unbounded_preceding_to_current_row() {
let mut buffer = WindowBuffer::new(
Expand Down
4 changes: 2 additions & 2 deletions src/expr/core/src/window_function/state/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ use smallvec::SmallVec;
use super::{WindowFuncCall, WindowFuncKind};
use crate::{ExprError, Result};

mod buffer;

mod aggregate;
mod buffer;
mod range_utils;
mod rank;

/// Unique and ordered identifier for a row in internal states.
Expand Down
158 changes: 158 additions & 0 deletions src/expr/core/src/window_function/state/range_utils.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
// Copyright 2024 RisingWave Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::ops::Range;

use smallvec::{smallvec, SmallVec};

/// Calculate range (A - B), the result might be the union of two ranges when B is totally included
/// in the A.
pub(super) fn range_except(a: Range<usize>, b: Range<usize>) -> (Range<usize>, Range<usize>) {
#[allow(clippy::if_same_then_else)] // for better readability
if a.is_empty() {
(0..0, 0..0)
} else if b.is_empty() {
(a, 0..0)
} else if a.end <= b.start || b.end <= a.start {
// a: [ )
// b: [ )
// or
// a: [ )
// b: [ )
(a, 0..0)
} else if b.start <= a.start && a.end <= b.end {
// a: [ )
// b: [ )
(0..0, 0..0)
} else if a.start < b.start && b.end < a.end {
// a: [ )
// b: [ )
(a.start..b.start, b.end..a.end)
} else if a.end <= b.end {
// a: [ )
// b: [ )
(a.start..b.start, 0..0)
} else if b.start <= a.start {
// a: [ )
// b: [ )
(b.end..a.end, 0..0)
} else {
unreachable!()
}
}

/// Calculate the difference of two ranges A and B, return (removed ranges, added ranges).
/// Note this is quite different from [`range_except`].
#[allow(clippy::type_complexity)] // looks complex but it's not
pub(super) fn range_diff(
a: Range<usize>,
b: Range<usize>,
) -> (SmallVec<[Range<usize>; 2]>, SmallVec<[Range<usize>; 2]>) {
if a.start == b.start {
match a.end.cmp(&b.end) {
std::cmp::Ordering::Equal => {
// a: [ )
// b: [ )
(smallvec![], smallvec![])
}
std::cmp::Ordering::Less => {
// a: [ )
// b: [ )
(smallvec![], smallvec![a.end..b.end])
}
std::cmp::Ordering::Greater => {
// a: [ )
// b: [ )
(smallvec![b.end..a.end], smallvec![])
}
}
} else if a.end == b.end {
debug_assert!(a.start != b.start);
if a.start < b.start {
// a: [ )
// b: [ )
(smallvec![a.start..b.start], smallvec![])
} else {
// a: [ )
// b: [ )
(smallvec![], smallvec![b.start..a.start])
}
} else {
debug_assert!(a.start != b.start && a.end != b.end);
if a.end <= b.start || b.end <= a.start {
// a: [ )
// b: [ [ )
// or
// a: [ )
// b: [ ) )
(smallvec![a], smallvec![b])
} else if b.start < a.start && a.end < b.end {
// a: [ )
// b: [ )
(smallvec![], smallvec![b.start..a.start, a.end..b.end])
} else if a.start < b.start && b.end < a.end {
// a: [ )
// b: [ )
(smallvec![a.start..b.start, b.end..a.end], smallvec![])
} else if a.end < b.end {
// a: [ )
// b: [ )
(smallvec![a.start..b.start], smallvec![a.end..b.end])
} else {
// a: [ )
// b: [ )
(smallvec![b.end..a.end], smallvec![b.start..a.start])
}
}
}

#[cfg(test)]
mod tests {
use std::collections::HashSet;

use super::*;

#[test]
fn test_range_diff() {
fn test(
a: Range<usize>,
b: Range<usize>,
expected_removed: impl IntoIterator<Item = usize>,
expected_added: impl IntoIterator<Item = usize>,
) {
let (removed, added) = range_diff(a, b);
let removed_set = removed.into_iter().flatten().collect::<HashSet<_>>();
let added_set = added.into_iter().flatten().collect::<HashSet<_>>();
let expected_removed_set = expected_removed.into_iter().collect::<HashSet<_>>();
let expected_added_set = expected_added.into_iter().collect::<HashSet<_>>();
assert_eq!(removed_set, expected_removed_set);
assert_eq!(added_set, expected_added_set);
}

test(0..0, 0..0, [], []);
test(0..1, 0..1, [], []);
test(0..1, 0..2, [], [1]);
test(0..2, 0..1, [1], []);
test(0..2, 1..2, [0], []);
test(1..2, 0..2, [], [0]);
test(0..1, 1..2, [0], [1]);
test(0..1, 2..3, [0], [2]);
test(1..2, 0..1, [1], [0]);
test(2..3, 0..1, [2], [0]);
test(0..3, 1..2, [0, 2], []);
test(1..2, 0..3, [], [0, 2]);
test(0..3, 2..4, [0, 1], [3]);
test(2..4, 0..3, [3], [0, 1]);
}
}

0 comments on commit d00a704

Please sign in to comment.