diff --git a/ethereal/src/main/java/com/salesforce/apollo/ethereal/memberships/ChRbcGossip.java b/ethereal/src/main/java/com/salesforce/apollo/ethereal/memberships/ChRbcGossip.java index effcb308f..2c1b6af62 100644 --- a/ethereal/src/main/java/com/salesforce/apollo/ethereal/memberships/ChRbcGossip.java +++ b/ethereal/src/main/java/com/salesforce/apollo/ethereal/memberships/ChRbcGossip.java @@ -47,7 +47,7 @@ */ public class ChRbcGossip { - private static final Logger log = LoggerFactory.getLogger( + private static final Logger log = LoggerFactory.getLogger( ChRbcGossip.class); private final CommonCommunications comm; private final Context context; @@ -55,7 +55,8 @@ public class ChRbcGossip { private final EtherealMetrics metrics; private final Processor processor; private final RingCommunications ring; - private final AtomicBoolean started = new AtomicBoolean(); + private final AtomicBoolean started = new AtomicBoolean(); + private final Terminal terminal = new Terminal(); private volatile ScheduledFuture scheduled; public ChRbcGossip(Context context, SigningMember member, Processor processor, Router communications, @@ -64,7 +65,7 @@ public ChRbcGossip(Context context, SigningMember member, Processor proc this.context = context; this.member = member; this.metrics = m; - comm = communications.create((Member) member, context.getId(), new Terminal(), getClass().getCanonicalName(), + comm = communications.create(member, context.getId(), terminal, getClass().getCanonicalName(), r -> new GossiperServer(communications.getClientIdentityProvider(), metrics, r), getCreate(metrics), Gossiper.getLocalLoopback(member)); ring = new RingCommunications<>(context, member, this.comm); @@ -83,7 +84,7 @@ public void start(Duration duration) { } Duration initialDelay = duration.plusMillis(Entropy.nextBitsStreamLong(duration.toMillis())); log.trace("Starting GossipService[{}] on: {}", context.getId(), member.getId()); - comm.register(context.getId(), new Terminal()); + comm.register(context.getId(), terminal); var scheduler = Executors.newScheduledThreadPool(1, Thread.ofVirtual().factory()); scheduler.schedule(() -> Thread.ofVirtual().start(Utils.wrapped(() -> { try { @@ -121,11 +122,11 @@ private Update gossipRound(Gossiper link, int ring) { try { return link.gossip(processor.gossip(context.getId(), ring)); } catch (StatusRuntimeException e) { - log.debug("gossiping[{}] failed with: {} with {} ring: {} on {}", context.getId(), e.getMessage(), + log.debug("gossiping[{}] failed: {} with: {} with {} ring: {} on {}", context.getId(), e.getMessage(), member.getId(), ring, link.getMember().getId(), member.getId(), e); return null; } catch (Throwable e) { - log.warn("gossiping[{}] failed from {} with {} ring: {} on {}", context.getId(), member.getId(), ring, + log.warn("gossiping[{}] failed: {} from {} with {} ring: {} on {}", context.getId(), member.getId(), ring, link.getMember().getId(), ring, member.getId(), e); return null; } @@ -168,7 +169,7 @@ private void handle(Optional result, RingCommunications.Destination gossipRound(link, ring), + ring.execute(this::gossipRound, (result, destination) -> handle(result, destination, duration, scheduler, timer)); } @@ -204,7 +205,8 @@ public Update gossip(Gossip request, Digest from) { Member predecessor = context.ring(request.getRing()).predecessor(member); if (predecessor == null || !from.equals(predecessor.getId())) { log.debug("Invalid inbound gossip on {}:{} from: {} on ring: {} - not predecessor: {}", context.getId(), - member, from, request.getRing(), predecessor.getId()); + member.getId(), from, request.getRing(), + predecessor == null ? "" : predecessor.getId()); return Update.getDefaultInstance(); } final var update = processor.gossip(request); @@ -218,10 +220,11 @@ public void update(ContextUpdate request, Digest from) { Member predecessor = context.ring(request.getRing()).predecessor(member); if (predecessor == null || !from.equals(predecessor.getId())) { log.debug("Invalid inbound update on {}:{} from: {} on ring: {} - not predecessor: {}", context.getId(), - member.getId(), from, request.getRing(), predecessor.getId()); + member.getId(), from, request.getRing(), + predecessor == null ? "" : predecessor.getId()); return; } - log.trace("gossip update with {} on: {}", from, member); + log.trace("gossip update with {} on: {}", from, member.getId()); processor.updateFrom(request.getUpdate()); } } diff --git a/memberships/src/main/java/com/salesforce/apollo/archipelago/Enclave.java b/memberships/src/main/java/com/salesforce/apollo/archipelago/Enclave.java index 40cfb8938..3e735b878 100644 --- a/memberships/src/main/java/com/salesforce/apollo/archipelago/Enclave.java +++ b/memberships/src/main/java/com/salesforce/apollo/archipelago/Enclave.java @@ -24,6 +24,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.List; import java.util.concurrent.Executor; import java.util.concurrent.Executors; import java.util.function.Consumer; @@ -72,7 +73,7 @@ public DomainSocketAddress getEndpoint() { @Override public RouterImpl router(ServerConnectionCache.Builder cacheBuilder, Supplier serverLimit, - LimitsRegistry limitsRegistry) { + LimitsRegistry limitsRegistry, List interceptors) { var limitsBuilder = new GrpcServerLimiterBuilder().limit(serverLimit.get()); if (limitsRegistry != null) { limitsBuilder.metricRegistry(limitsRegistry); @@ -91,6 +92,9 @@ public RouterImpl router(ServerConnectionCache.Builder cacheBuilder, Supplier
  • { + serverBuilder.intercept(i); + }); return new RouterImpl(from, serverBuilder, cacheBuilder.setFactory(t -> connectTo(t)), new RoutingClientIdentity() { @Override diff --git a/memberships/src/main/java/com/salesforce/apollo/archipelago/LocalServer.java b/memberships/src/main/java/com/salesforce/apollo/archipelago/LocalServer.java index 269aa684a..f62c01129 100644 --- a/memberships/src/main/java/com/salesforce/apollo/archipelago/LocalServer.java +++ b/memberships/src/main/java/com/salesforce/apollo/archipelago/LocalServer.java @@ -23,6 +23,7 @@ import org.slf4j.LoggerFactory; import java.lang.reflect.Method; +import java.util.List; import java.util.concurrent.Executor; import java.util.concurrent.Executors; import java.util.function.Supplier; @@ -67,7 +68,7 @@ public Member getFrom() { @Override public RouterImpl router(ServerConnectionCache.Builder cacheBuilder, Supplier serverLimit, - LimitsRegistry limitsRegistry) { + LimitsRegistry limitsRegistry, List interceptors) { String name = String.format(NAME_TEMPLATE, prefix, qb64(from.getId())); var limitsBuilder = new GrpcServerLimiterBuilder().limit(serverLimit.get()); if (limitsRegistry != null) { @@ -82,6 +83,9 @@ public RouterImpl router(ServerConnectionCache.Builder cacheBuilder, Supplier
  • { + serverBuilder.intercept(i); + }); return new RouterImpl(from, serverBuilder, cacheBuilder.setFactory(t -> connectTo(t)), new ClientIdentity() { @Override public Digest getFrom() { diff --git a/memberships/src/main/java/com/salesforce/apollo/archipelago/MtlsServer.java b/memberships/src/main/java/com/salesforce/apollo/archipelago/MtlsServer.java index 24b9a227b..98281d809 100644 --- a/memberships/src/main/java/com/salesforce/apollo/archipelago/MtlsServer.java +++ b/memberships/src/main/java/com/salesforce/apollo/archipelago/MtlsServer.java @@ -38,6 +38,7 @@ import java.security.Provider; import java.security.Security; import java.security.cert.X509Certificate; +import java.util.List; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; import java.util.concurrent.Executors; @@ -134,7 +135,7 @@ public static SslContext forServer(ClientAuth clientAuth, String alias, X509Cert @Override public RouterImpl router(ServerConnectionCache.Builder cacheBuilder, Supplier serverLimit, - LimitsRegistry limitsRegistry) { + LimitsRegistry limitsRegistry, List interceptors) { var limitsBuilder = new GrpcServerLimiterBuilder().limit(serverLimit.get()); if (limitsRegistry != null) { limitsBuilder.metricRegistry(limitsRegistry); @@ -149,6 +150,9 @@ public RouterImpl router(ServerConnectionCache.Builder cacheBuilder, Supplier
  • { + serverBuilder.intercept(i); + }); ClientIdentity identity = new ClientIdentity() { @Override diff --git a/memberships/src/main/java/com/salesforce/apollo/archipelago/RoutableService.java b/memberships/src/main/java/com/salesforce/apollo/archipelago/RoutableService.java index bf1bf4af8..5ce691e4d 100644 --- a/memberships/src/main/java/com/salesforce/apollo/archipelago/RoutableService.java +++ b/memberships/src/main/java/com/salesforce/apollo/archipelago/RoutableService.java @@ -6,6 +6,8 @@ */ package com.salesforce.apollo.archipelago; +import com.macasaet.fernet.Token; +import com.salesforce.apollo.archipelago.server.FernetServerInterceptor; import com.salesforce.apollo.cryptography.Digest; import io.grpc.Status; import io.grpc.StatusRuntimeException; @@ -15,6 +17,7 @@ import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import java.util.function.BiConsumer; import java.util.function.Consumer; import static com.salesforce.apollo.archipelago.Constants.SERVER_CONTEXT_KEY; @@ -25,8 +28,9 @@ * @author hal.hildebrand */ public class RoutableService { - private static final Logger log = LoggerFactory.getLogger(RoutableService.class); - private final Map services = new ConcurrentHashMap<>(); + private static final Logger log = LoggerFactory.getLogger(RoutableService.class); + + private final Map services = new ConcurrentHashMap<>(); public void bind(Digest context, Service service) { services.put(context, service); @@ -53,6 +57,27 @@ public void evaluate(StreamObserver responseObserver, Consumer c) { } } + public void evaluate(StreamObserver responseObserver, BiConsumer c) { + var context = SERVER_CONTEXT_KEY.get(); + if (context == null) { + responseObserver.onError(new StatusRuntimeException(Status.NOT_FOUND)); + log.error("Null context"); + } else { + Service service = services.get(context); + if (service == null) { + log.trace("No service for context {}", context); + responseObserver.onError(new StatusRuntimeException(Status.NOT_FOUND)); + } else { + try { + c.accept(service, FernetServerInterceptor.AccessTokenContextKey.get()); + } catch (Throwable t) { + log.error("Uncaught exception in service evaluation for context: {}", context, t); + responseObserver.onError(t); + } + } + } + } + public void unbind(Digest context) { services.remove(context); } diff --git a/memberships/src/main/java/com/salesforce/apollo/archipelago/RouterSupplier.java b/memberships/src/main/java/com/salesforce/apollo/archipelago/RouterSupplier.java index f69c55110..b63921bee 100644 --- a/memberships/src/main/java/com/salesforce/apollo/archipelago/RouterSupplier.java +++ b/memberships/src/main/java/com/salesforce/apollo/archipelago/RouterSupplier.java @@ -8,7 +8,10 @@ import com.netflix.concurrency.limits.Limit; import com.salesforce.apollo.protocols.LimitsRegistry; +import io.grpc.ServerInterceptor; +import java.util.Collections; +import java.util.List; import java.util.function.Supplier; /** @@ -17,14 +20,19 @@ public interface RouterSupplier { default Router router() { - return router(ServerConnectionCache.newBuilder(), () -> RouterImpl.defaultServerLimit(), null); + return router(ServerConnectionCache.newBuilder(), RouterImpl::defaultServerLimit, null); } default Router router(ServerConnectionCache.Builder cacheBuilder) { - return router(cacheBuilder, () -> RouterImpl.defaultServerLimit(), null); + return router(cacheBuilder, RouterImpl::defaultServerLimit, null); + } + + default Router router(ServerConnectionCache.Builder cacheBuilder, Supplier serverLimit, + LimitsRegistry limitsRegistry) { + return router(cacheBuilder, serverLimit, limitsRegistry, Collections.emptyList()); } Router router(ServerConnectionCache.Builder cacheBuilder, Supplier serverLimit, - LimitsRegistry limitsRegistry); + LimitsRegistry limitsRegistry, List interceptors); } diff --git a/memberships/src/main/java/com/salesforce/apollo/archipelago/client/FernetCallCredentials.java b/memberships/src/main/java/com/salesforce/apollo/archipelago/client/FernetCallCredentials.java index 8f51b4f62..e6b7ffb42 100644 --- a/memberships/src/main/java/com/salesforce/apollo/archipelago/client/FernetCallCredentials.java +++ b/memberships/src/main/java/com/salesforce/apollo/archipelago/client/FernetCallCredentials.java @@ -2,6 +2,7 @@ import com.macasaet.fernet.Token; import com.salesforce.apollo.archipelago.Constants; +import io.grpc.CallCredentials; import io.grpc.Metadata; import io.grpc.Status; import org.slf4j.Logger; @@ -11,7 +12,7 @@ import static com.salesforce.apollo.archipelago.server.FernetServerInterceptor.AUTH_HEADER_PREFIX; -public abstract class FernetCallCredentials extends io.grpc.CallCredentials { +public abstract class FernetCallCredentials extends CallCredentials { private static final Logger LOGGER = LoggerFactory.getLogger(FernetCallCredentials.class); public static FernetCallCredentials synchronous(SynchronousTokenProvider tokenProvider) { diff --git a/memberships/src/main/java/com/salesforce/apollo/archipelago/server/DelayedServerCallListener.java b/memberships/src/main/java/com/salesforce/apollo/archipelago/server/DelayedServerCallListener.java index 1d9d69516..5e73753ee 100644 --- a/memberships/src/main/java/com/salesforce/apollo/archipelago/server/DelayedServerCallListener.java +++ b/memberships/src/main/java/com/salesforce/apollo/archipelago/server/DelayedServerCallListener.java @@ -4,62 +4,95 @@ import java.util.ArrayList; import java.util.List; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; // https://stackoverflow.com/a/53656689/181796 class DelayedServerCallListener extends ServerCall.Listener { - private ServerCall.Listener delegate; - private List events = new ArrayList<>(); + private final Lock lock = new ReentrantLock(); + private ServerCall.Listener delegate; + private List events = new ArrayList<>(); @Override - public synchronized void onCancel() { - if (delegate == null) { - events.add(() -> delegate.onCancel()); - } else { - delegate.onCancel(); + public void onCancel() { + lock.lock(); + try { + if (delegate == null) { + events.add(() -> delegate.onCancel()); + } else { + delegate.onCancel(); + } + } finally { + lock.unlock(); } } @Override - public synchronized void onComplete() { - if (delegate == null) { - events.add(() -> delegate.onComplete()); - } else { - delegate.onComplete(); + public void onComplete() { + lock.lock(); + try { + if (delegate == null) { + events.add(() -> delegate.onComplete()); + } else { + delegate.onComplete(); + } + } finally { + lock.unlock(); } } @Override - public synchronized void onHalfClose() { - if (delegate == null) { - events.add(() -> delegate.onHalfClose()); - } else { - delegate.onHalfClose(); + public void onHalfClose() { + lock.lock(); + try { + if (delegate == null) { + events.add(() -> delegate.onHalfClose()); + } else { + delegate.onHalfClose(); + } + } finally { + lock.unlock(); } } @Override - public synchronized void onMessage(ReqT message) { - if (delegate == null) { - events.add(() -> delegate.onMessage(message)); - } else { - delegate.onMessage(message); + public void onMessage(ReqT message) { + lock.lock(); + try { + if (delegate == null) { + events.add(() -> delegate.onMessage(message)); + } else { + delegate.onMessage(message); + } + } finally { + lock.unlock(); } } @Override - public synchronized void onReady() { - if (delegate == null) { - events.add(() -> delegate.onReady()); - } else { - delegate.onReady(); + public void onReady() { + lock.lock(); + try { + if (delegate == null) { + events.add(() -> delegate.onReady()); + } else { + delegate.onReady(); + } + } finally { + lock.unlock(); } } - public synchronized void setDelegate(ServerCall.Listener delegate) { - this.delegate = delegate; - for (Runnable runnable : events) { - runnable.run(); + public void setDelegate(ServerCall.Listener delegate) { + lock.lock(); + try { + this.delegate = delegate; + for (Runnable runnable : events) { + runnable.run(); + } + events = null; + } finally { + lock.unlock(); } - events = null; } } diff --git a/memberships/src/main/java/com/salesforce/apollo/archipelago/server/FernetParser.java b/memberships/src/main/java/com/salesforce/apollo/archipelago/server/FernetParser.java deleted file mode 100644 index 04c8f1e88..000000000 --- a/memberships/src/main/java/com/salesforce/apollo/archipelago/server/FernetParser.java +++ /dev/null @@ -1,11 +0,0 @@ -package com.salesforce.apollo.archipelago.server; - -import com.macasaet.fernet.Token; - -import java.util.concurrent.CompletableFuture; - -@FunctionalInterface -public interface FernetParser { - - CompletableFuture parseToValid(String token); -} diff --git a/memberships/src/main/java/com/salesforce/apollo/archipelago/server/FernetServerInterceptor.java b/memberships/src/main/java/com/salesforce/apollo/archipelago/server/FernetServerInterceptor.java index 9ac84e898..66fc630f3 100644 --- a/memberships/src/main/java/com/salesforce/apollo/archipelago/server/FernetServerInterceptor.java +++ b/memberships/src/main/java/com/salesforce/apollo/archipelago/server/FernetServerInterceptor.java @@ -6,16 +6,14 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; + public class FernetServerInterceptor implements ServerInterceptor { public static final String AUTH_HEADER_PREFIX = "Bearer "; + public static final Context.Key AccessTokenContextKey = Context.key("AccessToken"); private static final Logger LOGGER = LoggerFactory.getLogger( FernetServerInterceptor.class); - public final Context.Key AccessTokenContextKey = Context.key("AccessToken"); - private final FernetParser tokenParser; - - public FernetServerInterceptor(FernetParser tokenParser) { - this.tokenParser = tokenParser; - } @Override public ServerCall.Listener interceptCall(ServerCall call, Metadata headers, @@ -40,7 +38,8 @@ public ServerCall.Listener interceptCall(ServerCall context.run(() -> { + + deserialize(serialized).whenComplete((token, e) -> context.run(() -> { if (e == null) { delayedListener.setDelegate( Contexts.interceptCall(Context.current().withValue(AccessTokenContextKey, token), call, headers, @@ -55,6 +54,19 @@ public ServerCall.Listener interceptCall(ServerCall deserialize(String serialized) { + if (serialized.equals("Invalid Token")) { + CompletableFuture res = new CompletableFuture<>(); + res.completeExceptionally(new RuntimeException("invalid token")); + return res; + } + try { + return CompletableFuture.completedFuture(Token.fromString(serialized)); + } catch (Throwable t) { + return CompletableFuture.failedFuture(t); + } + } + private ServerCall.Listener handleException(Throwable e, ServerCall call) { String msg = Constants.AuthorizationMetadataKey.name() + " header validation failed: " + e.getMessage(); LOGGER.warn(msg, e); diff --git a/memberships/src/test/java/com/salesforce/apollo/archipelago/FernetTest.java b/memberships/src/test/java/com/salesforce/apollo/archipelago/FernetTest.java new file mode 100644 index 000000000..9857baa06 --- /dev/null +++ b/memberships/src/test/java/com/salesforce/apollo/archipelago/FernetTest.java @@ -0,0 +1,183 @@ +package com.salesforce.apollo.archipelago; + +import com.google.protobuf.Any; +import com.google.protobuf.ByteString; +import com.macasaet.fernet.Key; +import com.macasaet.fernet.Token; +import com.salesforce.apollo.archipelago.client.FernetCallCredentials; +import com.salesforce.apollo.archipelago.server.FernetServerInterceptor; +import com.salesforce.apollo.cryptography.DigestAlgorithm; +import com.salesforce.apollo.membership.Member; +import com.salesforce.apollo.membership.impl.SigningMemberImpl; +import com.salesforce.apollo.test.proto.ByteMessage; +import com.salesforce.apollo.test.proto.TestItGrpc; +import com.salesforce.apollo.utils.Utils; +import io.grpc.CallCredentials; +import io.grpc.stub.StreamObserver; +import org.joou.ULong; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.security.SecureRandom; +import java.time.Duration; +import java.util.Collections; +import java.util.UUID; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +/** + * Test Token Auth via Fernet Tokens + * + * @author hal.hildebrand + **/ +public class FernetTest { + private final TestItService local = new TestItService() { + + @Override + public void close() throws IOException { + } + + @Override + public Member getMember() { + return null; + } + + @Override + public Any ping(Any request) { + return null; + } + }; + private Token token; + + @BeforeEach + public void before() { + final SecureRandom deterministicRandom = new SecureRandom() { + private static final long serialVersionUID = 3075400891983079965L; + + public void nextBytes(final byte[] bytes) { + for (int i = bytes.length; --i >= 0; bytes[i] = 1) + ; + } + }; + final Key key = Key.generateKey(deterministicRandom); + + token = Token.generate(deterministicRandom, key, "Hello, world!"); + } + + @Test + public void smokin() throws Exception { + final var credentials = FernetCallCredentials.blocking(() -> token); + final var memberA = new SigningMemberImpl(Utils.getMember(0), ULong.MIN); + final var memberB = new SigningMemberImpl(Utils.getMember(1), ULong.MIN); + final var ctxA = DigestAlgorithm.DEFAULT.getOrigin().prefix(0x666); + final var prefix = UUID.randomUUID().toString(); + + RouterSupplier serverA = new LocalServer(prefix, memberA); + var routerA = serverA.router(ServerConnectionCache.newBuilder(), () -> RouterImpl.defaultServerLimit(), null, + Collections.singletonList(new FernetServerInterceptor())); + + RouterImpl.CommonCommunications commsA = routerA.create(memberA, ctxA, new ServerA(), + "A", r -> new Server(r), + c -> new TestItClient(c, + credentials), + local); + + RouterSupplier serverB = new LocalServer(prefix, memberB); + var routerB = serverB.router(ServerConnectionCache.newBuilder(), () -> RouterImpl.defaultServerLimit(), null, + Collections.singletonList(new FernetServerInterceptor())); + + RouterImpl.CommonCommunications commsA_B = routerB.create(memberB, ctxA, new ServerB(), + "B", r -> new Server(r), + c -> new TestItClient(c, + credentials), + local); + + routerA.start(); + routerB.start(); + + var clientA = commsA.connect(memberB); + + var resultA = clientA.ping(Any.getDefaultInstance()); + assertNotNull(resultA); + assertEquals("Hello Server B", resultA.unpack(ByteMessage.class).getContents().toStringUtf8()); + + var clientB = commsA_B.connect(memberA); + var resultB = clientB.ping(Any.getDefaultInstance()); + assertNotNull(resultB); + assertEquals("Hello Server A", resultB.unpack(ByteMessage.class).getContents().toStringUtf8()); + + routerA.close(Duration.ofSeconds(1)); + routerB.close(Duration.ofSeconds(1)); + } + + public interface TestIt { + void ping(Any request, StreamObserver responseObserver); + } + + public interface TestItService extends Link { + Any ping(Any request); + } + + public static class Server extends TestItGrpc.TestItImplBase { + private final RoutableService router; + + public Server(RoutableService router) { + this.router = router; + } + + @Override + public void ping(Any request, StreamObserver responseObserver) { + router.evaluate(responseObserver, (t, token) -> { + assertNotNull(token); + t.ping(request, responseObserver); + }); + } + } + + public static class TestItClient implements TestItService { + private final TestItGrpc.TestItBlockingStub client; + private final ManagedServerChannel connection; + private final CallCredentials credentials; + + public TestItClient(ManagedServerChannel c, CallCredentials credentials) { + this.connection = c; + this.credentials = credentials; + client = TestItGrpc.newBlockingStub(c).withCallCredentials(credentials); + } + + @Override + public void close() throws IOException { + connection.release(); + } + + @Override + public Member getMember() { + return connection.getMember(); + } + + @Override + public Any ping(Any request) { + return client.ping(request); + } + } + + public class ServerA implements TestIt { + @Override + public void ping(Any request, StreamObserver responseObserver) { + responseObserver.onNext( + Any.pack(ByteMessage.newBuilder().setContents(ByteString.copyFromUtf8("Hello Server A")).build())); + responseObserver.onCompleted(); + } + } + + public class ServerB implements TestIt { + @Override + public void ping(Any request, StreamObserver responseObserver) { + responseObserver.onNext( + Any.pack(ByteMessage.newBuilder().setContents(ByteString.copyFromUtf8("Hello Server B")).build())); + responseObserver.onCompleted(); + } + } +} diff --git a/memberships/src/test/java/com/salesforce/apollo/archipelago/LocalServerTest.java b/memberships/src/test/java/com/salesforce/apollo/archipelago/LocalServerTest.java index c0264bc47..3287b2522 100644 --- a/memberships/src/test/java/com/salesforce/apollo/archipelago/LocalServerTest.java +++ b/memberships/src/test/java/com/salesforce/apollo/archipelago/LocalServerTest.java @@ -105,7 +105,7 @@ public Server(RoutableService router) { @Override public void ping(Any request, StreamObserver responseObserver) { - router.evaluate(responseObserver, t -> t.ping(request, responseObserver)); + router.evaluate(responseObserver, (t, token) -> t.ping(request, responseObserver)); } } diff --git a/memberships/src/test/java/com/salesforce/apollo/archipelago/server/ServerInterceptorTest.java b/memberships/src/test/java/com/salesforce/apollo/archipelago/server/ServerInterceptorTest.java index edc1399ba..db7d7f73e 100644 --- a/memberships/src/test/java/com/salesforce/apollo/archipelago/server/ServerInterceptorTest.java +++ b/memberships/src/test/java/com/salesforce/apollo/archipelago/server/ServerInterceptorTest.java @@ -10,7 +10,6 @@ import org.junit.jupiter.api.Test; import java.security.SecureRandom; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicReference; import static com.salesforce.apollo.archipelago.server.FernetServerInterceptor.AUTH_HEADER_PREFIX; @@ -19,21 +18,9 @@ public class ServerInterceptorTest { - private final FernetParser tokenParser = serialized -> { - if (serialized.equals("Invalid Token")) { - CompletableFuture res = new CompletableFuture<>(); - res.completeExceptionally(new RuntimeException("invalid token")); - return res; - } - try { - return CompletableFuture.completedFuture(Token.fromString(serialized)); - } catch (Throwable t) { - return CompletableFuture.failedFuture(t); - } - }; - private final FernetServerInterceptor target = new FernetServerInterceptor(tokenParser); - private final ServerCall serverCall = (ServerCall) mock(ServerCall.class); - private final ServerCallHandler next = (ServerCallHandler) mock( + private final FernetServerInterceptor target = new FernetServerInterceptor(); + private final ServerCall serverCall = (ServerCall) mock(ServerCall.class); + private final ServerCallHandler next = (ServerCallHandler) mock( ServerCallHandler.class); private Token token; @@ -58,7 +45,7 @@ public void callNextStageWithContextKeyOnValidHeader() { metadata.put(Constants.AuthorizationMetadataKey, AUTH_HEADER_PREFIX + token.serialise()); final AtomicReference actualToken = new AtomicReference<>(); when(next.startCall(any(), any())).thenAnswer(i -> { - actualToken.set(target.AccessTokenContextKey.get()); + actualToken.set(FernetServerInterceptor.AccessTokenContextKey.get()); return null; }); target.interceptCall(serverCall, metadata, next); diff --git a/memberships/src/test/java/com/salesforce/apollo/ring/RingIteratorTest.java b/memberships/src/test/java/com/salesforce/apollo/ring/RingIteratorTest.java index 9ce8d6bbb..812ca4fb0 100644 --- a/memberships/src/test/java/com/salesforce/apollo/ring/RingIteratorTest.java +++ b/memberships/src/test/java/com/salesforce/apollo/ring/RingIteratorTest.java @@ -77,27 +77,34 @@ public Any ping(Any request) { var cacheBuilder = ServerConnectionCache.newBuilder() .setFactory(to -> InProcessChannelBuilder.forName(name).build()); var router = new RouterImpl(serverMember1, serverBuilder, cacheBuilder, null); - RouterImpl.CommonCommunications commsA = router.create(serverMember1, context.getId(), - new ServiceImpl(local1, "A"), "A", - ServerImpl::new, - TestItClient::new, local1); + try { + RouterImpl.CommonCommunications commsA = router.create(serverMember1, + context.getId(), + new ServiceImpl(local1, "A"), + "A", ServerImpl::new, + TestItClient::new, local1); - RouterImpl.CommonCommunications commsB = router.create(serverMember2, context.getId(), - new ServiceImpl(local2, "B"), "B", - ServerImpl::new, - TestItClient::new, local2); + RouterImpl.CommonCommunications commsB = router.create(serverMember2, + context.getId(), + new ServiceImpl(local2, "B"), + "B", ServerImpl::new, + TestItClient::new, local2); - router.start(); - var frequency = Duration.ofMillis(1); - var scheduler = Executors.newScheduledThreadPool(1, Thread.ofVirtual().factory()); - var sync = new RingIterator(frequency, context, serverMember1, scheduler, commsA); - var countdown = new CountDownLatch(3); - sync.iterate(context.getId(), (link, round) -> link.ping(Any.getDefaultInstance()), (round, result, link) -> { - countdown.countDown(); - return true; - }); - assertTrue(countdown.await(1, TimeUnit.SECONDS)); - assertFalse(pinged1.get()); - assertTrue(pinged2.get()); + router.start(); + var frequency = Duration.ofMillis(1); + var scheduler = Executors.newScheduledThreadPool(1, Thread.ofVirtual().factory()); + var sync = new RingIterator(frequency, context, serverMember1, scheduler, commsA); + var countdown = new CountDownLatch(3); + sync.iterate(context.getId(), (link, round) -> link.ping(Any.getDefaultInstance()), + (round, result, link) -> { + countdown.countDown(); + return true; + }); + assertTrue(countdown.await(1, TimeUnit.SECONDS)); + assertFalse(pinged1.get()); + assertTrue(pinged2.get()); + } finally { + router.close(Duration.ofSeconds(2)); + } } } diff --git a/memberships/src/test/java/com/salesforce/apollo/ring/SliceIteratorTest.java b/memberships/src/test/java/com/salesforce/apollo/ring/SliceIteratorTest.java index 7df890e83..b203b47e7 100644 --- a/memberships/src/test/java/com/salesforce/apollo/ring/SliceIteratorTest.java +++ b/memberships/src/test/java/com/salesforce/apollo/ring/SliceIteratorTest.java @@ -1,6 +1,7 @@ package com.salesforce.apollo.ring; import com.google.protobuf.Any; +import com.salesforce.apollo.archipelago.Router; import com.salesforce.apollo.archipelago.RouterImpl; import com.salesforce.apollo.archipelago.ServerConnectionCache; import com.salesforce.apollo.membership.Context; @@ -76,27 +77,34 @@ public Any ping(Any request) { var serverBuilder = InProcessServerBuilder.forName(name); var cacheBuilder = ServerConnectionCache.newBuilder() .setFactory(to -> InProcessChannelBuilder.forName(name).build()); - var router = new RouterImpl(serverMember1, serverBuilder, cacheBuilder, null); - RouterImpl.CommonCommunications commsA = router.create(serverMember1, context.getId(), - new ServiceImpl(local1, "A"), "A", - ServerImpl::new, - TestItClient::new, local1); + Router router = new RouterImpl(serverMember1, serverBuilder, cacheBuilder, null); + try { + RouterImpl.CommonCommunications commsA = router.create(serverMember1, + context.getId(), + new ServiceImpl(local1, "A"), + "A", ServerImpl::new, + TestItClient::new, local1); - RouterImpl.CommonCommunications commsB = router.create(serverMember2, context.getId(), - new ServiceImpl(local2, "B"), "B", - ServerImpl::new, - TestItClient::new, local2); + RouterImpl.CommonCommunications commsB = router.create(serverMember2, + context.getId(), + new ServiceImpl(local2, "B"), + "B", ServerImpl::new, + TestItClient::new, local2); - router.start(); - var slice = new SliceIterator("Test Me", serverMember1, - Arrays.asList(serverMember1, serverMember2), commsA); - var countdown = new CountDownLatch(1); - slice.iterate((link, member) -> link.ping(Any.getDefaultInstance()), (result, comms, member) -> true, () -> { - countdown.countDown(); - }, Executors.newScheduledThreadPool(1, Thread.ofVirtual().factory()), Duration.ofMillis(1)); - boolean finished = countdown.await(3, TimeUnit.SECONDS); - assertTrue(finished, "completed: " + countdown.getCount()); - assertTrue(pinged1.get()); - assertTrue(pinged2.get()); + router.start(); + var slice = new SliceIterator("Test Me", serverMember1, + Arrays.asList(serverMember1, serverMember2), commsA); + var countdown = new CountDownLatch(1); + slice.iterate((link, member) -> link.ping(Any.getDefaultInstance()), (result, comms, member) -> true, + () -> { + countdown.countDown(); + }, Executors.newScheduledThreadPool(1, Thread.ofVirtual().factory()), Duration.ofMillis(1)); + boolean finished = countdown.await(3, TimeUnit.SECONDS); + assertTrue(finished, "completed: " + countdown.getCount()); + assertTrue(pinged1.get()); + assertTrue(pinged2.get()); + } finally { + router.close(Duration.ofSeconds(2)); + } } }