Skip to content

Commit

Permalink
Fix query hang when using rapids multithread shuffle manager with kudo (
Browse files Browse the repository at this point in the history
#11771)

* Fix query hang when using kudo and multi thread shuffle manager

Signed-off-by: liurenjie1024 <[email protected]>

* Fix NPE

---------

Signed-off-by: liurenjie1024 <[email protected]>
  • Loading branch information
liurenjie1024 authored Nov 26, 2024
1 parent 2b6ac11 commit e3dce9e
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -319,10 +319,12 @@ object SerializedTableColumn {
if (batch.numCols == 1) {
val cv = batch.column(0)
cv match {
case serializedTableColumn: SerializedTableColumn
if serializedTableColumn.hostBuffer != null =>
sum += serializedTableColumn.hostBuffer.getLength
case serializedTableColumn: SerializedTableColumn =>
sum += Option(serializedTableColumn.hostBuffer).map(_.getLength).getOrElse(0L)
case kudo: KudoSerializedTableColumn =>
sum += Option(kudo.kudoTable.getBuffer).map(_.getLength).getOrElse(0L)
case _ =>
throw new IllegalStateException(s"Unexpected column type: ${cv.getClass}" )
}
}
sum
Expand Down Expand Up @@ -496,65 +498,75 @@ object KudoSerializedTableColumn {
class KudoSerializedBatchIterator(dIn: DataInputStream)
extends BaseSerializedTableIterator {
private[this] var nextHeader: Option[KudoTableHeader] = None
private[this] var toBeReturned: Option[ColumnarBatch] = None
private[this] var streamClosed: Boolean = false

// Don't install the callback if in a unit test
Option(TaskContext.get()).foreach { tc =>
onTaskCompletion(tc) {
toBeReturned.foreach(_.close())
toBeReturned = None
dIn.close()
}
}

private def tryReadNextHeader(): Unit = {
if (!streamClosed) {
withResource(new NvtxRange("Read Kudo Header", NvtxColor.YELLOW)) { _ =>
require(nextHeader.isEmpty)
nextHeader = Option(KudoTableHeader.readFrom(dIn).orElse(null))
if (nextHeader.isEmpty) {
dIn.close()
streamClosed = true
override def peekNextBatchSize(): Option[Long] = {
if (streamClosed) {
None
} else {
if (nextHeader.isEmpty) {
withResource(new NvtxRange("Read Header", NvtxColor.YELLOW)) { _ =>
val header = Option(KudoTableHeader.readFrom(dIn).orElse(null))
if (header.isDefined) {
nextHeader = header
} else {
dIn.close()
streamClosed = true
nextHeader = None
}
}
}
nextHeader.map(_.getTotalDataLen)
}
}

override def hasNext: Boolean = {
private def tryReadNext(): Option[ColumnarBatch] = {
if (nextHeader.isEmpty) {
tryReadNextHeader()
}
nextHeader.isDefined
}

override def next(): (Int, ColumnarBatch) = {
if (hasNext) {
val header = nextHeader.get
nextHeader = None
val buffer = if (header.getNumColumns == 0) {
null
} else {
withResource(new NvtxRange("Read Kudo Body", NvtxColor.YELLOW)) { _ =>
val buffer = HostMemoryBuffer.allocate(header.getTotalDataLen, false)
closeOnExcept(buffer) { _ =>
buffer.copyFromStream(0, dIn, header.getTotalDataLen)
None
} else {
withResource(new NvtxRange("Read Batch", NvtxColor.YELLOW)) { _ =>
val header = nextHeader.get
if (header.getNumColumns > 0) {
// This buffer will later be concatenated into another host buffer before being
// sent to the GPU, so no need to use pinned memory for these buffers.
closeOnExcept(HostMemoryBuffer.allocate(header.getTotalDataLen, false)) { hostBuffer =>
hostBuffer.copyFromStream(0, dIn, header.getTotalDataLen)
val kudoTable = new KudoTable(header, hostBuffer)
Some(KudoSerializedTableColumn.from(kudoTable))
}
buffer
} else {
Some(KudoSerializedTableColumn.from(new KudoTable(header, null)))
}
}
(0, KudoSerializedTableColumn.from(new KudoTable(header, buffer)))
} else {
throw new NoSuchElementException("Walked off of the end...")
}
}

/**
* Attempt to read the next header from the stream.
*
* @return the length of the data to read, or None if the stream is closed or ended
*/
override def peekNextBatchSize(): Option[Long] = {
if (nextHeader.isEmpty) {
tryReadNextHeader()
override def hasNext: Boolean = {
peekNextBatchSize()
nextHeader.isDefined
}

override def next(): (Int, ColumnarBatch) = {
if (toBeReturned.isEmpty) {
peekNextBatchSize()
toBeReturned = tryReadNext()
if (nextHeader.isEmpty || toBeReturned.isEmpty) {
throw new NoSuchElementException("Walked off of the end...")
}
}
nextHeader.map(_.getTotalDataLen)
val ret = toBeReturned.get
toBeReturned = None
nextHeader = None
(0, ret)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import java.util.Optional
import java.util.concurrent.{Callable, ConcurrentHashMap, ExecutionException, Executors, Future, LinkedBlockingQueue, TimeUnit}
import java.util.concurrent.atomic.{AtomicInteger, AtomicLong}

import scala.collection
import scala.collection.mutable
import scala.collection.mutable.ListBuffer

Expand Down

0 comments on commit e3dce9e

Please sign in to comment.