Skip to content

Commit

Permalink
feat(expr): Implement range table function (risingwavelabs#6604)
Browse files Browse the repository at this point in the history
* add range table_func

* add RANGE in proto

* reuse GenerateSeries for Range

* add comments

* dashboard gen proto

* const generics

* remove range.rs

* minor: remove pub(super)

Co-authored-by: lmatz <[email protected]>
  • Loading branch information
yang-han and lmatz authored Nov 27, 2022
1 parent a551b3b commit 9783eae
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 25 deletions.
6 changes: 6 additions & 0 deletions dashboard/proto/gen/expr.ts

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ message TableFunction {
GENERATE = 1;
UNNEST = 2;
REGEXP_MATCHES = 3;
RANGE = 4;
}
Type function_type = 1;
repeated expr.ExprNode args = 2;
Expand Down
122 changes: 98 additions & 24 deletions src/expr/src/table_function/generate_series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@ use super::*;
use crate::ExprError;

#[derive(Debug)]
pub struct GenerateSeries<T: Array, S: Array> {
pub struct GenerateSeries<T: Array, S: Array, const STOP_INCLUSIVE: bool> {
start: BoxedExpression,
stop: BoxedExpression,
step: BoxedExpression,
chunk_size: usize,
_phantom: std::marker::PhantomData<(T, S)>,
}

impl<T: Array, S: Array> GenerateSeries<T, S>
impl<T: Array, S: Array, const STOP_INCLUSIVE: bool> GenerateSeries<T, S, STOP_INCLUSIVE>
where
T::OwnedItem: for<'a> PartialOrd<T::RefItem<'a>>,
T::OwnedItem: for<'a> CheckedAdd<S::RefItem<'a>, Output = T::OwnedItem>,
Expand Down Expand Up @@ -74,9 +74,15 @@ where
let mut cur: T::OwnedItem = start.to_owned_scalar();

while if step.is_negative() {
cur >= stop
} else {
if STOP_INCLUSIVE {
cur >= stop
} else {
cur > stop
}
} else if STOP_INCLUSIVE {
cur <= stop
} else {
cur < stop
} {
builder.append(Some(cur.as_scalar_ref()));
cur = cur.checked_add(step).ok_or(ExprError::NumericOutOfRange)?;
Expand All @@ -86,7 +92,8 @@ where
}
}

impl<T: Array, S: Array> TableFunction for GenerateSeries<T, S>
impl<T: Array, S: Array, const STOP_INCLUSIVE: bool> TableFunction
for GenerateSeries<T, S, STOP_INCLUSIVE>
where
T::OwnedItem: for<'a> PartialOrd<T::RefItem<'a>>,
T::OwnedItem: for<'a> CheckedAdd<S::RefItem<'a>, Output = T::OwnedItem>,
Expand Down Expand Up @@ -144,7 +151,7 @@ where
}
}

