Skip to content

Commit

Permalink
feat(expr): add array_sum (#12162)
Browse files Browse the repository at this point in the history
Signed-off-by: Runji Wang <[email protected]>
Co-authored-by: Runji Wang <[email protected]>
  • Loading branch information
xzhseh and wangrunji0408 authored Sep 14, 2023
1 parent 28bbf10 commit a566cfe
Show file tree
Hide file tree
Showing 8 changed files with 152 additions and 0 deletions.
46 changes: 46 additions & 0 deletions e2e_test/batch/functions/array_sum.slt.part
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
query I
select array_sum(array[1, 2, 3]);
----
6

# Testing for int16 with positive numbers
query I
select array_sum(array[10, 20, 30]);
----
60

# Testing for int16 with a mix of positive and negative numbers
query I
select array_sum(array[-10, 20, -30]);
----
-20

# Testing for int16 with all zeros
query I
select array_sum(array[0, 0, 0]);
----
0

# Testing for int32 with larger positive numbers
query I
select array_sum(array[1000, 2000, 3000]);
----
6000

# Testing for int32 with a mix of larger positive and negative numbers
query I
select array_sum(array[-1000, 2000, -3000]);
----
-2000

# Testing for int64 with much larger numbers
query I
select array_sum(array[1000000000, 2000000000, 3000000000]);
----
6000000000

# Testing for int64 with a mix of much larger positive and negative numbers
query I
select array_sum(array[-1000000000, 2000000000, -3000000000]);
----
-2000000000
1 change: 1 addition & 0 deletions proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ message ExprNode {
ARRAY_TRANSFORM = 545;
ARRAY_MIN = 546;
ARRAY_MAX = 547;
ARRAY_SUM = 548;
ARRAY_SORT = 549;

// Int256 functions
Expand Down
3 changes: 3 additions & 0 deletions src/expr/src/sig/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,9 @@ mod tests {
ArrayMax: [
"array_max(list) -> bytea/varchar/timestamptz/timestamp/time/date/int256/serial/decimal/float32/float64/int16/int32/int64",
],
ArraySum: [
"array_sum(list) -> interval/decimal/float64/float32/int64",
],
}
"#]];
expected.assert_debug_eq(&duplicated);
Expand Down
84 changes: 84 additions & 0 deletions src/expr/src/vector_op/array_sum.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// Copyright 2023 RisingWave Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use risingwave_common::array::{ArrayError, ListRef};
use risingwave_common::types::{CheckedAdd, Decimal, Scalar, ScalarImpl, ScalarRefImpl};
use risingwave_expr_macro::function;

use crate::{ExprError, Result};

/// `array_sum(int16`[]) -> int64
/// `array_sum(int32`[]) -> int64
#[function("array_sum(list) -> int64")]
#[function("array_sum(list) -> float32")]
#[function("array_sum(list) -> float64")]
/// `array_sum(int64`[]) -> decimal
/// `array_sum(decimal`[]) -> decimal
#[function("array_sum(list) -> decimal")]
#[function("array_sum(list) -> interval")]
fn array_sum<T: Scalar>(list: ListRef<'_>) -> Result<Option<T>>
where
T: Default + for<'a> TryFrom<ScalarRefImpl<'a>, Error = ArrayError> + CheckedAdd<Output = T>,
{
let flag = match list.iter().flatten().next() {
Some(v) => match v {
ScalarRefImpl::Int16(_) | ScalarRefImpl::Int32(_) => 1,
ScalarRefImpl::Int64(_) => 2,
_ => 0,
},
None => return Ok(None),
};

if flag != 0 {
match flag {
1 => {
let mut sum = 0;
for e in list.iter().flatten() {
sum = sum
.checked_add(match e {
ScalarRefImpl::Int16(v) => v as i64,
ScalarRefImpl::Int32(v) => v as i64,
_ => panic!("Expect ScalarRefImpl::Int16 or ScalarRefImpl::Int32"),
})
.ok_or_else(|| ExprError::NumericOutOfRange)?;
}
Ok(Some(ScalarImpl::from(sum).try_into()?))
}
2 => {
let mut sum = Decimal::Normalized(0.into());
for e in list.iter().flatten() {
sum = sum
.checked_add(match e {
ScalarRefImpl::Int64(v) => Decimal::Normalized(v.into()),
ScalarRefImpl::Decimal(v) => v,
// FIXME: We can't panic here due to the macro expansion
_ => Decimal::Normalized(0.into()),
})
.ok_or_else(|| ExprError::NumericOutOfRange)?;
}
Ok(Some(ScalarImpl::from(sum).try_into()?))
}
_ => Ok(None),
}
} else {
let mut sum = T::default();
for e in list.iter().flatten() {
let v = e.try_into()?;
sum = sum
.checked_add(v)
.ok_or_else(|| ExprError::NumericOutOfRange)?;
}
Ok(Some(sum))
}
}
1 change: 1 addition & 0 deletions src/expr/src/vector_op/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ pub mod array_range_access;
pub mod array_remove;
pub mod array_replace;
pub mod array_sort;
pub mod array_sum;
pub mod array_to_string;
pub mod ascii;
pub mod bitwise_op;
Expand Down
1 change: 1 addition & 0 deletions src/frontend/src/binder/expr/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -801,6 +801,7 @@ impl Binder {
("array_remove", raw_call(ExprType::ArrayRemove)),
("array_replace", raw_call(ExprType::ArrayReplace)),
("array_max", raw_call(ExprType::ArrayMax)),
("array_sum", raw_call(ExprType::ArraySum)),
("array_position", raw_call(ExprType::ArrayPosition)),
("array_positions", raw_call(ExprType::ArrayPositions)),
("trim_array", raw_call(ExprType::TrimArray)),
Expand Down
1 change: 1 addition & 0 deletions src/frontend/src/expr/pure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ impl ExprVisitor<bool> for ImpureAnalyzer {
| expr_node::Type::ArrayToString
| expr_node::Type::ArrayCat
| expr_node::Type::ArrayMax
| expr_node::Type::ArraySum
| expr_node::Type::ArraySort
| expr_node::Type::ArrayAppend
| expr_node::Type::ArrayPrepend
Expand Down
15 changes: 15 additions & 0 deletions src/frontend/src/expr/type_inference/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,21 @@ fn infer_type_for_special(

Ok(Some(inputs[0].return_type().as_list().clone()))
}
ExprType::ArraySum => {
ensure_arity!("array_sum", | inputs | == 1);
inputs[0].ensure_array_type()?;

let return_type = match inputs[0].return_type().as_list().clone() {
DataType::Int16 | DataType::Int32 => DataType::Int64,
DataType::Int64 | DataType::Decimal => DataType::Decimal,
DataType::Float32 => DataType::Float32,
DataType::Float64 => DataType::Float64,
DataType::Interval => DataType::Interval,
_ => return Err(ErrorCode::InvalidParameterValue("".to_string()).into()),
};

Ok(Some(return_type))
}
ExprType::StringToArray => {
ensure_arity!("string_to_array", 2 <= | inputs | <= 3);

Expand Down

0 comments on commit a566cfe

Please sign in to comment.