From 40c020f75428648b2bb9130a035c5abba02fa4ab Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Wed, 29 Nov 2023 17:47:31 +0800 Subject: [PATCH] fix(udf): support struct[] type in struct (#13689) Signed-off-by: Runji Wang Co-authored-by: xxchan --- e2e_test/udf/test.py | 8 +- e2e_test/udf/udf.slt | 22 ++-- java/udf-example/pom.xml | 4 +- .../src/main/java/com/example/UdfExample.java | 28 ++++- java/udf/CHANGELOG.md | 22 ++++ java/udf/pom.xml | 8 +- .../com/risingwave/functions/TypeUtils.java | 61 +++++++++-- .../com/risingwave/functions/UdfProducer.java | 103 ++++++++++++++++++ .../com/risingwave/functions/UdfServer.java | 84 -------------- .../risingwave/functions/TestUdfServer.java | 35 +++++- src/common/src/array/arrow.rs | 76 +++++++++---- src/common/src/array/list_array.rs | 11 ++ src/common/src/types/mod.rs | 1 + src/expr/core/src/expr/expr_udf.rs | 8 +- .../core/src/table_function/user_defined.rs | 8 +- src/udf/python/CHANGELOG.md | 11 ++ src/udf/python/pyproject.toml | 2 +- src/udf/python/risingwave/udf.py | 23 +++- 18 files changed, 365 insertions(+), 150 deletions(-) create mode 100644 java/udf/CHANGELOG.md create mode 100644 java/udf/src/main/java/com/risingwave/functions/UdfProducer.java diff --git a/e2e_test/udf/test.py b/e2e_test/udf/test.py index 45db54a8113b..6f1aa115bf95 100644 --- a/e2e_test/udf/test.py +++ b/e2e_test/udf/test.py @@ -117,7 +117,9 @@ def jsonb_array_struct_identity(v: Tuple[List[Any], int]) -> Tuple[List[Any], in ALL_TYPES = "BOOLEAN,SMALLINT,INT,BIGINT,FLOAT4,FLOAT8,DECIMAL,DATE,TIME,TIMESTAMP,INTERVAL,VARCHAR,BYTEA,JSONB".split( "," -) +) + [ + "STRUCT" +] @udf( @@ -139,6 +141,7 @@ def return_all( varchar, bytea, jsonb, + struct, ): return ( bool, @@ -155,6 +158,7 @@ def return_all( varchar, bytea, jsonb, + struct, ) @@ -177,6 +181,7 @@ def return_all_arrays( varchar, bytea, jsonb, + struct, ): return ( bool, @@ -193,6 +198,7 @@ def return_all_arrays( varchar, bytea, jsonb, + struct, ) diff --git a/e2e_test/udf/udf.slt b/e2e_test/udf/udf.slt index b6d3161b7d3f..f5ab290a69a8 100644 --- a/e2e_test/udf/udf.slt +++ b/e2e_test/udf/udf.slt @@ -59,13 +59,13 @@ create function jsonb_array_struct_identity(struct) returns as jsonb_array_struct_identity using link 'http://localhost:8815'; statement ok -create function return_all(BOOLEAN,SMALLINT,INT,BIGINT,FLOAT4,FLOAT8,DECIMAL,DATE,TIME,TIMESTAMP,INTERVAL,VARCHAR,BYTEA,JSONB) -returns struct +create function return_all(BOOLEAN,SMALLINT,INT,BIGINT,FLOAT4,FLOAT8,DECIMAL,DATE,TIME,TIMESTAMP,INTERVAL,VARCHAR,BYTEA,JSONB,STRUCT) +returns struct> as return_all using link 'http://localhost:8815'; statement ok -create function return_all_arrays(BOOLEAN[],SMALLINT[],INT[],BIGINT[],FLOAT4[],FLOAT8[],DECIMAL[],DATE[],TIME[],TIMESTAMP[],INTERVAL[],VARCHAR[],BYTEA[],JSONB[]) -returns struct +create function return_all_arrays(BOOLEAN[],SMALLINT[],INT[],BIGINT[],FLOAT4[],FLOAT8[],DECIMAL[],DATE[],TIME[],TIMESTAMP[],INTERVAL[],VARCHAR[],BYTEA[],JSONB[],STRUCT[]) +returns struct[]> as return_all_arrays using link 'http://localhost:8815'; query TTTTT rowsort @@ -81,8 +81,8 @@ jsonb_access jsonb, integer jsonb (empty) http://localhost:8815 jsonb_array_identity jsonb[] jsonb[] (empty) http://localhost:8815 jsonb_array_struct_identity struct struct (empty) http://localhost:8815 jsonb_concat jsonb[] jsonb (empty) http://localhost:8815 -return_all boolean, smallint, integer, bigint, real, double precision, numeric, date, time without time zone, timestamp without time zone, interval, character varying, bytea, jsonb struct (empty) http://localhost:8815 -return_all_arrays boolean[], smallint[], integer[], bigint[], real[], double precision[], numeric[], date[], time without time zone[], timestamp without time zone[], interval[], character varying[], bytea[], jsonb[] struct (empty) http://localhost:8815 +return_all boolean, smallint, integer, bigint, real, double precision, numeric, date, time without time zone, timestamp without time zone, interval, character varying, bytea, jsonb, struct struct> (empty) http://localhost:8815 +return_all_arrays boolean[], smallint[], integer[], bigint[], real[], double precision[], numeric[], date[], time without time zone[], timestamp without time zone[], interval[], character varying[], bytea[], jsonb[], struct[] struct[]> (empty) http://localhost:8815 series integer integer (empty) http://localhost:8815 split character varying struct (empty) http://localhost:8815 @@ -149,10 +149,11 @@ select (return_all( interval '1 month 2 days 3 seconds', 'string', 'bytes'::bytea, - '{"key":1}'::jsonb + '{"key":1}'::jsonb, + row(1, 2)::struct )).*; ---- -t 1 1 1 1 1 1234567890123456789012345678 2023-06-01 01:02:03.456789 2023-06-01 01:02:03.456789 1 mon 2 days 00:00:03 string \x6279746573 {"key": 1} +t 1 1 1 1 1 1234567890123456789012345678 2023-06-01 01:02:03.456789 2023-06-01 01:02:03.456789 1 mon 2 days 00:00:03 string \x6279746573 {"key": 1} (1,2) query T select (return_all_arrays( @@ -169,10 +170,11 @@ select (return_all_arrays( array[null, interval '1 month 2 days 3 seconds'], array[null, 'string'], array[null, 'bytes'::bytea], - array[null, '{"key":1}'::jsonb] + array[null, '{"key":1}'::jsonb], + array[null, row(1, 2)::struct] )).*; ---- -{NULL,t} {NULL,1} {NULL,1} {NULL,1} {NULL,1} {NULL,1} {NULL,1234567890123456789012345678} {NULL,2023-06-01} {NULL,01:02:03.456789} {NULL,"2023-06-01 01:02:03.456789"} {NULL,"1 mon 2 days 00:00:03"} {NULL,string} {NULL,"\\x6279746573"} {NULL,"{\"key\": 1}"} +{NULL,t} {NULL,1} {NULL,1} {NULL,1} {NULL,1} {NULL,1} {NULL,1234567890123456789012345678} {NULL,2023-06-01} {NULL,01:02:03.456789} {NULL,"2023-06-01 01:02:03.456789"} {NULL,"1 mon 2 days 00:00:03"} {NULL,string} {NULL,"\\x6279746573"} {NULL,"{\"key\": 1}"} {NULL,"(1,2)"} query I select series(5); diff --git a/java/udf-example/pom.xml b/java/udf-example/pom.xml index c71022ccd298..e47ff0263e61 100644 --- a/java/udf-example/pom.xml +++ b/java/udf-example/pom.xml @@ -15,7 +15,7 @@ com.risingwave risingwave-udf-example - 0.1.0-SNAPSHOT + 0.1.1-SNAPSHOT udf-example https://docs.risingwave.com/docs/current/udf-java @@ -31,7 +31,7 @@ com.risingwave risingwave-udf - 0.1.0-SNAPSHOT + 0.1.1-SNAPSHOT com.google.code.gson diff --git a/java/udf-example/src/main/java/com/example/UdfExample.java b/java/udf-example/src/main/java/com/example/UdfExample.java index 3f07f457fd81..ef673e61c1d9 100644 --- a/java/udf-example/src/main/java/com/example/UdfExample.java +++ b/java/udf-example/src/main/java/com/example/UdfExample.java @@ -190,6 +190,16 @@ public static class Row { public String str; public byte[] bytes; public @DataTypeHint("JSONB") String jsonb; + public Struct struct; + } + + public static class Struct { + public Integer f1; + public Integer f2; + + public String toString() { + return String.format("(%d, %d)", f1, f2); + } } public Row eval( @@ -206,7 +216,8 @@ public Row eval( PeriodDuration interval, String str, byte[] bytes, - @DataTypeHint("JSONB") String jsonb) { + @DataTypeHint("JSONB") String jsonb, + Struct struct) { var row = new Row(); row.bool = bool; row.i16 = i16; @@ -222,6 +233,7 @@ public Row eval( row.str = str; row.bytes = bytes; row.jsonb = jsonb; + row.struct = struct; return row; } } @@ -242,6 +254,16 @@ public static class Row { public String[] str; public byte[][] bytes; public @DataTypeHint("JSONB[]") String[] jsonb; + public Struct[] struct; + } + + public static class Struct { + public Integer f1; + public Integer f2; + + public String toString() { + return String.format("(%d, %d)", f1, f2); + } } public Row eval( @@ -258,7 +280,8 @@ public Row eval( PeriodDuration[] interval, String[] str, byte[][] bytes, - @DataTypeHint("JSONB[]") String[] jsonb) { + @DataTypeHint("JSONB[]") String[] jsonb, + Struct[] struct) { var row = new Row(); row.bool = bool; row.i16 = i16; @@ -274,6 +297,7 @@ public Row eval( row.str = str; row.bytes = bytes; row.jsonb = jsonb; + row.struct = struct; return row; } } diff --git a/java/udf/CHANGELOG.md b/java/udf/CHANGELOG.md new file mode 100644 index 000000000000..5f46b7ae339b --- /dev/null +++ b/java/udf/CHANGELOG.md @@ -0,0 +1,22 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [Unreleased] + +## [0.1.1] - 2023-11-28 + +### Added + +- Support struct in struct and struct[] in struct. + +### Changed + +- Bump Arrow version to 14. + +## [0.1.0] - 2023-09-01 + +- Initial release. \ No newline at end of file diff --git a/java/udf/pom.xml b/java/udf/pom.xml index e9ae758f5a88..7e9814b4af41 100644 --- a/java/udf/pom.xml +++ b/java/udf/pom.xml @@ -6,7 +6,7 @@ com.risingwave risingwave-udf jar - 0.1.0-SNAPSHOT + 0.1.1-SNAPSHOT risingwave-java-root @@ -28,12 +28,12 @@ org.apache.arrow arrow-vector - 13.0.0 + 14.0.0 org.apache.arrow flight-core - 13.0.0 + 14.0.0 org.slf4j @@ -55,4 +55,4 @@ - + \ No newline at end of file diff --git a/java/udf/src/main/java/com/risingwave/functions/TypeUtils.java b/java/udf/src/main/java/com/risingwave/functions/TypeUtils.java index 9485ace03e2d..aedba8abf982 100644 --- a/java/udf/src/main/java/com/risingwave/functions/TypeUtils.java +++ b/java/udf/src/main/java/com/risingwave/functions/TypeUtils.java @@ -15,6 +15,7 @@ package com.risingwave.functions; import java.lang.invoke.MethodHandles; +import java.lang.reflect.Array; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.lang.reflect.ParameterizedType; @@ -133,7 +134,7 @@ static Field classToField(Class param, DataTypeHint hint, String name) { var subhint = field.getAnnotation(DataTypeHint.class); fields.add(classToField(field.getType(), subhint, field.getName())); } - return new Field("", FieldType.nullable(new ArrowType.Struct()), fields); + return new Field(name, FieldType.nullable(new ArrowType.Struct()), fields); // TODO: more types // throw new IllegalArgumentException("Unsupported type: " + param); } @@ -359,22 +360,57 @@ static void fillVector(FieldVector fieldVector, Object[] values) { } else if (vector.getDataVector() instanceof VarBinaryVector) { TypeUtils.fillListVector( vector, values, (vec, i, val) -> vec.set(i, val)); + } else if (vector.getDataVector() instanceof StructVector) { + // flatten the `values` + var flattenLength = 0; + for (int i = 0; i < values.length; i++) { + if (values[i] == null) { + continue; + } + var len = Array.getLength(values[i]); + vector.startNewValue(i); + vector.endValue(i, len); + flattenLength += len; + } + var flattenValues = new Object[flattenLength]; + var ii = 0; + for (var list : values) { + if (list == null) { + continue; + } + var length = Array.getLength(list); + for (int i = 0; i < length; i++) { + flattenValues[ii++] = Array.get(list, i); + } + } + fillVector(vector.getDataVector(), flattenValues); } else { - throw new IllegalArgumentException("Unsupported type: " + fieldVector.getClass()); + throw new IllegalArgumentException( + "Unsupported type: " + vector.getDataVector().getClass()); } } else if (fieldVector instanceof StructVector) { var vector = (StructVector) fieldVector; vector.allocateNew(); var lookup = MethodHandles.lookup(); + // get class of the first non-null value + Class valueClass = null; + for (int i = 0; i < values.length; i++) { + if (values[i] != null) { + valueClass = values[i].getClass(); + break; + } + } for (var field : vector.getField().getChildren()) { // extract field from values var subvalues = new Object[values.length]; - if (values.length != 0) { + if (valueClass != null) { try { - var javaField = values[0].getClass().getDeclaredField(field.getName()); + var javaField = valueClass.getDeclaredField(field.getName()); var varHandle = lookup.unreflectVarHandle(javaField); for (int i = 0; i < values.length; i++) { - subvalues[i] = varHandle.get(values[i]); + if (values[i] != null) { + subvalues[i] = varHandle.get(values[i]); + } } } catch (NoSuchFieldException | IllegalAccessException e) { throw new RuntimeException(e); @@ -384,7 +420,9 @@ static void fillVector(FieldVector fieldVector, Object[] values) { fillVector(subvector, subvalues); } for (int i = 0; i < values.length; i++) { - vector.setIndexDefined(i); + if (values[i] != null) { + vector.setIndexDefined(i); + } } } else { throw new IllegalArgumentException("Unsupported type: " + fieldVector.getClass()); @@ -482,8 +520,17 @@ static Function processFunc0(Field field, Class targetClass) return obj -> ((List) obj).stream().map(subfunc).toArray(String[]::new); } else if (subfield.getType() instanceof ArrowType.Binary) { return obj -> ((List) obj).stream().map(subfunc).toArray(byte[][]::new); + } else if (subfield.getType() instanceof ArrowType.Struct) { + return obj -> { + var list = (List) obj; + Object array = Array.newInstance(targetClass.getComponentType(), list.size()); + for (int i = 0; i < list.size(); i++) { + Array.set(array, i, subfunc.apply(list.get(i))); + } + return array; + }; } - throw new IllegalArgumentException("Unsupported type: " + field.getType()); + throw new IllegalArgumentException("Unsupported type: " + subfield.getType()); } else if (field.getType() instanceof ArrowType.Struct) { // object is org.apache.arrow.vector.util.JsonStringHashMap var subfields = field.getChildren(); diff --git a/java/udf/src/main/java/com/risingwave/functions/UdfProducer.java b/java/udf/src/main/java/com/risingwave/functions/UdfProducer.java new file mode 100644 index 000000000000..71b810726c67 --- /dev/null +++ b/java/udf/src/main/java/com/risingwave/functions/UdfProducer.java @@ -0,0 +1,103 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.risingwave.functions; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import org.apache.arrow.flight.*; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.VectorLoader; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +class UdfProducer extends NoOpFlightProducer { + + private BufferAllocator allocator; + private HashMap functions = new HashMap<>(); + private static final Logger logger = LoggerFactory.getLogger(UdfServer.class); + + UdfProducer(BufferAllocator allocator) { + this.allocator = allocator; + } + + void addFunction(String name, UserDefinedFunction function) throws IllegalArgumentException { + UserDefinedFunctionBatch udf; + if (function instanceof ScalarFunction) { + udf = new ScalarFunctionBatch((ScalarFunction) function, this.allocator); + } else if (function instanceof TableFunction) { + udf = new TableFunctionBatch((TableFunction) function, this.allocator); + } else { + throw new IllegalArgumentException( + "Unknown function type: " + function.getClass().getName()); + } + if (functions.containsKey(name)) { + throw new IllegalArgumentException("Function already exists: " + name); + } + functions.put(name, udf); + } + + @Override + public FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor) { + try { + var functionName = descriptor.getPath().get(0); + var udf = functions.get(functionName); + if (udf == null) { + throw new IllegalArgumentException("Unknown function: " + functionName); + } + var fields = new ArrayList(); + fields.addAll(udf.getInputSchema().getFields()); + fields.addAll(udf.getOutputSchema().getFields()); + var fullSchema = new Schema(fields); + var inputLen = udf.getInputSchema().getFields().size(); + + return new FlightInfo(fullSchema, descriptor, Collections.emptyList(), 0, inputLen); + } catch (Exception e) { + logger.error("Error occurred during getFlightInfo", e); + throw e; + } + } + + @Override + public void doExchange(CallContext context, FlightStream reader, ServerStreamListener writer) { + try { + var functionName = reader.getDescriptor().getPath().get(0); + logger.debug("call function: " + functionName); + + var udf = this.functions.get(functionName); + try (var root = VectorSchemaRoot.create(udf.getOutputSchema(), this.allocator)) { + var loader = new VectorLoader(root); + writer.start(root); + while (reader.next()) { + var outputBatches = udf.evalBatch(reader.getRoot()); + while (outputBatches.hasNext()) { + var outputRoot = outputBatches.next(); + var unloader = new VectorUnloader(outputRoot); + loader.load(unloader.getRecordBatch()); + writer.putNext(); + } + } + writer.completed(); + } + } catch (Exception e) { + logger.error("Error occurred during UDF execution", e); + writer.error(e); + } + } +} diff --git a/java/udf/src/main/java/com/risingwave/functions/UdfServer.java b/java/udf/src/main/java/com/risingwave/functions/UdfServer.java index 7d063a8d80d3..8a07ff4ef671 100644 --- a/java/udf/src/main/java/com/risingwave/functions/UdfServer.java +++ b/java/udf/src/main/java/com/risingwave/functions/UdfServer.java @@ -15,17 +15,8 @@ package com.risingwave.functions; import java.io.IOException; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; import org.apache.arrow.flight.*; -import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; -import org.apache.arrow.vector.VectorLoader; -import org.apache.arrow.vector.VectorSchemaRoot; -import org.apache.arrow.vector.VectorUnloader; -import org.apache.arrow.vector.types.pojo.Field; -import org.apache.arrow.vector.types.pojo.Schema; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -88,78 +79,3 @@ public void close() throws InterruptedException { this.server.close(); } } - -class UdfProducer extends NoOpFlightProducer { - - private BufferAllocator allocator; - private HashMap functions = new HashMap<>(); - private static final Logger logger = LoggerFactory.getLogger(UdfServer.class); - - UdfProducer(BufferAllocator allocator) { - this.allocator = allocator; - } - - void addFunction(String name, UserDefinedFunction function) throws IllegalArgumentException { - UserDefinedFunctionBatch udf; - if (function instanceof ScalarFunction) { - udf = new ScalarFunctionBatch((ScalarFunction) function, this.allocator); - } else if (function instanceof TableFunction) { - udf = new TableFunctionBatch((TableFunction) function, this.allocator); - } else { - throw new IllegalArgumentException( - "Unknown function type: " + function.getClass().getName()); - } - if (functions.containsKey(name)) { - throw new IllegalArgumentException("Function already exists: " + name); - } - functions.put(name, udf); - } - - @Override - public FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor) { - try { - var functionName = descriptor.getPath().get(0); - var udf = functions.get(functionName); - if (udf == null) { - throw new IllegalArgumentException("Unknown function: " + functionName); - } - var fields = new ArrayList(); - fields.addAll(udf.getInputSchema().getFields()); - fields.addAll(udf.getOutputSchema().getFields()); - var fullSchema = new Schema(fields); - var inputLen = udf.getInputSchema().getFields().size(); - - return new FlightInfo(fullSchema, descriptor, Collections.emptyList(), 0, inputLen); - } catch (Exception e) { - logger.error("Error occurred during getFlightInfo", e); - throw e; - } - } - - @Override - public void doExchange(CallContext context, FlightStream reader, ServerStreamListener writer) { - try { - var functionName = reader.getDescriptor().getPath().get(0); - logger.debug("call function: " + functionName); - - var udf = this.functions.get(functionName); - try (var root = VectorSchemaRoot.create(udf.getOutputSchema(), this.allocator)) { - var loader = new VectorLoader(root); - writer.start(root); - while (reader.next()) { - var outputBatches = udf.evalBatch(reader.getRoot()); - while (outputBatches.hasNext()) { - var outputRoot = outputBatches.next(); - var unloader = new VectorUnloader(outputRoot); - loader.load(unloader.getRecordBatch()); - writer.putNext(); - } - } - writer.completed(); - } - } catch (Exception e) { - logger.error("Error occurred during UDF execution", e); - writer.error(e); - } - } -} diff --git a/java/udf/src/test/java/com/risingwave/functions/TestUdfServer.java b/java/udf/src/test/java/com/risingwave/functions/TestUdfServer.java index daba383ab561..c72774c641ae 100644 --- a/java/udf/src/test/java/com/risingwave/functions/TestUdfServer.java +++ b/java/udf/src/test/java/com/risingwave/functions/TestUdfServer.java @@ -27,7 +27,9 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.*; +import org.apache.arrow.vector.complex.StructVector; import org.apache.arrow.vector.types.Types.MinorType; +import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.FieldType; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; @@ -104,6 +106,16 @@ public static class Row { public String str; public byte[] bytes; public @DataTypeHint("JSONB") String jsonb; + public Struct struct; + } + + public static class Struct { + public Integer f1; + public Integer f2; + + public String toString() { + return String.format("(%d, %d)", f1, f2); + } } public Row eval( @@ -120,7 +132,8 @@ public Row eval( PeriodDuration interval, String str, byte[] bytes, - @DataTypeHint("JSONB") String jsonb) { + @DataTypeHint("JSONB") String jsonb, + Struct struct) { var row = new Row(); row.bool = bool; row.i16 = i16; @@ -136,6 +149,7 @@ public Row eval( row.str = str; row.bytes = bytes; row.jsonb = jsonb; + row.struct = struct; return row; } } @@ -220,13 +234,28 @@ public void all_types() throws Exception { c13.set(0, "{ key: 1 }".getBytes()); c13.setValueCount(2); - var input = VectorSchemaRoot.of(c0, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13); + var c14 = + new StructVector( + "", allocator, FieldType.nullable(ArrowType.Struct.INSTANCE), null); + c14.allocateNew(); + var f1 = c14.addOrGet("f1", FieldType.nullable(MinorType.INT.getType()), IntVector.class); + var f2 = c14.addOrGet("f2", FieldType.nullable(MinorType.INT.getType()), IntVector.class); + f1.allocateNew(2); + f2.allocateNew(2); + f1.set(0, 1); + f2.set(0, 2); + c14.setIndexDefined(0); + c14.setValueCount(2); + + var input = + VectorSchemaRoot.of( + c0, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14); try (var stream = client.call("return_all", input)) { var output = stream.getRoot(); assertTrue(stream.next()); assertEquals( - "{\"bool\":true,\"i16\":1,\"i32\":1,\"i64\":1,\"f32\":1.0,\"f64\":1.0,\"decimal\":10000000000000000000000000000000000000,\"date\":19358,\"time\":3723000000,\"timestamp\":[2023,1,1,1,2,3],\"interval\":{\"period\":\"P1000M2000D\",\"duration\":0.000003000},\"str\":\"string\",\"bytes\":\"Ynl0ZXM=\",\"jsonb\":\"{ key: 1 }\"}\n{}", + "{\"bool\":true,\"i16\":1,\"i32\":1,\"i64\":1,\"f32\":1.0,\"f64\":1.0,\"decimal\":10000000000000000000000000000000000000,\"date\":19358,\"time\":3723000000,\"timestamp\":[2023,1,1,1,2,3],\"interval\":{\"period\":\"P1000M2000D\",\"duration\":0.000003000},\"str\":\"string\",\"bytes\":\"Ynl0ZXM=\",\"jsonb\":\"{ key: 1 }\",\"struct\":{\"f1\":1,\"f2\":2}}\n{}", output.contentToTSVString().trim()); } } diff --git a/src/common/src/array/arrow.rs b/src/common/src/array/arrow.rs index b2b197532f20..a04584fbce00 100644 --- a/src/common/src/array/arrow.rs +++ b/src/common/src/array/arrow.rs @@ -24,7 +24,7 @@ use itertools::Itertools; use super::*; use crate::types::{Int256, StructType}; -use crate::util::iter_util::{ZipEqDebug, ZipEqFast}; +use crate::util::iter_util::ZipEqFast; /// Converts RisingWave array to Arrow array with the schema. /// This function will try to convert the array if the type is not same with the schema. @@ -197,6 +197,17 @@ impl From<&arrow_schema::Fields> for StructType { } } +impl TryFrom<&StructType> for arrow_schema::Fields { + type Error = ArrayError; + + fn try_from(struct_type: &StructType) -> Result { + struct_type + .iter() + .map(|(name, ty)| Ok(Field::new(name, ty.try_into()?, true))) + .try_collect() + } +} + impl From for DataType { fn from(value: arrow_schema::DataType) -> Self { (&value).into() @@ -204,7 +215,7 @@ impl From for DataType { } impl TryFrom<&DataType> for arrow_schema::DataType { - type Error = String; + type Error = ArrayError; fn try_from(value: &DataType) -> Result { match value { @@ -231,26 +242,34 @@ impl TryFrom<&DataType> for arrow_schema::DataType { struct_type .iter() .map(|(name, ty)| Ok(Field::new(name, ty.try_into()?, true))) - .try_collect::<_, _, String>()?, + .try_collect::<_, _, ArrayError>()?, )), DataType::List(datatype) => Ok(Self::List(Arc::new(Field::new( "item", datatype.as_ref().try_into()?, true, )))), - DataType::Serial => Err("Serial type is not supported to convert to arrow".to_string()), + DataType::Serial => Err(ArrayError::ToArrow( + "Serial type is not supported to convert to arrow".to_string(), + )), } } } impl TryFrom for arrow_schema::DataType { - type Error = String; + type Error = ArrayError; fn try_from(value: DataType) -> Result { (&value).try_into() } } +impl From<&Bitmap> for arrow_buffer::NullBuffer { + fn from(bitmap: &Bitmap) -> Self { + bitmap.iter().collect() + } +} + /// Implement bi-directional `From` between concrete array types. macro_rules! converts { ($ArrayType:ty, $ArrowType:ty) => { @@ -547,8 +566,10 @@ impl From<&arrow_array::Decimal256Array> for Int256Array { } } -impl From<&ListArray> for arrow_array::ListArray { - fn from(array: &ListArray) -> Self { +impl TryFrom<&ListArray> for arrow_array::ListArray { + type Error = ArrayError; + + fn try_from(array: &ListArray) -> Result { use arrow_array::builder::*; fn build( array: &ListArray, @@ -570,7 +591,7 @@ impl From<&ListArray> for arrow_array::ListArray { } builder.finish() } - match &*array.value { + Ok(match &*array.value { ArrayImpl::Int16(a) => build(array, a, Int16Builder::with_capacity(a.len()), |b, v| { b.append_option(v) }), @@ -658,7 +679,21 @@ impl From<&ListArray> for arrow_array::ListArray { |b, v| b.append_option(v.map(|j| j.to_string())), ), ArrayImpl::Serial(_) => todo!("list of serial"), - ArrayImpl::Struct(_) => todo!("list of struct"), + ArrayImpl::Struct(a) => { + let values = Arc::new(arrow_array::StructArray::try_from(a)?); + arrow_array::ListArray::new( + Arc::new(Field::new("item", a.data_type().try_into()?, true)), + arrow_buffer::OffsetBuffer::new(arrow_buffer::ScalarBuffer::from( + array + .offsets() + .iter() + .map(|o| *o as i32) + .collect::>(), + )), + values, + Some(array.null_bitmap().into()), + ) + } ArrayImpl::List(_) => todo!("list of list"), ArrayImpl::Bytea(a) => build( array, @@ -666,7 +701,7 @@ impl From<&ListArray> for arrow_array::ListArray { BinaryBuilder::with_capacity(a.len(), a.data().len()), |b, v| b.append_option(v), ), - } + }) } } @@ -689,17 +724,14 @@ impl TryFrom<&StructArray> for arrow_array::StructArray { type Error = ArrayError; fn try_from(array: &StructArray) -> Result { - let struct_data_vector: Vec<(arrow_schema::FieldRef, arrow_array::ArrayRef)> = array - .fields() - .zip_eq_debug(array.data_type().as_struct().iter()) - .map(|(arr, (name, ty))| { - Ok(( - Field::new(name, ty.try_into().map_err(ArrayError::ToArrow)?, true).into(), - arr.as_ref().try_into()?, - )) - }) - .try_collect::<_, _, ArrayError>()?; - Ok(arrow_array::StructArray::from(struct_data_vector)) + Ok(arrow_array::StructArray::new( + array.data_type().as_struct().try_into()?, + array + .fields() + .map(|arr| arr.as_ref().try_into()) + .try_collect::<_, _, ArrayError>()?, + Some(array.null_bitmap().into()), + )) } } @@ -908,7 +940,7 @@ mod tests { #[test] fn list() { let array = ListArray::from_iter([None, Some(vec![0, -127, 127, 50]), Some(vec![0; 0])]); - let arrow = arrow_array::ListArray::from(&array); + let arrow = arrow_array::ListArray::try_from(&array).unwrap(); assert_eq!(ListArray::try_from(&arrow).unwrap(), array); } } diff --git a/src/common/src/array/list_array.rs b/src/common/src/array/list_array.rs index fcd1648f13b9..2ca83677d4d8 100644 --- a/src/common/src/array/list_array.rs +++ b/src/common/src/array/list_array.rs @@ -273,6 +273,17 @@ impl ListArray { value: Box::new(new_value), }) } + + /// Returns the offsets of this list. + /// + /// # Example + /// ```text + /// list = [[a, b, c], [], NULL, [d], [NULL, f]] + /// offsets = [0, 3, 3, 3, 4, 6] + /// ``` + pub fn offsets(&self) -> &[u32] { + &self.offsets + } } impl FromIterator> for ListArray diff --git a/src/common/src/types/mod.rs b/src/common/src/types/mod.rs index 40e331227c8d..edb0a7d01045 100644 --- a/src/common/src/types/mod.rs +++ b/src/common/src/types/mod.rs @@ -457,6 +457,7 @@ impl DataType { pub fn equals_datatype(&self, other: &DataType) -> bool { match (self, other) { (Self::Struct(s1), Self::Struct(s2)) => s1.equals_datatype(s2), + (Self::List(d1), Self::List(d2)) => d1.equals_datatype(d2), _ => self == other, } } diff --git a/src/expr/core/src/expr/expr_udf.rs b/src/expr/core/src/expr/expr_udf.rs index 05cd58b69380..2770f39a682a 100644 --- a/src/expr/core/src/expr/expr_udf.rs +++ b/src/expr/core/src/expr/expr_udf.rs @@ -19,7 +19,7 @@ use std::sync::{Arc, LazyLock, Mutex, Weak}; use arrow_schema::{Field, Fields, Schema}; use await_tree::InstrumentAwait; use cfg_or_panic::cfg_or_panic; -use risingwave_common::array::{ArrayRef, DataChunk}; +use risingwave_common::array::{ArrayError, ArrayRef, DataChunk}; use risingwave_common::row::OwnedRow; use risingwave_common::types::{DataType, Datum}; use risingwave_pb::expr::ExprNode; @@ -148,9 +148,9 @@ impl Build for UdfExpression { .map::, _>(|t| { Ok(Field::new( "", - DataType::from(t) - .try_into() - .map_err(risingwave_udf::Error::unsupported)?, + DataType::from(t).try_into().map_err(|e: ArrayError| { + risingwave_udf::Error::unsupported(e.to_string()) + })?, true, )) }) diff --git a/src/expr/core/src/table_function/user_defined.rs b/src/expr/core/src/table_function/user_defined.rs index cbb8682ba48a..385ee7a83435 100644 --- a/src/expr/core/src/table_function/user_defined.rs +++ b/src/expr/core/src/table_function/user_defined.rs @@ -18,7 +18,7 @@ use arrow_array::RecordBatch; use arrow_schema::{Field, Fields, Schema, SchemaRef}; use cfg_or_panic::cfg_or_panic; use futures_util::stream; -use risingwave_common::array::{DataChunk, I32Array}; +use risingwave_common::array::{ArrayError, DataChunk, I32Array}; use risingwave_common::bail; use risingwave_udf::ArrowFlightUdfClient; @@ -138,9 +138,9 @@ pub fn new_user_defined(prost: &PbTableFunction, chunk_size: usize) -> Result, _>(|t| { Ok(Field::new( "", - DataType::from(t) - .try_into() - .map_err(risingwave_udf::Error::unsupported)?, + DataType::from(t).try_into().map_err(|e: ArrayError| { + risingwave_udf::Error::unsupported(e.to_string()) + })?, true, )) }) diff --git a/src/udf/python/CHANGELOG.md b/src/udf/python/CHANGELOG.md index 3c788857a395..e035aab4ebb9 100644 --- a/src/udf/python/CHANGELOG.md +++ b/src/udf/python/CHANGELOG.md @@ -7,6 +7,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.0.12] - 2023-11-28 + +### Changed + +- Change the default struct field name to `f{i}`. + +### Fixed + +- Fix parsing nested struct type. + + ## [0.0.11] - 2023-11-06 ### Fixed diff --git a/src/udf/python/pyproject.toml b/src/udf/python/pyproject.toml index de9b245175f9..67d17db55dad 100644 --- a/src/udf/python/pyproject.toml +++ b/src/udf/python/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "risingwave" -version = "0.0.11" +version = "0.0.12" authors = [{ name = "RisingWave Labs" }] description = "RisingWave Python API" readme = "README.md" diff --git a/src/udf/python/risingwave/udf.py b/src/udf/python/risingwave/udf.py index ea7476a756d6..6443034c9147 100644 --- a/src/udf/python/risingwave/udf.py +++ b/src/udf/python/risingwave/udf.py @@ -459,12 +459,23 @@ def _string_to_data_type(type_str: str): elif type_str in ("BYTEA"): return pa.binary() elif type_str.startswith("STRUCT"): - # extract 'STRUCT' - type_list = type_str[6:].strip("<>") + # extract 'STRUCT, ...>' + type_list = type_str[7:-1] # strip "STRUCT<>" fields = [] - for type_str in type_list.split(","): - type_str = type_str.strip() - fields.append(pa.field("", _string_to_data_type(type_str))) + elements = [] + start = 0 + depth = 0 + for i, c in enumerate(type_list): + if c == "<": + depth += 1 + elif c == ">": + depth -= 1 + elif c == "," and depth == 0: + type_str = type_list[start:i].strip() + fields.append(pa.field("", _string_to_data_type(type_str))) + start = i + 1 + type_str = type_list[start:].strip() + fields.append(pa.field("", _string_to_data_type(type_str))) return pa.struct(fields) raise ValueError(f"Unsupported type: {type_str}") @@ -508,7 +519,7 @@ def _data_type_to_string(t: pa.DataType) -> str: return ( "STRUCT<" + ",".join( - f"field_{i} {_data_type_to_string(field.type)}" + f"f{i+1} {_data_type_to_string(field.type)}" for i, field in enumerate(t) ) + ">"