Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the task count check in TrafficController #11783

Merged
merged 6 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import java.util.concurrent.{Callable, ExecutorService, Future, TimeUnit}

/**
* Thin wrapper around an ExecutorService that adds throttling.
*
* The given executor is owned by this class and will be shutdown when this class is shutdown.
*/
class ThrottlingExecutor(
val executor: ExecutorService, throttler: TrafficController) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.nvidia.spark.rapids.io.async

import java.util.concurrent.Callable
import java.util.concurrent.locks.ReentrantLock
import javax.annotation.concurrent.GuardedBy

import com.nvidia.spark.rapids.RapidsConf
Expand Down Expand Up @@ -85,38 +86,55 @@ class HostMemoryThrottle(val maxInFlightHostMemoryBytes: Long) extends Throttle
*
* This class is thread-safe as it is used by multiple tasks.
*/
class TrafficController protected[rapids] (throttle: Throttle) {
class TrafficController protected[rapids] (@GuardedBy("lock") throttle: Throttle) {

@GuardedBy("this")
@GuardedBy("lock")
private var numTasks: Int = 0

private val lock = new ReentrantLock()
private val canBeScheduled = lock.newCondition()
jihoonson marked this conversation as resolved.
Show resolved Hide resolved

/**
* Blocks the task from being scheduled until the throttle allows it. If there is no task
* currently scheduled, the task is scheduled immediately even if the throttle is exceeded.
*/
def blockUntilRunnable[T](task: Task[T]): Unit = synchronized {
if (numTasks > 0) {
while (!throttle.canAccept(task)) {
wait(100)
def blockUntilRunnable[T](task: Task[T]): Unit = {
lock.lockInterruptibly()
try {
while (numTasks > 0 && !throttle.canAccept(task)) {
jihoonson marked this conversation as resolved.
Show resolved Hide resolved
canBeScheduled.await()
}
numTasks += 1
throttle.taskScheduled(task)
} finally {
lock.unlock()
}
numTasks += 1
throttle.taskScheduled(task)
}

def taskCompleted[T](task: Task[T]): Unit = synchronized {
numTasks -= 1
throttle.taskCompleted(task)
notify()
def taskCompleted[T](task: Task[T]): Unit = {
lock.lockInterruptibly()
try {
numTasks -= 1
throttle.taskCompleted(task)
canBeScheduled.signal()
} finally {
lock.unlock()
}
}

def numScheduledTasks: Int = synchronized {
numTasks
def numScheduledTasks: Int = {
lock.lockInterruptibly()
try {
numTasks
} finally {
lock.unlock()
}
}
}

object TrafficController {

@GuardedBy("this")
private var instance: TrafficController = _

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,17 @@ class ThrottlingExecutorSuite extends AnyFunSuite with BeforeAndAfterEach {
assertResult(0)(throttle.getTotalHostMemoryBytes)
}

test("tasks submission fails if total weight exceeds maxWeight") {
test("tasks submission fails if totalHostMemoryBytes exceeds maxHostMemoryBytes") {
val task1 = new TestTask
val future1 = executor.submit(task1, 10)
assertResult(1)(trafficController.numScheduledTasks)
assertResult(10)(throttle.getTotalHostMemoryBytes)

val task2 = new TestTask
val task2Weight = 100
val task2HostMemory = 100
val exec = Executors.newSingleThreadExecutor()
val future2 = exec.submit(new Runnable {
override def run(): Unit = executor.submit(task2, task2Weight)
override def run(): Unit = executor.submit(task2, task2HostMemory)
})
Thread.sleep(100)
assert(!future2.isDone)
Expand All @@ -94,18 +94,18 @@ class ThrottlingExecutorSuite extends AnyFunSuite with BeforeAndAfterEach {
future1.get(longTimeoutSec, TimeUnit.SECONDS)
future2.get(longTimeoutSec, TimeUnit.SECONDS)
assertResult(1)(trafficController.numScheduledTasks)
assertResult(task2Weight)(throttle.getTotalHostMemoryBytes)
assertResult(task2HostMemory)(throttle.getTotalHostMemoryBytes)
}

test("submit one task heavier than maxWeight") {
test("submit one task heavier than maxHostMemoryBytes") {
val future = executor.submit(() => Thread.sleep(10), throttle.maxInFlightHostMemoryBytes + 1)
future.get(longTimeoutSec, TimeUnit.SECONDS)
assert(future.isDone)
assertResult(0)(trafficController.numScheduledTasks)
assertResult(0)(throttle.getTotalHostMemoryBytes)
}

test("submit multiple tasks such that total weight does not exceed maxWeight") {
test("submit multiple tasks such that totalHostMemoryBytes does not exceed maxHostMemoryBytes") {
val numTasks = 10
val taskRunTime = 10
var future: Future[Unit] = null
Expand All @@ -125,10 +125,10 @@ class ThrottlingExecutorSuite extends AnyFunSuite with BeforeAndAfterEach {
assertResult(10)(throttle.getTotalHostMemoryBytes)

val task2 = new TestTask
val task2Weight = 100
val task2HostMemory = 100
val exec = Executors.newSingleThreadExecutor()
val future2 = exec.submit(new Runnable {
override def run(): Unit = executor.submit(task2, task2Weight)
override def run(): Unit = executor.submit(task2, task2HostMemory)
})
executor.shutdownNow(longTimeoutSec, TimeUnit.SECONDS)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,34 @@

package com.nvidia.spark.rapids.io.async

import java.util.concurrent.{ExecutionException, Executors, ExecutorService, TimeUnit}
import java.util.concurrent.{ExecutionException, Executors, ExecutorService, Future, TimeUnit}

import org.scalatest.BeforeAndAfterEach
import org.scalatest.concurrent.TimeLimitedTests
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.time.Span
import org.scalatest.time.SpanSugar._

class TrafficControllerSuite extends AnyFunSuite with BeforeAndAfterEach {
class TrafficControllerSuite extends AnyFunSuite with BeforeAndAfterEach with TimeLimitedTests {

private var throttle: HostMemoryThrottle = _
class RecordingExecOrderHostMemoryThrottle(maxInFlightHostMemoryBytes: Long)
extends HostMemoryThrottle(maxInFlightHostMemoryBytes) {
var tasksScheduled = Seq.empty[TestTask]

override def taskScheduled[T](task: Task[T]): Unit = {
tasksScheduled = tasksScheduled :+ task.asInstanceOf[TestTask]
super.taskScheduled(task)
}
}

val timeLimit: Span = 10.seconds

private var throttle: RecordingExecOrderHostMemoryThrottle = _
private var controller: TrafficController = _
private var executor: ExecutorService = _

override def beforeEach(): Unit = {
throttle = new HostMemoryThrottle(100)
throttle = new RecordingExecOrderHostMemoryThrottle(100)
controller = new TrafficController(throttle)
executor = Executors.newSingleThreadExecutor()
}
Expand Down Expand Up @@ -76,6 +91,63 @@ class TrafficControllerSuite extends AnyFunSuite with BeforeAndAfterEach {
f.get(1, TimeUnit.SECONDS)
}

test("big task should be scheduled after all running tasks are completed") {
val taskMemoryBytes = 50
val t1 = new TestTask(taskMemoryBytes)
controller.blockUntilRunnable(t1)

val t2 = new TestTask(150)
val f = executor.submit(new Runnable {
override def run(): Unit = controller.blockUntilRunnable(t2)
})
Thread.sleep(100)
assert(!f.isDone)

controller.taskCompleted(t1)
f.get(1, TimeUnit.SECONDS)
}

test("all tasks are bigger than the total memory limit") {
val bigTaskMemoryBytes = 130
val (tasks, futures) = (0 to 2).map { _ =>
val t = new TestTask(bigTaskMemoryBytes)
val f: Future[_] = executor.submit(new Runnable {
override def run(): Unit = controller.blockUntilRunnable(t)
})
(t, f.asInstanceOf[Future[Unit]])
}.unzip
while (controller.numScheduledTasks == 0) {
Thread.sleep(100)
}
assert(futures(0).isDone)
assertResult(1)(controller.numScheduledTasks)
assertResult(throttle.tasksScheduled.head)(tasks(0))

// The first task has been completed
controller.taskCompleted(tasks(0))
// Wait for the second task to be scheduled
while (controller.numScheduledTasks == 0) {
Thread.sleep(100)
}
assert(futures(1).isDone)
assertResult(1)(controller.numScheduledTasks)
assertResult(throttle.tasksScheduled(1))(tasks(1))

// The second task has been completed
controller.taskCompleted(tasks(1))
// Wait for the third task to be scheduled
while (controller.numScheduledTasks == 0) {
Thread.sleep(100)
}
assert(futures(2).isDone)
assertResult(1)(controller.numScheduledTasks)
assertResult(throttle.tasksScheduled(2))(tasks(2))

// The third task has been completed
controller.taskCompleted(tasks(2))
assertResult(0)(controller.numScheduledTasks)
}

test("shutdown while blocking") {
val t1 = new TestTask(10)
controller.blockUntilRunnable(t1)
Expand Down
Loading