Skip to content

Commit

Permalink
Update for new retry state machine JNI APIs (#9656)
Browse files Browse the repository at this point in the history
Signed-off-by: Robert (Bobby) Evans <[email protected]>
  • Loading branch information
revans2 authored Nov 29, 2023
1 parent 3936815 commit 33bd589
Show file tree
Hide file tree
Showing 37 changed files with 540 additions and 777 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ class UCX(transport: UCXShuffleTransport, executor: BlockManagerId, rapidsConf:
.setNameFormat("progress-thread-%d")
.setDaemon(true)
.build,
() => RmmSpark.associateCurrentThreadWithShuffle(),
() => RmmSpark.removeCurrentThreadAssociation()))
null,
() => RmmSpark.removeAllCurrentThreadAssociation()))

// The pending queues are used to enqueue [[PendingReceive]] or [[PendingSend]], from executor
// task threads and [[progressThread]] will hand them to the UcpWorker thread.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,8 +250,8 @@ class UCXShuffleTransport(shuffleServerId: BlockManagerId, rapidsConf: RapidsCon
.setNameFormat("shuffle-transport-client-exec-%d")
.setDaemon(true)
.build,
() => RmmSpark.associateCurrentThreadWithShuffle(),
() => RmmSpark.removeCurrentThreadAssociation()),
null,
() => RmmSpark.removeAllCurrentThreadAssociation()),
// if we can't hand off because we are too busy, block the caller (in UCX's case,
// the progress thread)
new CallerRunsAndLogs())
Expand All @@ -262,8 +262,8 @@ class UCXShuffleTransport(shuffleServerId: BlockManagerId, rapidsConf: RapidsCon
.setNameFormat("shuffle-client-copy-thread-%d")
.setDaemon(true)
.build,
() => RmmSpark.associateCurrentThreadWithShuffle(),
() => RmmSpark.removeCurrentThreadAssociation()))
null,
() => RmmSpark.removeAllCurrentThreadAssociation()))

override def makeClient(blockManagerId: BlockManagerId): RapidsShuffleClient = {
val peerExecutorId = blockManagerId.executorId.toLong
Expand All @@ -286,8 +286,8 @@ class UCXShuffleTransport(shuffleServerId: BlockManagerId, rapidsConf: RapidsCon
.setNameFormat(s"shuffle-server-conn-thread-${shuffleServerId.executorId}-%d")
.setDaemon(true)
.build,
() => RmmSpark.associateCurrentThreadWithShuffle(),
() => RmmSpark.removeCurrentThreadAssociation()))
null,
() => RmmSpark.removeAllCurrentThreadAssociation()))

// This is used to queue up on the server all the [[BufferSendState]] as the server waits for
// bounce buffers to become available (it is the equivalent of the transport's throttle, minus
Expand All @@ -297,8 +297,8 @@ class UCXShuffleTransport(shuffleServerId: BlockManagerId, rapidsConf: RapidsCon
.setNameFormat(s"shuffle-server-bss-thread-%d")
.setDaemon(true)
.build,
() => RmmSpark.associateCurrentThreadWithShuffle(),
() => RmmSpark.removeCurrentThreadAssociation()))
null,
() => RmmSpark.removeAllCurrentThreadAssociation()))

