Skip to content

Commit

Permalink
fix(udf): fix memory leak in Java UDF (risingwavelabs#13789)
Browse files Browse the repository at this point in the history
Signed-off-by: Runji Wang <[email protected]>
  • Loading branch information
wangrunji0408 authored Dec 4, 2023
1 parent 4a51bd7 commit 4297914
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 22 deletions.
1 change: 1 addition & 0 deletions java/udf/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed

- Fix index-out-of-bound error when string or string list is large.
- Fix memory leak.

## [0.1.1] - 2023-12-03

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,8 @@ class ScalarFunctionBatch extends UserDefinedFunctionBatch {
MethodHandle methodHandle;
Function<Object, Object>[] processInputs;

ScalarFunctionBatch(ScalarFunction function, BufferAllocator allocator) {
ScalarFunctionBatch(ScalarFunction function) {
this.function = function;
this.allocator = allocator;
var method = Reflection.getEvalMethod(function);
this.methodHandle = Reflection.getMethodHandle(method);
this.inputSchema = TypeUtils.methodToInputSchema(method);
Expand All @@ -38,7 +37,7 @@ class ScalarFunctionBatch extends UserDefinedFunctionBatch {
}

@Override
Iterator<VectorSchemaRoot> evalBatch(VectorSchemaRoot batch) {
Iterator<VectorSchemaRoot> evalBatch(VectorSchemaRoot batch, BufferAllocator allocator) {
var row = new Object[batch.getSchema().getFields().size() + 1];
row[0] = this.function;
var outputValues = new Object[batch.getRowCount()];
Expand All @@ -55,7 +54,7 @@ Iterator<VectorSchemaRoot> evalBatch(VectorSchemaRoot batch) {
}
var outputVector =
TypeUtils.createVector(
this.outputSchema.getFields().get(0), this.allocator, outputValues);
this.outputSchema.getFields().get(0), allocator, outputValues);
var outputBatch = VectorSchemaRoot.of(outputVector);
return Collections.singleton(outputBatch).iterator();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,8 @@ class TableFunctionBatch extends UserDefinedFunctionBatch {
Function<Object, Object>[] processInputs;
int chunkSize = 1024;

TableFunctionBatch(TableFunction function, BufferAllocator allocator) {
TableFunctionBatch(TableFunction function) {
this.function = function;
this.allocator = allocator;
var method = Reflection.getEvalMethod(function);
this.methodHandle = Reflection.getMethodHandle(method);
this.inputSchema = TypeUtils.methodToInputSchema(method);
Expand All @@ -39,7 +38,7 @@ class TableFunctionBatch extends UserDefinedFunctionBatch {
}

@Override
Iterator<VectorSchemaRoot> evalBatch(VectorSchemaRoot batch) {
Iterator<VectorSchemaRoot> evalBatch(VectorSchemaRoot batch, BufferAllocator allocator) {
var outputs = new ArrayList<VectorSchemaRoot>();
var row = new Object[batch.getSchema().getFields().size() + 1];
row[0] = this.function;
Expand All @@ -49,10 +48,9 @@ Iterator<VectorSchemaRoot> evalBatch(VectorSchemaRoot batch) {
() -> {
var fields = this.outputSchema.getFields();
var indexVector =
TypeUtils.createVector(
fields.get(0), this.allocator, indexes.toArray());
TypeUtils.createVector(fields.get(0), allocator, indexes.toArray());
var valueVector =
TypeUtils.createVector(fields.get(1), this.allocator, values.toArray());
TypeUtils.createVector(fields.get(1), allocator, values.toArray());
indexes.clear();
values.clear();
var outputBatch = VectorSchemaRoot.of(indexVector, valueVector);
Expand Down
25 changes: 15 additions & 10 deletions java/udf/src/main/java/com/risingwave/functions/UdfProducer.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ class UdfProducer extends NoOpFlightProducer {
void addFunction(String name, UserDefinedFunction function) throws IllegalArgumentException {
UserDefinedFunctionBatch udf;
if (function instanceof ScalarFunction) {
udf = new ScalarFunctionBatch((ScalarFunction) function, this.allocator);
udf = new ScalarFunctionBatch((ScalarFunction) function);
} else if (function instanceof TableFunction) {
udf = new TableFunctionBatch((TableFunction) function, this.allocator);
udf = new TableFunctionBatch((TableFunction) function);
} else {
throw new IllegalArgumentException(
"Unknown function type: " + function.getClass().getName());
Expand Down Expand Up @@ -76,21 +76,26 @@ public FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor

@Override
public void doExchange(CallContext context, FlightStream reader, ServerStreamListener writer) {
try {
try (var allocator = this.allocator.newChildAllocator("exchange", 0, Long.MAX_VALUE)) {
var functionName = reader.getDescriptor().getPath().get(0);
logger.debug("call function: " + functionName);

var udf = this.functions.get(functionName);
try (var root = VectorSchemaRoot.create(udf.getOutputSchema(), this.allocator)) {
try (var root = VectorSchemaRoot.create(udf.getOutputSchema(), allocator)) {
var loader = new VectorLoader(root);
writer.start(root);
while (reader.next()) {
var outputBatches = udf.evalBatch(reader.getRoot());
while (outputBatches.hasNext()) {
var outputRoot = outputBatches.next();
var unloader = new VectorUnloader(outputRoot);
loader.load(unloader.getRecordBatch());
writer.putNext();
try (var input = reader.getRoot()) {
var outputBatches = udf.evalBatch(input, allocator);
while (outputBatches.hasNext()) {
try (var outputRoot = outputBatches.next()) {
var unloader = new VectorUnloader(outputRoot);
try (var outputBatch = unloader.getRecordBatch()) {
loader.load(outputBatch);
}
}
writer.putNext();
}
}
}
writer.completed();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
abstract class UserDefinedFunctionBatch {
protected Schema inputSchema;
protected Schema outputSchema;
protected BufferAllocator allocator;

/** Get the input schema of the function. */
Schema getInputSchema() {
Expand All @@ -44,9 +43,11 @@ Schema getOutputSchema() {
* Evaluate the function by processing a batch of input data.
*
* @param batch the input data batch to process
* @param allocator the allocator to use for allocating output data
* @return an iterator over the output data batches
*/
abstract Iterator<VectorSchemaRoot> evalBatch(VectorSchemaRoot batch);
abstract Iterator<VectorSchemaRoot> evalBatch(
VectorSchemaRoot batch, BufferAllocator allocator);
}

/** Utility class for reflection. */
Expand Down

0 comments on commit 4297914

Please sign in to comment.