From 7f41bef6d23592450d81072b1b8e1040f709da92 Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Wed, 15 Nov 2023 11:06:02 -0600 Subject: [PATCH] Add more comments --- .../com/nvidia/spark/rapids/BatchWithPartitionDataSuite.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/BatchWithPartitionDataSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/BatchWithPartitionDataSuite.scala index 40d135f1107..6c9f59e8ece 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/BatchWithPartitionDataSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/BatchWithPartitionDataSuite.scala @@ -78,12 +78,16 @@ class BatchWithPartitionDataSuite extends RmmSparkRetrySuiteBase with SparkQuery withResource(buildBatch(getSampleValueData)) { valueBatch => withResource(buildBatch(partCols)) { partBatch => withResource(GpuColumnVector.combineColumns(valueBatch, partBatch)) { expectedBatch => + // we incRefCounts here because `addPartitionValuesToBatch` takes ownership of + // `valueBatch`, but we are keeping it alive since its columns are part of + // `expectedBatch` GpuColumnVector.incRefCounts(valueBatch) val resultBatchIter = BatchWithPartitionDataUtils.addPartitionValuesToBatch(valueBatch, partRows, partValues, partSchema, maxGpuColumnSizeBytes) withResource(resultBatchIter) { _ => RmmSpark.forceSplitAndRetryOOM(RmmSpark.getCurrentThreadId) // Assert that the final count of rows matches expected batch + // We also need to close each batch coming from `resultBatchIter`. val rowCounts = resultBatchIter.map(withResource(_){_.numRows()}).sum assert(rowCounts == expectedBatch.numRows()) }