Skip to content

Commit

Permalink
add api for additional server interceptors, api for fernet token eval…
Browse files Browse the repository at this point in the history
…uation, E2E fernet token testing
  • Loading branch information
Hellblazer committed Feb 18, 2024
1 parent 2802148 commit d448a95
Show file tree
Hide file tree
Showing 15 changed files with 396 additions and 128 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,16 @@
*/
public class ChRbcGossip {

private static final Logger log = LoggerFactory.getLogger(
private static final Logger log = LoggerFactory.getLogger(
ChRbcGossip.class);
private final CommonCommunications<Gossiper, GossiperService> comm;
private final Context<Member> context;
private final SigningMember member;
private final EtherealMetrics metrics;
private final Processor processor;
private final RingCommunications<Member, Gossiper> ring;
private final AtomicBoolean started = new AtomicBoolean();
private final AtomicBoolean started = new AtomicBoolean();
private final Terminal terminal = new Terminal();
private volatile ScheduledFuture<?> scheduled;

public ChRbcGossip(Context<Member> context, SigningMember member, Processor processor, Router communications,
Expand All @@ -64,7 +65,7 @@ public ChRbcGossip(Context<Member> context, SigningMember member, Processor proc
this.context = context;
this.member = member;
this.metrics = m;
comm = communications.create((Member) member, context.getId(), new Terminal(), getClass().getCanonicalName(),
comm = communications.create(member, context.getId(), terminal, getClass().getCanonicalName(),
r -> new GossiperServer(communications.getClientIdentityProvider(), metrics, r),
getCreate(metrics), Gossiper.getLocalLoopback(member));
ring = new RingCommunications<>(context, member, this.comm);
Expand All @@ -83,7 +84,7 @@ public void start(Duration duration) {
}
Duration initialDelay = duration.plusMillis(Entropy.nextBitsStreamLong(duration.toMillis()));
log.trace("Starting GossipService[{}] on: {}", context.getId(), member.getId());
comm.register(context.getId(), new Terminal());
comm.register(context.getId(), terminal);
var scheduler = Executors.newScheduledThreadPool(1, Thread.ofVirtual().factory());
scheduler.schedule(() -> Thread.ofVirtual().start(Utils.wrapped(() -> {
try {
Expand Down Expand Up @@ -121,11 +122,11 @@ private Update gossipRound(Gossiper link, int ring) {
try {
return link.gossip(processor.gossip(context.getId(), ring));
} catch (StatusRuntimeException e) {
log.debug("gossiping[{}] failed with: {} with {} ring: {} on {}", context.getId(), e.getMessage(),
log.debug("gossiping[{}] failed: {} with: {} with {} ring: {} on {}", context.getId(), e.getMessage(),
member.getId(), ring, link.getMember().getId(), member.getId(), e);
return null;
} catch (Throwable e) {
log.warn("gossiping[{}] failed from {} with {} ring: {} on {}", context.getId(), member.getId(), ring,
log.warn("gossiping[{}] failed: {} from {} with {} ring: {} on {}", context.getId(), member.getId(), ring,
link.getMember().getId(), ring, member.getId(), e);
return null;
}
Expand Down Expand Up @@ -168,7 +169,7 @@ private void handle(Optional<Update> result, RingCommunications.Destination<Memb
.setUpdate(processor.update(update))
.build());
} catch (StatusRuntimeException e) {
log.debug("gossiping[{}] failed with: {} with {} ring: {} on {}", context.getId(), e.getMessage(),
log.debug("gossiping[{}] failed: {} with: {} with {} ring: {} on {}", context.getId(), e.getMessage(),
member.getId(), ring, destination.member().getId(), member.getId(), e);
}
} finally {
Expand All @@ -191,7 +192,7 @@ private void oneRound(Duration duration, ScheduledExecutorService scheduler) {
return;
}
var timer = metrics == null ? null : metrics.gossipRoundDuration().time();
ring.execute((link, ring) -> gossipRound(link, ring),
ring.execute(this::gossipRound,
(result, destination) -> handle(result, destination, duration, scheduler, timer));
}

Expand All @@ -204,7 +205,8 @@ public Update gossip(Gossip request, Digest from) {
Member predecessor = context.ring(request.getRing()).predecessor(member);
if (predecessor == null || !from.equals(predecessor.getId())) {
log.debug("Invalid inbound gossip on {}:{} from: {} on ring: {} - not predecessor: {}", context.getId(),
member, from, request.getRing(), predecessor.getId());
member.getId(), from, request.getRing(),
predecessor == null ? "<null>" : predecessor.getId());
return Update.getDefaultInstance();
}
final var update = processor.gossip(request);
Expand All @@ -218,10 +220,11 @@ public void update(ContextUpdate request, Digest from) {
Member predecessor = context.ring(request.getRing()).predecessor(member);
if (predecessor == null || !from.equals(predecessor.getId())) {
log.debug("Invalid inbound update on {}:{} from: {} on ring: {} - not predecessor: {}", context.getId(),
member.getId(), from, request.getRing(), predecessor.getId());
member.getId(), from, request.getRing(),
predecessor == null ? "<null>" : predecessor.getId());
return;
}
log.trace("gossip update with {} on: {}", from, member);
log.trace("gossip update with {} on: {}", from, member.getId());
processor.updateFrom(request.getUpdate());
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.List;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.function.Consumer;
Expand Down Expand Up @@ -72,7 +73,7 @@ public DomainSocketAddress getEndpoint() {

@Override
public RouterImpl router(ServerConnectionCache.Builder cacheBuilder, Supplier<Limit> serverLimit,
LimitsRegistry limitsRegistry) {
LimitsRegistry limitsRegistry, List<ServerInterceptor> interceptors) {
var limitsBuilder = new GrpcServerLimiterBuilder().limit(serverLimit.get());
if (limitsRegistry != null) {
limitsBuilder.metricRegistry(limitsRegistry);
Expand All @@ -91,6 +92,9 @@ public RouterImpl router(ServerConnectionCache.Builder cacheBuilder, Supplier<Li
"Enclave server concurrency limit reached"))
.build())
.intercept(serverInterceptor());
interceptors.forEach(i -> {
serverBuilder.intercept(i);
});
return new RouterImpl(from, serverBuilder, cacheBuilder.setFactory(t -> connectTo(t)),
new RoutingClientIdentity() {
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.slf4j.LoggerFactory;

import java.lang.reflect.Method;
import java.util.List;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.function.Supplier;
Expand Down Expand Up @@ -67,7 +68,7 @@ public Member getFrom() {

@Override
public RouterImpl router(ServerConnectionCache.Builder cacheBuilder, Supplier<Limit> serverLimit,
LimitsRegistry limitsRegistry) {
LimitsRegistry limitsRegistry, List<ServerInterceptor> interceptors) {
String name = String.format(NAME_TEMPLATE, prefix, qb64(from.getId()));
var limitsBuilder = new GrpcServerLimiterBuilder().limit(serverLimit.get());
if (limitsRegistry != null) {
Expand All @@ -82,6 +83,9 @@ public RouterImpl router(ServerConnectionCache.Builder cacheBuilder, Supplier<Li
"Server concurrency limit reached"))
.build())
.intercept(serverInterceptor());
interceptors.forEach(i -> {
serverBuilder.intercept(i);
});
return new RouterImpl(from, serverBuilder, cacheBuilder.setFactory(t -> connectTo(t)), new ClientIdentity() {
@Override
public Digest getFrom() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import java.security.Provider;
import java.security.Security;
import java.security.cert.X509Certificate;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
Expand Down Expand Up @@ -134,7 +135,7 @@ public static SslContext forServer(ClientAuth clientAuth, String alias, X509Cert

@Override
public RouterImpl router(ServerConnectionCache.Builder cacheBuilder, Supplier<Limit> serverLimit,
LimitsRegistry limitsRegistry) {
LimitsRegistry limitsRegistry, List<ServerInterceptor> interceptors) {
var limitsBuilder = new GrpcServerLimiterBuilder().limit(serverLimit.get());
if (limitsRegistry != null) {
limitsBuilder.metricRegistry(limitsRegistry);
Expand All @@ -149,6 +150,9 @@ public RouterImpl router(ServerConnectionCache.Builder cacheBuilder, Supplier<Li
.withChildOption(ChannelOption.TCP_NODELAY, true)
.intercept(new TlsInterceptor(sslSessionContext))
.intercept(EnableCompressionInterceptor.SINGLETON);
interceptors.forEach(i -> {
serverBuilder.intercept(i);
});
ClientIdentity identity = new ClientIdentity() {

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
*/
package com.salesforce.apollo.archipelago;

import com.macasaet.fernet.Token;
import com.salesforce.apollo.archipelago.server.FernetServerInterceptor;
import com.salesforce.apollo.cryptography.Digest;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
Expand All @@ -15,6 +17,7 @@

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.BiConsumer;
import java.util.function.Consumer;

import static com.salesforce.apollo.archipelago.Constants.SERVER_CONTEXT_KEY;
Expand All @@ -25,8 +28,9 @@
* @author hal.hildebrand
*/
public class RoutableService<Service> {
private static final Logger log = LoggerFactory.getLogger(RoutableService.class);
private final Map<Digest, Service> services = new ConcurrentHashMap<>();
private static final Logger log = LoggerFactory.getLogger(RoutableService.class);

private final Map<Digest, Service> services = new ConcurrentHashMap<>();

public void bind(Digest context, Service service) {
services.put(context, service);
Expand All @@ -53,6 +57,27 @@ public void evaluate(StreamObserver<?> responseObserver, Consumer<Service> c) {
}
}

public void evaluate(StreamObserver<?> responseObserver, BiConsumer<Service, Token> c) {
var context = SERVER_CONTEXT_KEY.get();
if (context == null) {
responseObserver.onError(new StatusRuntimeException(Status.NOT_FOUND));
log.error("Null context");
} else {
Service service = services.get(context);
if (service == null) {
log.trace("No service for context {}", context);
responseObserver.onError(new StatusRuntimeException(Status.NOT_FOUND));
} else {
try {
c.accept(service, FernetServerInterceptor.AccessTokenContextKey.get());
} catch (Throwable t) {
log.error("Uncaught exception in service evaluation for context: {}", context, t);
responseObserver.onError(t);
}
}
}
}

public void unbind(Digest context) {
services.remove(context);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@

import com.netflix.concurrency.limits.Limit;
import com.salesforce.apollo.protocols.LimitsRegistry;
import io.grpc.ServerInterceptor;

import java.util.Collections;
import java.util.List;
import java.util.function.Supplier;

/**
Expand All @@ -17,14 +20,19 @@
public interface RouterSupplier {

default Router router() {
return router(ServerConnectionCache.newBuilder(), () -> RouterImpl.defaultServerLimit(), null);
return router(ServerConnectionCache.newBuilder(), RouterImpl::defaultServerLimit, null);
}

default Router router(ServerConnectionCache.Builder cacheBuilder) {
return router(cacheBuilder, () -> RouterImpl.defaultServerLimit(), null);
return router(cacheBuilder, RouterImpl::defaultServerLimit, null);
}

default Router router(ServerConnectionCache.Builder cacheBuilder, Supplier<Limit> serverLimit,
LimitsRegistry limitsRegistry) {
return router(cacheBuilder, serverLimit, limitsRegistry, Collections.emptyList());
}

Router router(ServerConnectionCache.Builder cacheBuilder, Supplier<Limit> serverLimit,
LimitsRegistry limitsRegistry);
LimitsRegistry limitsRegistry, List<ServerInterceptor> interceptors);

}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.macasaet.fernet.Token;
import com.salesforce.apollo.archipelago.Constants;
import io.grpc.CallCredentials;
import io.grpc.Metadata;
import io.grpc.Status;
import org.slf4j.Logger;
Expand All @@ -11,7 +12,7 @@

import static com.salesforce.apollo.archipelago.server.FernetServerInterceptor.AUTH_HEADER_PREFIX;

public abstract class FernetCallCredentials extends io.grpc.CallCredentials {
public abstract class FernetCallCredentials extends CallCredentials {
private static final Logger LOGGER = LoggerFactory.getLogger(FernetCallCredentials.class);

public static FernetCallCredentials synchronous(SynchronousTokenProvider tokenProvider) {
Expand Down
Loading

0 comments on commit d448a95

Please sign in to comment.