Skip to content

Commit

Permalink
Use separate ForkJoin pools for LocalServer, MTLS Server/Client
Browse files Browse the repository at this point in the history
It's using UnsafeExecutors, so we'll see how long this can last.  But at least for the LocalServer case for in process simulation, this seems to be a win
  • Loading branch information
Hellblazer committed Jun 3, 2024
1 parent 24e8036 commit 0dcd865
Show file tree
Hide file tree
Showing 8 changed files with 230 additions and 19 deletions.
4 changes: 2 additions & 2 deletions choam/src/main/java/com/salesforce/apollo/choam/CHOAM.java
Original file line number Diff line number Diff line change
Expand Up @@ -1114,8 +1114,8 @@ public record PendingView(Digest diadem, Context<Member> context) {
*/
public View getView(Digest hash) {
var builder = View.newBuilder().setDiadem(diadem.toDigeste()).setMajority(context.majority());
((Context<? super Member>) context).bftSubset(hash).forEach(
d -> builder.addCommittee(d.getId().toDigeste()));
((Context<? super Member>) context).bftSubset(hash)
.forEach(d -> builder.addCommittee(d.getId().toDigeste()));
return builder.build();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@

import java.util.List;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.function.Consumer;
import java.util.function.Predicate;
import java.util.function.Supplier;
Expand All @@ -46,7 +45,7 @@ public class Enclave implements RouterSupplier {
private final static Class<? extends io.netty.channel.Channel> channelType = IMPL.getChannelType();
private static final Logger log = LoggerFactory.getLogger(Enclave.class);

private final Executor executor = Executors.newVirtualThreadPerTaskExecutor();
private final Executor executor = UnsafeExecutors.newVirtualThreadPerTaskExecutor();
private final DomainSocketAddress bridge;
private final Consumer<Digest> contextRegistration;
private final DomainSocketAddress endpoint;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import java.lang.reflect.Method;
import java.util.List;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.function.Predicate;
import java.util.function.Supplier;

Expand All @@ -37,10 +36,10 @@
* @author hal.hildebrand
*/
public class LocalServer implements RouterSupplier {
private static final Logger log = LoggerFactory.getLogger(LocalServer.class);
private static final String NAME_TEMPLATE = "%s-%s";
private final Executor executor = Executors.newVirtualThreadPerTaskExecutor();
private static final Logger log = LoggerFactory.getLogger(LocalServer.class);
private static final String NAME_TEMPLATE = "%s-%s";

private final Executor executor = UnsafeExecutors.newVirtualThreadPerTaskExecutor();
private final ClientInterceptor clientInterceptor;
private final Member from;
private final String prefix;
Expand Down Expand Up @@ -78,7 +77,8 @@ public RouterImpl router(ServerConnectionCache.Builder cacheBuilder, Supplier<Li
limitsBuilder.metricRegistry(limitsRegistry);
}
ServerBuilder<?> serverBuilder = InProcessServerBuilder.forName(name)
.executor(Executors.newVirtualThreadPerTaskExecutor())
.executor(
UnsafeExecutors.newVirtualThreadPerTaskExecutor())
.intercept(ConcurrencyLimitServerInterceptor.newBuilder(
limitsBuilder.build())
.statusSupplier(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.function.Supplier;
Expand All @@ -71,7 +70,7 @@ public MtlsServer(Member from, EndpointProvider epProvider, Function<Member, Cli
this.epProvider = epProvider;
this.contextSupplier = contextSupplier;
this.supplier = supplier;
this.executor = Executors.newVirtualThreadPerTaskExecutor();
this.executor = UnsafeExecutors.newVirtualThreadPerTaskExecutor();
cachedMembership = CacheBuilder.newBuilder().build(new CacheLoader<X509Certificate, Digest>() {
@Override
public Digest load(X509Certificate key) throws Exception {
Expand Down Expand Up @@ -148,7 +147,8 @@ public RouterImpl router(ServerConnectionCache.Builder cacheBuilder, Supplier<Li
limitsBuilder.metricRegistry(limitsRegistry);
}
NettyServerBuilder serverBuilder = NettyServerBuilder.forAddress(epProvider.getBindAddress())
.executor(executor)
.executor(
UnsafeExecutors.newVirtualThreadPerTaskExecutor())
.withOption(ChannelOption.SO_REUSEADDR, true)
.sslContext(supplier.forServer(ClientAuth.REQUIRE,
epProvider.getAlias(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import java.io.IOException;
import java.time.Duration;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;

Expand All @@ -38,7 +37,7 @@
public class Portal<To extends Member> {
private final static Class<? extends io.netty.channel.Channel> channelType = IMPL.getChannelType();

private final Executor executor = Executors.newVirtualThreadPerTaskExecutor();
private final Executor executor = UnsafeExecutors.newVirtualThreadPerTaskExecutor();
private final String agent;
private final EventLoopGroup eventLoopGroup = IMPL.getEventLoopGroup();
private final Demultiplexer inbound;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
package com.salesforce.apollo.archipelago;

import java.lang.Thread.UncaughtExceptionHandler;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicBoolean;

import static java.lang.invoke.MethodHandles.insertArguments;
import static java.lang.invoke.MethodType.methodType;

@SuppressWarnings("unused")
public class UnsafeExecutors {
private static final MethodHandle SET_EXECUTOR;

static {
try {
var unsafeClass = Class.forName("sun.misc.Unsafe");
var unsafeField = unsafeClass.getDeclaredField("theUnsafe");
unsafeField.setAccessible(true);
var unsafe = unsafeField.get(null);
var objectFieldOffset = unsafeClass.getMethod("objectFieldOffset", Field.class);
var executorField = VTB.class.getDeclaredField("executor");
executorField.setAccessible(true);
var executorOffset = (long) objectFieldOffset.invoke(unsafe, executorField);
var putObject = MethodHandles.lookup()
.findVirtual(unsafeClass, "putObject",
methodType(void.class, Object.class, long.class, Object.class));
var setExecutor = insertArguments(insertArguments(putObject, 2, executorOffset), 0, unsafe);
SET_EXECUTOR = setExecutor;
} catch (ClassNotFoundException | NoSuchFieldException | NoSuchMethodException | IllegalAccessException |
InvocationTargetException e) {
throw new AssertionError(e);
}
}

public static ExecutorService newVirtualThreadPerTaskExecutor() {
return virtualThreadExecutor(new ForkJoinPool());
}

public static <B extends Thread.Builder> B configureBuilderExecutor(B builder, Executor executor) {
if (executor != null) {
setExecutor(builder, executor);
}
return builder;
}

public static ExecutorService virtualThreadExecutor(ExecutorService executor) {
Objects.requireNonNull(executor);
return new VirtualThreadExecutor(executor);
}

private static void setExecutor(Object builder, Object executor) {
try {
SET_EXECUTOR.invokeExact(builder, executor);
} catch (Throwable e) {
throw new AssertionError(e);
}
}

private static class BTB {
private int characteristics;
private long counter;
private String name;
private UncaughtExceptionHandler uhe;
}

private static class VirtualThreadExecutor extends AbstractExecutorService {
private final ExecutorService executor;
private final AtomicBoolean started = new AtomicBoolean(true);

public VirtualThreadExecutor(ExecutorService executor) {
this.executor = executor;
}

@Override
public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException {
return true;
}

@Override
public void execute(Runnable command) {
if (!started.get()) {
throw new RejectedExecutionException("Executor shutdown");
}
var builder = Thread.ofVirtual();
setExecutor(builder, executor);
builder.start(command);
}

@Override
public boolean isShutdown() {
return executor.isShutdown();
}

@Override
public boolean isTerminated() {
return !executor.isTerminated();
}

@Override
public void shutdown() {
if (!started.compareAndSet(true, false)) {
return;
}
executor.shutdown();
}

@Override
public List<Runnable> shutdownNow() {
if (!started.compareAndSet(true, false)) {
return List.of();
}
return executor.shutdownNow();
}
}

private static class VTB extends BTB {
private Executor executor;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,7 @@

import java.io.IOException;
import java.time.Duration;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.*;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
Expand Down Expand Up @@ -51,14 +48,15 @@ public SliceIterator(String label, SigningMember member, Collection<? extends Me
public SliceIterator(String label, SigningMember member, Collection<? extends Member> s,
CommonCommunications<Comm, ?> comm, ScheduledExecutorService scheduler) {
assert member != null && s != null && comm != null;
assert !s.stream().filter(Objects::nonNull).toList().isEmpty() : "All elements must be non-null: " + s;
this.label = label;
this.member = member;
this.slice = new CopyOnWriteArrayList<>(s);
this.comm = comm;
this.scheduler = scheduler;
Entropy.secureShuffle(this.slice);
this.currentIteration = slice.iterator();
log.debug("Slice for: <{}> is: {} on: {}", label, slice.stream().map(m -> m.getId()).toList(), member.getId());
log.debug("Slice for: <{}> is: {} on: {}", label, slice.stream().map(Member::getId).toList(), member.getId());
}

public <T> void iterate(BiFunction<Comm, Member, T> round, SlicePredicateHandler<T, Comm> handler,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package com.salesforce.apollo.archipelago;

import org.junit.jupiter.api.Test;

import java.util.ArrayDeque;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.*;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;

public class UnsafeExecutorsTest {
private static String carrierThreadName() {
var name = Thread.currentThread().toString();
var index = name.lastIndexOf('@');
if (index == -1) {
throw new AssertionError();
}
return name.substring(index + 1);
}

@Test
public void virtualThreadExecutorSingleThreadExecutor() throws InterruptedException {
var executor = Executors.newSingleThreadExecutor();
var virtualExecutor = UnsafeExecutors.virtualThreadExecutor(executor);
var carrierThreadNames = new CopyOnWriteArraySet<String>();
for (var i = 0; i < 10; i++) {
virtualExecutor.execute(() -> carrierThreadNames.add(carrierThreadName()));
}
executor.shutdown();
executor.awaitTermination(1, TimeUnit.DAYS);
assertEquals(1, carrierThreadNames.size());
}

@Test
void testVirtualThread() {
Queue<Runnable> executor = new ArrayDeque<>();
var virtualExecutor = UnsafeExecutors.virtualThreadExecutor(wrap(executor::add));

Lock lock = new ReentrantLock();
lock.lock();
virtualExecutor.execute(lock::lock);
assertEquals(1, executor.size(), "runnable for vthread has not been submitted");
executor.poll().run();
assertEquals(0, executor.size(), "vthread has not blocked");
lock.unlock();
assertEquals(1, executor.size(), "vthread is not schedulable");
executor.poll().run();
assertFalse(lock.tryLock(), "the virtual thread does not hold the lock");
}

private ExecutorService wrap(Executor ex) {
return new AbstractExecutorService() {

@Override
public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException {
return false;
}

@Override
public void execute(Runnable command) {
System.out.println("Yes!");
ex.execute(command);
}

@Override
public boolean isShutdown() {
return false;
}

@Override
public boolean isTerminated() {
return false;
}

@Override
public void shutdown() {

}

@Override
public List<Runnable> shutdownNow() {
return List.of();
}
};
}
}

0 comments on commit 0dcd865

Please sign in to comment.