diff --git a/delta-lake/common/src/main/databricks/scala/org/apache/spark/sql/rapids/delta/GpuOptimizeWriteExchangeExec.scala b/delta-lake/common/src/main/databricks/scala/org/apache/spark/sql/rapids/delta/GpuOptimizeWriteExchangeExec.scala index 1a9936ea8085..eb897e6a1519 100644 --- a/delta-lake/common/src/main/databricks/scala/org/apache/spark/sql/rapids/delta/GpuOptimizeWriteExchangeExec.scala +++ b/delta-lake/common/src/main/databricks/scala/org/apache/spark/sql/rapids/delta/GpuOptimizeWriteExchangeExec.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * This file was derived from OptimizeWriteExchange.scala * in the Delta Lake project at https://github.com/delta-io/delta @@ -97,8 +97,10 @@ case class GpuOptimizeWriteExchangeExec( ) ++ additionalMetrics } - private lazy val serializer: Serializer = - new GpuColumnarBatchSerializer(gpuLongMetric("dataSize")) + private lazy val sparkTypes: Array[DataType] = child.output.map(_.dataType).toArray + + private lazy val serializer: Serializer = new GpuColumnarBatchSerializer( + gpuLongMetric("dataSize"), partitioning.serializingOnGPU, sparkTypes) @transient lazy val inputRDD: RDD[ColumnarBatch] = child.executeColumnar() @@ -116,7 +118,7 @@ case class GpuOptimizeWriteExchangeExec( inputRDD, child.output, partitioning, - child.output.map(_.dataType).toArray, + sparkTypes, serializer, useGPUShuffle=partitioning.usesGPUShuffle, useMultiThreadedShuffle=partitioning.usesMultiThreadedShuffle, diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuColumnarBatchSerializer.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuColumnarBatchSerializer.scala index be19cb1bcf80..67762f72ebd8 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuColumnarBatchSerializer.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuColumnarBatchSerializer.scala @@ -124,16 +124,16 @@ class SerializedBatchIterator(dIn: DataInputStream) extends Iterator[(Int, Colum * * @note The RAPIDS shuffle does not use this code. */ -class GpuColumnarBatchSerializer(dataSize: GpuMetric, serializingOnGpu: Boolean = false, +class GpuColumnarBatchSerializer(dataSize: GpuMetric, isSerializedTable: Boolean = false, sparkTypes: Array[DataType] = Array.empty) extends Serializer with Serializable { override def newInstance(): SerializerInstance = - new GpuColumnarBatchSerializerInstance(dataSize, serializingOnGpu, sparkTypes) + new GpuColumnarBatchSerializerInstance(dataSize, isSerializedTable, sparkTypes) override def supportsRelocationOfSerializedObjects: Boolean = true } private class GpuColumnarBatchSerializerInstance( dataSize: GpuMetric, - serializingOnGpu: Boolean, + isSerializedTable: Boolean, sparkTypes: Array[DataType]) extends SerializerInstance { private lazy val tableSerializer = new SimpleTableSerializer(sparkTypes) @@ -141,7 +141,7 @@ private class GpuColumnarBatchSerializerInstance( override def serializeStream(out: OutputStream): SerializationStream = new SerializationStream { private[this] val dOut = new DataOutputStream(new BufferedOutputStream(out)) - private def serializeBatchOnCPU(batch: ColumnarBatch): Unit = { + private def serializeCpuBatch(batch: ColumnarBatch): Unit = { val numCols = batch.numCols() if (numCols > 0) { withResource(new ArrayBuffer[AutoCloseable]()) { toClose => @@ -173,7 +173,7 @@ private class GpuColumnarBatchSerializerInstance( } } - private def serializeBatchOnGPU(batch: ColumnarBatch): Unit = { + private def serializeGpuBatch(batch: ColumnarBatch): Unit = { if (batch.numCols() > 0) { batch.column(0) match { case packTable: GpuPackedTableColumn => @@ -190,10 +190,10 @@ private class GpuColumnarBatchSerializerInstance( } } - private lazy val serializeBatch: ColumnarBatch => Unit = if (serializingOnGpu) { - serializeBatchOnGPU + private lazy val serializeBatch: ColumnarBatch => Unit = if (isSerializedTable) { + serializeGpuBatch } else { - serializeBatchOnCPU + serializeCpuBatch } override def writeValue[T: ClassTag](value: T): SerializationStream = { @@ -278,7 +278,8 @@ private class GpuColumnarBatchSerializerInstance( private[rapids] class SimpleTableSerializer(sparkTypes: Array[DataType]) { private val P_MAGIC_CUDF: Int = 0x43554446 - private val headerLen = 4 // the size in bytes of an Int + private val P_VERSION: Int = 0 + private val headerLen = 8 // the size in bytes of two Ints for a header private val tmpBuf = new Array[Byte](1024 * 64) // 64k private def writeByteBufferToStream(bBuf: ByteBuffer, dOut: DataOutputStream): Unit = { @@ -315,6 +316,7 @@ private[rapids] class SimpleTableSerializer(sparkTypes: Array[DataType]) { private def writeProtocolHeader(dOut: DataOutputStream): Unit = { dOut.writeInt(P_MAGIC_CUDF) + dOut.writeInt(P_VERSION) } def writeRowsOnlyToStream(numRows: Int, dOut: DataOutputStream): Long = { @@ -328,7 +330,7 @@ private[rapids] class SimpleTableSerializer(sparkTypes: Array[DataType]) { } def writeToStream(table: ContiguousTable, dOut: DataOutputStream): Long = { - // 1) header, now only a magic number, may add more as needed + // 1) header writeProtocolHeader(dOut) // 2) table metadata, val tableMetaBuf = MetaUtils.buildTableMeta(0, table).getByteBuffer @@ -343,10 +345,15 @@ private[rapids] class SimpleTableSerializer(sparkTypes: Array[DataType]) { } private def readProtocolHeader(dIn: DataInputStream): Unit = { - val num = dIn.readInt() - if (num != P_MAGIC_CUDF) { + val magicNum = dIn.readInt() + if (magicNum != P_MAGIC_CUDF) { throw new IllegalStateException(s"Expected magic number $P_MAGIC_CUDF for " + - s"table serializer, but got $num") + s"table serializer, but got $magicNum") + } + val version = dIn.readInt() + if (version != P_VERSION) { + throw new IllegalStateException(s"Version mismatch: expected $P_VERSION for " + + s"table serializer, but got $version") } } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuShuffleEnv.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuShuffleEnv.scala index d913b0db7cb7..4f352f0b5401 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuShuffleEnv.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuShuffleEnv.scala @@ -144,7 +144,8 @@ object GpuShuffleEnv extends Logging { def serializingOnGpu(conf: RapidsConf): Boolean = { // Serializing on GPU for CPU shuffle does not support compression yet. conf.isSerializingOnGpu && - conf.shuffleCompressionCodec.toLowerCase(Locale.ROOT) == "none" + conf.shuffleCompressionCodec.toLowerCase(Locale.ROOT) == "none" && + (!useGPUShuffle(conf)) } def getCatalog: ShuffleBufferCatalog = if (env == null) { diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExecBase.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExecBase.scala index c47b1f27fe4a..3e936c1c77f1 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExecBase.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExecBase.scala @@ -183,13 +183,6 @@ abstract class GpuShuffleExchangeExecBase( } } - private lazy val serializingOnGPU = { - gpuOutputPartitioning match { - case gpuPartitioning: GpuPartitioning => gpuPartitioning.serializingOnGPU - case _ => false - } - } - // Shuffle produces a lot of small output batches that should be coalesced together. // This coalesce occurs on the GPU and should always be done when using RAPIDS shuffle, // when it is under UCX or CACHE_ONLY modes. @@ -238,7 +231,7 @@ abstract class GpuShuffleExchangeExecBase( // This value must be lazy because the child's output may not have been resolved // yet in all cases. private lazy val serializer: Serializer = new GpuColumnarBatchSerializer( - gpuLongMetric("dataSize"), serializingOnGPU, sparkTypes) + gpuLongMetric("dataSize"), gpuOutputPartitioning.serializingOnGPU, sparkTypes) @transient lazy val inputBatchRDD: RDD[ColumnarBatch] = child.executeColumnar()