Skip to content

Commit

Permalink
fix: cherry-pick several udf fixes (#13852)
Browse files Browse the repository at this point in the history
Signed-off-by: Runji Wang <[email protected]>
  • Loading branch information
wangrunji0408 authored Dec 7, 2023
1 parent 4ec02b6 commit bdb7aab
Show file tree
Hide file tree
Showing 19 changed files with 257 additions and 211 deletions.
1 change: 1 addition & 0 deletions ci/scripts/build-other.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ cd java
mvn -B package -Dmaven.test.skip=true
mvn -B install -Dmaven.test.skip=true --pl java-binding-integration-test --am
mvn dependency:copy-dependencies --no-transfer-progress --pl java-binding-integration-test
mvn -B test --pl udf
cd ..

echo "--- Build rust binary for java binding integration test"
Expand Down
12 changes: 12 additions & 0 deletions e2e_test/udf/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,16 @@ def hex_to_dec(hex: Optional[str]) -> Optional[Decimal]:
return dec


@udf(input_types=["FLOAT8"], result_type="DECIMAL")
def float_to_decimal(f: float) -> Decimal:
return Decimal(f)


@udf(input_types=["DECIMAL", "DECIMAL"], result_type="DECIMAL")
def decimal_add(a: Decimal, b: Decimal) -> Decimal:
return a + b


@udf(input_types=["VARCHAR[]", "INT"], result_type="VARCHAR")
def array_access(list: List[str], idx: int) -> Optional[str]:
if idx == 0 or idx > len(list):
Expand Down Expand Up @@ -212,6 +222,8 @@ def return_all_arrays(
server.add_function(split)
server.add_function(extract_tcp_info)
server.add_function(hex_to_dec)
server.add_function(float_to_decimal)
server.add_function(decimal_add)
server.add_function(array_access)
server.add_function(jsonb_access)
server.add_function(jsonb_concat)
Expand Down
54 changes: 50 additions & 4 deletions e2e_test/udf/udf.slt
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,15 @@ create function series(int) returns table (x int) as series using link 'http://l
statement ok
create function split(varchar) returns table (word varchar, length int) as split using link 'http://localhost:8815';

statement ok
create function float_to_decimal(float8) returns decimal as float_to_decimal using link 'http://localhost:8815';

statement ok
create function hex_to_dec(varchar) returns decimal as hex_to_dec using link 'http://localhost:8815';

statement ok
create function decimal_add(decimal, decimal) returns decimal as decimal_add using link 'http://localhost:8815';

statement ok
create function array_access(varchar[], int) returns varchar as array_access using link 'http://localhost:8815';

Expand Down Expand Up @@ -72,7 +78,9 @@ query TTTTT rowsort
show functions
----
array_access character varying[], integer character varying (empty) http://localhost:8815
decimal_add numeric, numeric numeric (empty) http://localhost:8815
extract_tcp_info bytea struct<src_ip character varying, dst_ip character varying, src_port smallint, dst_port smallint> (empty) http://localhost:8815
float_to_decimal double precision numeric (empty) http://localhost:8815
gcd integer, integer integer (empty) http://localhost:8815
gcd integer, integer, integer integer (empty) http://localhost:8815
hex_to_dec character varying numeric (empty) http://localhost:8815
Expand Down Expand Up @@ -106,6 +114,16 @@ select hex_to_dec('000000000000000000000000000000000000000000c0f6346334241a61f90
----
233276425899864771438119478

query I
select float_to_decimal('-1e-10'::float8);
----
-0.0000000001000000000000000036

query R
select decimal_add(1.11, 2.22);
----
3.33

query T
select array_access(ARRAY['a', 'b', 'c'], 2);
----
Expand Down Expand Up @@ -142,7 +160,7 @@ select (return_all(
1 ::bigint,
1 ::float4,
1 ::float8,
1234567890123456789012345678 ::decimal,
12345678901234567890.12345678 ::decimal,
date '2023-06-01',
time '01:02:03.456789',
timestamp '2023-06-01 01:02:03.456789',
Expand All @@ -153,7 +171,7 @@ select (return_all(
row(1, 2)::struct<f1 int, f2 int>
)).*;
----
t 1 1 1 1 1 1234567890123456789012345678 2023-06-01 01:02:03.456789 2023-06-01 01:02:03.456789 1 mon 2 days 00:00:03 string \x6279746573 {"key": 1} (1,2)
t 1 1 1 1 1 12345678901234567890.12345678 2023-06-01 01:02:03.456789 2023-06-01 01:02:03.456789 1 mon 2 days 00:00:03 string \x6279746573 {"key": 1} (1,2)

query T
select (return_all_arrays(
Expand All @@ -163,7 +181,7 @@ select (return_all_arrays(
array[null, 1 ::bigint],
array[null, 1 ::float4],
array[null, 1 ::float8],
array[null, 1234567890123456789012345678 ::decimal],
array[null, 12345678901234567890.12345678 ::decimal],
array[null, date '2023-06-01'],
array[null, time '01:02:03.456789'],
array[null, timestamp '2023-06-01 01:02:03.456789'],
Expand All @@ -174,7 +192,29 @@ select (return_all_arrays(
array[null, row(1, 2)::struct<f1 int, f2 int>]
)).*;
----
{NULL,t} {NULL,1} {NULL,1} {NULL,1} {NULL,1} {NULL,1} {NULL,1234567890123456789012345678} {NULL,2023-06-01} {NULL,01:02:03.456789} {NULL,"2023-06-01 01:02:03.456789"} {NULL,"1 mon 2 days 00:00:03"} {NULL,string} {NULL,"\\x6279746573"} {NULL,"{\"key\": 1}"} {NULL,"(1,2)"}
{NULL,t} {NULL,1} {NULL,1} {NULL,1} {NULL,1} {NULL,1} {NULL,12345678901234567890.12345678} {NULL,2023-06-01} {NULL,01:02:03.456789} {NULL,"2023-06-01 01:02:03.456789"} {NULL,"1 mon 2 days 00:00:03"} {NULL,string} {NULL,"\\x6279746573"} {NULL,"{\"key\": 1}"} {NULL,"(1,2)"}

# test large string output
query I
select length((return_all(
null::boolean,
null::smallint,
null::int,
null::bigint,
null::float4,
null::float8,
null::decimal,
null::date,
null::time,
null::timestamp,
null::interval,
repeat('a', 100000)::varchar,
repeat('a', 100000)::bytea,
null::jsonb,
null::struct<f1 int, f2 int>
)).varchar);
----
100000

query I
select series(5);
Expand Down Expand Up @@ -312,9 +352,15 @@ drop function series;
statement ok
drop function split;

statement ok
drop function float_to_decimal;

statement ok
drop function hex_to_dec;

statement ok
drop function decimal_add;

statement ok
drop function array_access;

Expand Down
2 changes: 1 addition & 1 deletion java/udf-example/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
<dependency>
<groupId>com.risingwave</groupId>
<artifactId>risingwave-udf</artifactId>
<version>0.1.1-SNAPSHOT</version>
<version>0.1.3-SNAPSHOT</version>
</dependency>
<dependency>
<groupId>com.google.code.gson</groupId>
Expand Down
14 changes: 14 additions & 0 deletions java/udf-example/src/main/java/com/example/UdfExample.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,13 @@ public class UdfExample {
public static void main(String[] args) throws IOException {
try (var server = new UdfServer("0.0.0.0", 8815)) {
server.addFunction("int_42", new Int42());
server.addFunction("float_to_decimal", new FloatToDecimal());
server.addFunction("sleep", new Sleep());
server.addFunction("gcd", new Gcd());
server.addFunction("gcd3", new Gcd3());
server.addFunction("extract_tcp_info", new ExtractTcpInfo());
server.addFunction("hex_to_dec", new HexToDec());
server.addFunction("decimal_add", new DecimalAdd());
server.addFunction("array_access", new ArrayAccess());
server.addFunction("jsonb_access", new JsonbAccess());
server.addFunction("jsonb_concat", new JsonbConcat());
Expand All @@ -63,6 +65,12 @@ public int eval() {
}
}

public static class FloatToDecimal implements ScalarFunction {
public BigDecimal eval(Double f) {
return new BigDecimal(f);
}
}

public static class Sleep implements ScalarFunction {
public int eval(int x) {
try {
Expand Down Expand Up @@ -126,6 +134,12 @@ public BigDecimal eval(String hex) {
}
}

public static class DecimalAdd implements ScalarFunction {
public BigDecimal eval(BigDecimal a, BigDecimal b) {
return a.add(b);
}
}

public static class ArrayAccess implements ScalarFunction {
public String eval(String[] array, int index) {
return array[index - 1];
Expand Down
19 changes: 18 additions & 1 deletion java/udf/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,20 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

## [0.1.1] - 2023-11-28
## [0.1.3] - 2023-12-06

### Fixed

- Fix decimal type output.

## [0.1.2] - 2023-12-04

### Fixed

- Fix index-out-of-bound error when string or string list is large.
- Fix memory leak.

## [0.1.1] - 2023-12-03

### Added

Expand All @@ -17,6 +30,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Bump Arrow version to 14.

### Fixed

- Fix unconstrained decimal type.

## [0.1.0] - 2023-09-01

- Initial release.
2 changes: 1 addition & 1 deletion java/udf/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
<groupId>com.risingwave</groupId>
<artifactId>risingwave-udf</artifactId>
<packaging>jar</packaging>
<version>0.1.1-SNAPSHOT</version>
<version>0.1.3-SNAPSHOT</version>

<parent>
<artifactId>risingwave-java-root</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,8 @@ class ScalarFunctionBatch extends UserDefinedFunctionBatch {
MethodHandle methodHandle;
Function<Object, Object>[] processInputs;

ScalarFunctionBatch(ScalarFunction function, BufferAllocator allocator) {
ScalarFunctionBatch(ScalarFunction function) {
this.function = function;
this.allocator = allocator;
var method = Reflection.getEvalMethod(function);
this.methodHandle = Reflection.getMethodHandle(method);
this.inputSchema = TypeUtils.methodToInputSchema(method);
Expand All @@ -38,7 +37,7 @@ class ScalarFunctionBatch extends UserDefinedFunctionBatch {
}

@Override
Iterator<VectorSchemaRoot> evalBatch(VectorSchemaRoot batch) {
Iterator<VectorSchemaRoot> evalBatch(VectorSchemaRoot batch, BufferAllocator allocator) {
var row = new Object[batch.getSchema().getFields().size() + 1];
row[0] = this.function;
var outputValues = new Object[batch.getRowCount()];
Expand All @@ -55,7 +54,7 @@ Iterator<VectorSchemaRoot> evalBatch(VectorSchemaRoot batch) {
}
var outputVector =
TypeUtils.createVector(
this.outputSchema.getFields().get(0), this.allocator, outputValues);
this.outputSchema.getFields().get(0), allocator, outputValues);
var outputBatch = VectorSchemaRoot.of(outputVector);
return Collections.singleton(outputBatch).iterator();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,8 @@ class TableFunctionBatch extends UserDefinedFunctionBatch {
Function<Object, Object>[] processInputs;
int chunkSize = 1024;

TableFunctionBatch(TableFunction function, BufferAllocator allocator) {
TableFunctionBatch(TableFunction function) {
this.function = function;
this.allocator = allocator;
var method = Reflection.getEvalMethod(function);
this.methodHandle = Reflection.getMethodHandle(method);
this.inputSchema = TypeUtils.methodToInputSchema(method);
Expand All @@ -39,7 +38,7 @@ class TableFunctionBatch extends UserDefinedFunctionBatch {
}

@Override
Iterator<VectorSchemaRoot> evalBatch(VectorSchemaRoot batch) {
Iterator<VectorSchemaRoot> evalBatch(VectorSchemaRoot batch, BufferAllocator allocator) {
var outputs = new ArrayList<VectorSchemaRoot>();
var row = new Object[batch.getSchema().getFields().size() + 1];
row[0] = this.function;
Expand All @@ -49,10 +48,9 @@ Iterator<VectorSchemaRoot> evalBatch(VectorSchemaRoot batch) {
() -> {
var fields = this.outputSchema.getFields();
var indexVector =
TypeUtils.createVector(
fields.get(0), this.allocator, indexes.toArray());
TypeUtils.createVector(fields.get(0), allocator, indexes.toArray());
var valueVector =
TypeUtils.createVector(fields.get(1), this.allocator, values.toArray());
TypeUtils.createVector(fields.get(1), allocator, values.toArray());
indexes.clear();
values.clear();
var outputBatch = VectorSchemaRoot.of(indexVector, valueVector);
Expand Down
Loading

0 comments on commit bdb7aab

Please sign in to comment.