diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/BatchWithPartitionData.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/BatchWithPartitionData.scala index 02e5ee118db..6d640683c07 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/BatchWithPartitionData.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/BatchWithPartitionData.scala @@ -340,6 +340,8 @@ object BatchWithPartitionDataUtils { * Splits the input ColumnarBatch into smaller batches, wraps these batches with partition * data, and returns them as a sequence of [[BatchWithPartitionData]]. * + * This function does not take ownership of `batch`, and callers should make sure to close. + * * @note Partition values are merged with the columnar batches lazily by the resulting Iterator * to save GPU memory. * @param batch Input ColumnarBatch. @@ -502,9 +504,10 @@ object BatchWithPartitionDataUtils { throw new SplitAndRetryOOM("GPU OutOfMemory: cannot split input with one row") } // Split the batch into two halves - val cb = batchWithPartData.inputBatch.getColumnarBatch() - splitAndCombineBatchWithPartitionData(cb, splitPartitionData, - batchWithPartData.partitionSchema) + withResource(batchWithPartData.inputBatch.getColumnarBatch()) { cb => + splitAndCombineBatchWithPartitionData(cb, splitPartitionData, + batchWithPartData.partitionSchema) + } } } } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/BatchWithPartitionDataSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/BatchWithPartitionDataSuite.scala index ac4b5d89b47..40d135f1107 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/BatchWithPartitionDataSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/BatchWithPartitionDataSuite.scala @@ -78,12 +78,13 @@ class BatchWithPartitionDataSuite extends RmmSparkRetrySuiteBase with SparkQuery withResource(buildBatch(getSampleValueData)) { valueBatch => withResource(buildBatch(partCols)) { partBatch => withResource(GpuColumnVector.combineColumns(valueBatch, partBatch)) { expectedBatch => + GpuColumnVector.incRefCounts(valueBatch) val resultBatchIter = BatchWithPartitionDataUtils.addPartitionValuesToBatch(valueBatch, partRows, partValues, partSchema, maxGpuColumnSizeBytes) withResource(resultBatchIter) { _ => RmmSpark.forceSplitAndRetryOOM(RmmSpark.getCurrentThreadId) // Assert that the final count of rows matches expected batch - val rowCounts = resultBatchIter.map(_.numRows()).sum + val rowCounts = resultBatchIter.map(withResource(_){_.numRows()}).sum assert(rowCounts == expectedBatch.numRows()) } }