pub fn new_generate_series(
pub fn new_generate_series<const STOP_INCLUSIVE: bool>(
prost: &TableFunctionProst,
chunk_size: usize,
) -> Result<BoxedTableFunction> {
Expand All @@ -153,13 +160,16 @@ pub fn new_generate_series(
let [start, stop, step]: [_; 3] = args.try_into().unwrap();

match return_type {
DataType::Timestamp => Ok(GenerateSeries::<NaiveDateTimeArray, IntervalArray>::new(
DataType::Timestamp => Ok(GenerateSeries::<
NaiveDateTimeArray,
IntervalArray,
STOP_INCLUSIVE,
>::new(start, stop, step, chunk_size)
.boxed()),
DataType::Int32 => Ok(GenerateSeries::<I32Array, I32Array, STOP_INCLUSIVE>::new(
start, stop, step, chunk_size,
)
.boxed()),
DataType::Int32 => {
Ok(GenerateSeries::<I32Array, I32Array>::new(start, stop, step, chunk_size).boxed())
}
_ => Err(ExprError::Internal(anyhow!(
"the return type of Generate Series Function is incorrect".to_string(),
))),
Expand Down Expand Up @@ -189,13 +199,12 @@ mod tests {
LiteralExpression::new(DataType::Int32, Some(v.into())).boxed()
}

let function = GenerateSeries::<I32Array, I32Array> {
start: to_lit_expr(start),
stop: to_lit_expr(stop),
step: to_lit_expr(step),
chunk_size: CHUNK_SIZE,
_phantom: Default::default(),
}
let function = GenerateSeries::<I32Array, I32Array, true>::new(
to_lit_expr(start),
to_lit_expr(stop),
to_lit_expr(step),
CHUNK_SIZE,
)
.boxed();
let expect_cnt = ((stop - start) / step + 1) as usize;

Expand Down Expand Up @@ -229,13 +238,78 @@ mod tests {
LiteralExpression::new(ty, Some(v)).boxed()
}

let function = GenerateSeries::<NaiveDateTimeArray, IntervalArray> {
start: to_lit_expr(DataType::Timestamp, start.into()),
stop: to_lit_expr(DataType::Timestamp, stop.into()),
step: to_lit_expr(DataType::Interval, step.into()),
chunk_size: CHUNK_SIZE,
_phantom: Default::default(),
};
let function = GenerateSeries::<NaiveDateTimeArray, IntervalArray, true>::new(
to_lit_expr(DataType::Timestamp, start.into()),
to_lit_expr(DataType::Timestamp, stop.into()),
to_lit_expr(DataType::Interval, step.into()),
CHUNK_SIZE,
);

let dummy_chunk = DataChunk::new_dummy(1);
let arrays = function.eval(&dummy_chunk).unwrap();

let cnt: usize = arrays.iter().map(|a| a.len()).sum();
assert_eq!(cnt, expect_cnt);
}

#[test]
fn test_i32_range() {
range_test_case(2, 4, 1);
range_test_case(4, 2, -1);
range_test_case(0, 9, 2);
range_test_case(0, (CHUNK_SIZE * 2 + 3) as i32, 1);
}

fn range_test_case(start: i32, stop: i32, step: i32) {
fn to_lit_expr(v: i32) -> BoxedExpression {
LiteralExpression::new(DataType::Int32, Some(v.into())).boxed()
}

let function = GenerateSeries::<I32Array, I32Array, false>::new(
to_lit_expr(start),
to_lit_expr(stop),
to_lit_expr(step),
CHUNK_SIZE,
)
.boxed();
let expect_cnt = ((stop - start - step.signum()) / step + 1) as usize;

let dummy_chunk = DataChunk::new_dummy(1);
let arrays = function.eval(&dummy_chunk).unwrap();

let cnt: usize = arrays.iter().map(|a| a.len()).sum();
assert_eq!(cnt, expect_cnt);
}

#[test]
fn test_time_range() {
let start_time = str_to_timestamp("2008-03-01 00:00:00").unwrap();
let stop_time = str_to_timestamp("2008-03-09 00:00:00").unwrap();
let one_minute_step = IntervalUnit::from_minutes(1);
let one_hour_step = IntervalUnit::from_minutes(60);
let one_day_step = IntervalUnit::from_days(1);
time_range_test_case(start_time, stop_time, one_minute_step, 60 * 24 * 8);
time_range_test_case(start_time, stop_time, one_hour_step, 24 * 8);
time_range_test_case(start_time, stop_time, one_day_step, 8);
time_range_test_case(stop_time, start_time, -one_day_step, 8);
}

fn time_range_test_case(
start: NaiveDateTimeWrapper,
stop: NaiveDateTimeWrapper,
step: IntervalUnit,
expect_cnt: usize,
) {
fn to_lit_expr(ty: DataType, v: ScalarImpl) -> BoxedExpression {
LiteralExpression::new(ty, Some(v)).boxed()
}

let function = GenerateSeries::<NaiveDateTimeArray, IntervalArray, false>::new(
to_lit_expr(DataType::Timestamp, start.into()),
to_lit_expr(DataType::Timestamp, stop.into()),
to_lit_expr(DataType::Interval, step.into()),
CHUNK_SIZE,
);

let dummy_chunk = DataChunk::new_dummy(1);
let arrays = function.eval(&dummy_chunk).unwrap();
Expand Down
3 changes: 2 additions & 1 deletion src/expr/src/table_function/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,10 @@ pub fn build_from_prost(
use risingwave_pb::expr::table_function::Type::*;

match prost.get_function_type().unwrap() {
Generate => new_generate_series(prost, chunk_size),
Generate => new_generate_series::<true>(prost, chunk_size),
Unnest => new_unnest(prost, chunk_size),
RegexpMatches => new_regexp_matches(prost, chunk_size),
Range => new_generate_series::<false>(prost, chunk_size),
Unspecified => unreachable!(),
}
}
Expand Down

0 comments on commit 9783eae

Please sign in to comment.