Skip to content

Commit

Permalink
fix(udf): fix index-out-of-bound error when string or string list is …
Browse files Browse the repository at this point in the history
…large. (#13781)

Signed-off-by: Runji Wang <[email protected]>
  • Loading branch information
wangrunji0408 authored Dec 4, 2023
1 parent e41b62f commit 2461651
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 107 deletions.
22 changes: 22 additions & 0 deletions e2e_test/udf/udf.slt
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,28 @@ select (return_all_arrays(
----
{NULL,t} {NULL,1} {NULL,1} {NULL,1} {NULL,1} {NULL,1} {NULL,12345678901234567890.12345678} {NULL,2023-06-01} {NULL,01:02:03.456789} {NULL,"2023-06-01 01:02:03.456789"} {NULL,"1 mon 2 days 00:00:03"} {NULL,string} {NULL,"\\x6279746573"} {NULL,"{\"key\": 1}"} {NULL,"(1,2)"}

# test large string output
query I
select length((return_all(
null::boolean,
null::smallint,
null::int,
null::bigint,
null::float4,
null::float8,
null::decimal,
null::date,
null::time,
null::timestamp,
null::interval,
repeat('a', 100000)::varchar,
repeat('a', 100000)::bytea,
null::jsonb,
null::struct<f1 int, f2 int>
)).varchar);
----
100000

query I
select series(5);
----
Expand Down
2 changes: 1 addition & 1 deletion java/udf-example/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
<dependency>
<groupId>com.risingwave</groupId>
<artifactId>risingwave-udf</artifactId>
<version>0.1.1-SNAPSHOT</version>
<version>0.1.2-SNAPSHOT</version>
</dependency>
<dependency>
<groupId>com.google.code.gson</groupId>
Expand Down
8 changes: 7 additions & 1 deletion java/udf/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

## [0.1.1] - 2023-12-01
## [0.1.2] - 2023-12-04

### Fixed

- Fix index-out-of-bound error when string or string list is large.

## [0.1.1] - 2023-12-03

### Added

Expand Down
2 changes: 1 addition & 1 deletion java/udf/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
<groupId>com.risingwave</groupId>
<artifactId>risingwave-udf</artifactId>
<packaging>jar</packaging>
<version>0.1.1-SNAPSHOT</version>
<version>0.1.2-SNAPSHOT</version>

<parent>
<artifactId>risingwave-java-root</artifactId>
Expand Down
146 changes: 42 additions & 104 deletions java/udf/src/main/java/com/risingwave/functions/TypeUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -286,23 +286,41 @@ static void fillVector(FieldVector fieldVector, Object[] values) {
}
} else if (fieldVector instanceof VarCharVector) {
var vector = (VarCharVector) fieldVector;
vector.allocateNew(values.length);
int totalBytes = 0;
for (int i = 0; i < values.length; i++) {
if (values[i] != null) {
totalBytes += ((String) values[i]).length();
}
}
vector.allocateNew(totalBytes, values.length);
for (int i = 0; i < values.length; i++) {
if (values[i] != null) {
vector.set(i, ((String) values[i]).getBytes());
}
}
} else if (fieldVector instanceof LargeVarCharVector) {
var vector = (LargeVarCharVector) fieldVector;
vector.allocateNew(values.length);
int totalBytes = 0;
for (int i = 0; i < values.length; i++) {
if (values[i] != null) {
totalBytes += ((String) values[i]).length();
}
}
vector.allocateNew(totalBytes, values.length);
for (int i = 0; i < values.length; i++) {
if (values[i] != null) {
vector.set(i, ((String) values[i]).getBytes());
}
}
} else if (fieldVector instanceof VarBinaryVector) {
var vector = (VarBinaryVector) fieldVector;
vector.allocateNew(values.length);
int totalBytes = 0;
for (int i = 0; i < values.length; i++) {
if (values[i] != null) {
totalBytes += ((byte[]) values[i]).length;
}
}
vector.allocateNew(totalBytes, values.length);
for (int i = 0; i < values.length; i++) {
if (values[i] != null) {
vector.set(i, (byte[]) values[i]);
Expand All @@ -311,83 +329,30 @@ static void fillVector(FieldVector fieldVector, Object[] values) {
} else if (fieldVector instanceof ListVector) {
var vector = (ListVector) fieldVector;
vector.allocateNew();
if (vector.getDataVector() instanceof BitVector) {
TypeUtils.<BitVector, Boolean>fillListVector(
vector, values, (vec, i, val) -> vec.set(i, val ? 1 : 0));
} else if (vector.getDataVector() instanceof SmallIntVector) {
TypeUtils.<SmallIntVector, Short>fillListVector(
vector, values, (vec, i, val) -> vec.set(i, val));
} else if (vector.getDataVector() instanceof IntVector) {
TypeUtils.<IntVector, Integer>fillListVector(
vector, values, (vec, i, val) -> vec.set(i, val));
} else if (vector.getDataVector() instanceof BigIntVector) {
TypeUtils.<BigIntVector, Long>fillListVector(
vector, values, (vec, i, val) -> vec.set(i, val));
} else if (vector.getDataVector() instanceof Float4Vector) {
TypeUtils.<Float4Vector, Float>fillListVector(
vector, values, (vec, i, val) -> vec.set(i, val));
} else if (vector.getDataVector() instanceof Float8Vector) {
TypeUtils.<Float8Vector, Double>fillListVector(
vector, values, (vec, i, val) -> vec.set(i, val));
} else if (vector.getDataVector() instanceof LargeVarBinaryVector) {
TypeUtils.<LargeVarBinaryVector, BigDecimal>fillListVector(
vector, values, (vec, i, val) -> vec.set(i, val.toString().getBytes()));
} else if (vector.getDataVector() instanceof DateDayVector) {
TypeUtils.<DateDayVector, LocalDate>fillListVector(
vector, values, (vec, i, val) -> vec.set(i, (int) val.toEpochDay()));
} else if (vector.getDataVector() instanceof TimeMicroVector) {
TypeUtils.<TimeMicroVector, LocalTime>fillListVector(
vector, values, (vec, i, val) -> vec.set(i, val.toNanoOfDay() / 1000));
} else if (vector.getDataVector() instanceof TimeStampMicroVector) {
TypeUtils.<TimeStampMicroVector, LocalDateTime>fillListVector(
vector, values, (vec, i, val) -> vec.set(i, timestampToMicros(val)));
} else if (vector.getDataVector() instanceof IntervalMonthDayNanoVector) {
TypeUtils.<IntervalMonthDayNanoVector, PeriodDuration>fillListVector(
vector,
values,
(vec, i, val) -> {
var months = (int) val.getPeriod().toTotalMonths();
var days = val.getPeriod().getDays();
var nanos = val.getDuration().toNanos();
vec.set(i, months, days, nanos);
});
} else if (vector.getDataVector() instanceof VarCharVector) {
TypeUtils.<VarCharVector, String>fillListVector(
vector, values, (vec, i, val) -> vec.set(i, val.getBytes()));
} else if (vector.getDataVector() instanceof LargeVarCharVector) {
TypeUtils.<LargeVarCharVector, String>fillListVector(
vector, values, (vec, i, val) -> vec.set(i, val.getBytes()));
} else if (vector.getDataVector() instanceof VarBinaryVector) {
TypeUtils.<VarBinaryVector, byte[]>fillListVector(
vector, values, (vec, i, val) -> vec.set(i, val));
} else if (vector.getDataVector() instanceof StructVector) {
// flatten the `values`
var flattenLength = 0;
for (int i = 0; i < values.length; i++) {
if (values[i] == null) {
continue;
}
var len = Array.getLength(values[i]);
vector.startNewValue(i);
vector.endValue(i, len);
flattenLength += len;
// flatten the `values`
var flattenLength = 0;
for (int i = 0; i < values.length; i++) {
if (values[i] == null) {
continue;
}
var flattenValues = new Object[flattenLength];
var ii = 0;
for (var list : values) {
if (list == null) {
continue;
}
var length = Array.getLength(list);
for (int i = 0; i < length; i++) {
flattenValues[ii++] = Array.get(list, i);
}
var len = Array.getLength(values[i]);
vector.startNewValue(i);
vector.endValue(i, len);
flattenLength += len;
}
var flattenValues = new Object[flattenLength];
var ii = 0;
for (var list : values) {
if (list == null) {
continue;
}
var length = Array.getLength(list);
for (int i = 0; i < length; i++) {
flattenValues[ii++] = Array.get(list, i);
}
fillVector(vector.getDataVector(), flattenValues);
} else {
throw new IllegalArgumentException(
"Unsupported type: " + vector.getDataVector().getClass());
}
// fill the inner vector
fillVector(vector.getDataVector(), flattenValues);
} else if (fieldVector instanceof StructVector) {
var vector = (StructVector) fieldVector;
vector.allocateNew();
Expand Down Expand Up @@ -430,33 +395,6 @@ static void fillVector(FieldVector fieldVector, Object[] values) {
fieldVector.setValueCount(values.length);
}

@FunctionalInterface
interface TriFunction<T, U, V> {
void apply(T t, U u, V v);
}

@SuppressWarnings("unchecked")
static <V extends FieldVector, T> void fillListVector(
ListVector vector, Object[] values, TriFunction<V, Integer, T> set) {
var innerVector = (V) vector.getDataVector();
int ii = 0;
for (int i = 0; i < values.length; i++) {
var array = (T[]) values[i];
if (array == null) {
continue;
}
vector.startNewValue(i);
for (T v : array) {
if (v == null) {
innerVector.setNull(ii++);
} else {
set.apply(innerVector, ii++, v);
}
}
vector.endValue(i, array.length);
}
}

static long timestampToMicros(LocalDateTime timestamp) {
var date = timestamp.toLocalDate().toEpochDay();
var time = timestamp.toLocalTime().toNanoOfDay();
Expand Down

0 comments on commit 2461651

Please sign in to comment.