From 143e7c8dc7734c02b8c98718723cf3f1149b8de6 Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Wed, 26 Jul 2023 14:54:49 +0800 Subject: [PATCH 1/6] stash Signed-off-by: Runji Wang --- rfcs/0000-user-defined-aggregate-functions.md | 225 ++++++++++++++++++ 1 file changed, 225 insertions(+) create mode 100644 rfcs/0000-user-defined-aggregate-functions.md diff --git a/rfcs/0000-user-defined-aggregate-functions.md b/rfcs/0000-user-defined-aggregate-functions.md new file mode 100644 index 00000000..d7953f8b --- /dev/null +++ b/rfcs/0000-user-defined-aggregate-functions.md @@ -0,0 +1,225 @@ +--- +feature: user_defined_aggregate_functions +authors: + - "Runji Wang" +start_date: "2023/07/25" +--- + +# User Defined Aggregate Functions (UDAF) + +## Summary + +This RFC proposes the user interface and implementation of user defined aggregate functions (UDAF). + +## User Interfaces + +This section describes how users create UDAFs in Python, Java, and SQL through examples. + +### Python API + +Similar to scalar functions and table functions, we provide an `@udaf` decorator to define aggregate functions. +The difference is that below the decorator, we define a class instead of a function. +The class is an **accumulator** that accumulates the input rows and computes the aggregate result. +It can also optionally retracts a row and merges with another accumulator. + +```python +from risingwave.udf import udaf + +# The aggregate function is defined as a class. +# Specify the schema of the aggregate function in the `udaf` decorator. +@udaf(input_types=['BIGINT', 'INT'], result_type='BIGINT') +class WeightedAvg: + # The internal state of the accumulator is defined as fields. + # They will be serialized into bytes before sending to kernel, + # and deserialized from bytes after receiving from kernel. + sum: int + count: int + + # Initialize the accumulator. + def __init__(self): + self.sum = 0 + self.count = 0 + + # Get the aggregate result. + # The return value should match `result_type`. + def get_value(self) -> int: + if self.count == 0: + return None + else: + return self.sum / self.count + + # Accumulate a row. + # The arguments should match `input_types`. + def accumulate(self, value: int, weight: int): + self.sum += value * weight + self.count += weight + + # Retract a row. + # This method is optional. If not defined, the function is append-only and not retractable. + def retract(self, value: int, weight: int): + self.sum -= value * weight + self.count -= weight + + # Merge with another accumulator. + # This method is optional. If defined, the function can be optimized with two-phase aggregation. + def merge(self, other: WeightedAvg): + self.count += other.count + self.sum += other.sum +``` + +#### Alternative + +In Flink, the accumulator is a separate variable passed into functions as an argument. + +```python +class WeightedAvg(AggregateFunction): + def get_accumulator_type(self): + return 'ROW' + + def create_accumulator(self): + # Row(sum, count) + return Row(0, 0) + + def get_value(self, accumulator): + # ... + + def accumulate(self, accumulator, value, weight): + accumulator[0] += value * weight + accumulator[1] += weight +``` + +While in this proposal, the accumulator is the class itself. Users don't have to define the accumulator type explicitly. +The reason we can do this is that we serialize the state into bytes before storing it in the state table. +It is opaque to RisingWave. We don't care about the internal structure. + +Reference: [UDAF Python API in Flink](https://nightlies.apache.org/flink/flink-docs-release-1.17/docs/dev/python/table/udfs/python_udfs/#aggregate-functions) + +### Java API + +The Java API is similar to Python. +One of the differences is that users don't need to specify the input and output types, +because we can infer them from the function signatures. + +```java +import com.risingwave.functions.AggregateFunction; + +// The aggregate function is defined as a class, which implements the `AggregateFunction` interface. +public static class WeightedAvg implements AggregateFunction { + // The internal state of the accumulator is defined as class fields. + public long sum = 0; + public int count = 0; + + // Create a new accumulator. + public WeightedAvg() {} + + // Get the aggregate result. + // The result type is inferred from the signature. (BIGINT) + // If a Java type can not infer to an unique SQL type, it should be annotated with `@DataTypeHint`. + public Long getValue() { + if (count == 0) { + return null; + } else { + return sum / count; + } + } + + // Accumulate a row. + // The input types are inferred from the signatures. (BIGINT, INT) + // If a Java type can not infer to an unique SQL type, it should be annotated with `@DataTypeHint`. + public void accumulate(long iValue, int iWeight) { + sum += iValue * iWeight; + count += iWeight; + } + + // Retract a row. (optional) + // The function signature should match `accumulate`. + public void retract(long iValue, int iWeight) { + sum -= iValue * iWeight; + count -= iWeight; + } + + // Merge with another accumulator. (optional) + public void merge(WeightedAvg a) { + count += a.count; + sum += a.sum; + } +} +``` + +Reference: [UDAF Java API in Flink](https://nightlies.apache.org/flink/flink-docs-release-1.17/docs/dev/table/functions/udfs/#aggregate-functions) + +### SQL API + +After setting up the UDF server, users can register the function through `create aggregate` in RisingWave. + +```sql +create aggregate weighted_avg(value bigint, weight int) returns bigint +as 'weighted_avg' using link 'http://localhost:8815'; +``` + +The full syntax is: + +```sql +CREATE [ OR REPLACE ] AGGREGATE name ( [ argname ] arg_data_type [ , ... ] ) +[ RETURNS return_data_type ] [ APPEND ONLY ] +[ LANGUAGE language ] [ AS identifier ] [ USING LINK link ]; +``` + +The `RETURNS` clause is optional if the function has exactly one input argument. +In this case the return type is same as the input type. For example: + +```sql +create aggregate max(int) as ...; +``` + +The `APPEND ONLY` clause is required if the function doesn't have `retract` method. + +The remaining clauses are defined the same as for `create function` statement. + +## Implementations + +User defined functions are running in a separate process called UDF server. +The RisingWave kernel communicates with the UDF server through Arrow Flight RPC to exchange data. + +To avoid trouble in fault tolerance, the UDF server should be stateless even if the aggregate function is stateful. +However, attaching the state to each RPC call (like batch accumulation) is not efficient, especially when the state is large. +On the other hand, kernel doesn't need to know the state or the aggregate result before a barrier arrives. +Therefore, we create a streaming RPC to perform all operations within an epoch. + +```mermaid +sequenceDiagram + participant UP as Upstream + participant AG as Agg Executor + participant CP as UDF Server + UP ->> AG: init(epoch) + Note over AG: load state + AG ->>+ CP: doExchange(fid, state) + Note right of CP: decode state
create accumulator + loop epoch + loop chunks + UP ->> AG: chunk + AG ->> CP: chunk + Note right of CP: accumulate
retract + end + UP ->> AG: barrier(epoch+1) + AG ->> CP: finish + Note right of CP: encode state
get_value + CP -->> AG: state + output + Note over AG: store state + end + deactivate CP +``` + +## Unresolved questions + +* Are there some questions that haven't been resolved in the RFC? +* Can they be resolved in some future RFCs? +* Move some meaningful comments to here. + +## Alternatives + +What other designs have been considered and what is the rationale for not choosing them? + +## Future possibilities + +Some potential extensions or optimizations can be done in the future based on the RFC. From e247595694ccfa753358ebba85ef3430764d8ae7 Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Wed, 2 Aug 2023 17:08:54 +0800 Subject: [PATCH 2/6] add design of state store Signed-off-by: Runji Wang --- rfcs/0000-user-defined-aggregate-functions.md | 93 +++++++++++-------- 1 file changed, 56 insertions(+), 37 deletions(-) diff --git a/rfcs/0000-user-defined-aggregate-functions.md b/rfcs/0000-user-defined-aggregate-functions.md index d7953f8b..88d03298 100644 --- a/rfcs/0000-user-defined-aggregate-functions.md +++ b/rfcs/0000-user-defined-aggregate-functions.md @@ -2,14 +2,14 @@ feature: user_defined_aggregate_functions authors: - "Runji Wang" -start_date: "2023/07/25" +start_date: "2023/08/02" --- # User Defined Aggregate Functions (UDAF) ## Summary -This RFC proposes the user interface and implementation of user defined aggregate functions (UDAF). +This RFC proposes the user interface and internal design of user defined aggregate functions (UDAF). ## User Interfaces @@ -17,10 +17,7 @@ This section describes how users create UDAFs in Python, Java, and SQL through e ### Python API -Similar to scalar functions and table functions, we provide an `@udaf` decorator to define aggregate functions. -The difference is that below the decorator, we define a class instead of a function. -The class is an **accumulator** that accumulates the input rows and computes the aggregate result. -It can also optionally retracts a row and merges with another accumulator. +Similar to scalar functions and table functions, we provide an `@udaf` decorator to define aggregate functions. The difference is that below the decorator, we define a class instead of a function. The class is an **accumulator** that accumulates the input rows and computes the aggregate result. It can also optionally retracts a row and merges with another accumulator. ```python from risingwave.udf import udaf @@ -89,16 +86,12 @@ class WeightedAvg(AggregateFunction): ``` While in this proposal, the accumulator is the class itself. Users don't have to define the accumulator type explicitly. -The reason we can do this is that we serialize the state into bytes before storing it in the state table. -It is opaque to RisingWave. We don't care about the internal structure. Reference: [UDAF Python API in Flink](https://nightlies.apache.org/flink/flink-docs-release-1.17/docs/dev/python/table/udfs/python_udfs/#aggregate-functions) ### Java API -The Java API is similar to Python. -One of the differences is that users don't need to specify the input and output types, -because we can infer them from the function signatures. +The Java API is similar to Python. One of the differences is that users don't need to specify the input and output types, because we can infer them from the function signatures. ```java import com.risingwave.functions.AggregateFunction; @@ -176,50 +169,76 @@ The `APPEND ONLY` clause is required if the function doesn't have `retract` meth The remaining clauses are defined the same as for `create function` statement. -## Implementations +## Design -User defined functions are running in a separate process called UDF server. -The RisingWave kernel communicates with the UDF server through Arrow Flight RPC to exchange data. +### Dataflow -To avoid trouble in fault tolerance, the UDF server should be stateless even if the aggregate function is stateful. -However, attaching the state to each RPC call (like batch accumulation) is not efficient, especially when the state is large. -On the other hand, kernel doesn't need to know the state or the aggregate result before a barrier arrives. -Therefore, we create a streaming RPC to perform all operations within an epoch. +User defined functions are running in a separate process called UDF server. The RisingWave kernel communicates with the UDF server through Arrow Flight RPC to exchange data. + +To avoid trouble in fault tolerance, the UDF server should be **stateless** even if the aggregate function is stateful. This means the aggregate state should be maintained by the kernel. However, exchanging the state in each RPC call (batch aggregation) is not efficient, especially when the state is large. On the other hand, kernel doesn't need to know the state or the aggregate result before a barrier arrives. Therefore, we **create a streaming RPC `doExchange` to call the aggregate function, and sync the state every time a barrier arrives**. When the connection to the UDF server is broken, the executor raises an error and may retry or pause the stream. ```mermaid sequenceDiagram - participant UP as Upstream + participant UP as Upstream
Downstream participant AG as Agg Executor participant CP as UDF Server UP ->> AG: init(epoch) - Note over AG: load state + Note over AG: load state
from state table AG ->>+ CP: doExchange(fid, state) - Note right of CP: decode state
create accumulator - loop epoch - loop chunks - UP ->> AG: chunk - AG ->> CP: chunk - Note right of CP: accumulate
retract + CP -->> AG: output @ epoch + Note right of CP: decode state,
create accumulator + loop each epoch + loop each chunk + UP ->> AG: input chunk + AG ->> CP: input chunk + Note right of CP: accumulate,
retract end UP ->> AG: barrier(epoch+1) AG ->> CP: finish - Note right of CP: encode state
get_value - CP -->> AG: state + output - Note over AG: store state + Note right of CP: encode state,
get output + CP -->> AG: state + output @ epoch+1 + Note over AG: store state
into state table + AG ->> UP: - output @ epoch
+ output @ epoch+1 + AG ->> UP: barrier(epoch+1) end deactivate CP ``` -## Unresolved questions +### State Storage -* Are there some questions that haven't been resolved in the RFC? -* Can they be resolved in some future RFCs? -* Move some meaningful comments to here. +The state of UDAF is managed by the compute node as a single encoded BYTEA value. -## Alternatives +Currently, each aggregate operator has a **result table** to store the aggregate result. For most of our built-in aggregate functions, they have the same output as their state, so the result table is actually being used as the state table. However, for general UDAFs, their state may not be the same as their output. Such functions are not supported for now. -What other designs have been considered and what is the rationale for not choosing them? +Therefore, we propose to **transform the result table into state table**. The content of the table remains the same for existing functions. But for new functions whose state is different from output, only the state is stored. The output can be computed from the state when needed. -## Future possibilities +For example, given the input: + +| Epoch | op | id (pk) | v0 | w0 | v1 | +| ----- | ---- | ------- | ---- | ---- | ----- | +| 1 | + | 0 | 1 | 2 | false | +| 2 | U- | 0 | 1 | 2 | false | +| 2 | U+ | 0 | 2 | 1 | true | + +The new **state table** (derived from old result table) of the agg operator would be like: + +| Epoch | id | sum(v0) | bool_and(v1) | weighted_avg(v0, w0) | max(v0)* | +| ----- | ---- | ------- | ------------------- | -------------------- | ---------- | +| 1 | 0 | sum = 1 | false = 1, true = 0 | encode(1,2) = b'XXX' | output = 1 | +| 2 | 0 | sum = 2 | false = 0, true = 1 | encode(2,1) = b'YYY' | output = 2 | + +* For **append-only** aggregate functions (e.g. max, min, first, last, string_agg...), their states are all input values maintained in seperate "materialized input" tables. For backward compatibility, their values in the state table are still aggregate results. -Some potential extensions or optimizations can be done in the future based on the RFC. +The output would be: + +| Epoch | op | id | sum(v0) | bool_and(v1) | weighted_avg(v0, w0) | max(v0) | +| ----- | ---- | ---- | ------- | ------------ | -------------------- | ------- | +| 1 | + | 0 | 1 | false | 1 | 1 | +| 2 | U- | 0 | 1 | false | 1 | 1 | +| 2 | U+ | 0 | 2 | true | 2 | 2 | + +## Unresolved questions + +## Alternatives + +## Future possibilities From 9f4b2fc4ffed1b4509a11a1c2cb6cbe2c25ff4ca Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Wed, 2 Aug 2023 17:44:59 +0800 Subject: [PATCH 3/6] change the UDAF API to Flink style Signed-off-by: Runji Wang --- rfcs/0000-user-defined-aggregate-functions.md | 136 +++++++++--------- 1 file changed, 67 insertions(+), 69 deletions(-) diff --git a/rfcs/0000-user-defined-aggregate-functions.md b/rfcs/0000-user-defined-aggregate-functions.md index 88d03298..aff47b28 100644 --- a/rfcs/0000-user-defined-aggregate-functions.md +++ b/rfcs/0000-user-defined-aggregate-functions.md @@ -17,76 +17,68 @@ This section describes how users create UDAFs in Python, Java, and SQL through e ### Python API -Similar to scalar functions and table functions, we provide an `@udaf` decorator to define aggregate functions. The difference is that below the decorator, we define a class instead of a function. The class is an **accumulator** that accumulates the input rows and computes the aggregate result. It can also optionally retracts a row and merges with another accumulator. +Similar to scalar functions and table functions, we provide an `@udaf` decorator to define aggregate functions. The difference is that below the decorator, we define a class instead of a function. The class has an associated **accumulator** type, which is the intermediate state of the aggregate function. ```python from risingwave.udf import udaf +# The accumulator (intermediate state) can be arbitrary class. +# It will be serialized into bytes before sending to kernel, +# and deserialized from bytes after receiving from kernel. +class State: + sum: int + count: int + # The aggregate function is defined as a class. # Specify the schema of the aggregate function in the `udaf` decorator. @udaf(input_types=['BIGINT', 'INT'], result_type='BIGINT') class WeightedAvg: - # The internal state of the accumulator is defined as fields. - # They will be serialized into bytes before sending to kernel, - # and deserialized from bytes after receiving from kernel. - sum: int - count: int - - # Initialize the accumulator. - def __init__(self): - self.sum = 0 - self.count = 0 + # Create an empty accumulator. + def create_accumulator(self) -> State: + accumulator = State() + accumulator.sum = 0 + accumulator.count = 0 + return accumulator # Get the aggregate result. # The return value should match `result_type`. - def get_value(self) -> int: - if self.count == 0: + def get_value(self, accumulator: State) -> int: + if accumulator.count == 0: return None else: - return self.sum / self.count + return accumulator.sum / accumulator.count - # Accumulate a row. - # The arguments should match `input_types`. - def accumulate(self, value: int, weight: int): - self.sum += value * weight - self.count += weight + # Accumulate a row to the accumulator. + # The last arguments should match `input_types`. + def accumulate(self, accumulator: State, value: int, weight: int): + accumulator.sum += value * weight + accumulator.count += weight - # Retract a row. + # Retract a row from the accumulator. # This method is optional. If not defined, the function is append-only and not retractable. - def retract(self, value: int, weight: int): - self.sum -= value * weight - self.count -= weight + def retract(self, accumulator: State, value: int, weight: int): + accumulator.sum -= value * weight + accumulator.count -= weight - # Merge with another accumulator. + # Merge the accumulator with another one. # This method is optional. If defined, the function can be optimized with two-phase aggregation. - def merge(self, other: WeightedAvg): - self.count += other.count - self.sum += other.sum + def merge(self, accumulator: State, other: State): + accumulator.count += other.count + accumulator.sum += other.sum + + # Serialize the accumulator into bytes. + # This method is optional. If not defined, the accumulator would be serialized using pickle. + def serialize(self, accumulator: State) -> bytes: + # default implementation + return pickle.dumps(accumulator) + + # Deserialize the bytes into an accumulator. + # This method is optional. If not defined, the accumulator would be deserialized using pickle. + def deserialize(self, serialized: bytes) -> State: + # default implementation + return pickle.loads(serialized) ``` -#### Alternative - -In Flink, the accumulator is a separate variable passed into functions as an argument. - -```python -class WeightedAvg(AggregateFunction): - def get_accumulator_type(self): - return 'ROW' - - def create_accumulator(self): - # Row(sum, count) - return Row(0, 0) - - def get_value(self, accumulator): - # ... - - def accumulate(self, accumulator, value, weight): - accumulator[0] += value * weight - accumulator[1] += weight -``` - -While in this proposal, the accumulator is the class itself. Users don't have to define the accumulator type explicitly. - Reference: [UDAF Python API in Flink](https://nightlies.apache.org/flink/flink-docs-release-1.17/docs/dev/python/table/udfs/python_udfs/#aggregate-functions) ### Java API @@ -94,47 +86,53 @@ Reference: [UDAF Python API in Flink](https://nightlies.apache.org/flink/flink-d The Java API is similar to Python. One of the differences is that users don't need to specify the input and output types, because we can infer them from the function signatures. ```java +import java.io.Serializable; import com.risingwave.functions.AggregateFunction; -// The aggregate function is defined as a class, which implements the `AggregateFunction` interface. -public static class WeightedAvg implements AggregateFunction { - // The internal state of the accumulator is defined as class fields. +// Mutable accumulator for the aggregate function. +// This class should be Serializable. +public class WeightedAvgAccumulator implements Serializable { public long sum = 0; public int count = 0; +} +// The aggregate function is defined as a class, which implements the `AggregateFunction` interface. +public class WeightedAvg implements AggregateFunction { // Create a new accumulator. - public WeightedAvg() {} + public WeightedAvgAccumulator createAccumulator() { + return new WeightedAvgAccumulator(); + } // Get the aggregate result. // The result type is inferred from the signature. (BIGINT) // If a Java type can not infer to an unique SQL type, it should be annotated with `@DataTypeHint`. - public Long getValue() { - if (count == 0) { + public Long getValue(WeightedAvgAccumulator acc) { + if (acc.count == 0) { return null; } else { - return sum / count; + return acc.sum / acc.count; } } - // Accumulate a row. + // Accumulate a row to the accumulator. // The input types are inferred from the signatures. (BIGINT, INT) // If a Java type can not infer to an unique SQL type, it should be annotated with `@DataTypeHint`. - public void accumulate(long iValue, int iWeight) { - sum += iValue * iWeight; - count += iWeight; + public void accumulate(WeightedAvgAccumulator acc, long iValue, int iWeight) { + acc.sum += iValue * iWeight; + acc.count += iWeight; } - // Retract a row. (optional) + // Retract a row from the accumulator. (optional) // The function signature should match `accumulate`. - public void retract(long iValue, int iWeight) { - sum -= iValue * iWeight; - count -= iWeight; + public void retract(WeightedAvgAccumulator acc, long iValue, int iWeight) { + acc.sum -= iValue * iWeight; + acc.count -= iWeight; } - // Merge with another accumulator. (optional) - public void merge(WeightedAvg a) { - count += a.count; - sum += a.sum; + // Merge the accumulator with another one. (optional) + public void merge(WeightedAvgAccumulator acc, WeightedAvgAccumulator other) { + acc.count += other.count; + acc.sum += other.sum; } } ``` From d243fae632fcda10ee2fb552ef47d4bc5831a4b5 Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Wed, 2 Aug 2023 18:02:58 +0800 Subject: [PATCH 4/6] add an alternative for accumulator Signed-off-by: Runji Wang --- rfcs/0000-user-defined-aggregate-functions.md | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/rfcs/0000-user-defined-aggregate-functions.md b/rfcs/0000-user-defined-aggregate-functions.md index aff47b28..e9d40e7f 100644 --- a/rfcs/0000-user-defined-aggregate-functions.md +++ b/rfcs/0000-user-defined-aggregate-functions.md @@ -239,4 +239,31 @@ The output would be: ## Alternatives +### Aggregate Function Class as the Accumulator + +This proposal follows the Flink API, which defines a separate class as the accumulator. There is an alternative design to define the accumulator as the aggregate function class itself. For example: + +```python +@udaf(input_types=['BIGINT', 'INT'], result_type='BIGINT') +class WeightedAvg: + sum: int + count: int + + def __init__(self): + self.sum = 0 + self.count = 0 + + def get_value(self) -> int: + if self.count == 0: + return None + else: + return self.sum / self.count + + def accumulate(self, value: int, weight: int): + self.sum += value * weight + self.count += weight +``` + +This is simpler and more intuitive, but is also less flexible than the previous design. The previous design allows direct arguments (e.g. `percentile_cont(fraction)`, `rank(n)`) to be passed to the aggregate function, while this alternative requires the arguments to be stored in the state. + ## Future possibilities From 3ca95d1cffba37e6bc86e581b2be9cc345515307 Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Thu, 3 Aug 2023 15:33:00 +0800 Subject: [PATCH 5/6] replace term `accumulator` with `intermediate state` Signed-off-by: Runji Wang --- rfcs/0000-user-defined-aggregate-functions.md | 88 +++++++++---------- 1 file changed, 44 insertions(+), 44 deletions(-) diff --git a/rfcs/0000-user-defined-aggregate-functions.md b/rfcs/0000-user-defined-aggregate-functions.md index e9d40e7f..7f0d0a42 100644 --- a/rfcs/0000-user-defined-aggregate-functions.md +++ b/rfcs/0000-user-defined-aggregate-functions.md @@ -17,12 +17,12 @@ This section describes how users create UDAFs in Python, Java, and SQL through e ### Python API -Similar to scalar functions and table functions, we provide an `@udaf` decorator to define aggregate functions. The difference is that below the decorator, we define a class instead of a function. The class has an associated **accumulator** type, which is the intermediate state of the aggregate function. +Similar to scalar functions and table functions, we provide an `@udaf` decorator to define aggregate functions. The difference is that below the decorator, we define a class instead of a function. The class has an associated type for **intermediate state**. ```python from risingwave.udf import udaf -# The accumulator (intermediate state) can be arbitrary class. +# The intermediate state can be arbitrary class. # It will be serialized into bytes before sending to kernel, # and deserialized from bytes after receiving from kernel. class State: @@ -33,47 +33,47 @@ class State: # Specify the schema of the aggregate function in the `udaf` decorator. @udaf(input_types=['BIGINT', 'INT'], result_type='BIGINT') class WeightedAvg: - # Create an empty accumulator. - def create_accumulator(self) -> State: - accumulator = State() - accumulator.sum = 0 - accumulator.count = 0 - return accumulator + # Create an empty state. + def create_state(self) -> State: + state = State() + state.sum = 0 + state.count = 0 + return state # Get the aggregate result. # The return value should match `result_type`. - def get_value(self, accumulator: State) -> int: - if accumulator.count == 0: + def get_value(self, state: State) -> int: + if state.count == 0: return None else: - return accumulator.sum / accumulator.count + return state.sum / state.count - # Accumulate a row to the accumulator. + # Accumulate a row to the state. # The last arguments should match `input_types`. - def accumulate(self, accumulator: State, value: int, weight: int): - accumulator.sum += value * weight - accumulator.count += weight + def accumulate(self, state: State, value: int, weight: int): + state.sum += value * weight + state.count += weight - # Retract a row from the accumulator. + # Retract a row from the state. # This method is optional. If not defined, the function is append-only and not retractable. - def retract(self, accumulator: State, value: int, weight: int): - accumulator.sum -= value * weight - accumulator.count -= weight + def retract(self, state: State, value: int, weight: int): + state.sum -= value * weight + state.count -= weight - # Merge the accumulator with another one. + # Merge the state with another one. # This method is optional. If defined, the function can be optimized with two-phase aggregation. - def merge(self, accumulator: State, other: State): - accumulator.count += other.count - accumulator.sum += other.sum + def merge(self, state: State, other: State): + state.count += other.count + state.sum += other.sum - # Serialize the accumulator into bytes. - # This method is optional. If not defined, the accumulator would be serialized using pickle. - def serialize(self, accumulator: State) -> bytes: + # Serialize the state into bytes. + # This method is optional. If not defined, the state would be serialized using pickle. + def serialize(self, state: State) -> bytes: # default implementation - return pickle.dumps(accumulator) + return pickle.dumps(state) - # Deserialize the bytes into an accumulator. - # This method is optional. If not defined, the accumulator would be deserialized using pickle. + # Deserialize the bytes into an state. + # This method is optional. If not defined, the state would be deserialized using pickle. def deserialize(self, serialized: bytes) -> State: # default implementation return pickle.loads(serialized) @@ -89,24 +89,24 @@ The Java API is similar to Python. One of the differences is that users don't ne import java.io.Serializable; import com.risingwave.functions.AggregateFunction; -// Mutable accumulator for the aggregate function. +// Mutable intermediate state for the aggregate function. // This class should be Serializable. -public class WeightedAvgAccumulator implements Serializable { +public class WeightedAvgState implements Serializable { public long sum = 0; public int count = 0; } // The aggregate function is defined as a class, which implements the `AggregateFunction` interface. public class WeightedAvg implements AggregateFunction { - // Create a new accumulator. - public WeightedAvgAccumulator createAccumulator() { - return new WeightedAvgAccumulator(); + // Create a new state. + public WeightedAvgState createState() { + return new WeightedAvgState(); } // Get the aggregate result. // The result type is inferred from the signature. (BIGINT) // If a Java type can not infer to an unique SQL type, it should be annotated with `@DataTypeHint`. - public Long getValue(WeightedAvgAccumulator acc) { + public Long getValue(WeightedAvgState acc) { if (acc.count == 0) { return null; } else { @@ -114,23 +114,23 @@ public class WeightedAvg implements AggregateFunction { } } - // Accumulate a row to the accumulator. + // Accumulate a row to the state. // The input types are inferred from the signatures. (BIGINT, INT) // If a Java type can not infer to an unique SQL type, it should be annotated with `@DataTypeHint`. - public void accumulate(WeightedAvgAccumulator acc, long iValue, int iWeight) { + public void accumulate(WeightedAvgState acc, long iValue, int iWeight) { acc.sum += iValue * iWeight; acc.count += iWeight; } - // Retract a row from the accumulator. (optional) + // Retract a row from the state. (optional) // The function signature should match `accumulate`. - public void retract(WeightedAvgAccumulator acc, long iValue, int iWeight) { + public void retract(WeightedAvgState acc, long iValue, int iWeight) { acc.sum -= iValue * iWeight; acc.count -= iWeight; } - // Merge the accumulator with another one. (optional) - public void merge(WeightedAvgAccumulator acc, WeightedAvgAccumulator other) { + // Merge the state with another one. (optional) + public void merge(WeightedAvgState acc, WeightedAvgState other) { acc.count += other.count; acc.sum += other.sum; } @@ -184,7 +184,7 @@ sequenceDiagram Note over AG: load state
from state table AG ->>+ CP: doExchange(fid, state) CP -->> AG: output @ epoch - Note right of CP: decode state,
create accumulator + Note right of CP: decode state,
create state loop each epoch loop each chunk UP ->> AG: input chunk @@ -239,9 +239,9 @@ The output would be: ## Alternatives -### Aggregate Function Class as the Accumulator +### Aggregate Function Class as the State -This proposal follows the Flink API, which defines a separate class as the accumulator. There is an alternative design to define the accumulator as the aggregate function class itself. For example: +This proposal follows the Flink API, which defines a separate class as the intermediate state. There is an alternative design to define the state as the aggregate function class itself. For example: ```python @udaf(input_types=['BIGINT', 'INT'], result_type='BIGINT') From 7a9089768a5749fdb12c942e4648e7edbd0727cf Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Thu, 3 Aug 2023 15:40:50 +0800 Subject: [PATCH 6/6] refine based on comments Signed-off-by: Runji Wang --- rfcs/0000-user-defined-aggregate-functions.md | 25 ++++++++----------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/rfcs/0000-user-defined-aggregate-functions.md b/rfcs/0000-user-defined-aggregate-functions.md index 7f0d0a42..6852cf1c 100644 --- a/rfcs/0000-user-defined-aggregate-functions.md +++ b/rfcs/0000-user-defined-aggregate-functions.md @@ -26,8 +26,8 @@ from risingwave.udf import udaf # It will be serialized into bytes before sending to kernel, # and deserialized from bytes after receiving from kernel. class State: - sum: int - count: int + sum: int = 0 + count: int = 0 # The aggregate function is defined as a class. # Specify the schema of the aggregate function in the `udaf` decorator. @@ -35,14 +35,11 @@ class State: class WeightedAvg: # Create an empty state. def create_state(self) -> State: - state = State() - state.sum = 0 - state.count = 0 - return state + return State() # Get the aggregate result. # The return value should match `result_type`. - def get_value(self, state: State) -> int: + def get_result(self, state: State) -> int: if state.count == 0: return None else: @@ -106,7 +103,7 @@ public class WeightedAvg implements AggregateFunction { // Get the aggregate result. // The result type is inferred from the signature. (BIGINT) // If a Java type can not infer to an unique SQL type, it should be annotated with `@DataTypeHint`. - public Long getValue(WeightedAvgState acc) { + public Long getResult(WeightedAvgState acc) { if (acc.count == 0) { return null; } else { @@ -152,7 +149,7 @@ The full syntax is: ```sql CREATE [ OR REPLACE ] AGGREGATE name ( [ argname ] arg_data_type [ , ... ] ) -[ RETURNS return_data_type ] [ APPEND ONLY ] +[ RETURNS result_data_type ] [ APPEND ONLY ] [ LANGUAGE language ] [ AS identifier ] [ USING LINK link ]; ``` @@ -206,9 +203,9 @@ sequenceDiagram The state of UDAF is managed by the compute node as a single encoded BYTEA value. -Currently, each aggregate operator has a **result table** to store the aggregate result. For most of our built-in aggregate functions, they have the same output as their state, so the result table is actually being used as the state table. However, for general UDAFs, their state may not be the same as their output. Such functions are not supported for now. +Currently, each aggregate operator has a **result table** to store the aggregate result. For most of our built-in aggregate functions, they have the same output as their state, so the result table is actually being used as the **intermediate state table**. However, for general UDAFs, their state may not be the same as their output. Such functions are not supported for now. -Therefore, we propose to **transform the result table into state table**. The content of the table remains the same for existing functions. But for new functions whose state is different from output, only the state is stored. The output can be computed from the state when needed. +Therefore, we propose to **transform the result table into intermediate state table**. The content of the table remains the same for existing functions. But for new functions whose state is different from output, only the state is stored. The output can be computed from the state when needed. For example, given the input: @@ -218,14 +215,14 @@ For example, given the input: | 2 | U- | 0 | 1 | 2 | false | | 2 | U+ | 0 | 2 | 1 | true | -The new **state table** (derived from old result table) of the agg operator would be like: +The new **intermediate state table** (inherited from the old result table) of the agg operator would be like: | Epoch | id | sum(v0) | bool_and(v1) | weighted_avg(v0, w0) | max(v0)* | | ----- | ---- | ------- | ------------------- | -------------------- | ---------- | | 1 | 0 | sum = 1 | false = 1, true = 0 | encode(1,2) = b'XXX' | output = 1 | | 2 | 0 | sum = 2 | false = 0, true = 1 | encode(2,1) = b'YYY' | output = 2 | -* For **append-only** aggregate functions (e.g. max, min, first, last, string_agg...), their states are all input values maintained in seperate "materialized input" tables. For backward compatibility, their values in the state table are still aggregate results. +Note: for **append-only** aggregate functions (e.g. max, min, first, last, string_agg...), their states are all input values maintained in seperate "materialized input" tables. For backward compatibility, their values in the state table are still aggregate results. The output would be: @@ -253,7 +250,7 @@ class WeightedAvg: self.sum = 0 self.count = 0 - def get_value(self) -> int: + def get_result(self) -> int: if self.count == 0: return None else: