From c99fbe4621b5235e8c9aaad16add9e93b065f115 Mon Sep 17 00:00:00 2001 From: starry-shivam Date: Sat, 21 Sep 2024 14:48:24 +0530 Subject: [PATCH] fix: synchronize access to runningJobs map in CoroutineExecutor also improve & fix coroutine executor test cases Signed-off-by: starry-shivam --- .../ktscheduler/executor/CoroutineExecutor.kt | 20 +-- .../ktscheduler/CoroutineExecutorTest.kt | 132 +++++++++++------- 2 files changed, 90 insertions(+), 62 deletions(-) diff --git a/src/main/kotlin/dev/starry/ktscheduler/executor/CoroutineExecutor.kt b/src/main/kotlin/dev/starry/ktscheduler/executor/CoroutineExecutor.kt index ae9a88f..ab9206d 100644 --- a/src/main/kotlin/dev/starry/ktscheduler/executor/CoroutineExecutor.kt +++ b/src/main/kotlin/dev/starry/ktscheduler/executor/CoroutineExecutor.kt @@ -30,6 +30,7 @@ class CoroutineExecutor : Executor { // A map of currently running jobs. private val runningJobs = ConcurrentHashMap() + private val lock = Any() // Lock to synchronize access to runningJobs. /** * Executes the given job. @@ -41,24 +42,25 @@ class CoroutineExecutor : Executor { override suspend fun execute( job: Job, onSuccess: () -> Unit, onError: (Exception) -> Unit ) { - // If the job is not allowed to run concurrently and a job with the - // same ID is already running, return. - if (!job.runConcurrently && runningJobs.containsKey(job.jobId)) { - return + synchronized(lock) { + // If the job is not allowed to run concurrently and another job + // with the same ID is running, return. + if (!job.runConcurrently && runningJobs.containsKey(job.jobId)) { + return + } + runningJobs[job.jobId] = job } CoroutineScope(job.dispatcher).launch { - // Add the job to the running jobs map. - runningJobs[job.jobId] = job try { job.callback() withContext(Dispatchers.Default) { onSuccess() } } catch (exc: Exception) { withContext(Dispatchers.Default) { onError(exc) } } finally { - // Remove the job from the running jobs map. - runningJobs.remove(job.jobId) + // Remove the job from the runningJobs map after execution. + synchronized(lock) { runningJobs.remove(job.jobId) } } } } -} +} \ No newline at end of file diff --git a/src/test/kotlin/dev/starry/ktscheduler/CoroutineExecutorTest.kt b/src/test/kotlin/dev/starry/ktscheduler/CoroutineExecutorTest.kt index e9bd15e..fdcd828 100644 --- a/src/test/kotlin/dev/starry/ktscheduler/CoroutineExecutorTest.kt +++ b/src/test/kotlin/dev/starry/ktscheduler/CoroutineExecutorTest.kt @@ -20,109 +20,135 @@ package dev.starry.ktscheduler.test import dev.starry.ktscheduler.executor.CoroutineExecutor import dev.starry.ktscheduler.job.Job import dev.starry.ktscheduler.triggers.OneTimeTrigger -import junit.framework.TestCase.assertTrue import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.delay -import kotlinx.coroutines.test.TestCoroutineScheduler -import kotlinx.coroutines.test.UnconfinedTestDispatcher -import kotlinx.coroutines.test.resetMain -import kotlinx.coroutines.test.runTest +import kotlinx.coroutines.test.* import org.junit.After +import org.junit.Before import org.junit.Test import java.time.ZoneId import java.time.ZonedDateTime -import kotlin.test.DefaultAsserter.fail -import kotlin.test.assertEquals -import kotlin.test.assertNotNull +import java.util.concurrent.CountDownLatch +import java.util.concurrent.TimeUnit +import kotlin.test.* @OptIn(ExperimentalCoroutinesApi::class) class CoroutineExecutorTest { + private lateinit var testDispatcher: TestDispatcher + private lateinit var testScheduler: TestCoroutineScheduler + + @Before + fun setUp() { + testScheduler = TestCoroutineScheduler() + testDispatcher = StandardTestDispatcher(testScheduler) + Dispatchers.setMain(testDispatcher) + } + @After fun tearDown() { Dispatchers.resetMain() } @Test - fun testExecuteSuccess(): Unit = runTest { + fun testExecuteSuccess() = runTest { val executor = CoroutineExecutor() - val job = createTestJob(scheduler = testScheduler) { } + val job = createTestJob { } + val latch = CountDownLatch(1) var onSuccessCalled = false - val onSuccess: () -> Unit = { onSuccessCalled = true } + val onSuccess: () -> Unit = { + onSuccessCalled = true + latch.countDown() + } val onError: (Throwable) -> Unit = { fail("onError should not be called") } executor.execute(job, onSuccess, onError) - delay(50) - assertTrue(onSuccessCalled) + // Advance time to ensure the job is executed + advanceUntilIdle() + assertTrue(latch.await(5, TimeUnit.SECONDS), "Timeout waiting for job execution") + assertTrue(onSuccessCalled, "onSuccess should have been called") } @Test - fun testExecuteError(): Unit = runTest { + fun testExecuteError() = runTest { val executor = CoroutineExecutor() - val job = createTestJob(scheduler = testScheduler) { throw IllegalArgumentException("Error") } - + val job = createTestJob { throw IllegalArgumentException("Error") } + val latch = CountDownLatch(1) + var exceptionCaught: Throwable? = null val onSuccess: () -> Unit = { fail("onSuccess should not be called") } - var exception: Throwable? = null - val onError: (Throwable) -> Unit = { exception = it } + val onError: (Throwable) -> Unit = { + exceptionCaught = it + latch.countDown() + } executor.execute(job, onSuccess, onError) - delay(50) - assertNotNull(exception) - assertTrue(exception is IllegalArgumentException) - assertEquals("Error", exception.message) + // Advance time to ensure the job is executed + advanceUntilIdle() + assertTrue(latch.await(5, TimeUnit.SECONDS), "Timeout waiting for job execution") + assertNotNull(exceptionCaught, "Exception should have been caught") + assertTrue(exceptionCaught is IllegalArgumentException, "Exception should be IllegalArgumentException") + assertEquals("Error", (exceptionCaught as IllegalArgumentException).message) } @Test - fun testConcurrentExecution(): Unit = runTest { + fun testConcurrentExecution() = runTest { val executor = CoroutineExecutor() - // Create a job that takes 100ms to execute. - val job = createTestJob( - scheduler = testScheduler, runConcurrently = true - ) { delay(100) } + val job = createTestJob(runConcurrently = true) { + // Simulate some work + delay(100) + } - var onSuccessCalled = 0 - val onSuccess: () -> Unit = { onSuccessCalled += 1 } + val latch = CountDownLatch(3) + var executionCount = 0 + val onSuccess: () -> Unit = { + executionCount++ + latch.countDown() + } val onError: (Throwable) -> Unit = { fail("onError should not be called") } - // Execute the job 3 times concurrently. - executor.execute(job, onSuccess, onError) - executor.execute(job, onSuccess, onError) - executor.execute(job, onSuccess, onError) - // Wait for the jobs to complete. - delay(110) - assertEquals(3, onSuccessCalled) + + // Execute the job 3 times, all should run concurrently + repeat(3) { executor.execute(job, onSuccess, onError) } + // Advance the virtual time to let all jobs finish + advanceTimeBy(110) + assertTrue(latch.await(5, TimeUnit.SECONDS), "Timeout waiting for job executions") + assertEquals(3, executionCount, "All three jobs should have executed") } @Test - fun testNonConcurrentExecution(): Unit = runTest { + fun testNonConcurrentExecution() = runTest { val executor = CoroutineExecutor() - // Create a job that takes 100ms to execute. - val job = createTestJob( - scheduler = testScheduler, runConcurrently = false - ) { delay(100) } + val job = createTestJob(runConcurrently = false) { + // Simulate some work + delay(100) + } - var onSuccessCalled = 0 - val onSuccess: () -> Unit = { onSuccessCalled += 1 } + val latch = CountDownLatch(1) + var executionCount = 0 + val onSuccess: () -> Unit = { + executionCount++ + latch.countDown() + } val onError: (Throwable) -> Unit = { fail("onError should not be called") } - // Execute the job 3 times concurrently. - executor.execute(job, onSuccess, onError) - executor.execute(job, onSuccess, onError) - executor.execute(job, onSuccess, onError) - // Wait for the jobs to complete. - delay(110) - assertEquals(1, onSuccessCalled) + + // Execute the job 3 times, but only one should actually run since it's non-concurrent + repeat(3) { executor.execute(job, onSuccess, onError) } + // Advance the virtual time to let the first job finish + advanceTimeBy(110) + assertTrue(latch.await(5, TimeUnit.SECONDS), "Timeout waiting for job execution") + assertEquals(1, executionCount, "Only one job should have executed") } + private fun createTestJob( jobId: String = "job1", runConcurrently: Boolean = true, - scheduler: TestCoroutineScheduler, - callback: suspend () -> Unit, + callback: suspend () -> Unit ): Job = Job( jobId = jobId, trigger = OneTimeTrigger(ZonedDateTime.now(ZoneId.of("UTC")).plusSeconds(1)), nextRunTime = ZonedDateTime.now(), - dispatcher = UnconfinedTestDispatcher(scheduler), + dispatcher = testDispatcher, runConcurrently = runConcurrently, callback = callback )