diff --git a/dashboard/proto/gen/expr.ts b/dashboard/proto/gen/expr.ts index ff7ab1be8e391..033a82b40a999 100644 --- a/dashboard/proto/gen/expr.ts +++ b/dashboard/proto/gen/expr.ts @@ -546,6 +546,7 @@ export const TableFunction_Type = { GENERATE: "GENERATE", UNNEST: "UNNEST", REGEXP_MATCHES: "REGEXP_MATCHES", + RANGE: "RANGE", UNRECOGNIZED: "UNRECOGNIZED", } as const; @@ -565,6 +566,9 @@ export function tableFunction_TypeFromJSON(object: any): TableFunction_Type { case 3: case "REGEXP_MATCHES": return TableFunction_Type.REGEXP_MATCHES; + case 4: + case "RANGE": + return TableFunction_Type.RANGE; case -1: case "UNRECOGNIZED": default: @@ -582,6 +586,8 @@ export function tableFunction_TypeToJSON(object: TableFunction_Type): string { return "UNNEST"; case TableFunction_Type.REGEXP_MATCHES: return "REGEXP_MATCHES"; + case TableFunction_Type.RANGE: + return "RANGE"; case TableFunction_Type.UNRECOGNIZED: default: return "UNRECOGNIZED"; diff --git a/proto/expr.proto b/proto/expr.proto index 45855619867e6..6927a28a37e83 100644 --- a/proto/expr.proto +++ b/proto/expr.proto @@ -123,6 +123,7 @@ message TableFunction { GENERATE = 1; UNNEST = 2; REGEXP_MATCHES = 3; + RANGE = 4; } Type function_type = 1; repeated expr.ExprNode args = 2; diff --git a/src/expr/src/table_function/generate_series.rs b/src/expr/src/table_function/generate_series.rs index f2f562d03f5bd..2731cff0a7b95 100644 --- a/src/expr/src/table_function/generate_series.rs +++ b/src/expr/src/table_function/generate_series.rs @@ -27,7 +27,7 @@ use super::*; use crate::ExprError; #[derive(Debug)] -pub struct GenerateSeries { +pub struct GenerateSeries { start: BoxedExpression, stop: BoxedExpression, step: BoxedExpression, @@ -35,7 +35,7 @@ pub struct GenerateSeries { _phantom: std::marker::PhantomData<(T, S)>, } -impl GenerateSeries +impl GenerateSeries where T::OwnedItem: for<'a> PartialOrd>, T::OwnedItem: for<'a> CheckedAdd, Output = T::OwnedItem>, @@ -74,9 +74,15 @@ where let mut cur: T::OwnedItem = start.to_owned_scalar(); while if step.is_negative() { - cur >= stop - } else { + if STOP_INCLUSIVE { + cur >= stop + } else { + cur > stop + } + } else if STOP_INCLUSIVE { cur <= stop + } else { + cur < stop } { builder.append(Some(cur.as_scalar_ref())); cur = cur.checked_add(step).ok_or(ExprError::NumericOutOfRange)?; @@ -86,7 +92,8 @@ where } } -impl TableFunction for GenerateSeries +impl TableFunction + for GenerateSeries where T::OwnedItem: for<'a> PartialOrd>, T::OwnedItem: for<'a> CheckedAdd, Output = T::OwnedItem>, @@ -144,7 +151,7 @@ where } } -pub fn new_generate_series( +pub fn new_generate_series( prost: &TableFunctionProst, chunk_size: usize, ) -> Result { @@ -153,13 +160,16 @@ pub fn new_generate_series( let [start, stop, step]: [_; 3] = args.try_into().unwrap(); match return_type { - DataType::Timestamp => Ok(GenerateSeries::::new( + DataType::Timestamp => Ok(GenerateSeries::< + NaiveDateTimeArray, + IntervalArray, + STOP_INCLUSIVE, + >::new(start, stop, step, chunk_size) + .boxed()), + DataType::Int32 => Ok(GenerateSeries::::new( start, stop, step, chunk_size, ) .boxed()), - DataType::Int32 => { - Ok(GenerateSeries::::new(start, stop, step, chunk_size).boxed()) - } _ => Err(ExprError::Internal(anyhow!( "the return type of Generate Series Function is incorrect".to_string(), ))), @@ -189,13 +199,12 @@ mod tests { LiteralExpression::new(DataType::Int32, Some(v.into())).boxed() } - let function = GenerateSeries:: { - start: to_lit_expr(start), - stop: to_lit_expr(stop), - step: to_lit_expr(step), - chunk_size: CHUNK_SIZE, - _phantom: Default::default(), - } + let function = GenerateSeries::::new( + to_lit_expr(start), + to_lit_expr(stop), + to_lit_expr(step), + CHUNK_SIZE, + ) .boxed(); let expect_cnt = ((stop - start) / step + 1) as usize; @@ -229,13 +238,78 @@ mod tests { LiteralExpression::new(ty, Some(v)).boxed() } - let function = GenerateSeries:: { - start: to_lit_expr(DataType::Timestamp, start.into()), - stop: to_lit_expr(DataType::Timestamp, stop.into()), - step: to_lit_expr(DataType::Interval, step.into()), - chunk_size: CHUNK_SIZE, - _phantom: Default::default(), - }; + let function = GenerateSeries::::new( + to_lit_expr(DataType::Timestamp, start.into()), + to_lit_expr(DataType::Timestamp, stop.into()), + to_lit_expr(DataType::Interval, step.into()), + CHUNK_SIZE, + ); + + let dummy_chunk = DataChunk::new_dummy(1); + let arrays = function.eval(&dummy_chunk).unwrap(); + + let cnt: usize = arrays.iter().map(|a| a.len()).sum(); + assert_eq!(cnt, expect_cnt); + } + + #[test] + fn test_i32_range() { + range_test_case(2, 4, 1); + range_test_case(4, 2, -1); + range_test_case(0, 9, 2); + range_test_case(0, (CHUNK_SIZE * 2 + 3) as i32, 1); + } + + fn range_test_case(start: i32, stop: i32, step: i32) { + fn to_lit_expr(v: i32) -> BoxedExpression { + LiteralExpression::new(DataType::Int32, Some(v.into())).boxed() + } + + let function = GenerateSeries::::new( + to_lit_expr(start), + to_lit_expr(stop), + to_lit_expr(step), + CHUNK_SIZE, + ) + .boxed(); + let expect_cnt = ((stop - start - step.signum()) / step + 1) as usize; + + let dummy_chunk = DataChunk::new_dummy(1); + let arrays = function.eval(&dummy_chunk).unwrap(); + + let cnt: usize = arrays.iter().map(|a| a.len()).sum(); + assert_eq!(cnt, expect_cnt); + } + + #[test] + fn test_time_range() { + let start_time = str_to_timestamp("2008-03-01 00:00:00").unwrap(); + let stop_time = str_to_timestamp("2008-03-09 00:00:00").unwrap(); + let one_minute_step = IntervalUnit::from_minutes(1); + let one_hour_step = IntervalUnit::from_minutes(60); + let one_day_step = IntervalUnit::from_days(1); + time_range_test_case(start_time, stop_time, one_minute_step, 60 * 24 * 8); + time_range_test_case(start_time, stop_time, one_hour_step, 24 * 8); + time_range_test_case(start_time, stop_time, one_day_step, 8); + time_range_test_case(stop_time, start_time, -one_day_step, 8); + } + + fn time_range_test_case( + start: NaiveDateTimeWrapper, + stop: NaiveDateTimeWrapper, + step: IntervalUnit, + expect_cnt: usize, + ) { + fn to_lit_expr(ty: DataType, v: ScalarImpl) -> BoxedExpression { + LiteralExpression::new(ty, Some(v)).boxed() + } + + let function = GenerateSeries::::new( + to_lit_expr(DataType::Timestamp, start.into()), + to_lit_expr(DataType::Timestamp, stop.into()), + to_lit_expr(DataType::Interval, step.into()), + CHUNK_SIZE, + ); let dummy_chunk = DataChunk::new_dummy(1); let arrays = function.eval(&dummy_chunk).unwrap(); diff --git a/src/expr/src/table_function/mod.rs b/src/expr/src/table_function/mod.rs index 79cdb21af85e2..ae86082bf0b55 100644 --- a/src/expr/src/table_function/mod.rs +++ b/src/expr/src/table_function/mod.rs @@ -59,9 +59,10 @@ pub fn build_from_prost( use risingwave_pb::expr::table_function::Type::*; match prost.get_function_type().unwrap() { - Generate => new_generate_series(prost, chunk_size), + Generate => new_generate_series::(prost, chunk_size), Unnest => new_unnest(prost, chunk_size), RegexpMatches => new_regexp_matches(prost, chunk_size), + Range => new_generate_series::(prost, chunk_size), Unspecified => unreachable!(), } }