Skip to content

Commit

Permalink
feat(udf): support UDAF for Python and JS (#16874)
Browse files Browse the repository at this point in the history
Signed-off-by: Runji Wang <[email protected]>
  • Loading branch information
wangrunji0408 authored May 24, 2024
1 parent 37810fe commit 7bd4e04
Show file tree
Hide file tree
Showing 65 changed files with 1,137 additions and 338 deletions.
34 changes: 22 additions & 12 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,10 @@ arrow-flight = "50"
arrow-select = "50"
arrow-ord = "50"
arrow-row = "50"
arrow-udf-js = "0.2"
arrow-udf-js = "0.3"
arrow-udf-js-deno = { git = "https://github.com/risingwavelabs/arrow-udf.git", rev = "fa36365" }
arrow-udf-wasm = { version = "0.2.2", features = ["build"] }
arrow-udf-python = "0.1"
arrow-udf-python = "0.2"
arrow-udf-flight = "0.1"
arrow-array-deltalake = { package = "arrow-array", version = "48.0.1" }
arrow-buffer-deltalake = { package = "arrow-buffer", version = "48.0.1" }
Expand Down
83 changes: 83 additions & 0 deletions e2e_test/udf/js_udf.slt
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,86 @@ wave 4

statement ok
drop function split;


# aggregate function
statement ok
create aggregate weighted_avg(value int, weight int) returns float language javascript as $$
export function create_state() {
return {sum: 0, weight: 0};
}
export function accumulate(state, value, weight) {
if (value == null || weight == null) {
return state;
}
state.sum += value * weight;
state.weight += weight;
return state;
}
export function retract(state, value, weight) {
if (value == null || weight == null) {
return state;
}
state.sum -= value * weight;
state.weight -= weight;
return state;
}
export function finish(state) {
if (state.weight == 0) {
return null;
}
return state.sum / state.weight;
}
$$;

# batch
query F
select weighted_avg(value, weight) from (values (1, 1), (null, 2), (3, 3)) as t(value, weight);
----
2.5

# streaming
statement ok
create table t(value int, weight int);

statement ok
create materialized view mv as select weighted_avg(value, weight) from t;

query F
select * from mv;
----
NULL

statement ok
insert into t values (1, 1), (null, 2), (3, 3);

statement ok
flush;

query F
select * from mv;
----
2.5

statement ok
delete from t where value = 3;

statement ok
flush;

query F
select * from mv;
----
1

statement ok
drop materialized view mv;

statement ok
drop table t;

statement error "weighted_avg" is an aggregate function
drop function weighted_avg;

statement ok
drop aggregate weighted_avg;
81 changes: 81 additions & 0 deletions e2e_test/udf/python_udf.slt
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,87 @@ wave 4
statement ok
drop function split;


# aggregate function
statement ok
create aggregate weighted_avg(value int, weight int) returns float language python as $$
def create_state():
return (0, 0)
def accumulate(state, value, weight):
if value is None or weight is None:
return state
(s, w) = state
s += value * weight
w += weight
return (s, w)
def retract(state, value, weight):
if value is None or weight is None:
return state
(s, w) = state
s -= value * weight
w -= weight
return (s, w)
def finish(state):
(sum, weight) = state
if weight == 0:
return None
else:
return sum / weight
$$;

# batch
query F
select weighted_avg(value, weight) from (values (1, 1), (null, 2), (3, 3)) as t(value, weight);
----
2.5

# streaming
statement ok
create table t(value int, weight int);

statement ok
create materialized view mv as select weighted_avg(value, weight) from t;

query F
select * from mv;
----
NULL

statement ok
insert into t values (1, 1), (null, 2), (3, 3);

statement ok
flush;

query F
select * from mv;
----
2.5

statement ok
delete from t where value = 3;

statement ok
flush;

query F
select * from mv;
----
1

statement ok
drop materialized view mv;

statement ok
drop table t;

statement error "weighted_avg" is an aggregate function
drop function weighted_avg;

statement ok
drop aggregate weighted_avg;


statement ok
create function mismatched_arguments() returns int language python as $$
def mismatched_arguments(x):
Expand Down
16 changes: 10 additions & 6 deletions proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -340,13 +340,13 @@ message TableFunction {
JSONB_POPULATE_RECORDSET = 16;
JSONB_TO_RECORDSET = 17;
// User defined table function
UDTF = 100;
USER_DEFINED = 100;
}
Type function_type = 1;
repeated expr.ExprNode args = 2;
data.DataType return_type = 3;
// optional. only used when the type is UDTF.
UserDefinedTableFunction udtf = 4;
// optional. only used when the type is USER_DEFINED.
UserDefinedFunctionMetadata udf = 4;
}

// Reference to an upstream column, containing its index and data type.
Expand Down Expand Up @@ -428,6 +428,8 @@ message AggCall {
LAST_VALUE = 25;
GROUPING = 26;
INTERNAL_LAST_SEEN_VALUE = 27;
// user defined aggregate function
USER_DEFINED = 100;
}
Type type = 1;
repeated InputRef args = 2;
Expand All @@ -436,6 +438,8 @@ message AggCall {
repeated common.ColumnOrder order_by = 5;
ExprNode filter = 6;
repeated Constant direct_args = 7;
// optional. only used when the type is USER_DEFINED.
UserDefinedFunctionMetadata udf = 8;
}

message WindowFrame {
Expand Down Expand Up @@ -528,7 +532,7 @@ message WindowFunction {
}

// Note: due to historic reasons, UserDefinedFunction is a oneof variant parallel to FunctionCall,
// while UserDefinedTableFunction is embedded as a field in TableFunction.
// while UserDefinedFunctionMetadata is embedded as a field in TableFunction and AggCall.

message UserDefinedFunction {
repeated ExprNode children = 1;
Expand All @@ -554,8 +558,8 @@ message UserDefinedFunction {
optional string function_type = 12;
}

// Additional information for user defined table functions.
message UserDefinedTableFunction {
// Additional information for user defined table/aggregate functions.
message UserDefinedFunctionMetadata {
repeated string arg_names = 8;
repeated data.DataType arg_types = 3;
string language = 4;
Expand Down
1 change: 1 addition & 0 deletions src/batch/benches/hash_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ fn create_agg_call(
order_by: vec![],
filter: None,
direct_args: vec![],
udf: None,
}
}

Expand Down
10 changes: 5 additions & 5 deletions src/batch/src/executor/aggregation/distinct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,12 @@ impl AggregateFunction for Distinct {
self.inner.return_type()
}

fn create_state(&self) -> AggregateState {
AggregateState::Any(Box::new(State {
inner: self.inner.create_state(),
fn create_state(&self) -> Result<AggregateState> {
Ok(AggregateState::Any(Box::new(State {
inner: self.inner.create_state()?,
exists: HashSet::new(),
exists_estimated_heap_size: 0,
}))
})))
}

async fn update(&self, state: &mut AggregateState, input: &StreamChunk) -> Result<()> {
Expand Down Expand Up @@ -203,7 +203,7 @@ mod tests {

fn test_agg(pretty: &str, input: StreamChunk, expected: Datum) {
let agg = build(&AggCall::from_pretty(pretty)).unwrap();
let mut state = agg.create_state();
let mut state = agg.create_state().unwrap();
agg.update(&mut state, &input)
.now_or_never()
.unwrap()
Expand Down
Loading

0 comments on commit 7bd4e04

Please sign in to comment.