Skip to content

Commit

Permalink
Fixes leaks in BatchWithPartitionData and its suite
Browse files Browse the repository at this point in the history
Signed-off-by: Alessandro Bellina <[email protected]>
  • Loading branch information
abellina committed Nov 15, 2023
1 parent c1c4708 commit 4a0f489
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,8 @@ object BatchWithPartitionDataUtils {
* Splits the input ColumnarBatch into smaller batches, wraps these batches with partition
* data, and returns them as a sequence of [[BatchWithPartitionData]].
*
* This function does not take ownership of `batch`, and callers should make sure to close.
*
* @note Partition values are merged with the columnar batches lazily by the resulting Iterator
* to save GPU memory.
* @param batch Input ColumnarBatch.
Expand Down Expand Up @@ -502,9 +504,10 @@ object BatchWithPartitionDataUtils {
throw new SplitAndRetryOOM("GPU OutOfMemory: cannot split input with one row")
}
// Split the batch into two halves
val cb = batchWithPartData.inputBatch.getColumnarBatch()
splitAndCombineBatchWithPartitionData(cb, splitPartitionData,
batchWithPartData.partitionSchema)
withResource(batchWithPartData.inputBatch.getColumnarBatch()) { cb =>
splitAndCombineBatchWithPartitionData(cb, splitPartitionData,
batchWithPartData.partitionSchema)
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,13 @@ class BatchWithPartitionDataSuite extends RmmSparkRetrySuiteBase with SparkQuery
withResource(buildBatch(getSampleValueData)) { valueBatch =>
withResource(buildBatch(partCols)) { partBatch =>
withResource(GpuColumnVector.combineColumns(valueBatch, partBatch)) { expectedBatch =>
GpuColumnVector.incRefCounts(valueBatch)
val resultBatchIter = BatchWithPartitionDataUtils.addPartitionValuesToBatch(valueBatch,
partRows, partValues, partSchema, maxGpuColumnSizeBytes)
withResource(resultBatchIter) { _ =>
RmmSpark.forceSplitAndRetryOOM(RmmSpark.getCurrentThreadId)
// Assert that the final count of rows matches expected batch
val rowCounts = resultBatchIter.map(_.numRows()).sum
val rowCounts = resultBatchIter.map(withResource(_){_.numRows()}).sum
assert(rowCounts == expectedBatch.numRows())
}
}
Expand Down

0 comments on commit 4a0f489

Please sign in to comment.