Skip to content

Commit

Permalink
Support count down latch.
Browse files Browse the repository at this point in the history
  • Loading branch information
aoli-al committed Mar 26, 2024
1 parent 2e1f9b3 commit b229cc5
Show file tree
Hide file tree
Showing 10 changed files with 258 additions and 0 deletions.
30 changes: 30 additions & 0 deletions core/src/main/kotlin/cmu/pasta/sfuzz/core/GlobalContext.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package cmu.pasta.sfuzz.core
import cmu.pasta.sfuzz.core.concurrency.locks.LockManager
import cmu.pasta.sfuzz.core.concurrency.SFuzzThread
import cmu.pasta.sfuzz.core.concurrency.SynchronizationManager
import cmu.pasta.sfuzz.core.concurrency.locks.CountDownLatchManager
import cmu.pasta.sfuzz.core.concurrency.locks.SemaphoreManager
import cmu.pasta.sfuzz.core.concurrency.operations.*
import cmu.pasta.sfuzz.core.logger.LoggerBase
Expand All @@ -15,6 +16,7 @@ import cmu.pasta.sfuzz.runtime.Delegate
import cmu.pasta.sfuzz.runtime.MemoryOpType
import cmu.pasta.sfuzz.runtime.Runtime
import cmu.pasta.sfuzz.runtime.TargetTerminateException
import java.util.concurrent.CountDownLatch
import java.util.concurrent.Executors
import java.util.concurrent.Semaphore
import java.util.concurrent.locks.Condition
Expand All @@ -33,6 +35,7 @@ object GlobalContext {
private val lockManager = LockManager()
private val semaphoreManager = SemaphoreManager()
private val volatileManager = VolatileManager()
private val latchManager = CountDownLatchManager()
val syncManager = SynchronizationManager()
val loggers = mutableListOf<LoggerBase>()
var executor = Executors.newSingleThreadExecutor { r ->
Expand Down Expand Up @@ -496,6 +499,33 @@ object GlobalContext {
scheduleNextOperation(true)
}

fun latchAwait(latch: CountDownLatch) {
val t = Thread.currentThread().id
if (latchManager.await(latch)) {
val t = Thread.currentThread().id
registeredThreads[t]?.pendingOperation = PausedOperation()
registeredThreads[t]?.state = ThreadState.Paused
}
executor.submit {
while (registeredThreads[t]!!.thread.state == Thread.State.RUNNABLE) {
Thread.yield()
}
scheduleNextOperation(false)
}
}

fun latchAwaitDone(latch: CountDownLatch) {
val t = Thread.currentThread().id
val context = registeredThreads[t]!!
if (context.state != ThreadState.Running) {
context.block()
}
}

fun latchCountDown(latch: CountDownLatch) {
latchManager.countDown(latch)
}

fun scheduleNextOperation(shouldBlockCurrentThread: Boolean) {
// Our current design makes sure that reschedule is only called
// by scheduled thread.
Expand Down
32 changes: 32 additions & 0 deletions core/src/main/kotlin/cmu/pasta/sfuzz/core/RuntimeDelegate.kt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import cmu.pasta.sfuzz.core.concurrency.SFuzzThread
import cmu.pasta.sfuzz.runtime.Delegate
import cmu.pasta.sfuzz.runtime.MemoryOpType
import cmu.pasta.sfuzz.runtime.TargetTerminateException
import java.util.concurrent.CountDownLatch
import java.util.concurrent.Semaphore
import java.util.concurrent.locks.Condition
import java.util.concurrent.locks.ReentrantLock
Expand Down Expand Up @@ -341,6 +342,37 @@ class RuntimeDelegate: Delegate() {
skipFunctionEntered.set(skipFunctionEntered.get() - 1)
}

override fun onLatchAwait(latch: CountDownLatch) {
if (checkEntered()) {
skipFunctionEntered.set(1 + skipFunctionEntered.get())
return
}
GlobalContext.latchAwait(latch)
entered.set(false)
skipFunctionEntered.set(skipFunctionEntered.get() + 1)
}

override fun onLatchAwaitDone(latch: CountDownLatch) {
skipFunctionEntered.set(skipFunctionEntered.get() - 1)
if (checkEntered()) return
GlobalContext.latchAwaitDone(latch)
entered.set(false)
}

override fun onLatchCountDown(latch: CountDownLatch) {
if (checkEntered()) {
skipFunctionEntered.set(1 + skipFunctionEntered.get())
return
}
GlobalContext.latchCountDown(latch)
entered.set(false)
skipFunctionEntered.set(skipFunctionEntered.get() + 1)
}

override fun onLatchCountDownDone(latch: CountDownLatch?) {
skipFunctionEntered.set(skipFunctionEntered.get() - 1)
}

override fun start() {
// For the first thread, it is not registered.
// Therefor we cannot call `checkEntered` here.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package cmu.pasta.sfuzz.core.concurrency.locks

import cmu.pasta.sfuzz.core.GlobalContext
import cmu.pasta.sfuzz.core.ThreadState
import cmu.pasta.sfuzz.core.concurrency.operations.ThreadResumeOperation

class CountDownLatchContext(var count: Long) {
val latchWaiters = mutableSetOf<Long>()
fun await(): Boolean {
if (count > 0) {
latchWaiters.add(Thread.currentThread().id)
return true
}
assert(count == 0L)
return false
}

fun countDown() {
// If count is already zero we do not need to
// do anything
if (count == 0L) {
return
}
count -= 1
if (count == 0L) {
for (tid in latchWaiters) {
GlobalContext.registeredThreads[tid]!!.pendingOperation = ThreadResumeOperation()
GlobalContext.registeredThreads[tid]!!.state = ThreadState.Enabled
}
}
return
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package cmu.pasta.sfuzz.core.concurrency.locks

import java.util.concurrent.CountDownLatch

class CountDownLatchManager {
val latchStore = ReferencedContextManager {it ->
if (it is CountDownLatch) {
CountDownLatchContext(it.count)
} else {
throw IllegalArgumentException("CountDownLatchManager can only manage CountDownLatch objects")
}
}

fun await(latch: CountDownLatch): Boolean {
return latchStore.getLockContext(latch).await()
}

fun countDown(latch: CountDownLatch) {
latchStore.getLockContext(latch).countDown()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ fun instrumentClass(path:String, inputStream: InputStream): ByteArray {
cv = ClassloaderInstrumenter(cv)
cv = ObjectInstrumenter(cv)
cv = SemaphoreInstrumenter(cv)
cv = CountDownLatchInstrumenter(cv)
// MonitorInstrumenter should come second because ObjectInstrumenter will insert more
// monitors.
cv = MonitorInstrumenter(cv)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package cmu.pasta.sfuzz.instrumentation.visitors

import cmu.pasta.sfuzz.runtime.Runtime
import org.objectweb.asm.ClassVisitor
import org.objectweb.asm.MethodVisitor
import java.util.concurrent.CountDownLatch

class CountDownLatchInstrumenter(cv: ClassVisitor): ClassVisitorBase(cv, CountDownLatch::class.java.name) {
override fun instrumentMethod(
mv: MethodVisitor,
access: Int,
name: String,
descriptor: String,
signature: String?,
exceptions: Array<out String>?
): MethodVisitor {
if (name == "await") {
val eMv = MethodEnterVisitor(mv, Runtime::onLatchAwait, access, name, descriptor, true, false)
return MethodExitVisitor(eMv, Runtime::onLatchAwaitDone, access, name, descriptor, true, false)
}
if (name == "countDown") {
val eMv = MethodEnterVisitor(mv, Runtime::onLatchCountDown, access, name, descriptor, true, false)
return MethodExitVisitor(eMv, Runtime::onLatchCountDownDone, access, name, descriptor, true, false)
}
return super.instrumentMethod(mv, access, name, descriptor, signature, exceptions)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package cmu.pasta.sfuzz.it.core;

import cmu.pasta.sfuzz.it.IntegrationTestRunner;
import org.junit.jupiter.api.Test;

import static org.junit.jupiter.api.Assertions.assertEquals;

public class CountDownLatchTest extends IntegrationTestRunner {

@Test
public void testCountDown() {
String expected = "[1]: WORKER-1 finished\n" +
"[2]: WORKER-2 finished\n" +
"[3]: WORKER-3 finished\n" +
"[4]: WORKER-4 finished\n" +
"[0]: Test worker has finished\n";
assertEquals(expected, runTest("testCountDown"));
}

@Test
public void testAwaitAfterCountDown() {
String expected = "[0]: Test worker has finished\n";
assertEquals(expected, runTest("testAwaitAfterCountDown"));
}
}
58 changes: 58 additions & 0 deletions integration-tests/src/test/java/example/CountDownLatchTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package example;

import jdk.jshell.execution.Util;

import java.util.concurrent.CountDownLatch;

public class CountDownLatchTest {
public static void testCountDown() throws InterruptedException {
CountDownLatch latch = new CountDownLatch(4);
Worker first = new Worker(1000, latch,
"WORKER-1");
Worker second = new Worker(2000, latch,
"WORKER-2");
Worker third = new Worker(3000, latch,
"WORKER-3");
Worker fourth = new Worker(4000, latch,
"WORKER-4");
first.start();
second.start();
third.start();
fourth.start();
latch.await();
Utils.log(Thread.currentThread().getName() +
" has finished");
}

public static void testAwaitAfterCountDown() throws InterruptedException {
CountDownLatch latch = new CountDownLatch(1);
latch.countDown();
latch.await();
Utils.log(Thread.currentThread().getName() +
" has finished");
}
}

// A class to represent threads for which
// the main thread waits.
class Worker extends Thread
{
private int delay;
private CountDownLatch latch;

public Worker(int delay, CountDownLatch latch,
String name)
{
super(name);
this.delay = delay;
this.latch = latch;
}

@Override
public void run()
{
latch.countDown();
Utils.log(Thread.currentThread().getName()
+ " finished");
}
}
13 changes: 13 additions & 0 deletions runtime/src/main/java/cmu/pasta/sfuzz/runtime/Delegate.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package cmu.pasta.sfuzz.runtime;

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Semaphore;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.ReentrantLock;
Expand Down Expand Up @@ -153,6 +154,18 @@ public void onSemaphoreReducePermits(Semaphore sem, int permits) {
public void onSemaphoreReducePermitsDone() {
}

public void onLatchAwait(CountDownLatch latch) {
}

public void onLatchAwaitDone(CountDownLatch latch) {
}

public void onLatchCountDown(CountDownLatch latch) {
}

public void onLatchCountDownDone(CountDownLatch latch) {
}

public boolean onThreadClearInterrupt(Boolean originValue, Thread t) {
return originValue;
}
Expand Down
17 changes: 17 additions & 0 deletions runtime/src/main/java/cmu/pasta/sfuzz/runtime/Runtime.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package cmu.pasta.sfuzz.runtime;

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Semaphore;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.ReentrantLock;
Expand Down Expand Up @@ -218,4 +219,20 @@ public static void onSemaphoreReducePermitsDone() {
public static void onSemaphoreDrainPermits(Semaphore sem) {
DELEGATE.onSemaphoreDrainPermits(sem);
}

public static void onLatchAwait(CountDownLatch latch) {
DELEGATE.onLatchAwait(latch);
}

public static void onLatchAwaitDone(CountDownLatch latch) {
DELEGATE.onLatchAwaitDone(latch);
}

public static void onLatchCountDown(CountDownLatch latch) {
DELEGATE.onLatchCountDown(latch);
}

public static void onLatchCountDownDone(CountDownLatch latch) {
DELEGATE.onLatchCountDownDone(latch);
}
}

0 comments on commit b229cc5

Please sign in to comment.