diff --git a/e2e_test/udf/test.py b/e2e_test/udf/test.py index 725f6cff43b37..a0089c2e4b1b0 100644 --- a/e2e_test/udf/test.py +++ b/e2e_test/udf/test.py @@ -84,6 +84,11 @@ def hex_to_dec(hex: Optional[str]) -> Optional[Decimal]: return dec +@udf(input_types=["FLOAT8"], result_type="DECIMAL") +def float_to_decimal(f: float) -> Decimal: + return Decimal(f) + + @udf(input_types=["DECIMAL", "DECIMAL"], result_type="DECIMAL") def decimal_add(a: Decimal, b: Decimal) -> Decimal: return a + b @@ -217,6 +222,7 @@ def return_all_arrays( server.add_function(split) server.add_function(extract_tcp_info) server.add_function(hex_to_dec) + server.add_function(float_to_decimal) server.add_function(decimal_add) server.add_function(array_access) server.add_function(jsonb_access) diff --git a/e2e_test/udf/udf.slt b/e2e_test/udf/udf.slt index 0b397cd05d284..d3f88e8f5b5d8 100644 --- a/e2e_test/udf/udf.slt +++ b/e2e_test/udf/udf.slt @@ -39,6 +39,9 @@ create function series(int) returns table (x int) as series using link 'http://l statement ok create function split(varchar) returns table (word varchar, length int) as split using link 'http://localhost:8815'; +statement ok +create function float_to_decimal(float8) returns decimal as float_to_decimal using link 'http://localhost:8815'; + statement ok create function hex_to_dec(varchar) returns decimal as hex_to_dec using link 'http://localhost:8815'; @@ -77,6 +80,7 @@ show functions array_access character varying[], integer character varying (empty) http://localhost:8815 decimal_add numeric, numeric numeric (empty) http://localhost:8815 extract_tcp_info bytea struct (empty) http://localhost:8815 +float_to_decimal double precision numeric (empty) http://localhost:8815 gcd integer, integer integer (empty) http://localhost:8815 gcd integer, integer, integer integer (empty) http://localhost:8815 hex_to_dec character varying numeric (empty) http://localhost:8815 @@ -110,6 +114,11 @@ select hex_to_dec('000000000000000000000000000000000000000000c0f6346334241a61f90 ---- 233276425899864771438119478 +query I +select float_to_decimal('-1e-10'::float8); +---- +-0.0000000001000000000000000036 + query R select decimal_add(1.11, 2.22); ---- @@ -343,6 +352,9 @@ drop function series; statement ok drop function split; +statement ok +drop function float_to_decimal; + statement ok drop function hex_to_dec; diff --git a/java/udf-example/pom.xml b/java/udf-example/pom.xml index 49de72ab3fac7..8bf51cd108128 100644 --- a/java/udf-example/pom.xml +++ b/java/udf-example/pom.xml @@ -31,7 +31,7 @@ com.risingwave risingwave-udf - 0.1.2-SNAPSHOT + 0.1.3-SNAPSHOT com.google.code.gson diff --git a/java/udf-example/src/main/java/com/example/UdfExample.java b/java/udf-example/src/main/java/com/example/UdfExample.java index 4cdbaebb93294..52b7d1df6f120 100644 --- a/java/udf-example/src/main/java/com/example/UdfExample.java +++ b/java/udf-example/src/main/java/com/example/UdfExample.java @@ -35,6 +35,7 @@ public class UdfExample { public static void main(String[] args) throws IOException { try (var server = new UdfServer("0.0.0.0", 8815)) { server.addFunction("int_42", new Int42()); + server.addFunction("float_to_decimal", new FloatToDecimal()); server.addFunction("sleep", new Sleep()); server.addFunction("gcd", new Gcd()); server.addFunction("gcd3", new Gcd3()); @@ -64,6 +65,12 @@ public int eval() { } } + public static class FloatToDecimal implements ScalarFunction { + public BigDecimal eval(Double f) { + return new BigDecimal(f); + } + } + public static class Sleep implements ScalarFunction { public int eval(int x) { try { diff --git a/java/udf/CHANGELOG.md b/java/udf/CHANGELOG.md index c64dbd1427737..fb1f055783225 100644 --- a/java/udf/CHANGELOG.md +++ b/java/udf/CHANGELOG.md @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.1.3] - 2023-12-06 + +### Fixed + +- Fix decimal type output. + ## [0.1.2] - 2023-12-04 ### Fixed diff --git a/java/udf/pom.xml b/java/udf/pom.xml index ea19d85234dbc..f747603ca8429 100644 --- a/java/udf/pom.xml +++ b/java/udf/pom.xml @@ -6,7 +6,7 @@ com.risingwave risingwave-udf jar - 0.1.2-SNAPSHOT + 0.1.3-SNAPSHOT risingwave-java-root diff --git a/java/udf/src/main/java/com/risingwave/functions/TypeUtils.java b/java/udf/src/main/java/com/risingwave/functions/TypeUtils.java index f09e52f9548fa..19edaabf811d0 100644 --- a/java/udf/src/main/java/com/risingwave/functions/TypeUtils.java +++ b/java/udf/src/main/java/com/risingwave/functions/TypeUtils.java @@ -245,7 +245,8 @@ static void fillVector(FieldVector fieldVector, Object[] values) { vector.allocateNew(values.length); for (int i = 0; i < values.length; i++) { if (values[i] != null) { - vector.set(i, ((BigDecimal) values[i]).toString().getBytes()); + // use `toPlainString` to avoid scientific notation + vector.set(i, ((BigDecimal) values[i]).toPlainString().getBytes()); } } } else if (fieldVector instanceof DateDayVector) { diff --git a/src/expr/core/src/expr/expr_udf.rs b/src/expr/core/src/expr/expr_udf.rs index 234df130e22dd..5df8f9d3a6b61 100644 --- a/src/expr/core/src/expr/expr_udf.rs +++ b/src/expr/core/src/expr/expr_udf.rs @@ -144,8 +144,7 @@ impl UdfExpression { ); } - let data_chunk = - DataChunk::try_from(&output).expect("failed to convert UDF output to DataChunk"); + let data_chunk = DataChunk::try_from(&output)?; let output = data_chunk.uncompact(vis.clone()); let Some(array) = output.columns().first() else { diff --git a/src/udf/python/CHANGELOG.md b/src/udf/python/CHANGELOG.md index 9255d727c1e1d..a20411e69d83e 100644 --- a/src/udf/python/CHANGELOG.md +++ b/src/udf/python/CHANGELOG.md @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.1.1] - 2023-12-06 + +### Fixed + +- Fix decimal type output. + ## [0.1.0] - 2023-12-01 ### Fixed diff --git a/src/udf/python/pyproject.toml b/src/udf/python/pyproject.toml index 97003bc29157a..b535355168363 100644 --- a/src/udf/python/pyproject.toml +++ b/src/udf/python/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "risingwave" -version = "0.1.0" +version = "0.1.1" authors = [{ name = "RisingWave Labs" }] description = "RisingWave Python API" readme = "README.md" diff --git a/src/udf/python/risingwave/udf.py b/src/udf/python/risingwave/udf.py index fc8166bcd83b8..803ab1acbcbfb 100644 --- a/src/udf/python/risingwave/udf.py +++ b/src/udf/python/risingwave/udf.py @@ -120,7 +120,14 @@ def _process_func(type: pa.DataType, output: bool) -> Callable: if type.equals(UNCONSTRAINED_DECIMAL): if output: - return lambda v: str(v).encode("utf-8") + + def decimal_to_str(v): + if not isinstance(v, Decimal): + raise ValueError(f"Expected Decimal, got {v}") + # use `f` format to avoid scientific notation, e.g. `1e10` + return format(v, "f").encode("utf-8") + + return decimal_to_str else: return lambda v: Decimal(v.decode("utf-8"))