diff --git a/e2e_test/batch/basic/generate_series.slt.part b/e2e_test/batch/basic/generate_series.slt.part index 9d75de1733d42..35efce3a741f1 100644 --- a/e2e_test/batch/basic/generate_series.slt.part +++ b/e2e_test/batch/basic/generate_series.slt.part @@ -127,3 +127,24 @@ SELECT * FROM generate_series(0.1::numeric, 2.1::numeric, 0.5::numeric) 1.1 1.6 2.1 + +statement error start value cannot be infinity +SELECT * FROM generate_series('infinity'::numeric,10::numeric); + +statement error stop value cannot be infinity +SELECT * FROM generate_series(0::numeric,'-infinity'::numeric); + +statement error stop value cannot be NaN +SELECT * FROM generate_series(0::numeric,'nan'::numeric); + +statement error start value cannot be infinity +SELECT * FROM generate_series('infinity'::numeric,10::numeric,0::numeric); + +statement error stop value cannot be infinity +SELECT * FROM generate_series(0::numeric,'-infinity'::numeric,0::numeric); + +statement error step value cannot be NaN +SELECT * FROM generate_series(0::numeric,10::numeric,'nan'::numeric); + +statement error start value cannot be infinity +SELECT * FROM generate_series('-infinity'::numeric,'infinity'::numeric,'nan'::numeric); diff --git a/e2e_test/batch/basic/range.slt.part b/e2e_test/batch/basic/range.slt.part index 9d404e5f5b97b..e8f43a17811e5 100644 --- a/e2e_test/batch/basic/range.slt.part +++ b/e2e_test/batch/basic/range.slt.part @@ -114,6 +114,27 @@ SELECT * FROM range(0.1::numeric, 2.1::numeric, 0.5::numeric) 1.1 1.6 +statement error start value cannot be infinity +SELECT * FROM range('infinity'::numeric,10::numeric); + +statement error stop value cannot be infinity +SELECT * FROM range(0::numeric,'-infinity'::numeric); + +statement error stop value cannot be NaN +SELECT * FROM range(0::numeric,'nan'::numeric); + +statement error start value cannot be infinity +SELECT * FROM range('infinity'::numeric,10::numeric,0::numeric); + +statement error stop value cannot be infinity +SELECT * FROM range(0::numeric,'-infinity'::numeric,0::numeric); + +statement error step value cannot be NaN +SELECT * FROM range(0::numeric,10::numeric,'nan'::numeric); + +statement error start value cannot be infinity +SELECT * FROM range('-infinity'::numeric,'infinity'::numeric,'nan'::numeric); + # test table function with aliases query I SELECT alias from range(1,2) alias; diff --git a/src/expr/impl/src/table_function/generate_series.rs b/src/expr/impl/src/table_function/generate_series.rs index cf222aa66db49..586fa60de02c2 100644 --- a/src/expr/impl/src/table_function/generate_series.rs +++ b/src/expr/impl/src/table_function/generate_series.rs @@ -13,12 +13,11 @@ // limitations under the License. use num_traits::One; -use risingwave_common::types::{CheckedAdd, IsNegative}; +use risingwave_common::types::{CheckedAdd, Decimal, IsNegative}; use risingwave_expr::{function, ExprError, Result}; #[function("generate_series(int4, int4) -> setof int4")] #[function("generate_series(int8, int8) -> setof int8")] -#[function("generate_series(decimal, decimal) -> setof decimal")] fn generate_series(start: T, stop: T) -> Result>> where T: CheckedAdd + PartialOrd + Copy + One + IsNegative, @@ -26,9 +25,19 @@ where range_generic::<_, _, true>(start, stop, T::one()) } +#[function("generate_series(decimal, decimal) -> setof decimal")] +fn generate_series_decimal( + start: Decimal, + stop: Decimal, +) -> Result>> +where +{ + validate_range_parameters(start, stop, Decimal::one())?; + range_generic::(start, stop, Decimal::one()) +} + #[function("generate_series(int4, int4, int4) -> setof int4")] #[function("generate_series(int8, int8, int8) -> setof int8")] -#[function("generate_series(decimal, decimal, decimal) -> setof decimal")] #[function("generate_series(timestamp, timestamp, interval) -> setof timestamp")] fn generate_series_step(start: T, stop: T, step: S) -> Result>> where @@ -38,9 +47,18 @@ where range_generic::<_, _, true>(start, stop, step) } +#[function("generate_series(decimal, decimal, decimal) -> setof decimal")] +fn generate_series_step_decimal( + start: Decimal, + stop: Decimal, + step: Decimal, +) -> Result>> { + validate_range_parameters(start, stop, step)?; + range_generic::<_, _, true>(start, stop, step) +} + #[function("range(int4, int4) -> setof int4")] #[function("range(int8, int8) -> setof int8")] -#[function("range(decimal, decimal) -> setof decimal")] fn range(start: T, stop: T) -> Result>> where T: CheckedAdd + PartialOrd + Copy + One + IsNegative, @@ -48,9 +66,16 @@ where range_generic::<_, _, false>(start, stop, T::one()) } +#[function("range(decimal, decimal) -> setof decimal")] +fn range_decimal(start: Decimal, stop: Decimal) -> Result>> +where +{ + validate_range_parameters(start, stop, Decimal::one())?; + range_generic::(start, stop, Decimal::one()) +} + #[function("range(int4, int4, int4) -> setof int4")] #[function("range(int8, int8, int8) -> setof int8")] -#[function("range(decimal, decimal, decimal) -> setof decimal")] #[function("range(timestamp, timestamp, interval) -> setof timestamp")] fn range_step(start: T, stop: T, step: S) -> Result>> where @@ -60,6 +85,16 @@ where range_generic::<_, _, false>(start, stop, step) } +#[function("range(decimal, decimal, decimal) -> setof decimal")] +fn range_step_decimal( + start: Decimal, + stop: Decimal, + step: Decimal, +) -> Result>> { + validate_range_parameters(start, stop, step)?; + range_generic::<_, _, false>(start, stop, step) +} + #[inline] fn range_generic( start: T, @@ -93,6 +128,29 @@ where Ok(std::iter::from_fn(move || next().transpose())) } +#[inline] +fn validate_range_parameters(start: Decimal, stop: Decimal, step: Decimal) -> Result<()> { + validate_decimal(start, "start")?; + validate_decimal(stop, "stop")?; + validate_decimal(step, "step")?; + Ok(()) +} + +#[inline] +fn validate_decimal(decimal: Decimal, name: &'static str) -> Result<()> { + match decimal { + Decimal::Normalized(_) => Ok(()), + Decimal::PositiveInf | Decimal::NegativeInf => Err(ExprError::InvalidParam { + name, + reason: format!("{} value cannot be infinity", name).into(), + }), + Decimal::NaN => Err(ExprError::InvalidParam { + name, + reason: format!("{} value cannot be NaN", name).into(), + }), + } +} + #[cfg(test)] mod tests { use std::str::FromStr; @@ -100,9 +158,10 @@ mod tests { use futures_util::StreamExt; use risingwave_common::array::DataChunk; use risingwave_common::types::test_utils::IntervalTestExt; - use risingwave_common::types::{DataType, Interval, ScalarImpl, Timestamp}; + use risingwave_common::types::{DataType, Decimal, Interval, ScalarImpl, Timestamp}; use risingwave_expr::expr::{BoxedExpression, Expression, LiteralExpression}; use risingwave_expr::table_function::build; + use risingwave_expr::ExprError; use risingwave_pb::expr::table_function::PbType; const CHUNK_SIZE: usize = 1024; @@ -247,4 +306,61 @@ mod tests { } assert_eq!(actual_cnt, expect_cnt); } + + #[tokio::test] + async fn test_generate_series_decimal() { + let start = Decimal::from_str("1").unwrap(); + let start_inf = Decimal::from_str("infinity").unwrap(); + let stop = Decimal::from_str("5").unwrap(); + let stop_inf = Decimal::from_str("-infinity").unwrap(); + + let step = Decimal::from_str("1").unwrap(); + let step_nan = Decimal::from_str("nan").unwrap(); + let step_inf = Decimal::from_str("infinity").unwrap(); + generate_series_decimal(start, stop, step, true).await; + generate_series_decimal(start_inf, stop, step, false).await; + generate_series_decimal(start_inf, stop_inf, step, false).await; + generate_series_decimal(start, stop_inf, step, false).await; + generate_series_decimal(start, stop, step_nan, false).await; + generate_series_decimal(start, stop, step_inf, false).await; + generate_series_decimal(start, stop_inf, step_nan, false).await; + } + + async fn generate_series_decimal( + start: Decimal, + stop: Decimal, + step: Decimal, + expect_ok: bool, + ) { + fn literal(ty: DataType, v: ScalarImpl) -> BoxedExpression { + LiteralExpression::new(ty, Some(v)).boxed() + } + let function = build( + PbType::GenerateSeries, + DataType::Decimal, + CHUNK_SIZE, + vec![ + literal(DataType::Decimal, start.into()), + literal(DataType::Decimal, stop.into()), + literal(DataType::Decimal, step.into()), + ], + ) + .unwrap(); + + let dummy_chunk = DataChunk::new_dummy(1); + let mut output = function.eval(&dummy_chunk).await; + while let Some(res) = output.next().await { + match res { + Ok(_) => { + assert!(expect_ok); + } + Err(ExprError::InvalidParam { .. }) => { + assert!(!expect_ok); + } + Err(_) => { + unreachable!(); + } + } + } + } }