diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/CachedBatchWriterSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/CachedBatchWriterSuite.scala index e7418b06916..038eabed3fc 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/CachedBatchWriterSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/CachedBatchWriterSuite.scala @@ -20,7 +20,7 @@ import scala.collection.JavaConverters.asScalaIteratorConverter import scala.collection.mutable.{ArrayBuffer, ListBuffer} import scala.collection.mutable -import ai.rapids.cudf.{ColumnVector, CompressionType, DType, Table, TableWriter} +import ai.rapids.cudf.{ColumnVector, CompressionType, DType, Rmm, Table, TableWriter} import com.nvidia.spark.rapids.Arm.withResource import com.nvidia.spark.rapids.RapidsPluginImplicits.AutoCloseableFromBatchColumns import org.apache.hadoop.mapreduce.{RecordWriter, TaskAttemptContext} @@ -47,6 +47,8 @@ import org.apache.spark.storage.StorageLevel.MEMORY_ONLY class CachedBatchWriterSuite extends SparkQueryCompareTestSuite { class TestResources extends AutoCloseable { + assert(Rmm.isInitialized, "Need to use this within Spark GPU session, or it may fail to " + + "release column vector.") val byteCv1 = ColumnVector.fromBytes(1) val byteCv3 = ColumnVector.fromBytes(3) val byteCv456 = ColumnVector.fromBytes(4, 5, 6) @@ -59,68 +61,75 @@ class CachedBatchWriterSuite extends SparkQueryCompareTestSuite { } test("convert columnar batch to cached batch on single col table with 0 rows in a batch") { - withResource(new TestResources()) { resources => - val (_, spyGpuCol0) = getCudfAndGpuVectors(resources) - val cb = new ColumnarBatch(Array(spyGpuCol0), 0) - val ser = new ParquetCachedBatchSerializer - val dummySchema = new StructType( - Array(StructField("empty", ByteType, false), - StructField("empty", ByteType, false), - StructField("empty", ByteType, false))) - val listOfPCB = ser.compressColumnarBatchWithParquet(cb, dummySchema, dummySchema, - BYTES_ALLOWED_PER_BATCH, false) - assert(listOfPCB.isEmpty) - } + withGpuSparkSession(_ => + withResource(new TestResources()) { resources => + val (_, spyGpuCol0) = getCudfAndGpuVectors(resources) + val cb = new ColumnarBatch(Array(spyGpuCol0), 0) + val ser = new ParquetCachedBatchSerializer + val dummySchema = new StructType( + Array( + StructField("empty", ByteType, false), + StructField("empty", ByteType, false), + StructField("empty", ByteType, false))) + val listOfPCB = ser.compressColumnarBatchWithParquet( + cb, dummySchema, dummySchema, + BYTES_ALLOWED_PER_BATCH, false) + assert(listOfPCB.isEmpty) + }) } test("convert large columnar batch to cached batch on single col table") { - withResource(new TestResources()) { resources => - val (spyCol0, spyGpuCol0) = getCudfAndGpuVectors(resources) - val splitAt = 2086912 - testCompressColBatch(resources, Array(spyCol0), Array(spyGpuCol0), splitAt) - verify(spyCol0).split(splitAt) - } + withGpuSparkSession(_ => + withResource(new TestResources()) { resources => + val (spyCol0, spyGpuCol0) = getCudfAndGpuVectors(resources) + val splitAt = 2086912 + testCompressColBatch(resources, Array(spyCol0), Array(spyGpuCol0), splitAt) + verify(spyCol0).split(splitAt) + }) } test("convert large columnar batch to cached batch on multi-col table") { - withResource(new TestResources()) { resources => - val (spyCol0, spyGpuCol0) = getCudfAndGpuVectors(resources) - val splitAt = Seq(695637, 1391274, 2086911, 2782548) - testCompressColBatch(resources, Array(spyCol0, spyCol0, spyCol0), - Array(spyGpuCol0, spyGpuCol0, spyGpuCol0), splitAt: _*) - verify(spyCol0, times(3)).split(splitAt: _*) - } + withGpuSparkSession(_ => + withResource(new TestResources()) { resources => + val (spyCol0, spyGpuCol0) = getCudfAndGpuVectors(resources) + val splitAt = Seq(695637, 1391274, 2086911, 2782548) + testCompressColBatch(resources, Array(spyCol0, spyCol0, spyCol0), + Array(spyGpuCol0, spyGpuCol0, spyGpuCol0), splitAt: _*) + verify(spyCol0, times(3)).split(splitAt: _*) + }) } test("convert large InternalRow iterator to cached batch single col") { - withResource(new TestResources()) { resources => - val (_, spyGpuCol0) = getCudfAndGpuVectors(resources) - val cb = new ColumnarBatch(Array(spyGpuCol0), ROWS) - val mockByteType = mock(classOf[ByteType]) - when(mockByteType.defaultSize).thenReturn(1024) - val schema = Seq(AttributeReference("field0", mockByteType, true)()) - testColumnarBatchToCachedBatchIterator(cb, schema) - } + withGpuSparkSession(_ => + withResource(new TestResources()) { resources => + val (_, spyGpuCol0) = getCudfAndGpuVectors(resources) + val cb = new ColumnarBatch(Array(spyGpuCol0), ROWS) + val mockByteType = mock(classOf[ByteType]) + when(mockByteType.defaultSize).thenReturn(1024) + val schema = Seq(AttributeReference("field0", mockByteType, true)()) + testColumnarBatchToCachedBatchIterator(cb, schema) + }) } test("convert large InternalRow iterator to cached batch multi-col") { - withResource(new TestResources()) { resources1 => - val (_, spyGpuCol0) = getCudfAndGpuVectors(resources1) - withResource(new TestResources()) { resources2 => - val (_, spyGpuCol1) = getCudfAndGpuVectors(resources2) - withResource(new TestResources()) { resources3 => - val (_, spyGpuCol2) = getCudfAndGpuVectors(resources3) - val cb = new ColumnarBatch(Array(spyGpuCol0, spyGpuCol1, spyGpuCol2), ROWS) - val mockByteType = mock(classOf[ByteType]) - when(mockByteType.defaultSize).thenReturn(1024) - val schema = Seq(AttributeReference("field0", mockByteType, true)(), - AttributeReference("field1", mockByteType, true)(), - AttributeReference("field2", mockByteType, true)()) - - testColumnarBatchToCachedBatchIterator(cb, schema) + withGpuSparkSession(_ => + withResource(new TestResources()) { resources1 => + val (_, spyGpuCol0) = getCudfAndGpuVectors(resources1) + withResource(new TestResources()) { resources2 => + val (_, spyGpuCol1) = getCudfAndGpuVectors(resources2) + withResource(new TestResources()) { resources3 => + val (_, spyGpuCol2) = getCudfAndGpuVectors(resources3) + val cb = new ColumnarBatch(Array(spyGpuCol0, spyGpuCol1, spyGpuCol2), ROWS) + val mockByteType = mock(classOf[ByteType]) + when(mockByteType.defaultSize).thenReturn(1024) + val schema = Seq(AttributeReference("field0", mockByteType, true)(), + AttributeReference("field1", mockByteType, true)(), + AttributeReference("field2", mockByteType, true)()) + + testColumnarBatchToCachedBatchIterator(cb, schema) + } } - } - } + }) } test("test useCompression conf is honored") {