Skip to content

Commit

Permalink
fix(udf): avoid panic on invalid output and fix decimal output scale …
Browse files Browse the repository at this point in the history
…lost (#13828)

Signed-off-by: Runji Wang <[email protected]>
  • Loading branch information
wangrunji0408 authored Dec 7, 2023
1 parent 2fba274 commit fcea158
Show file tree
Hide file tree
Showing 11 changed files with 51 additions and 7 deletions.
6 changes: 6 additions & 0 deletions e2e_test/udf/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ def hex_to_dec(hex: Optional[str]) -> Optional[Decimal]:
return dec


@udf(input_types=["FLOAT8"], result_type="DECIMAL")
def float_to_decimal(f: float) -> Decimal:
return Decimal(f)


@udf(input_types=["DECIMAL", "DECIMAL"], result_type="DECIMAL")
def decimal_add(a: Decimal, b: Decimal) -> Decimal:
return a + b
Expand Down Expand Up @@ -217,6 +222,7 @@ def return_all_arrays(
server.add_function(split)
server.add_function(extract_tcp_info)
server.add_function(hex_to_dec)
server.add_function(float_to_decimal)
server.add_function(decimal_add)
server.add_function(array_access)
server.add_function(jsonb_access)
Expand Down
12 changes: 12 additions & 0 deletions e2e_test/udf/udf.slt
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ create function series(int) returns table (x int) as series using link 'http://l
statement ok
create function split(varchar) returns table (word varchar, length int) as split using link 'http://localhost:8815';

statement ok
create function float_to_decimal(float8) returns decimal as float_to_decimal using link 'http://localhost:8815';

statement ok
create function hex_to_dec(varchar) returns decimal as hex_to_dec using link 'http://localhost:8815';

Expand Down Expand Up @@ -77,6 +80,7 @@ show functions
array_access character varying[], integer character varying (empty) http://localhost:8815
decimal_add numeric, numeric numeric (empty) http://localhost:8815
extract_tcp_info bytea struct<src_ip character varying, dst_ip character varying, src_port smallint, dst_port smallint> (empty) http://localhost:8815
float_to_decimal double precision numeric (empty) http://localhost:8815
gcd integer, integer integer (empty) http://localhost:8815
gcd integer, integer, integer integer (empty) http://localhost:8815
hex_to_dec character varying numeric (empty) http://localhost:8815
Expand Down Expand Up @@ -110,6 +114,11 @@ select hex_to_dec('000000000000000000000000000000000000000000c0f6346334241a61f90
----
233276425899864771438119478

query I
select float_to_decimal('-1e-10'::float8);
----
-0.0000000001000000000000000036

query R
select decimal_add(1.11, 2.22);
----
Expand Down Expand Up @@ -343,6 +352,9 @@ drop function series;
statement ok
drop function split;

statement ok
drop function float_to_decimal;

statement ok
drop function hex_to_dec;

Expand Down
2 changes: 1 addition & 1 deletion java/udf-example/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
<dependency>
<groupId>com.risingwave</groupId>
<artifactId>risingwave-udf</artifactId>
<version>0.1.2-SNAPSHOT</version>
<version>0.1.3-SNAPSHOT</version>
</dependency>
<dependency>
<groupId>com.google.code.gson</groupId>
Expand Down
7 changes: 7 additions & 0 deletions java/udf-example/src/main/java/com/example/UdfExample.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ public class UdfExample {
public static void main(String[] args) throws IOException {
try (var server = new UdfServer("0.0.0.0", 8815)) {
server.addFunction("int_42", new Int42());
server.addFunction("float_to_decimal", new FloatToDecimal());
server.addFunction("sleep", new Sleep());
server.addFunction("gcd", new Gcd());
server.addFunction("gcd3", new Gcd3());
Expand Down Expand Up @@ -64,6 +65,12 @@ public int eval() {
}
}

public static class FloatToDecimal implements ScalarFunction {
public BigDecimal eval(Double f) {
return new BigDecimal(f);
}
}

public static class Sleep implements ScalarFunction {
public int eval(int x) {
try {
Expand Down
6 changes: 6 additions & 0 deletions java/udf/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

## [0.1.3] - 2023-12-06

### Fixed

- Fix decimal type output.

## [0.1.2] - 2023-12-04

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion java/udf/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
<groupId>com.risingwave</groupId>
<artifactId>risingwave-udf</artifactId>
<packaging>jar</packaging>
<version>0.1.2-SNAPSHOT</version>
<version>0.1.3-SNAPSHOT</version>

<parent>
<artifactId>risingwave-java-root</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,8 @@ static void fillVector(FieldVector fieldVector, Object[] values) {
vector.allocateNew(values.length);
for (int i = 0; i < values.length; i++) {
if (values[i] != null) {
vector.set(i, ((BigDecimal) values[i]).toString().getBytes());
// use `toPlainString` to avoid scientific notation
vector.set(i, ((BigDecimal) values[i]).toPlainString().getBytes());
}
}
} else if (fieldVector instanceof DateDayVector) {
Expand Down
3 changes: 1 addition & 2 deletions src/expr/core/src/expr/expr_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,7 @@ impl UdfExpression {
);
}

let data_chunk =
DataChunk::try_from(&output).expect("failed to convert UDF output to DataChunk");
let data_chunk = DataChunk::try_from(&output)?;
let output = data_chunk.uncompact(vis.clone());

let Some(array) = output.columns().first() else {
Expand Down
6 changes: 6 additions & 0 deletions src/udf/python/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

## [0.1.1] - 2023-12-06

### Fixed

- Fix decimal type output.

## [0.1.0] - 2023-12-01

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion src/udf/python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "risingwave"
version = "0.1.0"
version = "0.1.1"
authors = [{ name = "RisingWave Labs" }]
description = "RisingWave Python API"
readme = "README.md"
Expand Down
9 changes: 8 additions & 1 deletion src/udf/python/risingwave/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,14 @@ def _process_func(type: pa.DataType, output: bool) -> Callable:

if type.equals(UNCONSTRAINED_DECIMAL):
if output:
return lambda v: str(v).encode("utf-8")

def decimal_to_str(v):
if not isinstance(v, Decimal):
raise ValueError(f"Expected Decimal, got {v}")
# use `f` format to avoid scientific notation, e.g. `1e10`
return format(v, "f").encode("utf-8")

return decimal_to_str
else:
return lambda v: Decimal(v.decode("utf-8"))

Expand Down

0 comments on commit fcea158

Please sign in to comment.