Skip to content

Commit

Permalink
feat(expr): support using scalar functions as aggregates in batch que…
Browse files Browse the repository at this point in the history
…ry (#17622)

Signed-off-by: Runji Wang <[email protected]>
  • Loading branch information
wangrunji0408 authored Jul 23, 2024
1 parent 285afdb commit 063db16
Show file tree
Hide file tree
Showing 53 changed files with 794 additions and 582 deletions.
9 changes: 9 additions & 0 deletions e2e_test/batch/aggregate/scalar_as_agg.slt.part
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
query T
select aggregate:array_sum(v) from (values (3), (2), (1)) as t(v);
----
6

query T
select aggregate:array_sort(v) from (values (3), (2), (1)) as t(v);
----
{1,2,3}
26 changes: 26 additions & 0 deletions e2e_test/udf/python_udf.slt
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,32 @@ statement ok
drop aggregate weighted_avg;


# UDF as aggregate function
statement ok
create function median(int[]) returns float language python as $$
def median(values):
values.sort()
n = len(values)
if n % 2 == 0:
return (values[n // 2 - 1] + values[n // 2]) / 2
else:
return values[n // 2]
$$;

query F
select aggregate:median(x) from (values (1), (2), (3), (4), (5)) as t(x);
----
3

query F
select aggregate:median(x) from (values (4), (3), (2), (1)) as t(x);
----
2.5

statement ok
drop function median;


statement ok
create function mismatched_arguments() returns int language python as $$
def mismatched_arguments(x):
Expand Down
5 changes: 5 additions & 0 deletions proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,8 @@ message AggCall {
INTERNAL_LAST_SEEN_VALUE = 27;
// user defined aggregate function
USER_DEFINED = 100;
// wraps a scalar function that takes a list as input as an aggregate function.
WRAP_SCALAR = 101;
}
Type type = 1;
repeated InputRef args = 2;
Expand All @@ -448,6 +450,8 @@ message AggCall {
repeated Constant direct_args = 7;
// optional. only used when the type is USER_DEFINED.
UserDefinedFunctionMetadata udf = 8;
// optional. only used when the type is WRAP_SCALAR.
ExprNode scalar = 9;
}

message WindowFrame {
Expand Down Expand Up @@ -579,6 +583,7 @@ message UserDefinedFunction {
message UserDefinedFunctionMetadata {
repeated string arg_names = 8;
repeated data.DataType arg_types = 3;
data.DataType return_type = 13;
string language = 4;
optional string link = 5;
optional string identifier = 6;
Expand Down
23 changes: 12 additions & 11 deletions src/batch/benches/hash_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use risingwave_common::catalog::{Field, Schema};
use risingwave_common::memory::MemoryContext;
use risingwave_common::types::DataType;
use risingwave_common::{enable_jemalloc, hash};
use risingwave_expr::aggregate::{AggCall, AggKind};
use risingwave_expr::aggregate::{AggCall, AggKind, PbAggKind};
use risingwave_pb::expr::{PbAggCall, PbInputRef};
use tokio::runtime::Runtime;
use utils::{create_input, execute_executor};
Expand Down Expand Up @@ -53,6 +53,7 @@ fn create_agg_call(
filter: None,
direct_args: vec![],
udf: None,
scalar: None,
}
}

Expand Down Expand Up @@ -119,15 +120,15 @@ fn bench_hash_agg(c: &mut Criterion) {

let bench_variants = [
// (group by, agg, args, return type)
(vec![0], AggKind::Sum, vec![1], DataType::Int64),
(vec![0], AggKind::Count, vec![], DataType::Int64),
(vec![0], AggKind::Count, vec![2], DataType::Int64),
(vec![0], AggKind::Min, vec![1], DataType::Int64),
(vec![0], AggKind::StringAgg, vec![2], DataType::Varchar),
(vec![0, 2], AggKind::Sum, vec![1], DataType::Int64),
(vec![0, 2], AggKind::Count, vec![], DataType::Int64),
(vec![0, 2], AggKind::Count, vec![2], DataType::Int64),
(vec![0, 2], AggKind::Min, vec![1], DataType::Int64),
(vec![0], PbAggKind::Sum, vec![1], DataType::Int64),
(vec![0], PbAggKind::Count, vec![], DataType::Int64),
(vec![0], PbAggKind::Count, vec![2], DataType::Int64),
(vec![0], PbAggKind::Min, vec![1], DataType::Int64),
(vec![0], PbAggKind::StringAgg, vec![2], DataType::Varchar),
(vec![0, 2], PbAggKind::Sum, vec![1], DataType::Int64),
(vec![0, 2], PbAggKind::Count, vec![], DataType::Int64),
(vec![0, 2], PbAggKind::Count, vec![2], DataType::Int64),
(vec![0, 2], PbAggKind::Min, vec![1], DataType::Int64),
];

for (group_key_columns, agg_kind, arg_columns, return_type) in bench_variants {
Expand All @@ -141,7 +142,7 @@ fn bench_hash_agg(c: &mut Criterion) {
|| {
create_hash_agg_executor(
group_key_columns.clone(),
agg_kind,
agg_kind.into(),
arg_columns.clone(),
return_type.clone(),
chunk_size,
Expand Down
4 changes: 4 additions & 0 deletions src/batch/src/executor/hash_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -805,6 +805,7 @@ mod tests {
filter: None,
direct_args: vec![],
udf: None,
scalar: None,
};

let agg_prost = HashAggNode {
Expand Down Expand Up @@ -883,6 +884,7 @@ mod tests {
filter: None,
direct_args: vec![],
udf: None,
scalar: None,
};

let agg_prost = HashAggNode {
Expand Down Expand Up @@ -1000,6 +1002,7 @@ mod tests {
filter: None,
direct_args: vec![],
udf: None,
scalar: None,
};

let agg_prost = HashAggNode {
Expand Down Expand Up @@ -1092,6 +1095,7 @@ mod tests {
filter: None,
direct_args: vec![],
udf: None,
scalar: None,
};

let agg_prost = HashAggNode {
Expand Down
Loading

0 comments on commit 063db16

Please sign in to comment.