/**
* Construct a server instance
Expand Down Expand Up @@ -359,8 +359,8 @@ class UCXShuffleTransport(shuffleServerId: BlockManagerId, rapidsConf: RapidsCon
.setNameFormat(s"shuffle-transport-throttle-monitor")
.setDaemon(true)
.build,
() => RmmSpark.associateCurrentThreadWithShuffle(),
() => RmmSpark.removeCurrentThreadAssociation()))
null,
() => RmmSpark.removeAllCurrentThreadAssociation()))

// helper class to hold transfer requests that have a bounce buffer
// and should be ready to be handled by a `BufferReceiveState`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ abstract class AbstractGpuJoinIterator(
// This withRetry block will always return an iterator with one ColumnarBatch.
// The gatherer tracks how many rows we have used already. The withRestoreOnRetry
// ensures that we restart at the same place in the gatherer. In the case of a
// SplitAndRetryOOM, we retry with a smaller (halved) targetSize, so we are taking
// GpuSplitAndRetryOOM, we retry with a smaller (halved) targetSize, so we are taking
// less from the gatherer, but because the gatherer tracks how much is used, the
// next call to this function will start in the right place.
gather.checkpoint()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import ai.rapids.cudf.ColumnVector
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.RmmRapidsRetryIterator.withRetry
import com.nvidia.spark.rapids.jni.SplitAndRetryOOM
import com.nvidia.spark.rapids.jni.GpuSplitAndRetryOOM

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types.{StringType, StructType}
Expand Down Expand Up @@ -500,8 +500,8 @@ object BatchWithPartitionDataUtils {
withResource(batchWithPartData) { _ =>
// Split partition rows data into two halves
val splitPartitionData = splitPartitionDataInHalf(batchWithPartData.partitionedRowsData)
if(splitPartitionData.length < 2) {
throw new SplitAndRetryOOM("GPU OutOfMemory: cannot split input with one row")
if (splitPartitionData.length < 2) {
throw new GpuSplitAndRetryOOM("GPU OutOfMemory: cannot split input with one row")
}
// Split the batch into two halves
withResource(batchWithPartData.inputBatch.getColumnarBatch()) { cb =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.RmmRapidsRetryIterator.{withRetry, withRetryNoSplit}
import com.nvidia.spark.rapids.ScalableTaskCompletion.onTaskCompletion
import com.nvidia.spark.rapids.jni.SplitAndRetryOOM
import com.nvidia.spark.rapids.jni.GpuSplitAndRetryOOM
import com.nvidia.spark.rapids.shims.{ShimExpression, ShimUnaryExecNode}

import org.apache.spark.TaskContext
Expand Down Expand Up @@ -671,7 +671,7 @@ abstract class AbstractGpuCoalesceIterator(
val it = batchesToCoalesce.batches
val numBatches = it.length
if (numBatches <= 1) {
throw new SplitAndRetryOOM(s"Cannot split a sequence of $numBatches batches")
throw new GpuSplitAndRetryOOM(s"Cannot split a sequence of $numBatches batches")
}
val res = it.splitAt(numBatches / 2)
Seq(BatchesToCoalesce(res._1), BatchesToCoalesce(res._2))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ object GpuDeviceManager extends Logging {

private def initializeOffHeapLimits(gpuId: Int, rapidsConf: Option[RapidsConf]): Unit = {
val conf = rapidsConf.getOrElse(new RapidsConf(SparkEnv.get.conf))
val pinnedSize = if (conf.offHeapLimitEnabled) {
val (pinnedSize, nonPinnedLimit) = if (conf.offHeapLimitEnabled) {
logWarning("OFF HEAP MEMORY LIMITS IS ENABLED. " +
"THIS IS EXPERIMENTAL FOR NOW USE WITH CAUTION")
val perTaskOverhead = conf.perTaskOverhead
Expand Down Expand Up @@ -425,9 +425,9 @@ object GpuDeviceManager extends Logging {
} else {
memoryLimit
}
// TODO need to configure the limits when we have those APIs available, and log what those
// limits are
if (confPinnedSize + totalOverhead <= finalMemoryLimit) {

// Now we need to know the pinned vs non-pinned limits
val pinnedLimit = if (confPinnedSize + totalOverhead <= finalMemoryLimit) {
confPinnedSize
} else {
val ret = finalMemoryLimit - totalOverhead
Expand All @@ -437,13 +437,23 @@ object GpuDeviceManager extends Logging {
s"dropping pinned memory to ${ret / 1024 / 1024.0} MiB")
ret
}
val nonPinnedLimit = finalMemoryLimit - totalOverhead - pinnedLimit
logWarning(s"Off Heap Host Memory configured to be " +
s"${pinnedLimit / 1024 / 1024.0} MiB pinned, " +
s"${nonPinnedLimit / 1024 / 1024.0} MiB non-pinned, and " +
s"${totalOverhead / 1024 / 1024.0} MiB of untracked overhead.")
(pinnedLimit, nonPinnedLimit)
} else {
conf.pinnedPoolSize
(conf.pinnedPoolSize, -1L)
}
if (!PinnedMemoryPool.isInitialized && pinnedSize > 0) {
logInfo(s"Initializing pinned memory pool (${pinnedSize / 1024 / 1024.0} MiB)")
PinnedMemoryPool.initialize(pinnedSize, gpuId)
}
if (nonPinnedLimit >= 0) {
// Host memory limits must be set after the pinned memory pool is initialized
HostAlloc.initialize(nonPinnedLimit)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import ai.rapids.cudf.{HostMemoryBuffer, NvtxColor, NvtxRange, Table}
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}
import com.nvidia.spark.rapids.GpuMetric.{BUFFER_TIME, FILTER_TIME}
import com.nvidia.spark.rapids.RapidsPluginImplicits.AutoCloseableProducingSeq
import com.nvidia.spark.rapids.jni.SplitAndRetryOOM
import com.nvidia.spark.rapids.jni.GpuSplitAndRetryOOM
import org.apache.commons.io.IOUtils
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
Expand Down Expand Up @@ -1029,7 +1029,7 @@ abstract class MultiFileCoalescingPartitionReaderBase(
* Set this to a splitter instance when chunked reading is supported
*/
def chunkedSplit(buffer: HostMemoryBuffer): Seq[HostMemoryBuffer] = {
throw new SplitAndRetryOOM("Split is not currently supported")
throw new GpuSplitAndRetryOOM("Split is not currently supported")
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import com.nvidia.spark.rapids.ParquetPartitionReader.{CopyRange, LocalCopy}
import com.nvidia.spark.rapids.RapidsConf.ParquetFooterReaderType
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.filecache.FileCache
import com.nvidia.spark.rapids.jni.{DateTimeRebase, ParquetFooter, SplitAndRetryOOM}
import com.nvidia.spark.rapids.jni.{DateTimeRebase, GpuSplitAndRetryOOM, ParquetFooter}
import com.nvidia.spark.rapids.shims.{ColumnDefaultValuesShims, GpuParquetCrypto, GpuTypeShims, ParquetLegacyNanoAsLongShims, ParquetSchemaClipShims, ParquetStringPredShims, ReaderUtils, ShimFilePartitionReaderFactory, SparkShimImpl}
import org.apache.commons.io.IOUtils
import org.apache.commons.io.output.{CountingOutputStream, NullOutputStream}
Expand Down Expand Up @@ -248,7 +248,7 @@ object GpuParquetScan {

/**
* Check that we can split the targetBatchSize and then return a split targetBatchSize. This
* is intended to be called from the SplitAndRetryOOM handler for all implementations of
* is intended to be called from the GpuSplitAndRetryOOM handler for all implementations of
* the parquet reader
* @param targetBatchSize the current target batch size.
* @param useChunkedReader if chunked reading is enabled. This only works if chunked reading is
Expand All @@ -257,13 +257,13 @@ object GpuParquetScan {
*/
def splitTargetBatchSize(targetBatchSize: Long, useChunkedReader: Boolean): Long = {
if (!useChunkedReader) {
throw new SplitAndRetryOOM("GPU OutOfMemory: could not split inputs " +
"chunked parquet reader is configured off")
throw new GpuSplitAndRetryOOM("GPU OutOfMemory: could not split inputs " +
"chunked parquet reader is configured off")
}
val ret = targetBatchSize / 2
if (targetBatchSize < minTargetBatchSizeMiB * 1024 * 1024) {
throw new SplitAndRetryOOM("GPU OutOfMemory: could not split input " +
s"target batch size to less than $minTargetBatchSizeMiB MiB")
throw new GpuSplitAndRetryOOM("GPU OutOfMemory: could not split input " +
s"target batch size to less than $minTargetBatchSizeMiB MiB")
}
ret
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import scala.collection.mutable.ArrayBuffer

import ai.rapids.cudf.{NvtxColor, NvtxRange}
import com.nvidia.spark.rapids.ScalableTaskCompletion.onTaskCompletion
import com.nvidia.spark.rapids.jni.RmmSpark

import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -269,14 +268,15 @@ private final class GpuSemaphore() extends Logging {
private val tasks = new ConcurrentHashMap[Long, SemaphoreTaskInfo]

def acquireIfNecessary(context: TaskContext): Unit = {
// Make sure that the thread/task is registered before we try and block
TaskRegistryTracker.registerThreadForRetry()
GpuTaskMetrics.get.semWaitTime {
val taskAttemptId = context.taskAttemptId()
val taskInfo = tasks.computeIfAbsent(taskAttemptId, _ => {
onTaskCompletion(context, completeTask)
new SemaphoreTaskInfo()
})
taskInfo.blockUntilReady(semaphore)
RmmSpark.associateCurrentThreadWithTask(taskAttemptId)
GpuDeviceManager.initializeFromTask()
}
}
Expand All @@ -286,7 +286,6 @@ private final class GpuSemaphore() extends Logging {
try {
val taskAttemptId = context.taskAttemptId()
GpuTaskMetrics.get.updateRetry(taskAttemptId)
RmmSpark.removeCurrentThreadAssociation()
val taskInfo = tasks.get(taskAttemptId)
if (taskInfo != null) {
taskInfo.releaseSemaphore(semaphore)
Expand All @@ -299,7 +298,6 @@ private final class GpuSemaphore() extends Logging {
def completeTask(context: TaskContext): Unit = {
val taskAttemptId = context.taskAttemptId()
GpuTaskMetrics.get.updateRetry(taskAttemptId)
RmmSpark.taskDone(taskAttemptId)
val refs = tasks.remove(taskAttemptId)
if (refs == null) {
throw new IllegalStateException(s"Completion of unknown task $taskAttemptId")
Expand Down
Loading

0 comments on commit 33bd589

Please sign in to comment.