From 2ca8e1c0f0ddc9c6e96ac920e2748b555f616fc3 Mon Sep 17 00:00:00 2001 From: Hellblazer Date: Sat, 22 Jun 2024 15:17:19 -0700 Subject: [PATCH] fixins. Reimplement ReservoirSampler. 5x5 --- .../salesforce/apollo/fireflies/Binding.java | 200 ++++++++++++------ .../com/salesforce/apollo/fireflies/View.java | 25 ++- .../apollo/fireflies/ViewManagement.java | 4 +- .../fireflies/comm/entrance/Entrance.java | 5 +- .../comm/entrance/EntranceClient.java | 31 ++- .../apollo/fireflies/ChurnTest.java | 2 +- fireflies/src/test/resources/logback-test.xml | 27 ++- .../apollo/context/DelegatedContext.java | 3 +- .../apollo/context/DynamicContextImpl.java | 6 +- .../apollo/context/StaticContext.java | 9 +- .../apollo/membership/ReservoirSampler.java | 62 +++--- 11 files changed, 238 insertions(+), 136 deletions(-) diff --git a/fireflies/src/main/java/com/salesforce/apollo/fireflies/Binding.java b/fireflies/src/main/java/com/salesforce/apollo/fireflies/Binding.java index 00965a98f..2bbe4fe53 100644 --- a/fireflies/src/main/java/com/salesforce/apollo/fireflies/Binding.java +++ b/fireflies/src/main/java/com/salesforce/apollo/fireflies/Binding.java @@ -9,6 +9,7 @@ import com.codahale.metrics.Timer; import com.google.common.collect.HashMultiset; import com.google.common.collect.Multiset; +import com.google.common.util.concurrent.ListenableFuture; import com.google.protobuf.ByteString; import com.salesforce.apollo.archipelago.RouterImpl.CommonCommunications; import com.salesforce.apollo.context.Context; @@ -35,9 +36,7 @@ import java.time.Duration; import java.util.*; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.Executors; -import java.util.concurrent.TimeUnit; +import java.util.concurrent.*; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiConsumer; @@ -75,6 +74,12 @@ public Binding(View view, List seeds, Duration duration, DynamicContext complete, AtomicInteger remaining) { + if (remaining.decrementAndGet() <= 0) { + complete.complete(false); + } + } + void seeding() { if (seeds.isEmpty()) {// This node is the bootstrap seed bootstrap(); @@ -135,33 +140,40 @@ private boolean complete(CompletableFuture redirect, Optional gateway, - Optional futureSailor, HashMultiset trusts, - Set initialSeedSet, Digest v, int majority) { - if (futureSailor.isEmpty()) { - log.warn("No gateway returned from: {} on: {}", member.getId(), node.getId()); - return true; + private void complete(Member member, CompletableFuture gateway, HashMultiset trusts, + Set initialSeedSet, Digest v, int majority, CompletableFuture complete, + AtomicInteger remaining, ListenableFuture futureSailor) { + if (complete.isDone()) { + return; } - if (gateway.isDone()) { - log.warn("gateway is complete, ignoring from: {} on: {}", member.getId(), node.getId()); - return false; + Gateway g = null; + try { + g = futureSailor.get(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } catch (ExecutionException e) { + log.warn("Error retrieving Gateway from: {} on: {}", member.getId(), node.getId(), e.getCause()); + dec(complete, remaining); + return; } - Gateway g = futureSailor.get(); - if (g.equals(Gateway.getDefaultInstance())) { - return true; + log.warn("Empty gateway returned from: {} on: {}", member.getId(), node.getId()); + dec(complete, remaining); + return; } if (g.getInitialSeedSetCount() == 0) { log.warn("No seeds in gateway returned from: {} on: {}", member.getId(), node.getId()); - return true; + dec(complete, remaining); + return; } if (g.getTrust().equals(BootstrapTrust.getDefaultInstance()) || g.getTrust() .getDiadem() .equals(HexBloome.getDefaultInstance())) { log.trace("Empty bootstrap trust in join returned from: {} on: {}", member.getId(), node.getId()); - return true; + dec(complete, remaining); + return; } trusts.add(new Bootstrapping(g.getTrust())); initialSeedSet.addAll(g.getInitialSeedSetList()); @@ -175,7 +187,13 @@ private boolean completeGateway(Participant member, CompletableFuture gat .findFirst() .orElse(null); if (trust != null) { - validate(trust, gateway, initialSeedSet); + var bound = new Bound(trust.crown, + trust.successors.stream().map(sn -> new NoteWrapper(sn, digestAlgo)).toList(), + initialSeedSet.stream().map(sn -> new NoteWrapper(sn, digestAlgo)).toList()); + if (gateway.complete(bound)) { + log.info("Gateway acquired: {} context: {} on: {}", trust.diadem, this.context.getId(), node.getId()); + } + complete.complete(true); } else { log.debug("Gateway received, trust count: {} majority: {} from: {} trusts: {} view: {} context: {} on: {}", trusts.size(), majority, member.getId(), v, trusts.entrySet() @@ -185,8 +203,8 @@ private boolean completeGateway(Participant member, CompletableFuture gat e -> "%s x %s".formatted(e.getElement().diadem, e.getCount())) .toList(), this.context.getId(), node.getId()); + dec(complete, remaining); } - return true; } private void gatewaySRE(Digest v, Entrance link, StatusRuntimeException sre, AtomicInteger abandon) { @@ -216,6 +234,31 @@ private void gatewaySRE(Digest v, Entrance link, StatusRuntimeException sre, Ato } } + private boolean join(Member member, CompletableFuture gateway, Optional> fs, + HashMultiset trusts, Set initialSeedSet, Digest v, int majority, + CompletableFuture complete, AtomicInteger remaining) { + if (complete.isDone()) { + log.trace("join round already completed for: {} on: {}", member.getId(), node.getId()); + return false; + } + if (fs.isEmpty()) { + log.warn("No gateway returned from: {} on: {}", member.getId(), node.getId()); + dec(complete, remaining); + return true; + } + if (gateway.isDone()) { + log.warn("gateway is complete, ignoring from: {} on: {}", member.getId(), node.getId()); + complete.complete(true); + return false; + } + var futureSailor = fs.get(); + futureSailor.addListener( + () -> complete(member, gateway, trusts, initialSeedSet, v, majority, complete, remaining, futureSailor), + r -> Thread.ofVirtual().start(r)); + + return true; + } + private Join join(Digest v) { return Join.newBuilder().setView(v.toDigeste()).setNote(node.getNote().getWrapped()).build(); } @@ -276,63 +319,88 @@ private void join(Redirect redirect, Digest v, Duration duration) { final var redirecting = new SliceIterator<>("Gateways", node, sample, approaches); var majority = redirect.getBootstrap() ? 1 : Context.minimalQuorum(redirect.getRings(), this.context.getBias()); final var join = join(v); - final var abandon = new AtomicInteger(); var scheduler = Executors.newScheduledThreadPool(1, Thread.ofVirtual().factory()); regate.set(() -> { + log.info("Round: {} formally joining view: {} on: {}", retries.get(), v, node.getId()); if (!view.started.get()) { return; } - redirecting.iterate((link) -> { - if (gateway.isDone() || !view.started.get()) { - return null; - } - log.debug("Joining: {} contacting: {} on: {}", v, link.getMember().getId(), node.getId()); - try { - var g = link.join(join, params.seedingTimeout()); - if (g == null || g.equals(Gateway.getDefaultInstance())) { - log.debug("Gateway view: {} empty from: {} on: {}", v, link.getMember().getId(), node.getId()); - abandon.incrementAndGet(); - return null; - } - return g; - } catch (StatusRuntimeException sre) { - gatewaySRE(v, link, sre, abandon); - return null; - } catch (Throwable t) { - log.info("Gateway view: {} error: {} from: {} on: {}", v, t, link.getMember().getId(), - node.getId()); - abandon.incrementAndGet(); - return null; + var complete = new CompletableFuture(); + final var abandon = new AtomicInteger(); + complete.whenComplete((success, error) -> { + if (error != null) { + log.info("Failed Join on: {}", node.getId(), error); + return; } - }, (futureSailor, _, _, member) -> completeGateway((Participant) member, gateway, futureSailor, trusts, - initialSeedSet, v, majority), () -> { - if (!view.started.get() || gateway.isDone()) { + if (success) { return; } - if (abandon.get() >= majority) { - log.debug("Abandoning Gateway view: {} abandons: {} majority: {} reseeding on: {}", v, - abandon.get(), majority, node.getId()); - seeding(); + log.info("Join unsuccessful, abandoned: {} trusts: {} on: {}", abandon.get(), trusts.entrySet() + .stream() + .sorted() + .map( + e -> "%s x %s".formatted( + e.getElement().diadem, + e.getCount())) + .toList(), + node.getId()); + abandon.set(0); + if (retries.get() < params.joinRetries()) { + log.info("Failed to join view: {} retry: {} out of: {} on: {}", v, retries.incrementAndGet(), + params.joinRetries(), node.getId()); + trusts.clear(); + initialSeedSet.clear(); + scheduler.schedule(() -> Thread.ofVirtual().start(Utils.wrapped(regate.get(), log)), + Entropy.nextBitsStreamLong(params.retryDelay().toNanos()), TimeUnit.NANOSECONDS); } else { - abandon.set(0); - if (retries.get() < params.joinRetries()) { - log.info("Failed to join view: {} retry: {} out of: {} on: {}", v, retries.incrementAndGet(), - params.joinRetries(), node.getId()); - trusts.clear(); - initialSeedSet.clear(); - scheduler.schedule(() -> Thread.ofVirtual().start(Utils.wrapped(regate.get(), log)), - Entropy.nextBitsStreamLong(params.retryDelay().toNanos()), - TimeUnit.NANOSECONDS); - } else { - log.error("Failed to join view: {} cannot obtain majority Gateway on: {}", view, node.getId()); - view.stop(); - } + log.error("Failed to join view: {} cannot obtain majority Gateway on: {}", view, node.getId()); + view.stop(); } - }, params.retryDelay()); + }); + var remaining = new AtomicInteger(sample.size()); + redirecting.iterate((link) -> join(v, link, gateway, join, abandon, complete), + (futureSailor, _, _, member) -> join(member, gateway, futureSailor, trusts, + initialSeedSet, v, majority, complete, remaining), + () -> { + if (!view.started.get() || gateway.isDone()) { + return; + } + if (abandon.get() >= majority) { + log.debug( + "Abandoning Gateway view: {} abandons: {} majority: {} reseeding on: {}", v, + abandon.get(), majority, node.getId()); + complete.completeExceptionally(new TimeoutException("Failed Join")); + seeding(); + } + }, params.retryDelay()); }); regate.get().run(); } + private ListenableFuture join(Digest v, Entrance link, CompletableFuture gateway, Join join, + AtomicInteger abandon, CompletableFuture complete) { + if (!view.started.get() || complete.isDone() || gateway.isDone()) { + return null; + } + log.debug("Joining: {} contacting: {} on: {}", v, link.getMember().getId(), node.getId()); + try { + var g = link.join(join, params.seedingTimeout()); + if (g == null || g.equals(Gateway.getDefaultInstance())) { + log.debug("Gateway view: {} empty from: {} on: {}", v, link.getMember().getId(), node.getId()); + abandon.incrementAndGet(); + return null; + } + return g; + } catch (StatusRuntimeException sre) { + gatewaySRE(v, link, sre, abandon); + return null; + } catch (Throwable t) { + log.info("Gateway view: {} error: {} from: {} on: {}", v, t, link.getMember().getId(), node.getId()); + abandon.incrementAndGet(); + return null; + } + } + private Registration registration() { return Registration.newBuilder() .setView(view.currentView().toDigeste()) @@ -354,14 +422,6 @@ private NoteWrapper seedFor(Seed seed) { return new NoteWrapper(seedNote, digestAlgo); } - private void validate(Bootstrapping trust, CompletableFuture gateway, Set initialSeedSet) { - if (gateway.complete( - new Bound(trust.crown, trust.successors.stream().map(sn -> new NoteWrapper(sn, digestAlgo)).toList(), - initialSeedSet.stream().map(sn -> new NoteWrapper(sn, digestAlgo)).toList()))) { - log.info("Gateway acquired: {} context: {} on: {}", trust.diadem, this.context.getId(), node.getId()); - } - } - private record Bootstrapping(Digest diadem, HexBloom crown, Set successors) { public Bootstrapping(BootstrapTrust trust) { this(HexBloom.from(trust.getDiadem()), new HashSet<>(trust.getSuccessorsList())); diff --git a/fireflies/src/main/java/com/salesforce/apollo/fireflies/View.java b/fireflies/src/main/java/com/salesforce/apollo/fireflies/View.java index 47a7ef3b0..7f851207f 100644 --- a/fireflies/src/main/java/com/salesforce/apollo/fireflies/View.java +++ b/fireflies/src/main/java/com/salesforce/apollo/fireflies/View.java @@ -617,7 +617,6 @@ void viewChange(Runnable r) { * * @param ring - the index of the gossip ring the gossip is originating from in this view * @param link - the outbound communications to the paired member - * @param ring */ protected Gossip gossip(Fireflies link, int ring) { tick(); @@ -656,8 +655,8 @@ protected Gossip gossip(Fireflies link, int ring) { node.getId()); break; case UNAVAILABLE: - log.trace("Communication cancelled for gossip view: {} from: {} on: {}", currentView(), p.getId(), - node.getId(), sre); + log.trace("Communication unavailable for gossip view: {} from: {} on: {}", currentView(), p.getId(), + node.getId()); accuse(p, ring, sre); break; default: @@ -1017,8 +1016,8 @@ private void gc(Participant member) { * @return the bloom filter containing the digests of known accusations */ private BloomFilter getAccusationsBff(long seed, double p) { - BloomFilter bff = new BloomFilter.DigestBloomFilter(seed, Math.max(params.minimumBiffCardinality(), - context.cardinality() * 2), p); + var n = Math.max(params.minimumBiffCardinality(), context.cardinality()); + BloomFilter bff = new BloomFilter.DigestBloomFilter(seed, n, 1.0 / (double) n); context.allMembers() .flatMap(Participant::getAccusations) .filter(Objects::nonNull) @@ -1033,9 +1032,9 @@ private BloomFilter getAccusationsBff(long seed, double p) { * @return the bloom filter containing the digests of known notes */ private BloomFilter getNotesBff(long seed, double p) { - BloomFilter bff = new BloomFilter.DigestBloomFilter(seed, Math.max(params.minimumBiffCardinality(), - context.cardinality() * 2), p); - context.allMembers().map(m -> m.getNote()).filter(e -> e != null).forEach(n -> bff.add(n.getHash())); + var n = Math.max(params.minimumBiffCardinality(), context.cardinality()); + BloomFilter bff = new BloomFilter.DigestBloomFilter(seed, n, 1.0 / (double) n); + context.allMembers().map(m -> m.getNote()).filter(e -> e != null).forEach(note -> bff.add(note.getHash())); return bff; } @@ -1045,8 +1044,8 @@ private BloomFilter getNotesBff(long seed, double p) { * @return the bloom filter containing the digests of known observations */ private BloomFilter getObservationsBff(long seed, double p) { - BloomFilter bff = new BloomFilter.DigestBloomFilter(seed, Math.max(params.minimumBiffCardinality(), - context.cardinality() * 2), p); + var n = Math.max(params.minimumBiffCardinality(), observations.size()); + BloomFilter bff = new BloomFilter.DigestBloomFilter(seed, n, 1.0 / (double) n); observations.keySet().stream().collect(Utils.toShuffledList()).forEach(bff::add); return bff; } @@ -1226,8 +1225,9 @@ private NoteGossip.Builder processNotes(BloomFilter bff) { .filter(m -> current.equals(m.getNote().currentView())) .filter(m -> !shunned.contains(m.getId())) .filter(m -> !bff.contains(m.getNote().getHash())) - .collect(new ReservoirSampler<>(params.maximumTxfr(), Entropy.bitsStream())) + .collect(new ReservoirSampler<>(params.maximumTxfr())) .stream() + .filter(sn -> sn != null) .map(Participant::getNote) .forEach(n -> builder.addUpdates(n.getWrapped())); return builder; @@ -1351,9 +1351,8 @@ private Update updatesForDigests(Gossip gossip) { .filter(m -> m.getNote() != null) .filter(m -> current.equals(m.getNote().currentView())) .filter(m -> !notesBff.contains(m.getNote().getHash())) - .collect(new ReservoirSampler<>(params.maximumTxfr(), Entropy.bitsStream())) - .stream() .map(m -> m.getNote().getWrapped()) + .limit(params.maximumTxfr()) .forEach(builder::addNotes); } diff --git a/fireflies/src/main/java/com/salesforce/apollo/fireflies/ViewManagement.java b/fireflies/src/main/java/com/salesforce/apollo/fireflies/ViewManagement.java index 3f2e50579..816bd22e9 100644 --- a/fireflies/src/main/java/com/salesforce/apollo/fireflies/ViewManagement.java +++ b/fireflies/src/main/java/com/salesforce/apollo/fireflies/ViewManagement.java @@ -19,7 +19,6 @@ import com.salesforce.apollo.fireflies.proto.*; import com.salesforce.apollo.fireflies.proto.Update.Builder; import com.salesforce.apollo.membership.Member; -import com.salesforce.apollo.membership.ReservoirSampler; import com.salesforce.apollo.ring.SliceIterator; import com.salesforce.apollo.stereotomy.identifier.SelfAddressingIdentifier; import com.salesforce.apollo.utils.Entropy; @@ -207,6 +206,7 @@ void install(Ballot ballot) { final var seedSet = context.sample(params.maximumTxfr(), Entropy.bitsStream(), node.getId()) .stream() + .filter(sn -> sn != null) .map(p -> p.note.getWrapped()) .collect(Collectors.toSet()); @@ -413,7 +413,6 @@ void joinUpdatesFor(BloomFilter joinBff, Builder builder) { joins.entrySet() .stream() .filter(e -> !joinBff.contains(e.getKey())) - .collect(new ReservoirSampler<>(params.maximumTxfr(), Entropy.bitsStream())) .forEach(e -> builder.addJoins(e.getValue().getWrapped())); } @@ -496,7 +495,6 @@ JoinGossip.Builder processJoins(BloomFilter bff) { .stream() .filter(m -> !bff.contains(m.getKey())) .map(Map.Entry::getValue) - .collect(new ReservoirSampler<>(params.maximumTxfr(), Entropy.bitsStream())) .forEach(n -> builder.addUpdates(n.getWrapped())); return builder; } diff --git a/fireflies/src/main/java/com/salesforce/apollo/fireflies/comm/entrance/Entrance.java b/fireflies/src/main/java/com/salesforce/apollo/fireflies/comm/entrance/Entrance.java index 98b63d053..6459a0491 100644 --- a/fireflies/src/main/java/com/salesforce/apollo/fireflies/comm/entrance/Entrance.java +++ b/fireflies/src/main/java/com/salesforce/apollo/fireflies/comm/entrance/Entrance.java @@ -6,6 +6,7 @@ */ package com.salesforce.apollo.fireflies.comm.entrance; +import com.google.common.util.concurrent.ListenableFuture; import com.salesforce.apollo.archipelago.Link; import com.salesforce.apollo.fireflies.View.Node; import com.salesforce.apollo.fireflies.proto.Gateway; @@ -35,7 +36,7 @@ public Member getMember() { } @Override - public Gateway join(Join join, Duration timeout) { + public ListenableFuture join(Join join, Duration timeout) { return null; } @@ -46,7 +47,7 @@ public Redirect seed(Registration registration) { }; } - Gateway join(Join join, Duration timeout); + ListenableFuture join(Join join, Duration timeout); Redirect seed(Registration registration); } diff --git a/fireflies/src/main/java/com/salesforce/apollo/fireflies/comm/entrance/EntranceClient.java b/fireflies/src/main/java/com/salesforce/apollo/fireflies/comm/entrance/EntranceClient.java index 1a1256437..c3a5c1d85 100644 --- a/fireflies/src/main/java/com/salesforce/apollo/fireflies/comm/entrance/EntranceClient.java +++ b/fireflies/src/main/java/com/salesforce/apollo/fireflies/comm/entrance/EntranceClient.java @@ -6,6 +6,7 @@ */ package com.salesforce.apollo.fireflies.comm.entrance; +import com.google.common.util.concurrent.ListenableFuture; import com.salesforce.apollo.archipelago.ManagedServerChannel; import com.salesforce.apollo.archipelago.ServerConnectionCache.CreateClientCommunications; import com.salesforce.apollo.fireflies.FireflyMetrics; @@ -13,6 +14,7 @@ import com.salesforce.apollo.membership.Member; import java.time.Duration; +import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; /** @@ -23,10 +25,12 @@ public class EntranceClient implements Entrance { private final ManagedServerChannel channel; private final EntranceGrpc.EntranceBlockingStub client; private final FireflyMetrics metrics; + private final EntranceGrpc.EntranceFutureStub ayncClient; public EntranceClient(ManagedServerChannel channel, FireflyMetrics metrics) { this.channel = channel; this.client = channel.wrap(EntranceGrpc.newBlockingStub(channel)); + ayncClient = channel.wrap(EntranceGrpc.newFutureStub(channel)); this.metrics = metrics; } @@ -46,23 +50,34 @@ public Member getMember() { } @Override - public Gateway join(Join join, Duration timeout) { + public ListenableFuture join(Join join, Duration timeout) { if (metrics != null) { var serializedSize = join.getSerializedSize(); metrics.outboundBandwidth().mark(serializedSize); metrics.outboundJoin().update(serializedSize); } - Gateway result = client.withDeadlineAfter(timeout.toNanos(), TimeUnit.NANOSECONDS).join(join); - if (metrics != null) { + ListenableFuture result = ayncClient.withDeadlineAfter(timeout.toNanos(), TimeUnit.NANOSECONDS) + .join(join); + result.addListener(() -> { + Gateway g = null; try { - var serializedSize = result.getSerializedSize(); - metrics.inboundBandwidth().mark(serializedSize); - metrics.inboundGateway().update(serializedSize); - } catch (Throwable e) { + g = result.get(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } catch (ExecutionException e) { // nothing } - } + if (metrics != null) { + try { + var serializedSize = g.getSerializedSize(); + metrics.inboundBandwidth().mark(serializedSize); + metrics.inboundGateway().update(serializedSize); + } catch (Throwable e) { + // nothing + } + } + }, Runnable::run); return result; } diff --git a/fireflies/src/test/java/com/salesforce/apollo/fireflies/ChurnTest.java b/fireflies/src/test/java/com/salesforce/apollo/fireflies/ChurnTest.java index d39f26c77..b689e159d 100644 --- a/fireflies/src/test/java/com/salesforce/apollo/fireflies/ChurnTest.java +++ b/fireflies/src/test/java/com/salesforce/apollo/fireflies/ChurnTest.java @@ -271,7 +271,7 @@ public void churn() throws Exception { private void initialize() { executor = UnsafeExecutors.newVirtualThreadPerTaskExecutor(); executor2 = UnsafeExecutors.newVirtualThreadPerTaskExecutor(); - var parameters = Parameters.newBuilder().setFpr(0.0000125).setMaximumTxfr(20).build(); + var parameters = Parameters.newBuilder().setMaximumTxfr(10).build(); registry = new MetricRegistry(); node0Registry = new MetricRegistry(); diff --git a/fireflies/src/test/resources/logback-test.xml b/fireflies/src/test/resources/logback-test.xml index b4afb147d..f86c9c940 100644 --- a/fireflies/src/test/resources/logback-test.xml +++ b/fireflies/src/test/resources/logback-test.xml @@ -9,7 +9,12 @@ %d{mm:ss.SSS} %logger{0} - %msg%n - + + + + + + @@ -17,19 +22,31 @@ - + + + + + + + + + + + + + - + - + - + diff --git a/memberships/src/main/java/com/salesforce/apollo/context/DelegatedContext.java b/memberships/src/main/java/com/salesforce/apollo/context/DelegatedContext.java index ecb159025..e319a2977 100644 --- a/memberships/src/main/java/com/salesforce/apollo/context/DelegatedContext.java +++ b/memberships/src/main/java/com/salesforce/apollo/context/DelegatedContext.java @@ -2,6 +2,7 @@ import com.salesforce.apollo.cryptography.Digest; import com.salesforce.apollo.membership.Member; +import com.salesforce.apollo.utils.Entropy; import org.apache.commons.math3.random.BitsStreamGenerator; import java.util.List; @@ -231,7 +232,7 @@ public int rank(int ring, T item, T dest) { @Override public List sample(int range, BitsStreamGenerator entropy, Digest exc) { - return delegate.sample(range, entropy, exc); + return delegate.sample(range, Entropy.bitsStream(), exc); } @Override diff --git a/memberships/src/main/java/com/salesforce/apollo/context/DynamicContextImpl.java b/memberships/src/main/java/com/salesforce/apollo/context/DynamicContextImpl.java index 18a22e1d1..58fc888b7 100644 --- a/memberships/src/main/java/com/salesforce/apollo/context/DynamicContextImpl.java +++ b/memberships/src/main/java/com/salesforce/apollo/context/DynamicContextImpl.java @@ -585,9 +585,7 @@ public Stream> rings() { */ @Override public List sample(int range, BitsStreamGenerator entropy, Predicate excluded) { - return rings.get(entropy.nextInt(rings.size())) - .stream() - .collect(new ReservoirSampler<>(excluded, range, entropy)); + return rings.get(entropy.nextInt(rings.size())).stream().collect(new ReservoirSampler<>(range, excluded)); } /** @@ -603,7 +601,7 @@ public List sample(int range, BitsStreamGenerator entropy, Dige Member excluded = exc == null ? null : getMember(exc); return rings.get(entropy.nextInt(rings.size())) .stream() - .collect(new ReservoirSampler(excluded, range, entropy)); + .collect(new ReservoirSampler(range, t -> t.equals(excluded))); } @Override diff --git a/memberships/src/main/java/com/salesforce/apollo/context/StaticContext.java b/memberships/src/main/java/com/salesforce/apollo/context/StaticContext.java index 42d6a1668..d91662176 100644 --- a/memberships/src/main/java/com/salesforce/apollo/context/StaticContext.java +++ b/memberships/src/main/java/com/salesforce/apollo/context/StaticContext.java @@ -313,21 +313,20 @@ public int rank(int ring, T item, T dest) { */ @Override public List sample(int range, BitsStreamGenerator entropy, Predicate excluded) { - return ring(entropy.nextInt(rings.length)).stream().collect(new ReservoirSampler<>(excluded, range, entropy)); + return ring(entropy.nextInt(rings.length)).stream().collect(new ReservoirSampler<>(range, excluded)); } /** * Answer a random sample of at least range size from the active members of the context * - * @param range - the desired range - * @param entropy - source o randomness - * @param exc - the member to exclude from sample + * @param range - the desired range + * @param exc - the member to exclude from sample * @return a random sample set of the view's live members. May be limited by the number of active members. */ @Override public List sample(int range, BitsStreamGenerator entropy, Digest exc) { Member excluded = exc == null ? null : getMember(exc); - return ring(entropy.nextInt(rings.length)).stream().collect(new ReservoirSampler(excluded, range, entropy)); + return ring(entropy.nextInt(rings.length)).stream().collect(new ReservoirSampler<>(range, (T) excluded)); } @Override diff --git a/memberships/src/main/java/com/salesforce/apollo/membership/ReservoirSampler.java b/memberships/src/main/java/com/salesforce/apollo/membership/ReservoirSampler.java index d62dde510..d5f770cad 100644 --- a/memberships/src/main/java/com/salesforce/apollo/membership/ReservoirSampler.java +++ b/memberships/src/main/java/com/salesforce/apollo/membership/ReservoirSampler.java @@ -6,36 +6,40 @@ */ package com.salesforce.apollo.membership; -import org.apache.commons.math3.random.BitsStreamGenerator; - import java.util.ArrayList; import java.util.EnumSet; import java.util.List; import java.util.Set; -import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.ThreadLocalRandom; import java.util.function.*; import java.util.stream.Collector; -public class ReservoirSampler implements Collector, List> { +import static java.lang.Math.exp; +import static java.lang.Math.log; - private final Predicate exclude; - private final BitsStreamGenerator rand; - private final int sz; - private AtomicInteger c = new AtomicInteger(); +/** + * @author hal.hildebrand + **/ +public class ReservoirSampler implements Collector, List> { + private final int capacity; + private final Predicate ignore; + private volatile double w; + private volatile long counter; + private volatile long next; - public ReservoirSampler(int size, BitsStreamGenerator entropy) { - this(null, size, entropy); + public ReservoirSampler(int capacity, T ignore) { + this(capacity, t -> t.equals(ignore)); } - public ReservoirSampler(Object excluded, int size, BitsStreamGenerator entropy) { - this(t -> excluded == null ? false : excluded.equals(t), size, entropy); + public ReservoirSampler(int capacity, Predicate ignore) { + this.capacity = capacity; + w = exp(log(ThreadLocalRandom.current().nextDouble()) / capacity); + skip(); + this.ignore = ignore == null ? t -> false : ignore; } - public ReservoirSampler(Predicate excluded, int size, BitsStreamGenerator entropy) { - assert size >= 0; - this.exclude = excluded; - this.sz = size; - rand = entropy; + public ReservoirSampler(int capacity) { + this(capacity, (Predicate) null); } @Override @@ -63,21 +67,31 @@ public Function, List> finisher() { @Override public Supplier> supplier() { - return ArrayList::new; + var reservoir = new ArrayList(capacity); + for (int i = 0; i < capacity; i++) { + reservoir.add(null); + } + return () -> reservoir; } private void addIt(final List in, T s) { - if (exclude != null && exclude.test(s)) { + if (ignore.test(s)) { return; } - if (in.size() < sz) { - in.add(s); + + if (counter < in.size()) { + in.add((int) counter, s); } else { - int replaceInIndex = (int) (rand.nextLong(sz + (c.getAndIncrement()) + 1)); - if (replaceInIndex < sz) { - in.set(replaceInIndex, s); + if (counter == next) { + in.add(ThreadLocalRandom.current().nextInt(in.size()), s); + skip(); } } + ++counter; } + private void skip() { + next += (long) (log(ThreadLocalRandom.current().nextDouble()) / log(1 - w)) + 1; + w *= exp(log(ThreadLocalRandom.current().nextDouble()) / capacity); + } }