Skip to content

Commit

Permalink
fix(expr): reject inf/-inf/nan decimal in generate_series and range (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
KeXiangWang authored Oct 12, 2023
1 parent 2607ae7 commit f9e3d99
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 6 deletions.
21 changes: 21 additions & 0 deletions e2e_test/batch/basic/generate_series.slt.part
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,24 @@ SELECT * FROM generate_series(0.1::numeric, 2.1::numeric, 0.5::numeric)
1.1
1.6
2.1

statement error start value cannot be infinity
SELECT * FROM generate_series('infinity'::numeric,10::numeric);

statement error stop value cannot be infinity
SELECT * FROM generate_series(0::numeric,'-infinity'::numeric);

statement error stop value cannot be NaN
SELECT * FROM generate_series(0::numeric,'nan'::numeric);

statement error start value cannot be infinity
SELECT * FROM generate_series('infinity'::numeric,10::numeric,0::numeric);

statement error stop value cannot be infinity
SELECT * FROM generate_series(0::numeric,'-infinity'::numeric,0::numeric);

statement error step value cannot be NaN
SELECT * FROM generate_series(0::numeric,10::numeric,'nan'::numeric);

statement error start value cannot be infinity
SELECT * FROM generate_series('-infinity'::numeric,'infinity'::numeric,'nan'::numeric);
21 changes: 21 additions & 0 deletions e2e_test/batch/basic/range.slt.part
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,27 @@ SELECT * FROM range(0.1::numeric, 2.1::numeric, 0.5::numeric)
1.1
1.6

statement error start value cannot be infinity
SELECT * FROM range('infinity'::numeric,10::numeric);

statement error stop value cannot be infinity
SELECT * FROM range(0::numeric,'-infinity'::numeric);

statement error stop value cannot be NaN
SELECT * FROM range(0::numeric,'nan'::numeric);

statement error start value cannot be infinity
SELECT * FROM range('infinity'::numeric,10::numeric,0::numeric);

statement error stop value cannot be infinity
SELECT * FROM range(0::numeric,'-infinity'::numeric,0::numeric);

statement error step value cannot be NaN
SELECT * FROM range(0::numeric,10::numeric,'nan'::numeric);

statement error start value cannot be infinity
SELECT * FROM range('-infinity'::numeric,'infinity'::numeric,'nan'::numeric);

# test table function with aliases
query I
SELECT alias from range(1,2) alias;
Expand Down
128 changes: 122 additions & 6 deletions src/expr/impl/src/table_function/generate_series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,31 @@
// limitations under the License.

use num_traits::One;
use risingwave_common::types::{CheckedAdd, IsNegative};
use risingwave_common::types::{CheckedAdd, Decimal, IsNegative};
use risingwave_expr::{function, ExprError, Result};

#[function("generate_series(int4, int4) -> setof int4")]
#[function("generate_series(int8, int8) -> setof int8")]
#[function("generate_series(decimal, decimal) -> setof decimal")]
fn generate_series<T>(start: T, stop: T) -> Result<impl Iterator<Item = Result<T>>>
where
T: CheckedAdd<Output = T> + PartialOrd + Copy + One + IsNegative,
{
range_generic::<_, _, true>(start, stop, T::one())
}

#[function("generate_series(decimal, decimal) -> setof decimal")]
fn generate_series_decimal(
start: Decimal,
stop: Decimal,
) -> Result<impl Iterator<Item = Result<Decimal>>>
where
{
validate_range_parameters(start, stop, Decimal::one())?;
range_generic::<Decimal, Decimal, true>(start, stop, Decimal::one())
}

#[function("generate_series(int4, int4, int4) -> setof int4")]
#[function("generate_series(int8, int8, int8) -> setof int8")]
#[function("generate_series(decimal, decimal, decimal) -> setof decimal")]
#[function("generate_series(timestamp, timestamp, interval) -> setof timestamp")]
fn generate_series_step<T, S>(start: T, stop: T, step: S) -> Result<impl Iterator<Item = Result<T>>>
where
Expand All @@ -38,19 +47,35 @@ where
range_generic::<_, _, true>(start, stop, step)
}

#[function("generate_series(decimal, decimal, decimal) -> setof decimal")]
fn generate_series_step_decimal(
start: Decimal,
stop: Decimal,
step: Decimal,
) -> Result<impl Iterator<Item = Result<Decimal>>> {
validate_range_parameters(start, stop, step)?;
range_generic::<_, _, true>(start, stop, step)
}

#[function("range(int4, int4) -> setof int4")]
#[function("range(int8, int8) -> setof int8")]
#[function("range(decimal, decimal) -> setof decimal")]
fn range<T>(start: T, stop: T) -> Result<impl Iterator<Item = Result<T>>>
where
T: CheckedAdd<Output = T> + PartialOrd + Copy + One + IsNegative,
{
range_generic::<_, _, false>(start, stop, T::one())
}

