Skip to content

Commit

Permalink
add test for when all tasks are big
Browse files Browse the repository at this point in the history
  • Loading branch information
jihoonson committed Dec 2, 2024
1 parent a5cf7f0 commit f229f71
Showing 1 changed file with 54 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package com.nvidia.spark.rapids.io.async

import java.util.concurrent.{ExecutionException, Executors, ExecutorService, TimeUnit}
import java.util.concurrent.{ExecutionException, Executors, ExecutorService, Future, TimeUnit}

import org.scalatest.BeforeAndAfterEach
import org.scalatest.concurrent.TimeLimitedTests
Expand All @@ -26,14 +26,24 @@ import org.scalatest.time.SpanSugar._

class TrafficControllerSuite extends AnyFunSuite with BeforeAndAfterEach with TimeLimitedTests {

class RecordingExecOrderHostMemoryThrottle(maxInFlightHostMemoryBytes: Long)
extends HostMemoryThrottle(maxInFlightHostMemoryBytes) {
var tasksScheduled = Seq.empty[TestTask]

override def taskScheduled[T](task: Task[T]): Unit = {
tasksScheduled = tasksScheduled :+ task.asInstanceOf[TestTask]
super.taskScheduled(task)
}
}

val timeLimit: Span = 10.seconds

private var throttle: HostMemoryThrottle = _
private var throttle: RecordingExecOrderHostMemoryThrottle = _
private var controller: TrafficController = _
private var executor: ExecutorService = _

override def beforeEach(): Unit = {
throttle = new HostMemoryThrottle(100)
throttle = new RecordingExecOrderHostMemoryThrottle(100)
controller = new TrafficController(throttle)
executor = Executors.newSingleThreadExecutor()
}
Expand Down Expand Up @@ -97,6 +107,47 @@ class TrafficControllerSuite extends AnyFunSuite with BeforeAndAfterEach with Ti
f.get(1, TimeUnit.SECONDS)
}

test("all tasks are bigger than the total memory limit") {
val bigTaskMemoryBytes = 130
val (tasks, futures) = (0 to 2).map { _ =>
val t = new TestTask(bigTaskMemoryBytes)
val f: Future[_] = executor.submit(new Runnable {
override def run(): Unit = controller.blockUntilRunnable(t)
})
(t, f.asInstanceOf[Future[Unit]])
}.unzip
while (controller.numScheduledTasks == 0) {
Thread.sleep(100)
}
assert(futures(0).isDone)
assertResult(1)(controller.numScheduledTasks)
assertResult(throttle.tasksScheduled.head)(tasks(0))

// The first task has been completed
controller.taskCompleted(tasks(0))
// Wait for the second task to be scheduled
while (controller.numScheduledTasks == 0) {
Thread.sleep(100)
}
assert(futures(1).isDone)
assertResult(1)(controller.numScheduledTasks)
assertResult(throttle.tasksScheduled(1))(tasks(1))

// The second task has been completed
controller.taskCompleted(tasks(1))
// Wait for the third task to be scheduled
while (controller.numScheduledTasks == 0) {
Thread.sleep(100)
}
assert(futures(2).isDone)
assertResult(1)(controller.numScheduledTasks)
assertResult(throttle.tasksScheduled(2))(tasks(2))

// The third task has been completed
controller.taskCompleted(tasks(2))
assertResult(0)(controller.numScheduledTasks)
}

test("shutdown while blocking") {
val t1 = new TestTask(10)
controller.blockUntilRunnable(t1)
Expand Down

0 comments on commit f229f71

Please sign in to comment.