From 0ea72527d4349207ffefa0a2b328df808f80088c Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Mon, 4 Dec 2023 21:55:33 +0800 Subject: [PATCH] fix(udf): fix memory leak in Java UDF (#13789) Signed-off-by: Runji Wang --- java/udf/CHANGELOG.md | 1 + .../functions/ScalarFunctionBatch.java | 7 +++--- .../functions/TableFunctionBatch.java | 10 +++----- .../com/risingwave/functions/UdfProducer.java | 25 +++++++++++-------- .../functions/UserDefinedFunctionBatch.java | 5 ++-- 5 files changed, 26 insertions(+), 22 deletions(-) diff --git a/java/udf/CHANGELOG.md b/java/udf/CHANGELOG.md index 7766206a0c9fd..c64dbd1427737 100644 --- a/java/udf/CHANGELOG.md +++ b/java/udf/CHANGELOG.md @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - Fix index-out-of-bound error when string or string list is large. +- Fix memory leak. ## [0.1.1] - 2023-12-03 diff --git a/java/udf/src/main/java/com/risingwave/functions/ScalarFunctionBatch.java b/java/udf/src/main/java/com/risingwave/functions/ScalarFunctionBatch.java index 95b0dede5d99b..47a43a8dc2132 100644 --- a/java/udf/src/main/java/com/risingwave/functions/ScalarFunctionBatch.java +++ b/java/udf/src/main/java/com/risingwave/functions/ScalarFunctionBatch.java @@ -27,9 +27,8 @@ class ScalarFunctionBatch extends UserDefinedFunctionBatch { MethodHandle methodHandle; Function[] processInputs; - ScalarFunctionBatch(ScalarFunction function, BufferAllocator allocator) { + ScalarFunctionBatch(ScalarFunction function) { this.function = function; - this.allocator = allocator; var method = Reflection.getEvalMethod(function); this.methodHandle = Reflection.getMethodHandle(method); this.inputSchema = TypeUtils.methodToInputSchema(method); @@ -38,7 +37,7 @@ class ScalarFunctionBatch extends UserDefinedFunctionBatch { } @Override - Iterator evalBatch(VectorSchemaRoot batch) { + Iterator evalBatch(VectorSchemaRoot batch, BufferAllocator allocator) { var row = new Object[batch.getSchema().getFields().size() + 1]; row[0] = this.function; var outputValues = new Object[batch.getRowCount()]; @@ -55,7 +54,7 @@ Iterator evalBatch(VectorSchemaRoot batch) { } var outputVector = TypeUtils.createVector( - this.outputSchema.getFields().get(0), this.allocator, outputValues); + this.outputSchema.getFields().get(0), allocator, outputValues); var outputBatch = VectorSchemaRoot.of(outputVector); return Collections.singleton(outputBatch).iterator(); } diff --git a/java/udf/src/main/java/com/risingwave/functions/TableFunctionBatch.java b/java/udf/src/main/java/com/risingwave/functions/TableFunctionBatch.java index 5dbe5db470dfd..1480b8de2b9f2 100644 --- a/java/udf/src/main/java/com/risingwave/functions/TableFunctionBatch.java +++ b/java/udf/src/main/java/com/risingwave/functions/TableFunctionBatch.java @@ -28,9 +28,8 @@ class TableFunctionBatch extends UserDefinedFunctionBatch { Function[] processInputs; int chunkSize = 1024; - TableFunctionBatch(TableFunction function, BufferAllocator allocator) { + TableFunctionBatch(TableFunction function) { this.function = function; - this.allocator = allocator; var method = Reflection.getEvalMethod(function); this.methodHandle = Reflection.getMethodHandle(method); this.inputSchema = TypeUtils.methodToInputSchema(method); @@ -39,7 +38,7 @@ class TableFunctionBatch extends UserDefinedFunctionBatch { } @Override - Iterator evalBatch(VectorSchemaRoot batch) { + Iterator evalBatch(VectorSchemaRoot batch, BufferAllocator allocator) { var outputs = new ArrayList(); var row = new Object[batch.getSchema().getFields().size() + 1]; row[0] = this.function; @@ -49,10 +48,9 @@ Iterator evalBatch(VectorSchemaRoot batch) { () -> { var fields = this.outputSchema.getFields(); var indexVector = - TypeUtils.createVector( - fields.get(0), this.allocator, indexes.toArray()); + TypeUtils.createVector(fields.get(0), allocator, indexes.toArray()); var valueVector = - TypeUtils.createVector(fields.get(1), this.allocator, values.toArray()); + TypeUtils.createVector(fields.get(1), allocator, values.toArray()); indexes.clear(); values.clear(); var outputBatch = VectorSchemaRoot.of(indexVector, valueVector); diff --git a/java/udf/src/main/java/com/risingwave/functions/UdfProducer.java b/java/udf/src/main/java/com/risingwave/functions/UdfProducer.java index 71b810726c67c..9f73ab8f9b86c 100644 --- a/java/udf/src/main/java/com/risingwave/functions/UdfProducer.java +++ b/java/udf/src/main/java/com/risingwave/functions/UdfProducer.java @@ -40,9 +40,9 @@ class UdfProducer extends NoOpFlightProducer { void addFunction(String name, UserDefinedFunction function) throws IllegalArgumentException { UserDefinedFunctionBatch udf; if (function instanceof ScalarFunction) { - udf = new ScalarFunctionBatch((ScalarFunction) function, this.allocator); + udf = new ScalarFunctionBatch((ScalarFunction) function); } else if (function instanceof TableFunction) { - udf = new TableFunctionBatch((TableFunction) function, this.allocator); + udf = new TableFunctionBatch((TableFunction) function); } else { throw new IllegalArgumentException( "Unknown function type: " + function.getClass().getName()); @@ -76,21 +76,26 @@ public FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor @Override public void doExchange(CallContext context, FlightStream reader, ServerStreamListener writer) { - try { + try (var allocator = this.allocator.newChildAllocator("exchange", 0, Long.MAX_VALUE)) { var functionName = reader.getDescriptor().getPath().get(0); logger.debug("call function: " + functionName); var udf = this.functions.get(functionName); - try (var root = VectorSchemaRoot.create(udf.getOutputSchema(), this.allocator)) { + try (var root = VectorSchemaRoot.create(udf.getOutputSchema(), allocator)) { var loader = new VectorLoader(root); writer.start(root); while (reader.next()) { - var outputBatches = udf.evalBatch(reader.getRoot()); - while (outputBatches.hasNext()) { - var outputRoot = outputBatches.next(); - var unloader = new VectorUnloader(outputRoot); - loader.load(unloader.getRecordBatch()); - writer.putNext(); + try (var input = reader.getRoot()) { + var outputBatches = udf.evalBatch(input, allocator); + while (outputBatches.hasNext()) { + try (var outputRoot = outputBatches.next()) { + var unloader = new VectorUnloader(outputRoot); + try (var outputBatch = unloader.getRecordBatch()) { + loader.load(outputBatch); + } + } + writer.putNext(); + } } } writer.completed(); diff --git a/java/udf/src/main/java/com/risingwave/functions/UserDefinedFunctionBatch.java b/java/udf/src/main/java/com/risingwave/functions/UserDefinedFunctionBatch.java index d271ac1245ee4..a32aa85d509db 100644 --- a/java/udf/src/main/java/com/risingwave/functions/UserDefinedFunctionBatch.java +++ b/java/udf/src/main/java/com/risingwave/functions/UserDefinedFunctionBatch.java @@ -28,7 +28,6 @@ abstract class UserDefinedFunctionBatch { protected Schema inputSchema; protected Schema outputSchema; - protected BufferAllocator allocator; /** Get the input schema of the function. */ Schema getInputSchema() { @@ -44,9 +43,11 @@ Schema getOutputSchema() { * Evaluate the function by processing a batch of input data. * * @param batch the input data batch to process + * @param allocator the allocator to use for allocating output data * @return an iterator over the output data batches */ - abstract Iterator evalBatch(VectorSchemaRoot batch); + abstract Iterator evalBatch( + VectorSchemaRoot batch, BufferAllocator allocator); } /** Utility class for reflection. */