Skip to content

Commit

Permalink
fix(udf): support struct[] type in struct (#13689)
Browse files Browse the repository at this point in the history
Signed-off-by: Runji Wang <[email protected]>
Co-authored-by: xxchan <[email protected]>
  • Loading branch information
wangrunji0408 and xxchan authored Nov 29, 2023
1 parent 037b9a2 commit 40c020f
Show file tree
Hide file tree
Showing 18 changed files with 365 additions and 150 deletions.
8 changes: 7 additions & 1 deletion e2e_test/udf/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,9 @@ def jsonb_array_struct_identity(v: Tuple[List[Any], int]) -> Tuple[List[Any], in

ALL_TYPES = "BOOLEAN,SMALLINT,INT,BIGINT,FLOAT4,FLOAT8,DECIMAL,DATE,TIME,TIMESTAMP,INTERVAL,VARCHAR,BYTEA,JSONB".split(
","
)
) + [
"STRUCT<INT,INT>"
]


@udf(
Expand All @@ -139,6 +141,7 @@ def return_all(
varchar,
bytea,
jsonb,
struct,
):
return (
bool,
Expand All @@ -155,6 +158,7 @@ def return_all(
varchar,
bytea,
jsonb,
struct,
)


Expand All @@ -177,6 +181,7 @@ def return_all_arrays(
varchar,
bytea,
jsonb,
struct,
):
return (
bool,
Expand All @@ -193,6 +198,7 @@ def return_all_arrays(
varchar,
bytea,
jsonb,
struct,
)


Expand Down
22 changes: 12 additions & 10 deletions e2e_test/udf/udf.slt
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,13 @@ create function jsonb_array_struct_identity(struct<v jsonb[], len int>) returns
as jsonb_array_struct_identity using link 'http://localhost:8815';

statement ok
create function return_all(BOOLEAN,SMALLINT,INT,BIGINT,FLOAT4,FLOAT8,DECIMAL,DATE,TIME,TIMESTAMP,INTERVAL,VARCHAR,BYTEA,JSONB)
returns struct<bool BOOLEAN, i16 SMALLINT, i32 INT, i64 BIGINT, f32 FLOAT4, f64 FLOAT8, decimal DECIMAL, date DATE, time TIME, timestamp TIMESTAMP, interval INTERVAL, varchar VARCHAR, bytea BYTEA, jsonb JSONB>
create function return_all(BOOLEAN,SMALLINT,INT,BIGINT,FLOAT4,FLOAT8,DECIMAL,DATE,TIME,TIMESTAMP,INTERVAL,VARCHAR,BYTEA,JSONB,STRUCT<f1 INT,f2 INT>)
returns struct<bool BOOLEAN, i16 SMALLINT, i32 INT, i64 BIGINT, f32 FLOAT4, f64 FLOAT8, decimal DECIMAL, date DATE, time TIME, timestamp TIMESTAMP, interval INTERVAL, varchar VARCHAR, bytea BYTEA, jsonb JSONB, struct STRUCT<f1 INT,f2 INT>>
as return_all using link 'http://localhost:8815';

statement ok
create function return_all_arrays(BOOLEAN[],SMALLINT[],INT[],BIGINT[],FLOAT4[],FLOAT8[],DECIMAL[],DATE[],TIME[],TIMESTAMP[],INTERVAL[],VARCHAR[],BYTEA[],JSONB[])
returns struct<bool BOOLEAN[], i16 SMALLINT[], i32 INT[], i64 BIGINT[], f32 FLOAT4[], f64 FLOAT8[], decimal DECIMAL[], date DATE[], time TIME[], timestamp TIMESTAMP[], interval INTERVAL[], varchar VARCHAR[], bytea BYTEA[], jsonb JSONB[]>
create function return_all_arrays(BOOLEAN[],SMALLINT[],INT[],BIGINT[],FLOAT4[],FLOAT8[],DECIMAL[],DATE[],TIME[],TIMESTAMP[],INTERVAL[],VARCHAR[],BYTEA[],JSONB[],STRUCT<f1 INT,f2 INT>[])
returns struct<bool BOOLEAN[], i16 SMALLINT[], i32 INT[], i64 BIGINT[], f32 FLOAT4[], f64 FLOAT8[], decimal DECIMAL[], date DATE[], time TIME[], timestamp TIMESTAMP[], interval INTERVAL[], varchar VARCHAR[], bytea BYTEA[], jsonb JSONB[], struct STRUCT<f1 INT,f2 INT>[]>
as return_all_arrays using link 'http://localhost:8815';

query TTTTT rowsort
Expand All @@ -81,8 +81,8 @@ jsonb_access jsonb, integer jsonb (empty) http://localhost:8815
jsonb_array_identity jsonb[] jsonb[] (empty) http://localhost:8815
jsonb_array_struct_identity struct<v jsonb[], len integer> struct<v jsonb[], len integer> (empty) http://localhost:8815
jsonb_concat jsonb[] jsonb (empty) http://localhost:8815
return_all boolean, smallint, integer, bigint, real, double precision, numeric, date, time without time zone, timestamp without time zone, interval, character varying, bytea, jsonb struct<bool boolean, i16 smallint, i32 integer, i64 bigint, f32 real, f64 double precision, decimal numeric, date date, time time without time zone, timestamp timestamp without time zone, interval interval, varchar character varying, bytea bytea, jsonb jsonb> (empty) http://localhost:8815
return_all_arrays boolean[], smallint[], integer[], bigint[], real[], double precision[], numeric[], date[], time without time zone[], timestamp without time zone[], interval[], character varying[], bytea[], jsonb[] struct<bool boolean[], i16 smallint[], i32 integer[], i64 bigint[], f32 real[], f64 double precision[], decimal numeric[], date date[], time time without time zone[], timestamp timestamp without time zone[], interval interval[], varchar character varying[], bytea bytea[], jsonb jsonb[]> (empty) http://localhost:8815
return_all boolean, smallint, integer, bigint, real, double precision, numeric, date, time without time zone, timestamp without time zone, interval, character varying, bytea, jsonb, struct<f1 integer, f2 integer> struct<bool boolean, i16 smallint, i32 integer, i64 bigint, f32 real, f64 double precision, decimal numeric, date date, time time without time zone, timestamp timestamp without time zone, interval interval, varchar character varying, bytea bytea, jsonb jsonb, struct struct<f1 integer, f2 integer>> (empty) http://localhost:8815
return_all_arrays boolean[], smallint[], integer[], bigint[], real[], double precision[], numeric[], date[], time without time zone[], timestamp without time zone[], interval[], character varying[], bytea[], jsonb[], struct<f1 integer, f2 integer>[] struct<bool boolean[], i16 smallint[], i32 integer[], i64 bigint[], f32 real[], f64 double precision[], decimal numeric[], date date[], time time without time zone[], timestamp timestamp without time zone[], interval interval[], varchar character varying[], bytea bytea[], jsonb jsonb[], struct struct<f1 integer, f2 integer>[]> (empty) http://localhost:8815
series integer integer (empty) http://localhost:8815
split character varying struct<word character varying, length integer> (empty) http://localhost:8815

Expand Down Expand Up @@ -149,10 +149,11 @@ select (return_all(
interval '1 month 2 days 3 seconds',
'string',
'bytes'::bytea,
'{"key":1}'::jsonb
'{"key":1}'::jsonb,
row(1, 2)::struct<f1 int, f2 int>
)).*;
----
t 1 1 1 1 1 1234567890123456789012345678 2023-06-01 01:02:03.456789 2023-06-01 01:02:03.456789 1 mon 2 days 00:00:03 string \x6279746573 {"key": 1}
t 1 1 1 1 1 1234567890123456789012345678 2023-06-01 01:02:03.456789 2023-06-01 01:02:03.456789 1 mon 2 days 00:00:03 string \x6279746573 {"key": 1} (1,2)

query T
select (return_all_arrays(
Expand All @@ -169,10 +170,11 @@ select (return_all_arrays(
array[null, interval '1 month 2 days 3 seconds'],
array[null, 'string'],
array[null, 'bytes'::bytea],
array[null, '{"key":1}'::jsonb]
array[null, '{"key":1}'::jsonb],
array[null, row(1, 2)::struct<f1 int, f2 int>]
)).*;
----
{NULL,t} {NULL,1} {NULL,1} {NULL,1} {NULL,1} {NULL,1} {NULL,1234567890123456789012345678} {NULL,2023-06-01} {NULL,01:02:03.456789} {NULL,"2023-06-01 01:02:03.456789"} {NULL,"1 mon 2 days 00:00:03"} {NULL,string} {NULL,"\\x6279746573"} {NULL,"{\"key\": 1}"}
{NULL,t} {NULL,1} {NULL,1} {NULL,1} {NULL,1} {NULL,1} {NULL,1234567890123456789012345678} {NULL,2023-06-01} {NULL,01:02:03.456789} {NULL,"2023-06-01 01:02:03.456789"} {NULL,"1 mon 2 days 00:00:03"} {NULL,string} {NULL,"\\x6279746573"} {NULL,"{\"key\": 1}"} {NULL,"(1,2)"}

query I
select series(5);
Expand Down
4 changes: 2 additions & 2 deletions java/udf-example/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

<groupId>com.risingwave</groupId>
<artifactId>risingwave-udf-example</artifactId>
<version>0.1.0-SNAPSHOT</version>
<version>0.1.1-SNAPSHOT</version>

<name>udf-example</name>
<url>https://docs.risingwave.com/docs/current/udf-java</url>
Expand All @@ -31,7 +31,7 @@
<dependency>
<groupId>com.risingwave</groupId>
<artifactId>risingwave-udf</artifactId>
<version>0.1.0-SNAPSHOT</version>
<version>0.1.1-SNAPSHOT</version>
</dependency>
<dependency>
<groupId>com.google.code.gson</groupId>
Expand Down
28 changes: 26 additions & 2 deletions java/udf-example/src/main/java/com/example/UdfExample.java
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,16 @@ public static class Row {
public String str;
public byte[] bytes;
public @DataTypeHint("JSONB") String jsonb;
public Struct struct;
}

public static class Struct {
public Integer f1;
public Integer f2;

public String toString() {
return String.format("(%d, %d)", f1, f2);
}
}

public Row eval(
Expand All @@ -206,7 +216,8 @@ public Row eval(
PeriodDuration interval,
String str,
byte[] bytes,
@DataTypeHint("JSONB") String jsonb) {
@DataTypeHint("JSONB") String jsonb,
Struct struct) {
var row = new Row();
row.bool = bool;
row.i16 = i16;
Expand All @@ -222,6 +233,7 @@ public Row eval(
row.str = str;
row.bytes = bytes;
row.jsonb = jsonb;
row.struct = struct;
return row;
}
}
Expand All @@ -242,6 +254,16 @@ public static class Row {
public String[] str;
public byte[][] bytes;
public @DataTypeHint("JSONB[]") String[] jsonb;
public Struct[] struct;
}

public static class Struct {
public Integer f1;
public Integer f2;

public String toString() {
return String.format("(%d, %d)", f1, f2);
}
}

public Row eval(
Expand All @@ -258,7 +280,8 @@ public Row eval(
PeriodDuration[] interval,
String[] str,
byte[][] bytes,
@DataTypeHint("JSONB[]") String[] jsonb) {
@DataTypeHint("JSONB[]") String[] jsonb,
Struct[] struct) {
var row = new Row();
row.bool = bool;
row.i16 = i16;
Expand All @@ -274,6 +297,7 @@ public Row eval(
row.str = str;
row.bytes = bytes;
row.jsonb = jsonb;
row.struct = struct;
return row;
}
}
Expand Down
22 changes: 22 additions & 0 deletions java/udf/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Changelog

All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]

## [0.1.1] - 2023-11-28

### Added

- Support struct in struct and struct[] in struct.

### Changed

- Bump Arrow version to 14.

## [0.1.0] - 2023-09-01

- Initial release.
8 changes: 4 additions & 4 deletions java/udf/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
<groupId>com.risingwave</groupId>
<artifactId>risingwave-udf</artifactId>
<packaging>jar</packaging>
<version>0.1.0-SNAPSHOT</version>
<version>0.1.1-SNAPSHOT</version>

<parent>
<artifactId>risingwave-java-root</artifactId>
Expand All @@ -28,12 +28,12 @@
<dependency>
<groupId>org.apache.arrow</groupId>
<artifactId>arrow-vector</artifactId>
<version>13.0.0</version>
<version>14.0.0</version>
</dependency>
<dependency>
<groupId>org.apache.arrow</groupId>
<artifactId>flight-core</artifactId>
<version>13.0.0</version>
<version>14.0.0</version>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
Expand All @@ -55,4 +55,4 @@
</extension>
</extensions>
</build>
</project>
</project>
61 changes: 54 additions & 7 deletions java/udf/src/main/java/com/risingwave/functions/TypeUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package com.risingwave.functions;

import java.lang.invoke.MethodHandles;
import java.lang.reflect.Array;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.ParameterizedType;
Expand Down Expand Up @@ -133,7 +134,7 @@ static Field classToField(Class<?> param, DataTypeHint hint, String name) {
var subhint = field.getAnnotation(DataTypeHint.class);
fields.add(classToField(field.getType(), subhint, field.getName()));
}
return new Field("", FieldType.nullable(new ArrowType.Struct()), fields);
return new Field(name, FieldType.nullable(new ArrowType.Struct()), fields);
// TODO: more types
// throw new IllegalArgumentException("Unsupported type: " + param);
}
Expand Down Expand Up @@ -359,22 +360,57 @@ static void fillVector(FieldVector fieldVector, Object[] values) {
} else if (vector.getDataVector() instanceof VarBinaryVector) {
TypeUtils.<VarBinaryVector, byte[]>fillListVector(
vector, values, (vec, i, val) -> vec.set(i, val));
} else if (vector.getDataVector() instanceof StructVector) {
// flatten the `values`
var flattenLength = 0;
for (int i = 0; i < values.length; i++) {
if (values[i] == null) {
continue;
}
var len = Array.getLength(values[i]);
vector.startNewValue(i);
vector.endValue(i, len);
flattenLength += len;
}
var flattenValues = new Object[flattenLength];
var ii = 0;
for (var list : values) {
if (list == null) {
continue;
}
var length = Array.getLength(list);
for (int i = 0; i < length; i++) {
flattenValues[ii++] = Array.get(list, i);
}
}
fillVector(vector.getDataVector(), flattenValues);
} else {
throw new IllegalArgumentException("Unsupported type: " + fieldVector.getClass());
throw new IllegalArgumentException(
"Unsupported type: " + vector.getDataVector().getClass());
}
} else if (fieldVector instanceof StructVector) {
var vector = (StructVector) fieldVector;
vector.allocateNew();
var lookup = MethodHandles.lookup();
// get class of the first non-null value
Class<?> valueClass = null;
for (int i = 0; i < values.length; i++) {
if (values[i] != null) {
valueClass = values[i].getClass();
break;
}
}
for (var field : vector.getField().getChildren()) {
// extract field from values
var subvalues = new Object[values.length];
if (values.length != 0) {
if (valueClass != null) {
try {
var javaField = values[0].getClass().getDeclaredField(field.getName());
var javaField = valueClass.getDeclaredField(field.getName());
var varHandle = lookup.unreflectVarHandle(javaField);
for (int i = 0; i < values.length; i++) {
subvalues[i] = varHandle.get(values[i]);
if (values[i] != null) {
subvalues[i] = varHandle.get(values[i]);
}
}
} catch (NoSuchFieldException | IllegalAccessException e) {
throw new RuntimeException(e);
Expand All @@ -384,7 +420,9 @@ static void fillVector(FieldVector fieldVector, Object[] values) {
fillVector(subvector, subvalues);
}
for (int i = 0; i < values.length; i++) {
vector.setIndexDefined(i);
if (values[i] != null) {
vector.setIndexDefined(i);
}
}
} else {
throw new IllegalArgumentException("Unsupported type: " + fieldVector.getClass());
Expand Down Expand Up @@ -482,8 +520,17 @@ static Function<Object, Object> processFunc0(Field field, Class<?> targetClass)
return obj -> ((List<?>) obj).stream().map(subfunc).toArray(String[]::new);
} else if (subfield.getType() instanceof ArrowType.Binary) {
return obj -> ((List<?>) obj).stream().map(subfunc).toArray(byte[][]::new);
} else if (subfield.getType() instanceof ArrowType.Struct) {
return obj -> {
var list = (List<?>) obj;
Object array = Array.newInstance(targetClass.getComponentType(), list.size());
for (int i = 0; i < list.size(); i++) {
Array.set(array, i, subfunc.apply(list.get(i)));
}
return array;
};
}
throw new IllegalArgumentException("Unsupported type: " + field.getType());
throw new IllegalArgumentException("Unsupported type: " + subfield.getType());
} else if (field.getType() instanceof ArrowType.Struct) {
// object is org.apache.arrow.vector.util.JsonStringHashMap
var subfields = field.getChildren();
Expand Down
Loading

0 comments on commit 40c020f

Please sign in to comment.