Skip to content

Commit

Permalink
Add lower-level simulate functions
Browse files Browse the repository at this point in the history
  • Loading branch information
dirmeier committed Oct 4, 2023
1 parent 5b6f78b commit 594b937
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 12 deletions.
89 changes: 78 additions & 11 deletions sbijax/_sne_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,21 +59,73 @@ def simulate_data_and_possibly_append(
"""

observable = jnp.atleast_2d(observable)
sample_key, rng_key = jr.split(rng_key)
new_data, diagnostics = self.simulate_data(
rng_key,
params=params,
observable=observable,
n_simulations=n_simulations,
**kwargs,
)
if data is None:
d_new = new_data
else:
d_new = self.stack_data(data, new_data)
return d_new, diagnostics

def simulate_data(
self,
rng_key,
*,
params=None,
observable=None,
n_simulations=1000,
**kwargs,
):
"""
Simulate data from the posterior or prior and append it to an
existing data set (if provided)
Parameters
----------
rng_key: jax.PRNGKey
a random key
params: Optional[pytree]
a dictionary of neural network parameters. If None, will draw from
prior. If parameters given, will draw from amortized posterior
using 'observable;
observable: Optional[jnp.ndarray]
an observation. Needs to be gfiven if posterior draws are desired
n_simulations: int
number of newly simulated data
kwargs: keyword arguments
dictionary of ey value pairs passed to `sample_posterior`
Returns
-------
NamedTuple:
returns a NamedTuple of two axis, y and theta
"""

sample_key, rng_key = jr.split(rng_key)
if params is None or len(params) == 0:
diagnostics = None
self.n_total_simulations += n_simulations
new_thetas = self.prior_sampler_fn(
seed=sample_key,
sample_shape=(n_simulations,),
)
else:
if observable is None:
raise ValueError(
"need to have access to 'observable' "
"when sampling from posterior"
)
if "n_samples" not in kwargs:
kwargs["n_samples"] = n_simulations
new_thetas, diagnostics = self.sample_posterior(
rng_key=sample_key,
params=params,
observable=observable,
observable=jnp.atleast_2d(observable),
**kwargs,
)
perm_key, rng_key = jr.split(rng_key)
Expand All @@ -82,18 +134,33 @@ def simulate_data_and_possibly_append(

simulate_key, rng_key = jr.split(rng_key)
new_obs = self.simulator_fn(seed=simulate_key, theta=new_thetas)
chex.assert_shape(new_thetas, [n_simulations, None])
chex.assert_shape(new_obs, [n_simulations, None])

new_data = named_dataset(new_obs, new_thetas)

chex.assert_shape(new_thetas, [n_simulations, None])
chex.assert_shape(new_data, [n_simulations, None])
return new_data, diagnostics

if data is None:
d_new = new_data
else:
d_new = named_dataset(
*[jnp.vstack([a, b]) for a, b in zip(data, new_data)]
)
return d_new, diagnostics
@staticmethod
def stack_data(data, also_data):
"""
Stack two data sets.
Parameters
----------
data: NamedTuple
one data set
also_data: : NamedTuple
Returns
-------
NamedTuple:
returns the stack of the two data sets
"""

return named_dataset(
*[jnp.vstack([a, b]) for a, b in zip(data, also_data)]
)

def as_iterators(
self, rng_key, data, batch_size, percentage_data_as_validation_set
Expand Down
35 changes: 34 additions & 1 deletion sbijax/snl_test.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# pylint: skip-file

import chex
import distrax
import haiku as hk
import pytest
from jax import numpy as jnp
from jax import random as jr
from surjectors import Chain, MaskedCoupling, TransformedDistribution
from surjectors.conditioners import mlp_conditioner
from surjectors.util import make_alternating_binary_mask
Expand Down Expand Up @@ -88,3 +90,34 @@ def test_snl():
n_samples=200,
n_warmup=100,
)


def test_stack_data():
prior_simulator_fn, prior_logdensity_fn = prior_model_fns()
fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn

snl = SNL(fns, make_model(2))
n = 100
data, _ = snl.simulate_data(jr.PRNGKey(1), n_simulations=n)
also_data, _ = snl.simulate_data(jr.PRNGKey(2), n_simulations=n)
stacked_data = snl.stack_data(data, also_data)

chex.assert_trees_all_equal(data[0], stacked_data[0][:n])
chex.assert_trees_all_equal(data[1], stacked_data[1][:n])
chex.assert_trees_all_equal(also_data[0], stacked_data[0][n:])
chex.assert_trees_all_equal(also_data[1], stacked_data[1][n:])


def test_simulate_data_from_posterior_fail():
rng_seq = hk.PRNGSequence(0)

prior_simulator_fn, prior_logdensity_fn = prior_model_fns()
fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn

snl = SNL(fns, make_model(2))
n = 100

data, _ = snl.simulate_data(jr.PRNGKey(1), n_simulations=n)
params, _ = snl.fit(next(rng_seq), data=data, n_iter=10)
with pytest.raises(ValueError):
snl.simulate_data(jr.PRNGKey(2), n_simulations=n, params=params)

0 comments on commit 594b937

Please sign in to comment.