Skip to content

Commit

Permalink
implemented #148
Browse files Browse the repository at this point in the history
  • Loading branch information
bmcfee committed Feb 1, 2024
1 parent 9ad3511 commit dae9d93
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 11 deletions.
41 changes: 35 additions & 6 deletions pescador/mux.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ def __init__(
weights=None,
mode="with_replacement",
prune_empty_streams=True,
dist="binomial",
random_state=None,
):
"""Given an array (pool) of streamer types, do the following:
Expand Down Expand Up @@ -335,7 +336,12 @@ def __init__(
Run every selected stream once to exhaustion.
prune_empty_streams : bool
Disable streamers that produce no data. See `BaseMux`
Disable streamers that produce no data.
dist : ["constant", "binomial", "poisson"]
Distribution governing the (maximum) number of samples taken
from an active streamer.
In each case, the expected number of samples will be `rate`.
random_state : None, int, or np.random.RandomState
See `BaseMux`
Expand All @@ -344,6 +350,7 @@ def __init__(
self.n_active = n_active
self.rate = rate
self.prune_empty_streams = prune_empty_streams
self.dist = dist

super().__init__(streamers, random_state=random_state)

Expand All @@ -354,6 +361,10 @@ def __init__(
raise PescadorError(
f"{self.mode} is not a valid mode for StochasticMux"
)
if self.dist not in ["constant", "binomial", "poisson"]:
raise PescadorError(
f"{self.dist} is not a valid distribution"
)

self.weights = weights
if self.weights is None:
Expand Down Expand Up @@ -436,7 +447,7 @@ def _on_stream_exhausted(self, idx):
else:
self.distribution_[self.stream_idxs_[idx]] = 1.0

def _activate_stream(self, idx):
def _activate_stream(self, idx, old_idx):
"""Randomly select and create a stream.
StochasticMux adds mode handling to _activate_stream, making it so that
Expand All @@ -447,16 +458,34 @@ def _activate_stream(self, idx):
Parameters
----------
idx : int, [0:n_streams - 1]
The stream index to replace
The stream index to activate
old_idx : int
The index of the stream being replaced in the active set.
This is needed for computing binomial probabilities.
"""
weight = self.weights[idx]

# Get the number of samples for this streamer.
n_samples_to_stream = None
if self.rate is not None:
n_samples_to_stream = 1 + self.rng.poisson(lam=self.rate)
if self.dist == "constant":
n_samples_to_stream = self.rate
elif self.dist == "poisson":
n_samples_to_stream = 1 + self.rng.poisson(lam=self.rate - 1)
elif self.dist == "binomial":
# Bin((rate-1) / (1-p), 1-p) where p = prob of selecting the new
# streamer from the active set
p = weight / (np.sum(self.stream_weights_) - self.stream_weights_[old_idx] + weight)
if p > 0.9999:
# If we effectively have only one streamer, use the poisson limit
# theorem
n_samples_to_stream = 1 + self.rng.poisson(lam=self.rate - 1)
else:
n_samples_to_stream = 1 + self.rng.binomial((self.rate-1) / (1-p), 1-p)

# instantiate a new streamer
streamer = self.streamers[idx].iterate(max_iter=n_samples_to_stream)
weight = self.weights[idx]

# If we're sampling without replacement, zero this one out
# This effectively disables this stream as soon as it is chosen,
Expand Down Expand Up @@ -484,7 +513,7 @@ def _new_stream(self, idx):

# Activate the Streamer, and get the weights
self.streams_[idx], self.stream_weights_[idx] = self._activate_stream(
self.stream_idxs_[idx]
self.stream_idxs_[idx], idx
)

# Reset the sample count to zero
Expand Down
13 changes: 8 additions & 5 deletions tests/test_mux.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,10 @@ def test_mux_inf_loop(self, mux_class):

assert len(list(mux(max_iter=100))) == 0

def test_mux_stacked_uniform_convergence(self, mux_class):
@pytest.mark.parametrize(
'dist', ['constant', 'binomial', 'poisson',
pytest.param('gaussian', marks=pytest.mark.xfail(raises=pescador.PescadorError))])
def test_mux_stacked_uniform_convergence(self, mux_class, dist):
"""This test is designed to check that bootstrapped streams of data
(Streamer subsampling, rate limiting) cascaded through multiple
multiplexors converges in expectation to a flat, uniform sample of the
Expand All @@ -553,19 +556,19 @@ def test_mux_stacked_uniform_convergence(self, mux_class):
ab = pescador.Streamer(_choice, 'ab')
cd = pescador.Streamer(_choice, 'cd')
ef = pescador.Streamer(_choice, 'ef')
mux1 = mux_class([ab, cd, ef], 2, rate=2, random_state=1357)
mux1 = mux_class([ab, cd, ef], 2, rate=4, random_state=1357, dist=dist)

gh = pescador.Streamer(_choice, 'gh')
ij = pescador.Streamer(_choice, 'ij')
kl = pescador.Streamer(_choice, 'kl')

mux2 = mux_class([gh, ij, kl], 2, rate=2, random_state=2468)
mux2 = mux_class([gh, ij, kl], 2, rate=4, random_state=2468, dist=dist)

stacked_mux = mux_class([mux1, mux2], 2, rate=None,
random_state=12345)

max_iter = 1000
chars = 'abcdefghijkl'
max_iter = len(chars) * 500
samples = list(stacked_mux.iterate(max_iter=max_iter))
counter = collections.Counter(samples)
assert set(chars) == set(counter.keys())
Expand All @@ -574,7 +577,7 @@ def test_mux_stacked_uniform_convergence(self, mux_class):

# Check that the pvalue for the chi^2 test is at least 0.95
test = scipy.stats.chisquare(counts)
assert test.pvalue >= 0.95
assert test.pvalue >= 0.5, counts


class TestShuffledMux:
Expand Down

0 comments on commit dae9d93

Please sign in to comment.