#[function("range(decimal, decimal) -> setof decimal")]
fn range_decimal(start: Decimal, stop: Decimal) -> Result<impl Iterator<Item = Result<Decimal>>>
where
{
validate_range_parameters(start, stop, Decimal::one())?;
range_generic::<Decimal, Decimal, false>(start, stop, Decimal::one())
}

#[function("range(int4, int4, int4) -> setof int4")]
#[function("range(int8, int8, int8) -> setof int8")]
#[function("range(decimal, decimal, decimal) -> setof decimal")]
#[function("range(timestamp, timestamp, interval) -> setof timestamp")]
fn range_step<T, S>(start: T, stop: T, step: S) -> Result<impl Iterator<Item = Result<T>>>
where
Expand All @@ -60,6 +85,16 @@ where
range_generic::<_, _, false>(start, stop, step)
}

#[function("range(decimal, decimal, decimal) -> setof decimal")]
fn range_step_decimal(
start: Decimal,
stop: Decimal,
step: Decimal,
) -> Result<impl Iterator<Item = Result<Decimal>>> {
validate_range_parameters(start, stop, step)?;
range_generic::<_, _, false>(start, stop, step)
}

#[inline]
fn range_generic<T, S, const INCLUSIVE: bool>(
start: T,
Expand Down Expand Up @@ -93,16 +128,40 @@ where
Ok(std::iter::from_fn(move || next().transpose()))
}

#[inline]
fn validate_range_parameters(start: Decimal, stop: Decimal, step: Decimal) -> Result<()> {
validate_decimal(start, "start")?;
validate_decimal(stop, "stop")?;
validate_decimal(step, "step")?;
Ok(())
}

#[inline]
fn validate_decimal(decimal: Decimal, name: &'static str) -> Result<()> {
match decimal {
Decimal::Normalized(_) => Ok(()),
Decimal::PositiveInf | Decimal::NegativeInf => Err(ExprError::InvalidParam {
name,
reason: format!("{} value cannot be infinity", name).into(),
}),
Decimal::NaN => Err(ExprError::InvalidParam {
name,
reason: format!("{} value cannot be NaN", name).into(),
}),
}
}

#[cfg(test)]
mod tests {
use std::str::FromStr;

use futures_util::StreamExt;
use risingwave_common::array::DataChunk;
use risingwave_common::types::test_utils::IntervalTestExt;
use risingwave_common::types::{DataType, Interval, ScalarImpl, Timestamp};
use risingwave_common::types::{DataType, Decimal, Interval, ScalarImpl, Timestamp};
use risingwave_expr::expr::{BoxedExpression, Expression, LiteralExpression};
use risingwave_expr::table_function::build;
use risingwave_expr::ExprError;
use risingwave_pb::expr::table_function::PbType;

const CHUNK_SIZE: usize = 1024;
Expand Down Expand Up @@ -247,4 +306,61 @@ mod tests {
}
assert_eq!(actual_cnt, expect_cnt);
}

#[tokio::test]
async fn test_generate_series_decimal() {
let start = Decimal::from_str("1").unwrap();
let start_inf = Decimal::from_str("infinity").unwrap();
let stop = Decimal::from_str("5").unwrap();
let stop_inf = Decimal::from_str("-infinity").unwrap();

let step = Decimal::from_str("1").unwrap();
let step_nan = Decimal::from_str("nan").unwrap();
let step_inf = Decimal::from_str("infinity").unwrap();
generate_series_decimal(start, stop, step, true).await;
generate_series_decimal(start_inf, stop, step, false).await;
generate_series_decimal(start_inf, stop_inf, step, false).await;
generate_series_decimal(start, stop_inf, step, false).await;
generate_series_decimal(start, stop, step_nan, false).await;
generate_series_decimal(start, stop, step_inf, false).await;
generate_series_decimal(start, stop_inf, step_nan, false).await;
}

async fn generate_series_decimal(
start: Decimal,
stop: Decimal,
step: Decimal,
expect_ok: bool,
) {
fn literal(ty: DataType, v: ScalarImpl) -> BoxedExpression {
LiteralExpression::new(ty, Some(v)).boxed()
}
let function = build(
PbType::GenerateSeries,
DataType::Decimal,
CHUNK_SIZE,
vec![
literal(DataType::Decimal, start.into()),
literal(DataType::Decimal, stop.into()),
literal(DataType::Decimal, step.into()),
],
)
.unwrap();

let dummy_chunk = DataChunk::new_dummy(1);
let mut output = function.eval(&dummy_chunk).await;
while let Some(res) = output.next().await {
match res {
Ok(_) => {
assert!(expect_ok);
}
Err(ExprError::InvalidParam { .. }) => {
assert!(!expect_ok);
}
Err(_) => {
unreachable!();
}
}
}
}
}

0 comments on commit f9e3d99

Please sign in to comment.