Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #15

Merged
merged 2 commits into from
Oct 4, 2023
Merged

Fix #15

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sbijax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
sbijax: Simulation-based inference in JAX
"""

__version__ = "0.1.0"
__version__ = "0.1.1"


from sbijax.abc.rejection_abc import RejectionABC
Expand Down
89 changes: 78 additions & 11 deletions sbijax/_sne_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,21 +59,73 @@ def simulate_data_and_possibly_append(
"""

observable = jnp.atleast_2d(observable)
sample_key, rng_key = jr.split(rng_key)
new_data, diagnostics = self.simulate_data(
rng_key,
params=params,
observable=observable,
n_simulations=n_simulations,
**kwargs,
)
if data is None:
d_new = new_data
else:
d_new = self.stack_data(data, new_data)
return d_new, diagnostics

def simulate_data(
self,
rng_key,
*,
params=None,
observable=None,
n_simulations=1000,
**kwargs,
):
"""
Simulate data from the posterior or prior and append it to an
existing data set (if provided)

Parameters
----------
rng_key: jax.PRNGKey
a random key
params: Optional[pytree]
a dictionary of neural network parameters. If None, will draw from
prior. If parameters given, will draw from amortized posterior
using 'observable;
observable: Optional[jnp.ndarray]
an observation. Needs to be gfiven if posterior draws are desired
n_simulations: int
number of newly simulated data
kwargs: keyword arguments
dictionary of ey value pairs passed to `sample_posterior`

Returns
-------
NamedTuple:
returns a NamedTuple of two axis, y and theta
"""

sample_key, rng_key = jr.split(rng_key)
if params is None or len(params) == 0:
diagnostics = None
self.n_total_simulations += n_simulations
new_thetas = self.prior_sampler_fn(
seed=sample_key,
sample_shape=(n_simulations,),
)
else:
if observable is None:
raise ValueError(
"need to have access to 'observable' "
"when sampling from posterior"
)
if "n_samples" not in kwargs:
kwargs["n_samples"] = n_simulations
new_thetas, diagnostics = self.sample_posterior(
rng_key=sample_key,
params=params,
observable=observable,
observable=jnp.atleast_2d(observable),
**kwargs,
)
perm_key, rng_key = jr.split(rng_key)
Expand All @@ -82,18 +134,33 @@ def simulate_data_and_possibly_append(

simulate_key, rng_key = jr.split(rng_key)
new_obs = self.simulator_fn(seed=simulate_key, theta=new_thetas)
chex.assert_shape(new_thetas, [n_simulations, None])
chex.assert_shape(new_obs, [n_simulations, None])

new_data = named_dataset(new_obs, new_thetas)

chex.assert_shape(new_thetas, [n_simulations, None])
chex.assert_shape(new_data, [n_simulations, None])
return new_data, diagnostics

if data is None:
d_new = new_data
else:
d_new = named_dataset(
*[jnp.vstack([a, b]) for a, b in zip(data, new_data)]
)
return d_new, diagnostics
@staticmethod
def stack_data(data, also_data):
"""
Stack two data sets.

Parameters
----------
data: NamedTuple
one data set
also_data: : NamedTuple

Returns
-------
NamedTuple:
returns the stack of the two data sets
"""

return named_dataset(
*[jnp.vstack([a, b]) for a, b in zip(data, also_data)]
)

def as_iterators(
self, rng_key, data, batch_size, percentage_data_as_validation_set
Expand Down
35 changes: 34 additions & 1 deletion sbijax/snl_test.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# pylint: skip-file

import chex
import distrax
import haiku as hk
import pytest
from jax import numpy as jnp
from jax import random as jr
from surjectors import Chain, MaskedCoupling, TransformedDistribution
from surjectors.conditioners import mlp_conditioner
from surjectors.util import make_alternating_binary_mask
Expand Down Expand Up @@ -88,3 +90,34 @@ def test_snl():
n_samples=200,
n_warmup=100,
)


def test_stack_data():
prior_simulator_fn, prior_logdensity_fn = prior_model_fns()
fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn

snl = SNL(fns, make_model(2))
n = 100
data, _ = snl.simulate_data(jr.PRNGKey(1), n_simulations=n)
also_data, _ = snl.simulate_data(jr.PRNGKey(2), n_simulations=n)
stacked_data = snl.stack_data(data, also_data)

chex.assert_trees_all_equal(data[0], stacked_data[0][:n])
chex.assert_trees_all_equal(data[1], stacked_data[1][:n])
chex.assert_trees_all_equal(also_data[0], stacked_data[0][n:])
chex.assert_trees_all_equal(also_data[1], stacked_data[1][n:])


def test_simulate_data_from_posterior_fail():
rng_seq = hk.PRNGSequence(0)

prior_simulator_fn, prior_logdensity_fn = prior_model_fns()
fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn

snl = SNL(fns, make_model(2))
n = 100

data, _ = snl.simulate_data(jr.PRNGKey(1), n_simulations=n)
params, _ = snl.fit(next(rng_seq), data=data, n_iter=10)
with pytest.raises(ValueError):
snl.simulate_data(jr.PRNGKey(2), n_simulations=n, params=params)