From bdb7aab996c49f6cda8fe6670ee97c3eae155c27 Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Thu, 7 Dec 2023 15:09:15 +0800 Subject: [PATCH] fix: cherry-pick several udf fixes (#13852) Signed-off-by: Runji Wang --- ci/scripts/build-other.sh | 1 + e2e_test/udf/test.py | 12 ++ e2e_test/udf/udf.slt | 54 +++++- java/udf-example/pom.xml | 2 +- .../src/main/java/com/example/UdfExample.java | 14 ++ java/udf/CHANGELOG.md | 19 +- java/udf/pom.xml | 2 +- .../functions/ScalarFunctionBatch.java | 7 +- .../functions/TableFunctionBatch.java | 10 +- .../com/risingwave/functions/TypeUtils.java | 165 ++++++------------ .../com/risingwave/functions/UdfProducer.java | 25 +-- .../functions/UserDefinedFunctionBatch.java | 5 +- .../risingwave/functions/TestUdfServer.java | 6 +- src/common/src/array/arrow.rs | 91 ++++------ src/expr/core/src/expr/expr_udf.rs | 3 +- src/udf/python/CHANGELOG.md | 12 ++ src/udf/python/pyproject.toml | 2 +- src/udf/python/risingwave/test_udf.py | 4 +- src/udf/python/risingwave/udf.py | 34 +++- 19 files changed, 257 insertions(+), 211 deletions(-) diff --git a/ci/scripts/build-other.sh b/ci/scripts/build-other.sh index f5b6692c6cbf9..9cd44dc78c95a 100755 --- a/ci/scripts/build-other.sh +++ b/ci/scripts/build-other.sh @@ -11,6 +11,7 @@ cd java mvn -B package -Dmaven.test.skip=true mvn -B install -Dmaven.test.skip=true --pl java-binding-integration-test --am mvn dependency:copy-dependencies --no-transfer-progress --pl java-binding-integration-test +mvn -B test --pl udf cd .. echo "--- Build rust binary for java binding integration test" diff --git a/e2e_test/udf/test.py b/e2e_test/udf/test.py index 6f1aa115bf953..a0089c2e4b1b0 100644 --- a/e2e_test/udf/test.py +++ b/e2e_test/udf/test.py @@ -84,6 +84,16 @@ def hex_to_dec(hex: Optional[str]) -> Optional[Decimal]: return dec +@udf(input_types=["FLOAT8"], result_type="DECIMAL") +def float_to_decimal(f: float) -> Decimal: + return Decimal(f) + + +@udf(input_types=["DECIMAL", "DECIMAL"], result_type="DECIMAL") +def decimal_add(a: Decimal, b: Decimal) -> Decimal: + return a + b + + @udf(input_types=["VARCHAR[]", "INT"], result_type="VARCHAR") def array_access(list: List[str], idx: int) -> Optional[str]: if idx == 0 or idx > len(list): @@ -212,6 +222,8 @@ def return_all_arrays( server.add_function(split) server.add_function(extract_tcp_info) server.add_function(hex_to_dec) + server.add_function(float_to_decimal) + server.add_function(decimal_add) server.add_function(array_access) server.add_function(jsonb_access) server.add_function(jsonb_concat) diff --git a/e2e_test/udf/udf.slt b/e2e_test/udf/udf.slt index f5ab290a69a8a..d3f88e8f5b5d8 100644 --- a/e2e_test/udf/udf.slt +++ b/e2e_test/udf/udf.slt @@ -39,9 +39,15 @@ create function series(int) returns table (x int) as series using link 'http://l statement ok create function split(varchar) returns table (word varchar, length int) as split using link 'http://localhost:8815'; +statement ok +create function float_to_decimal(float8) returns decimal as float_to_decimal using link 'http://localhost:8815'; + statement ok create function hex_to_dec(varchar) returns decimal as hex_to_dec using link 'http://localhost:8815'; +statement ok +create function decimal_add(decimal, decimal) returns decimal as decimal_add using link 'http://localhost:8815'; + statement ok create function array_access(varchar[], int) returns varchar as array_access using link 'http://localhost:8815'; @@ -72,7 +78,9 @@ query TTTTT rowsort show functions ---- array_access character varying[], integer character varying (empty) http://localhost:8815 +decimal_add numeric, numeric numeric (empty) http://localhost:8815 extract_tcp_info bytea struct (empty) http://localhost:8815 +float_to_decimal double precision numeric (empty) http://localhost:8815 gcd integer, integer integer (empty) http://localhost:8815 gcd integer, integer, integer integer (empty) http://localhost:8815 hex_to_dec character varying numeric (empty) http://localhost:8815 @@ -106,6 +114,16 @@ select hex_to_dec('000000000000000000000000000000000000000000c0f6346334241a61f90 ---- 233276425899864771438119478 +query I +select float_to_decimal('-1e-10'::float8); +---- +-0.0000000001000000000000000036 + +query R +select decimal_add(1.11, 2.22); +---- +3.33 + query T select array_access(ARRAY['a', 'b', 'c'], 2); ---- @@ -142,7 +160,7 @@ select (return_all( 1 ::bigint, 1 ::float4, 1 ::float8, - 1234567890123456789012345678 ::decimal, + 12345678901234567890.12345678 ::decimal, date '2023-06-01', time '01:02:03.456789', timestamp '2023-06-01 01:02:03.456789', @@ -153,7 +171,7 @@ select (return_all( row(1, 2)::struct )).*; ---- -t 1 1 1 1 1 1234567890123456789012345678 2023-06-01 01:02:03.456789 2023-06-01 01:02:03.456789 1 mon 2 days 00:00:03 string \x6279746573 {"key": 1} (1,2) +t 1 1 1 1 1 12345678901234567890.12345678 2023-06-01 01:02:03.456789 2023-06-01 01:02:03.456789 1 mon 2 days 00:00:03 string \x6279746573 {"key": 1} (1,2) query T select (return_all_arrays( @@ -163,7 +181,7 @@ select (return_all_arrays( array[null, 1 ::bigint], array[null, 1 ::float4], array[null, 1 ::float8], - array[null, 1234567890123456789012345678 ::decimal], + array[null, 12345678901234567890.12345678 ::decimal], array[null, date '2023-06-01'], array[null, time '01:02:03.456789'], array[null, timestamp '2023-06-01 01:02:03.456789'], @@ -174,7 +192,29 @@ select (return_all_arrays( array[null, row(1, 2)::struct] )).*; ---- -{NULL,t} {NULL,1} {NULL,1} {NULL,1} {NULL,1} {NULL,1} {NULL,1234567890123456789012345678} {NULL,2023-06-01} {NULL,01:02:03.456789} {NULL,"2023-06-01 01:02:03.456789"} {NULL,"1 mon 2 days 00:00:03"} {NULL,string} {NULL,"\\x6279746573"} {NULL,"{\"key\": 1}"} {NULL,"(1,2)"} +{NULL,t} {NULL,1} {NULL,1} {NULL,1} {NULL,1} {NULL,1} {NULL,12345678901234567890.12345678} {NULL,2023-06-01} {NULL,01:02:03.456789} {NULL,"2023-06-01 01:02:03.456789"} {NULL,"1 mon 2 days 00:00:03"} {NULL,string} {NULL,"\\x6279746573"} {NULL,"{\"key\": 1}"} {NULL,"(1,2)"} + +# test large string output +query I +select length((return_all( + null::boolean, + null::smallint, + null::int, + null::bigint, + null::float4, + null::float8, + null::decimal, + null::date, + null::time, + null::timestamp, + null::interval, + repeat('a', 100000)::varchar, + repeat('a', 100000)::bytea, + null::jsonb, + null::struct +)).varchar); +---- +100000 query I select series(5); @@ -312,9 +352,15 @@ drop function series; statement ok drop function split; +statement ok +drop function float_to_decimal; + statement ok drop function hex_to_dec; +statement ok +drop function decimal_add; + statement ok drop function array_access; diff --git a/java/udf-example/pom.xml b/java/udf-example/pom.xml index e47ff0263e61d..8bf51cd108128 100644 --- a/java/udf-example/pom.xml +++ b/java/udf-example/pom.xml @@ -31,7 +31,7 @@ com.risingwave risingwave-udf - 0.1.1-SNAPSHOT + 0.1.3-SNAPSHOT com.google.code.gson diff --git a/java/udf-example/src/main/java/com/example/UdfExample.java b/java/udf-example/src/main/java/com/example/UdfExample.java index ef673e61c1d91..52b7d1df6f120 100644 --- a/java/udf-example/src/main/java/com/example/UdfExample.java +++ b/java/udf-example/src/main/java/com/example/UdfExample.java @@ -35,11 +35,13 @@ public class UdfExample { public static void main(String[] args) throws IOException { try (var server = new UdfServer("0.0.0.0", 8815)) { server.addFunction("int_42", new Int42()); + server.addFunction("float_to_decimal", new FloatToDecimal()); server.addFunction("sleep", new Sleep()); server.addFunction("gcd", new Gcd()); server.addFunction("gcd3", new Gcd3()); server.addFunction("extract_tcp_info", new ExtractTcpInfo()); server.addFunction("hex_to_dec", new HexToDec()); + server.addFunction("decimal_add", new DecimalAdd()); server.addFunction("array_access", new ArrayAccess()); server.addFunction("jsonb_access", new JsonbAccess()); server.addFunction("jsonb_concat", new JsonbConcat()); @@ -63,6 +65,12 @@ public int eval() { } } + public static class FloatToDecimal implements ScalarFunction { + public BigDecimal eval(Double f) { + return new BigDecimal(f); + } + } + public static class Sleep implements ScalarFunction { public int eval(int x) { try { @@ -126,6 +134,12 @@ public BigDecimal eval(String hex) { } } + public static class DecimalAdd implements ScalarFunction { + public BigDecimal eval(BigDecimal a, BigDecimal b) { + return a.add(b); + } + } + public static class ArrayAccess implements ScalarFunction { public String eval(String[] array, int index) { return array[index - 1]; diff --git a/java/udf/CHANGELOG.md b/java/udf/CHANGELOG.md index 5f46b7ae339b5..fb1f055783225 100644 --- a/java/udf/CHANGELOG.md +++ b/java/udf/CHANGELOG.md @@ -7,7 +7,20 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] -## [0.1.1] - 2023-11-28 +## [0.1.3] - 2023-12-06 + +### Fixed + +- Fix decimal type output. + +## [0.1.2] - 2023-12-04 + +### Fixed + +- Fix index-out-of-bound error when string or string list is large. +- Fix memory leak. + +## [0.1.1] - 2023-12-03 ### Added @@ -17,6 +30,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Bump Arrow version to 14. +### Fixed + +- Fix unconstrained decimal type. + ## [0.1.0] - 2023-09-01 - Initial release. \ No newline at end of file diff --git a/java/udf/pom.xml b/java/udf/pom.xml index 7e9814b4af41e..f747603ca8429 100644 --- a/java/udf/pom.xml +++ b/java/udf/pom.xml @@ -6,7 +6,7 @@ com.risingwave risingwave-udf jar - 0.1.1-SNAPSHOT + 0.1.3-SNAPSHOT risingwave-java-root diff --git a/java/udf/src/main/java/com/risingwave/functions/ScalarFunctionBatch.java b/java/udf/src/main/java/com/risingwave/functions/ScalarFunctionBatch.java index 95b0dede5d99b..47a43a8dc2132 100644 --- a/java/udf/src/main/java/com/risingwave/functions/ScalarFunctionBatch.java +++ b/java/udf/src/main/java/com/risingwave/functions/ScalarFunctionBatch.java @@ -27,9 +27,8 @@ class ScalarFunctionBatch extends UserDefinedFunctionBatch { MethodHandle methodHandle; Function[] processInputs; - ScalarFunctionBatch(ScalarFunction function, BufferAllocator allocator) { + ScalarFunctionBatch(ScalarFunction function) { this.function = function; - this.allocator = allocator; var method = Reflection.getEvalMethod(function); this.methodHandle = Reflection.getMethodHandle(method); this.inputSchema = TypeUtils.methodToInputSchema(method); @@ -38,7 +37,7 @@ class ScalarFunctionBatch extends UserDefinedFunctionBatch { } @Override - Iterator evalBatch(VectorSchemaRoot batch) { + Iterator evalBatch(VectorSchemaRoot batch, BufferAllocator allocator) { var row = new Object[batch.getSchema().getFields().size() + 1]; row[0] = this.function; var outputValues = new Object[batch.getRowCount()]; @@ -55,7 +54,7 @@ Iterator evalBatch(VectorSchemaRoot batch) { } var outputVector = TypeUtils.createVector( - this.outputSchema.getFields().get(0), this.allocator, outputValues); + this.outputSchema.getFields().get(0), allocator, outputValues); var outputBatch = VectorSchemaRoot.of(outputVector); return Collections.singleton(outputBatch).iterator(); } diff --git a/java/udf/src/main/java/com/risingwave/functions/TableFunctionBatch.java b/java/udf/src/main/java/com/risingwave/functions/TableFunctionBatch.java index 5dbe5db470dfd..1480b8de2b9f2 100644 --- a/java/udf/src/main/java/com/risingwave/functions/TableFunctionBatch.java +++ b/java/udf/src/main/java/com/risingwave/functions/TableFunctionBatch.java @@ -28,9 +28,8 @@ class TableFunctionBatch extends UserDefinedFunctionBatch { Function[] processInputs; int chunkSize = 1024; - TableFunctionBatch(TableFunction function, BufferAllocator allocator) { + TableFunctionBatch(TableFunction function) { this.function = function; - this.allocator = allocator; var method = Reflection.getEvalMethod(function); this.methodHandle = Reflection.getMethodHandle(method); this.inputSchema = TypeUtils.methodToInputSchema(method); @@ -39,7 +38,7 @@ class TableFunctionBatch extends UserDefinedFunctionBatch { } @Override - Iterator evalBatch(VectorSchemaRoot batch) { + Iterator evalBatch(VectorSchemaRoot batch, BufferAllocator allocator) { var outputs = new ArrayList(); var row = new Object[batch.getSchema().getFields().size() + 1]; row[0] = this.function; @@ -49,10 +48,9 @@ Iterator evalBatch(VectorSchemaRoot batch) { () -> { var fields = this.outputSchema.getFields(); var indexVector = - TypeUtils.createVector( - fields.get(0), this.allocator, indexes.toArray()); + TypeUtils.createVector(fields.get(0), allocator, indexes.toArray()); var valueVector = - TypeUtils.createVector(fields.get(1), this.allocator, values.toArray()); + TypeUtils.createVector(fields.get(1), allocator, values.toArray()); indexes.clear(); values.clear(); var outputBatch = VectorSchemaRoot.of(indexVector, valueVector); diff --git a/java/udf/src/main/java/com/risingwave/functions/TypeUtils.java b/java/udf/src/main/java/com/risingwave/functions/TypeUtils.java index aedba8abf9823..19edaabf811d0 100644 --- a/java/udf/src/main/java/com/risingwave/functions/TypeUtils.java +++ b/java/udf/src/main/java/com/risingwave/functions/TypeUtils.java @@ -53,8 +53,8 @@ static Field stringToField(String typeStr, String name) { return Field.nullable(name, new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)); } else if (typeStr.equals("FLOAT8") || typeStr.equals("DOUBLE PRECISION")) { return Field.nullable(name, new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)); - } else if (typeStr.startsWith("DECIMAL") || typeStr.startsWith("NUMERIC")) { - return Field.nullable(name, new ArrowType.Decimal(38, 0, 128)); + } else if (typeStr.equals("DECIMAL") || typeStr.equals("NUMERIC")) { + return Field.nullable(name, new ArrowType.LargeBinary()); } else if (typeStr.equals("DATE")) { return Field.nullable(name, new ArrowType.Date(DateUnit.DAY)); } else if (typeStr.equals("TIME") || typeStr.equals("TIME WITHOUT TIME ZONE")) { @@ -110,7 +110,7 @@ static Field classToField(Class param, DataTypeHint hint, String name) { } else if (param == Double.class || param == double.class) { return Field.nullable(name, new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)); } else if (param == BigDecimal.class) { - return Field.nullable(name, new ArrowType.Decimal(38, 0, 128)); + return Field.nullable(name, new ArrowType.LargeBinary()); } else if (param == LocalDate.class) { return Field.nullable(name, new ArrowType.Date(DateUnit.DAY)); } else if (param == LocalTime.class) { @@ -240,12 +240,13 @@ static void fillVector(FieldVector fieldVector, Object[] values) { vector.set(i, (double) values[i]); } } - } else if (fieldVector instanceof DecimalVector) { - var vector = (DecimalVector) fieldVector; + } else if (fieldVector instanceof LargeVarBinaryVector) { + var vector = (LargeVarBinaryVector) fieldVector; vector.allocateNew(values.length); for (int i = 0; i < values.length; i++) { if (values[i] != null) { - vector.set(i, (BigDecimal) values[i]); + // use `toPlainString` to avoid scientific notation + vector.set(i, ((BigDecimal) values[i]).toPlainString().getBytes()); } } } else if (fieldVector instanceof DateDayVector) { @@ -286,7 +287,13 @@ static void fillVector(FieldVector fieldVector, Object[] values) { } } else if (fieldVector instanceof VarCharVector) { var vector = (VarCharVector) fieldVector; - vector.allocateNew(values.length); + int totalBytes = 0; + for (int i = 0; i < values.length; i++) { + if (values[i] != null) { + totalBytes += ((String) values[i]).length(); + } + } + vector.allocateNew(totalBytes, values.length); for (int i = 0; i < values.length; i++) { if (values[i] != null) { vector.set(i, ((String) values[i]).getBytes()); @@ -294,7 +301,13 @@ static void fillVector(FieldVector fieldVector, Object[] values) { } } else if (fieldVector instanceof LargeVarCharVector) { var vector = (LargeVarCharVector) fieldVector; - vector.allocateNew(values.length); + int totalBytes = 0; + for (int i = 0; i < values.length; i++) { + if (values[i] != null) { + totalBytes += ((String) values[i]).length(); + } + } + vector.allocateNew(totalBytes, values.length); for (int i = 0; i < values.length; i++) { if (values[i] != null) { vector.set(i, ((String) values[i]).getBytes()); @@ -302,7 +315,13 @@ static void fillVector(FieldVector fieldVector, Object[] values) { } } else if (fieldVector instanceof VarBinaryVector) { var vector = (VarBinaryVector) fieldVector; - vector.allocateNew(values.length); + int totalBytes = 0; + for (int i = 0; i < values.length; i++) { + if (values[i] != null) { + totalBytes += ((byte[]) values[i]).length; + } + } + vector.allocateNew(totalBytes, values.length); for (int i = 0; i < values.length; i++) { if (values[i] != null) { vector.set(i, (byte[]) values[i]); @@ -311,83 +330,30 @@ static void fillVector(FieldVector fieldVector, Object[] values) { } else if (fieldVector instanceof ListVector) { var vector = (ListVector) fieldVector; vector.allocateNew(); - if (vector.getDataVector() instanceof BitVector) { - TypeUtils.fillListVector( - vector, values, (vec, i, val) -> vec.set(i, val ? 1 : 0)); - } else if (vector.getDataVector() instanceof SmallIntVector) { - TypeUtils.fillListVector( - vector, values, (vec, i, val) -> vec.set(i, val)); - } else if (vector.getDataVector() instanceof IntVector) { - TypeUtils.fillListVector( - vector, values, (vec, i, val) -> vec.set(i, val)); - } else if (vector.getDataVector() instanceof BigIntVector) { - TypeUtils.fillListVector( - vector, values, (vec, i, val) -> vec.set(i, val)); - } else if (vector.getDataVector() instanceof Float4Vector) { - TypeUtils.fillListVector( - vector, values, (vec, i, val) -> vec.set(i, val)); - } else if (vector.getDataVector() instanceof Float8Vector) { - TypeUtils.fillListVector( - vector, values, (vec, i, val) -> vec.set(i, val)); - } else if (vector.getDataVector() instanceof DecimalVector) { - TypeUtils.fillListVector( - vector, values, (vec, i, val) -> vec.set(i, val)); - } else if (vector.getDataVector() instanceof DateDayVector) { - TypeUtils.fillListVector( - vector, values, (vec, i, val) -> vec.set(i, (int) val.toEpochDay())); - } else if (vector.getDataVector() instanceof TimeMicroVector) { - TypeUtils.fillListVector( - vector, values, (vec, i, val) -> vec.set(i, val.toNanoOfDay() / 1000)); - } else if (vector.getDataVector() instanceof TimeStampMicroVector) { - TypeUtils.fillListVector( - vector, values, (vec, i, val) -> vec.set(i, timestampToMicros(val))); - } else if (vector.getDataVector() instanceof IntervalMonthDayNanoVector) { - TypeUtils.fillListVector( - vector, - values, - (vec, i, val) -> { - var months = (int) val.getPeriod().toTotalMonths(); - var days = val.getPeriod().getDays(); - var nanos = val.getDuration().toNanos(); - vec.set(i, months, days, nanos); - }); - } else if (vector.getDataVector() instanceof VarCharVector) { - TypeUtils.fillListVector( - vector, values, (vec, i, val) -> vec.set(i, val.getBytes())); - } else if (vector.getDataVector() instanceof LargeVarCharVector) { - TypeUtils.fillListVector( - vector, values, (vec, i, val) -> vec.set(i, val.getBytes())); - } else if (vector.getDataVector() instanceof VarBinaryVector) { - TypeUtils.fillListVector( - vector, values, (vec, i, val) -> vec.set(i, val)); - } else if (vector.getDataVector() instanceof StructVector) { - // flatten the `values` - var flattenLength = 0; - for (int i = 0; i < values.length; i++) { - if (values[i] == null) { - continue; - } - var len = Array.getLength(values[i]); - vector.startNewValue(i); - vector.endValue(i, len); - flattenLength += len; + // flatten the `values` + var flattenLength = 0; + for (int i = 0; i < values.length; i++) { + if (values[i] == null) { + continue; } - var flattenValues = new Object[flattenLength]; - var ii = 0; - for (var list : values) { - if (list == null) { - continue; - } - var length = Array.getLength(list); - for (int i = 0; i < length; i++) { - flattenValues[ii++] = Array.get(list, i); - } + var len = Array.getLength(values[i]); + vector.startNewValue(i); + vector.endValue(i, len); + flattenLength += len; + } + var flattenValues = new Object[flattenLength]; + var ii = 0; + for (var list : values) { + if (list == null) { + continue; + } + var length = Array.getLength(list); + for (int i = 0; i < length; i++) { + flattenValues[ii++] = Array.get(list, i); } - fillVector(vector.getDataVector(), flattenValues); - } else { - throw new IllegalArgumentException( - "Unsupported type: " + vector.getDataVector().getClass()); } + // fill the inner vector + fillVector(vector.getDataVector(), flattenValues); } else if (fieldVector instanceof StructVector) { var vector = (StructVector) fieldVector; vector.allocateNew(); @@ -430,33 +396,6 @@ static void fillVector(FieldVector fieldVector, Object[] values) { fieldVector.setValueCount(values.length); } - @FunctionalInterface - interface TriFunction { - void apply(T t, U u, V v); - } - - @SuppressWarnings("unchecked") - static void fillListVector( - ListVector vector, Object[] values, TriFunction set) { - var innerVector = (V) vector.getDataVector(); - int ii = 0; - for (int i = 0; i < values.length; i++) { - var array = (T[]) values[i]; - if (array == null) { - continue; - } - vector.startNewValue(i); - for (T v : array) { - if (v == null) { - innerVector.setNull(ii++); - } else { - set.apply(innerVector, ii++, v); - } - } - vector.endValue(i, array.length); - } - } - static long timestampToMicros(LocalDateTime timestamp) { var date = timestamp.toLocalDate().toEpochDay(); var time = timestamp.toLocalTime().toNanoOfDay(); @@ -476,6 +415,10 @@ static Function processFunc0(Field field, Class targetClass) } else if (field.getType() instanceof ArrowType.LargeUtf8 && targetClass == String.class) { // object is org.apache.arrow.vector.util.Text return obj -> obj.toString(); + } else if (field.getType() instanceof ArrowType.LargeBinary + && targetClass == BigDecimal.class) { + // object is byte[] + return obj -> new BigDecimal(new String((byte[]) obj)); } else if (field.getType() instanceof ArrowType.Date && targetClass == LocalDate.class) { // object is Integer return obj -> LocalDate.ofEpochDay((int) obj); @@ -504,7 +447,7 @@ static Function processFunc0(Field field, Class targetClass) } else if (subfield.getType() .equals(new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE))) { return obj -> ((List) obj).stream().map(subfunc).toArray(Double[]::new); - } else if (subfield.getType() instanceof ArrowType.Decimal) { + } else if (subfield.getType() instanceof ArrowType.LargeBinary) { return obj -> ((List) obj).stream().map(subfunc).toArray(BigDecimal[]::new); } else if (subfield.getType() instanceof ArrowType.Date) { return obj -> ((List) obj).stream().map(subfunc).toArray(LocalDate[]::new); diff --git a/java/udf/src/main/java/com/risingwave/functions/UdfProducer.java b/java/udf/src/main/java/com/risingwave/functions/UdfProducer.java index 71b810726c67c..9f73ab8f9b86c 100644 --- a/java/udf/src/main/java/com/risingwave/functions/UdfProducer.java +++ b/java/udf/src/main/java/com/risingwave/functions/UdfProducer.java @@ -40,9 +40,9 @@ class UdfProducer extends NoOpFlightProducer { void addFunction(String name, UserDefinedFunction function) throws IllegalArgumentException { UserDefinedFunctionBatch udf; if (function instanceof ScalarFunction) { - udf = new ScalarFunctionBatch((ScalarFunction) function, this.allocator); + udf = new ScalarFunctionBatch((ScalarFunction) function); } else if (function instanceof TableFunction) { - udf = new TableFunctionBatch((TableFunction) function, this.allocator); + udf = new TableFunctionBatch((TableFunction) function); } else { throw new IllegalArgumentException( "Unknown function type: " + function.getClass().getName()); @@ -76,21 +76,26 @@ public FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor @Override public void doExchange(CallContext context, FlightStream reader, ServerStreamListener writer) { - try { + try (var allocator = this.allocator.newChildAllocator("exchange", 0, Long.MAX_VALUE)) { var functionName = reader.getDescriptor().getPath().get(0); logger.debug("call function: " + functionName); var udf = this.functions.get(functionName); - try (var root = VectorSchemaRoot.create(udf.getOutputSchema(), this.allocator)) { + try (var root = VectorSchemaRoot.create(udf.getOutputSchema(), allocator)) { var loader = new VectorLoader(root); writer.start(root); while (reader.next()) { - var outputBatches = udf.evalBatch(reader.getRoot()); - while (outputBatches.hasNext()) { - var outputRoot = outputBatches.next(); - var unloader = new VectorUnloader(outputRoot); - loader.load(unloader.getRecordBatch()); - writer.putNext(); + try (var input = reader.getRoot()) { + var outputBatches = udf.evalBatch(input, allocator); + while (outputBatches.hasNext()) { + try (var outputRoot = outputBatches.next()) { + var unloader = new VectorUnloader(outputRoot); + try (var outputBatch = unloader.getRecordBatch()) { + loader.load(outputBatch); + } + } + writer.putNext(); + } } } writer.completed(); diff --git a/java/udf/src/main/java/com/risingwave/functions/UserDefinedFunctionBatch.java b/java/udf/src/main/java/com/risingwave/functions/UserDefinedFunctionBatch.java index d271ac1245ee4..a32aa85d509db 100644 --- a/java/udf/src/main/java/com/risingwave/functions/UserDefinedFunctionBatch.java +++ b/java/udf/src/main/java/com/risingwave/functions/UserDefinedFunctionBatch.java @@ -28,7 +28,6 @@ abstract class UserDefinedFunctionBatch { protected Schema inputSchema; protected Schema outputSchema; - protected BufferAllocator allocator; /** Get the input schema of the function. */ Schema getInputSchema() { @@ -44,9 +43,11 @@ Schema getOutputSchema() { * Evaluate the function by processing a batch of input data. * * @param batch the input data batch to process + * @param allocator the allocator to use for allocating output data * @return an iterator over the output data batches */ - abstract Iterator evalBatch(VectorSchemaRoot batch); + abstract Iterator evalBatch( + VectorSchemaRoot batch, BufferAllocator allocator); } /** Utility class for reflection. */ diff --git a/java/udf/src/test/java/com/risingwave/functions/TestUdfServer.java b/java/udf/src/test/java/com/risingwave/functions/TestUdfServer.java index c72774c641aed..25977ec15e45c 100644 --- a/java/udf/src/test/java/com/risingwave/functions/TestUdfServer.java +++ b/java/udf/src/test/java/com/risingwave/functions/TestUdfServer.java @@ -186,9 +186,9 @@ public void all_types() throws Exception { c5.set(0, 1); c5.setValueCount(2); - var c6 = new DecimalVector("", allocator, 38, 0); + var c6 = new LargeVarBinaryVector("", allocator); c6.allocateNew(2); - c6.set(0, BigDecimal.valueOf(10).pow(37)); + c6.set(0, "1.234".getBytes()); c6.setValueCount(2); var c7 = new DateDayVector("", allocator); @@ -255,7 +255,7 @@ public void all_types() throws Exception { var output = stream.getRoot(); assertTrue(stream.next()); assertEquals( - "{\"bool\":true,\"i16\":1,\"i32\":1,\"i64\":1,\"f32\":1.0,\"f64\":1.0,\"decimal\":10000000000000000000000000000000000000,\"date\":19358,\"time\":3723000000,\"timestamp\":[2023,1,1,1,2,3],\"interval\":{\"period\":\"P1000M2000D\",\"duration\":0.000003000},\"str\":\"string\",\"bytes\":\"Ynl0ZXM=\",\"jsonb\":\"{ key: 1 }\",\"struct\":{\"f1\":1,\"f2\":2}}\n{}", + "{\"bool\":true,\"i16\":1,\"i32\":1,\"i64\":1,\"f32\":1.0,\"f64\":1.0,\"decimal\":\"MS4yMzQ=\",\"date\":19358,\"time\":3723000000,\"timestamp\":[2023,1,1,1,2,3],\"interval\":{\"period\":\"P1000M2000D\",\"duration\":0.000003000},\"str\":\"string\",\"bytes\":\"Ynl0ZXM=\",\"jsonb\":\"{ key: 1 }\",\"struct\":{\"f1\":1,\"f2\":2}}\n{}", output.contentToTSVString().trim()); } } diff --git a/src/common/src/array/arrow.rs b/src/common/src/array/arrow.rs index a04584fbce005..7fc6d277bce9e 100644 --- a/src/common/src/array/arrow.rs +++ b/src/common/src/array/arrow.rs @@ -143,7 +143,10 @@ converts_generic! { { arrow_array::Float64Array, Float64, ArrayImpl::Float64 }, { arrow_array::StringArray, Utf8, ArrayImpl::Utf8 }, { arrow_array::BooleanArray, Boolean, ArrayImpl::Bool }, - { arrow_array::Decimal128Array, Decimal128(_, _), ArrayImpl::Decimal }, + // Arrow doesn't have a data type to represent unconstrained numeric (`DECIMAL` in RisingWave and + // Postgres). So we pick a special type `LargeBinary` for it. + // Values stored in the array are the string representation of the decimal. e.g. b"1.234", b"+inf" + { arrow_array::LargeBinaryArray, LargeBinary, ArrayImpl::Decimal }, { arrow_array::Decimal256Array, Decimal256(_, _), ArrayImpl::Int256 }, { arrow_array::Date32Array, Date32, ArrayImpl::Date }, { arrow_array::TimestampMicrosecondArray, Timestamp(Microsecond, None), ArrayImpl::Timestamp }, @@ -169,7 +172,7 @@ impl From<&arrow_schema::DataType> for DataType { Int64 => Self::Int64, Float32 => Self::Float32, Float64 => Self::Float64, - Decimal128(_, _) => Self::Decimal, + LargeBinary => Self::Decimal, Decimal256(_, _) => Self::Int256, Date32 => Self::Date, Time64(Microsecond) => Self::Time, @@ -237,7 +240,7 @@ impl TryFrom<&DataType> for arrow_schema::DataType { DataType::Varchar => Ok(Self::Utf8), DataType::Jsonb => Ok(Self::LargeUtf8), DataType::Bytea => Ok(Self::Binary), - DataType::Decimal => Ok(Self::Decimal128(38, 0)), // arrow precision can not be 0 + DataType::Decimal => Ok(Self::LargeBinary), DataType::Struct(struct_type) => Ok(Self::Struct( struct_type .iter() @@ -462,53 +465,33 @@ impl FromIntoArrow for Interval { } } -// RisingWave Decimal type is self-contained, but Arrow is not. -// In Arrow DecimalArray, the scale is stored in data type as metadata, and the mantissa is stored -// as i128 in the array. -impl From<&DecimalArray> for arrow_array::Decimal128Array { +impl From<&DecimalArray> for arrow_array::LargeBinaryArray { fn from(array: &DecimalArray) -> Self { - let max_scale = array - .iter() - .filter_map(|o| o.map(|v| v.scale().unwrap_or(0))) - .max() - .unwrap_or(0) as u32; - let mut builder = arrow_array::builder::Decimal128Builder::with_capacity(array.len()) - .with_data_type(arrow_schema::DataType::Decimal128(38, max_scale as i8)); + let mut builder = + arrow_array::builder::LargeBinaryBuilder::with_capacity(array.len(), array.len() * 8); for value in array.iter() { - builder.append_option(value.map(|d| decimal_to_i128(d, max_scale))); + builder.append_option(value.map(|d| d.to_string())); } builder.finish() } } -fn decimal_to_i128(value: Decimal, scale: u32) -> i128 { - match value { - Decimal::Normalized(mut d) => { - d.rescale(scale); - d.mantissa() - } - Decimal::NaN => i128::MIN + 1, - Decimal::PositiveInf => i128::MAX, - Decimal::NegativeInf => i128::MIN, - } -} +impl TryFrom<&arrow_array::LargeBinaryArray> for DecimalArray { + type Error = ArrayError; -impl From<&arrow_array::Decimal128Array> for DecimalArray { - fn from(array: &arrow_array::Decimal128Array) -> Self { - assert!(array.scale() >= 0, "todo: support negative scale"); - let from_arrow = |value| { - const NAN: i128 = i128::MIN + 1; - match value { - NAN => Decimal::NaN, - i128::MAX => Decimal::PositiveInf, - i128::MIN => Decimal::NegativeInf, - _ => Decimal::Normalized(rust_decimal::Decimal::from_i128_with_scale( - value, - array.scale() as u32, - )), - } - }; - array.iter().map(|o| o.map(from_arrow)).collect() + fn try_from(array: &arrow_array::LargeBinaryArray) -> Result { + array + .iter() + .map(|o| { + o.map(|s| { + let s = std::str::from_utf8(s) + .map_err(|_| ArrayError::FromArrow(format!("invalid decimal: {s:?}")))?; + s.parse() + .map_err(|_| ArrayError::FromArrow(format!("invalid decimal: {s:?}"))) + }) + .transpose() + }) + .try_collect() } } @@ -631,20 +614,12 @@ impl TryFrom<&ListArray> for arrow_array::ListArray { b.append_option(v) }) } - ArrayImpl::Decimal(a) => { - let max_scale = a - .iter() - .filter_map(|o| o.map(|v| v.scale().unwrap_or(0))) - .max() - .unwrap_or(0) as u32; - build( - array, - a, - Decimal128Builder::with_capacity(a.len()) - .with_data_type(arrow_schema::DataType::Decimal128(38, max_scale as i8)), - |b, v| b.append_option(v.map(|d| decimal_to_i128(d, max_scale))), - ) - } + ArrayImpl::Decimal(a) => build( + array, + a, + LargeBinaryBuilder::with_capacity(a.len(), a.len() * 8), + |b, v| b.append_option(v.map(|d| d.to_string())), + ), ArrayImpl::Interval(a) => build( array, a, @@ -842,8 +817,8 @@ mod tests { Some(Decimal::Normalized("123.4".parse().unwrap())), Some(Decimal::Normalized("123.456".parse().unwrap())), ]); - let arrow = arrow_array::Decimal128Array::from(&array); - assert_eq!(DecimalArray::from(&arrow), array); + let arrow = arrow_array::LargeBinaryArray::from(&array); + assert_eq!(DecimalArray::try_from(&arrow).unwrap(), array); } #[test] diff --git a/src/expr/core/src/expr/expr_udf.rs b/src/expr/core/src/expr/expr_udf.rs index 234df130e22dd..5df8f9d3a6b61 100644 --- a/src/expr/core/src/expr/expr_udf.rs +++ b/src/expr/core/src/expr/expr_udf.rs @@ -144,8 +144,7 @@ impl UdfExpression { ); } - let data_chunk = - DataChunk::try_from(&output).expect("failed to convert UDF output to DataChunk"); + let data_chunk = DataChunk::try_from(&output)?; let output = data_chunk.uncompact(vis.clone()); let Some(array) = output.columns().first() else { diff --git a/src/udf/python/CHANGELOG.md b/src/udf/python/CHANGELOG.md index e035aab4ebb9e..a20411e69d83e 100644 --- a/src/udf/python/CHANGELOG.md +++ b/src/udf/python/CHANGELOG.md @@ -7,6 +7,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.1.1] - 2023-12-06 + +### Fixed + +- Fix decimal type output. + +## [0.1.0] - 2023-12-01 + +### Fixed + +- Fix unconstrained decimal type. + ## [0.0.12] - 2023-11-28 ### Changed diff --git a/src/udf/python/pyproject.toml b/src/udf/python/pyproject.toml index 67d17db55dadc..b535355168363 100644 --- a/src/udf/python/pyproject.toml +++ b/src/udf/python/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "risingwave" -version = "0.0.12" +version = "0.1.1" authors = [{ name = "RisingWave Labs" }] description = "RisingWave Python API" readme = "README.md" diff --git a/src/udf/python/risingwave/test_udf.py b/src/udf/python/risingwave/test_udf.py index e331e12f3a761..595695d6a07a6 100644 --- a/src/udf/python/risingwave/test_udf.py +++ b/src/udf/python/risingwave/test_udf.py @@ -199,7 +199,7 @@ def test_all_types(): pa.array([1], type=pa.int64()), pa.array([1], type=pa.float32()), pa.array([1], type=pa.float64()), - pa.array([10**37], type=pa.decimal128(38)), + pa.array(["12345678901234567890.1234567890"], type=pa.large_binary()), pa.array([datetime.date(2023, 6, 1)], type=pa.date32()), pa.array([datetime.time(1, 2, 3, 456789)], type=pa.time64("us")), pa.array( @@ -229,7 +229,7 @@ def test_all_types(): 1, 1.0, 1.0, - 10**37, + b"12345678901234567890.1234567890", datetime.date(2023, 6, 1), datetime.time(1, 2, 3, 456789), datetime.datetime(2023, 6, 1, 1, 2, 3, 456789), diff --git a/src/udf/python/risingwave/udf.py b/src/udf/python/risingwave/udf.py index 6443034c9147a..803ab1acbcbfb 100644 --- a/src/udf/python/risingwave/udf.py +++ b/src/udf/python/risingwave/udf.py @@ -21,6 +21,7 @@ import json from concurrent.futures import ThreadPoolExecutor import concurrent +from decimal import Decimal import signal @@ -96,6 +97,7 @@ def _process_func(type: pa.DataType, output: bool) -> Callable: if pa.types.is_list(type): func = _process_func(type.value_type, output) return lambda array: [(func(v) if v is not None else None) for v in array] + if pa.types.is_struct(type): funcs = [_process_func(field.type, output) for field in type] if output: @@ -109,11 +111,26 @@ def _process_func(type: pa.DataType, output: bool) -> Callable: (func(v) if v is not None else None) for v, func in zip(map.values(), funcs) ) - if pa.types.is_large_string(type): + + if type.equals(JSONB): if output: return lambda v: json.dumps(v) else: return lambda v: json.loads(v) + + if type.equals(UNCONSTRAINED_DECIMAL): + if output: + + def decimal_to_str(v): + if not isinstance(v, Decimal): + raise ValueError(f"Expected Decimal, got {v}") + # use `f` format to avoid scientific notation, e.g. `1e10` + return format(v, "f").encode("utf-8") + + return decimal_to_str + else: + return lambda v: Decimal(v.decode("utf-8")) + return lambda v: v @@ -416,6 +433,11 @@ def _to_data_type(t: Union[str, pa.DataType]) -> pa.DataType: return t +# we use `large_binary` to represent unconstrained decimal type +UNCONSTRAINED_DECIMAL = pa.large_binary() +JSONB = pa.large_string() + + def _string_to_data_type(type_str: str): """ Convert a SQL data type string to `pyarrow.DataType`. @@ -437,7 +459,7 @@ def _string_to_data_type(type_str: str): return pa.float64() elif type_str.startswith("DECIMAL") or type_str.startswith("NUMERIC"): if type_str == "DECIMAL" or type_str == "NUMERIC": - return pa.decimal128(38) + return UNCONSTRAINED_DECIMAL rest = type_str[8:-1] # remove "DECIMAL(" and ")" if "," in rest: precision, scale = rest.split(",") @@ -455,7 +477,7 @@ def _string_to_data_type(type_str: str): elif type_str in ("VARCHAR"): return pa.string() elif type_str in ("JSONB"): - return pa.large_string() + return JSONB elif type_str in ("BYTEA"): return pa.binary() elif type_str.startswith("STRUCT"): @@ -499,8 +521,10 @@ def _data_type_to_string(t: pa.DataType) -> str: return "FLOAT4" elif t.equals(pa.float64()): return "FLOAT8" - elif t.equals(pa.decimal128(38)): + elif t.equals(UNCONSTRAINED_DECIMAL): return "DECIMAL" + elif pa.types.is_decimal(t): + return f"DECIMAL({t.precision},{t.scale})" elif t.equals(pa.date32()): return "DATE" elif t.equals(pa.time64("us")): @@ -511,7 +535,7 @@ def _data_type_to_string(t: pa.DataType) -> str: return "INTERVAL" elif t.equals(pa.string()): return "VARCHAR" - elif t.equals(pa.large_string()): + elif t.equals(JSONB): return "JSONB" elif t.equals(pa.binary()): return "BYTEA"