From 1c2e5aaba56ccd3cf99c7228573cad22b5fe9809 Mon Sep 17 00:00:00 2001 From: William Wen <44139337+wenym1@users.noreply.github.com> Date: Tue, 17 Oct 2023 18:06:11 +0800 Subject: [PATCH] feat(expr): implement greatest and least function (#12838) --- .../batch/functions/greatest_least.slt.part | 73 ++++ proto/expr.proto | 2 + src/expr/impl/src/scalar/cmp.rs | 340 +++++++++++------- src/expr/impl/tests/sig.rs | 2 + src/frontend/src/binder/expr/function.rs | 3 + src/frontend/src/expr/pure.rs | 4 +- src/frontend/src/expr/type_inference/func.rs | 4 + 7 files changed, 303 insertions(+), 125 deletions(-) create mode 100644 e2e_test/batch/functions/greatest_least.slt.part diff --git a/e2e_test/batch/functions/greatest_least.slt.part b/e2e_test/batch/functions/greatest_least.slt.part new file mode 100644 index 0000000000000..43a7996fe7df8 --- /dev/null +++ b/e2e_test/batch/functions/greatest_least.slt.part @@ -0,0 +1,73 @@ +statement ok +create table t(id int, v1 int2, v2 int4, v3 int8); + +statement ok +insert into t values (1, 1, 2, 3), (2, 2, NULL, 5), (3, NULL, NULL, 8), (4, NULL, NULL, NULL); + +statement ok +flush; + +statement error +select greatest(v1, '123'); + +statement error +select greatest(); + +statement error +select least(); + +query I +select greatest(1, 2, 3); +---- +3 + +query I +select greatest(2); +---- +2 + +query I +select least(1, 2, 3); +---- +1 + +query I +select least(2); +---- +2 + +query I +select greatest(v1, v2, v3) from t order by id; +---- +3 +5 +8 +NULL + +query I +select least(v1, v2, v3) from t order by id; +---- +1 +2 +8 +NULL + +query I +select greatest(7, v3) from t order by id; +---- +7 +7 +8 +7 + +query I +select least(NULL, v1, 2) from t order by id; +---- +1 +2 +2 +2 + + +statement ok +drop table t; \ No newline at end of file diff --git a/proto/expr.proto b/proto/expr.proto index 2086554e78975..769532d8dbe19 100644 --- a/proto/expr.proto +++ b/proto/expr.proto @@ -33,6 +33,8 @@ message ExprNode { LESS_THAN_OR_EQUAL = 11; GREATER_THAN = 12; GREATER_THAN_OR_EQUAL = 13; + GREATEST = 14; + LEAST = 15; // logical operators AND = 21; OR = 22; diff --git a/src/expr/impl/src/scalar/cmp.rs b/src/expr/impl/src/scalar/cmp.rs index b5a38af0d44a6..beccbc9e6766d 100644 --- a/src/expr/impl/src/scalar/cmp.rs +++ b/src/expr/impl/src/scalar/cmp.rs @@ -16,6 +16,8 @@ use std::fmt::Debug; use risingwave_common::array::{Array, BoolArray}; use risingwave_common::buffer::Bitmap; +use risingwave_common::row::Row; +use risingwave_common::types::{Scalar, ScalarRef, ScalarRefImpl}; use risingwave_expr::function; #[function("equal(boolean, boolean) -> boolean", batch_fn = "boolarray_eq")] @@ -287,6 +289,66 @@ fn is_not_null(v: Option) -> bool { v.is_some() } +#[function("greatest(...) -> boolean")] +#[function("greatest(...) -> *int")] +#[function("greatest(...) -> decimal")] +#[function("greatest(...) -> *float")] +#[function("greatest(...) -> serial")] +#[function("greatest(...) -> int256")] +#[function("greatest(...) -> date")] +#[function("greatest(...) -> time")] +#[function("greatest(...) -> interval")] +#[function("greatest(...) -> timestamp")] +#[function("greatest(...) -> timestamptz")] +#[function("greatest(...) -> varchar")] +#[function("greatest(...) -> bytea")] +pub fn general_variadic_greatest(row: impl Row) -> Option +where + T: Scalar, + for<'a> ::ScalarRefType<'a>: TryFrom> + Ord + Debug, +{ + row.iter() + .flatten() + .map( + |scalar| match <::ScalarRefType<'_>>::try_from(scalar) { + Ok(v) => v, + Err(_) => unreachable!("all input type should have been aligned in the frontend"), + }, + ) + .max() + .map(|v| v.to_owned_scalar()) +} + +#[function("least(...) -> boolean")] +#[function("least(...) -> *int")] +#[function("least(...) -> decimal")] +#[function("least(...) -> *float")] +#[function("least(...) -> serial")] +#[function("least(...) -> int256")] +#[function("least(...) -> date")] +#[function("least(...) -> time")] +#[function("least(...) -> interval")] +#[function("least(...) -> timestamp")] +#[function("least(...) -> timestamptz")] +#[function("least(...) -> varchar")] +#[function("least(...) -> bytea")] +pub fn general_variadic_least(row: impl Row) -> Option +where + T: Scalar, + for<'a> ::ScalarRefType<'a>: TryFrom> + Ord + Debug, +{ + row.iter() + .flatten() + .map( + |scalar| match <::ScalarRefType<'_>>::try_from(scalar) { + Ok(v) => v, + Err(_) => unreachable!("all input type should have been aligned in the frontend"), + }, + ) + .min() + .map(|v| v.to_owned_scalar()) +} + // optimized functions for bool arrays fn boolarray_eq(l: &BoolArray, r: &BoolArray) -> BoolArray { @@ -365,7 +427,7 @@ fn batch_is_not_null(a: &impl Array) -> BoolArray { mod tests { use std::str::FromStr; - use risingwave_common::types::{Decimal, F32, F64}; + use risingwave_common::types::{Decimal, Timestamp, F32, F64}; use risingwave_expr::expr::build_from_pretty; use super::*; @@ -472,6 +534,26 @@ mod tests { test_binary_i32::(|x, y| x >= y, Type::GreaterThanOrEqual).await; test_binary_i32::(|x, y| x < y, Type::LessThan).await; test_binary_i32::(|x, y| x <= y, Type::LessThanOrEqual).await; + test_binary_inner::( + reduce(std::cmp::max::), + Type::Greatest, + ) + .await; + test_binary_inner::( + reduce(std::cmp::min::), + Type::Least, + ) + .await; + test_binary_inner::( + reduce(std::cmp::max::), + Type::Greatest, + ) + .await; + test_binary_inner::( + reduce(std::cmp::min::), + Type::Least, + ) + .await; test_binary_decimal::(|x, y| x + y, Type::Add).await; test_binary_decimal::(|x, y| x - y, Type::Subtract).await; test_binary_decimal::(|x, y| x * y, Type::Multiply).await; @@ -494,46 +576,132 @@ mod tests { .await; } - async fn test_binary_i32(f: F, kind: Type) + trait TestFrom: Copy { + const NAME: &'static str; + fn test_from(i: usize) -> Self; + } + + impl TestFrom for i32 { + const NAME: &'static str = "int4"; + + fn test_from(i: usize) -> Self { + i as i32 + } + } + + impl TestFrom for Decimal { + const NAME: &'static str = "decimal"; + + fn test_from(i: usize) -> Self { + i.into() + } + } + + impl TestFrom for bool { + const NAME: &'static str = "boolean"; + + fn test_from(i: usize) -> Self { + i % 2 == 0 + } + } + + impl TestFrom for Timestamp { + const NAME: &'static str = "timestamp"; + + fn test_from(_: usize) -> Self { + unimplemented!("not implemented as input yet") + } + } + + impl TestFrom for Interval { + const NAME: &'static str = "interval"; + + fn test_from(i: usize) -> Self { + Interval::from_ymd(0, i as _, i as _) + } + } + + impl TestFrom for Date { + const NAME: &'static str = "date"; + + fn test_from(i: usize) -> Self { + Date::from_num_days_from_ce_uncheck(i as i32) + } + } + + #[expect(clippy::type_complexity)] + fn gen_test_data( + count: usize, + f: impl Fn(Option, Option) -> Option, + ) -> (Vec>, Vec>, Vec>) { + let mut lhs = Vec::>::new(); + let mut rhs = Vec::>::new(); + let mut target = Vec::>::new(); + for i in 0..count { + let (l, r) = if i % 2 == 0 { + (Some(i), None) + } else if i % 3 == 0 { + (Some(i), Some(i + 1)) + } else if i % 5 == 0 { + (Some(i + 1), Some(i)) + } else if i % 7 == 0 { + (None, Some(i)) + } else { + (Some(i), Some(i)) + }; + let l = l.map(TestFrom::test_from); + let r = r.map(TestFrom::test_from); + lhs.push(l); + rhs.push(r); + target.push(f(l, r)); + } + (lhs, rhs, target) + } + + fn arithmetic(f: impl Fn(L, R) -> O) -> impl Fn(Option, Option) -> Option { + move |l, r| match (l, r) { + (Some(l), Some(r)) => Some(f(l, r)), + _ => None, + } + } + + fn reduce(f: impl Fn(I, I) -> I) -> impl Fn(Option, Option) -> Option { + move |l, r| match (l, r) { + (Some(l), Some(r)) => Some(f(l, r)), + (Some(l), None) => Some(l), + (None, Some(r)) => Some(r), + (None, None) => None, + } + } + + async fn test_binary_inner(f: F, kind: Type) where + L: Array, + L: for<'a> FromIterator<&'a Option<::OwnedItem>>, + ::OwnedItem: TestFrom, + R: Array, + R: for<'a> FromIterator<&'a Option<::OwnedItem>>, + ::OwnedItem: TestFrom, A: Array, for<'a> &'a A: std::convert::From<&'a ArrayImpl>, for<'a> ::RefItem<'a>: PartialEq, - F: Fn(i32, i32) -> ::OwnedItem, + ::OwnedItem: TestFrom, + F: Fn( + Option<::OwnedItem>, + Option<::OwnedItem>, + ) -> Option<::OwnedItem>, { - let mut lhs = Vec::>::new(); - let mut rhs = Vec::>::new(); - let mut target = Vec::::OwnedItem>>::new(); - for i in 0..100 { - if i % 2 == 0 { - lhs.push(Some(i)); - rhs.push(None); - target.push(None); - } else if i % 3 == 0 { - lhs.push(Some(i)); - rhs.push(Some(i + 1)); - target.push(Some(f(i, i + 1))); - } else if i % 5 == 0 { - lhs.push(Some(i + 1)); - rhs.push(Some(i)); - target.push(Some(f(i + 1, i))); - } else { - lhs.push(Some(i)); - rhs.push(Some(i)); - target.push(Some(f(i, i))); - } - } + let (lhs, rhs, target) = gen_test_data(100, f); - let col1 = I32Array::from_iter(&lhs).into_ref(); - let col2 = I32Array::from_iter(&rhs).into_ref(); + let col1 = L::from_iter(&lhs).into_ref(); + let col2 = R::from_iter(&rhs).into_ref(); let data_chunk = DataChunk::new(vec![col1, col2], 100); - let ty = match kind { - Type::Add | Type::Subtract | Type::Multiply | Type::Divide => "int4", - _ => "boolean", - }; + let l_name = <::OwnedItem as TestFrom>::NAME; + let r_name = <::OwnedItem as TestFrom>::NAME; + let output_name = <::OwnedItem as TestFrom>::NAME; let expr = build_from_pretty(format!( - "({name}:{ty} $0:int4 $1:int4)", - name = kind.as_str_name() + "({name}:{output_name} $0:{l_name} $1:{r_name})", + name = kind.as_str_name(), )); let res = expr.eval(&data_chunk).await.unwrap(); let arr: &A = res.as_ref().into(); @@ -553,54 +721,26 @@ mod tests { } } + async fn test_binary_i32(f: F, kind: Type) + where + A: Array, + for<'a> &'a A: std::convert::From<&'a ArrayImpl>, + for<'a> ::RefItem<'a>: PartialEq, + ::OwnedItem: TestFrom, + F: Fn(i32, i32) -> ::OwnedItem, + { + test_binary_inner::(arithmetic(f), kind).await + } + async fn test_binary_interval(f: F, kind: Type) where A: Array, for<'a> &'a A: std::convert::From<&'a ArrayImpl>, for<'a> ::RefItem<'a>: PartialEq, + ::OwnedItem: TestFrom, F: Fn(Date, Interval) -> ::OwnedItem, { - let mut lhs = Vec::>::new(); - let mut rhs = Vec::>::new(); - let mut target = Vec::::OwnedItem>>::new(); - for i in 0..100 { - if i % 2 == 0 { - rhs.push(Some(Interval::from_ymd(0, i, i))); - lhs.push(None); - target.push(None); - } else { - rhs.push(Some(Interval::from_ymd(0, i, i))); - lhs.push(Some(Date::from_num_days_from_ce_uncheck(i))); - target.push(Some(f( - Date::from_num_days_from_ce_uncheck(i), - Interval::from_ymd(0, i, i), - ))); - } - } - - let col1 = DateArray::from_iter(&lhs).into_ref(); - let col2 = IntervalArray::from_iter(&rhs).into_ref(); - let data_chunk = DataChunk::new(vec![col1, col2], 100); - let expr = build_from_pretty(format!( - "({name}:timestamp $0:date $1:interval)", - name = kind.as_str_name() - )); - let res = expr.eval(&data_chunk).await.unwrap(); - let arr: &A = res.as_ref().into(); - for (idx, item) in arr.iter().enumerate() { - let x = target[idx].as_ref().map(|x| x.as_scalar_ref()); - assert_eq!(x, item); - } - - for i in 0..lhs.len() { - let row = OwnedRow::new(vec![ - lhs[i].map(|date| date.to_scalar_value()), - rhs[i].map(|date| date.to_scalar_value()), - ]); - let result = expr.eval_row(&row).await.unwrap(); - let expected = target[i].as_ref().cloned().map(|x| x.to_scalar_value()); - assert_eq!(result, expected); - } + test_binary_inner::(arithmetic(f), kind).await } async fn test_binary_decimal(f: F, kind: Type) @@ -608,57 +748,9 @@ mod tests { A: Array, for<'a> &'a A: std::convert::From<&'a ArrayImpl>, for<'a> ::RefItem<'a>: PartialEq, + ::OwnedItem: TestFrom, F: Fn(Decimal, Decimal) -> ::OwnedItem, { - let mut lhs = Vec::>::new(); - let mut rhs = Vec::>::new(); - let mut target = Vec::::OwnedItem>>::new(); - for i in 0..100 { - if i % 2 == 0 { - lhs.push(Some(i.into())); - rhs.push(None); - target.push(None); - } else if i % 3 == 0 { - lhs.push(Some(i.into())); - rhs.push(Some((i + 1).into())); - target.push(Some(f((i).into(), (i + 1).into()))); - } else if i % 5 == 0 { - lhs.push(Some((i + 1).into())); - rhs.push(Some((i).into())); - target.push(Some(f((i + 1).into(), (i).into()))); - } else { - lhs.push(Some((i).into())); - rhs.push(Some((i).into())); - target.push(Some(f((i).into(), (i).into()))); - } - } - - let col1 = DecimalArray::from_iter(&lhs).into_ref(); - let col2 = DecimalArray::from_iter(&rhs).into_ref(); - let data_chunk = DataChunk::new(vec![col1, col2], 100); - let ty = match kind { - Type::Add | Type::Subtract | Type::Multiply | Type::Divide => "decimal", - _ => "boolean", - }; - let expr = build_from_pretty(format!( - "({name}:{ty} $0:decimal $1:decimal)", - name = kind.as_str_name() - )); - let res = expr.eval(&data_chunk).await.unwrap(); - let arr: &A = res.as_ref().into(); - for (idx, item) in arr.iter().enumerate() { - let x = target[idx].as_ref().map(|x| x.as_scalar_ref()); - assert_eq!(x, item); - } - - for i in 0..lhs.len() { - let row = OwnedRow::new(vec![ - lhs[i].map(|dec| dec.to_scalar_value()), - rhs[i].map(|dec| dec.to_scalar_value()), - ]); - let result = expr.eval_row(&row).await.unwrap(); - let expected = target[i].as_ref().cloned().map(|x| x.to_scalar_value()); - assert_eq!(result, expected); - } + test_binary_inner::(arithmetic(f), kind).await } } diff --git a/src/expr/impl/tests/sig.rs b/src/expr/impl/tests/sig.rs index 1a227e9472042..95798d8284929 100644 --- a/src/expr/impl/tests/sig.rs +++ b/src/expr/impl/tests/sig.rs @@ -74,6 +74,8 @@ fn test_func_sig_map() { "cast(smallint) -> rw_int256/numeric/double precision/real/bigint/integer/character varying", "cast(time without time zone) -> interval/character varying", "cast(timestamp without time zone) -> time without time zone/date/character varying", + "greatest() -> bytea/character varying/timestamp with time zone/timestamp without time zone/interval/time without time zone/date/rw_int256/serial/real/double precision/numeric/smallint/integer/bigint/boolean", + "least() -> bytea/character varying/timestamp with time zone/timestamp without time zone/interval/time without time zone/date/rw_int256/serial/real/double precision/numeric/smallint/integer/bigint/boolean", ] "#]]; expected.assert_debug_eq(&duplicated); diff --git a/src/frontend/src/binder/expr/function.rs b/src/frontend/src/binder/expr/function.rs index a4cb82528cb37..18438b28c0a98 100644 --- a/src/frontend/src/binder/expr/function.rs +++ b/src/frontend/src/binder/expr/function.rs @@ -866,6 +866,9 @@ impl Binder { ("jsonb_array_length", raw_call(ExprType::JsonbArrayLength)), // Functions that return a constant value ("pi", pi()), + // greatest and least + ("greatest", raw_call(ExprType::Greatest)), + ("least", raw_call(ExprType::Least)), // System information operations. ( "pg_typeof", diff --git a/src/frontend/src/expr/pure.rs b/src/frontend/src/expr/pure.rs index 71b3a2e20f475..44300073fd678 100644 --- a/src/frontend/src/expr/pure.rs +++ b/src/frontend/src/expr/pure.rs @@ -198,7 +198,9 @@ impl ExprVisitor for ImpureAnalyzer { | expr_node::Type::ArrayPositions | expr_node::Type::StringToArray | expr_node::Type::Format - | expr_node::Type::ArrayTransform => + | expr_node::Type::ArrayTransform + | expr_node::Type::Greatest + | expr_node::Type::Least => // expression output is deterministic(same result for the same input) { let x = func_call diff --git a/src/frontend/src/expr/type_inference/func.rs b/src/frontend/src/expr/type_inference/func.rs index 2e7eebf42362f..84e315dacae45 100644 --- a/src/frontend/src/expr/type_inference/func.rs +++ b/src/frontend/src/expr/type_inference/func.rs @@ -537,6 +537,10 @@ fn infer_type_for_special( ensure_arity!("vnode", 1 <= | inputs |); Ok(Some(DataType::Int16)) } + ExprType::Greatest | ExprType::Least => { + ensure_arity!("greatest/least", 1 <= | inputs |); + Ok(Some(align_types(inputs.iter_mut())?)) + } _ => Ok(None), } }