diff --git a/sbijax/__init__.py b/sbijax/__init__.py index 867e4cf..44930df 100644 --- a/sbijax/__init__.py +++ b/sbijax/__init__.py @@ -2,7 +2,7 @@ sbijax: Simulation-based inference in JAX """ -__version__ = "0.1.1" +__version__ = "0.1.2" from sbijax.abc.rejection_abc import RejectionABC diff --git a/sbijax/_sne_base.py b/sbijax/_sne_base.py index e5f10da..595cc96 100644 --- a/sbijax/_sne_base.py +++ b/sbijax/_sne_base.py @@ -158,6 +158,10 @@ def stack_data(data, also_data): returns the stack of the two data sets """ + if data is None: + return also_data + if also_data is None: + return data return named_dataset( *[jnp.vstack([a, b]) for a, b in zip(data, also_data)] ) diff --git a/sbijax/snl_test.py b/sbijax/snl_test.py index 24cdfcb..8f4c704 100644 --- a/sbijax/snl_test.py +++ b/sbijax/snl_test.py @@ -108,6 +108,19 @@ def test_stack_data(): chex.assert_trees_all_equal(also_data[1], stacked_data[1][n:]) +def test_stack_data_with_none(): + prior_simulator_fn, prior_logdensity_fn = prior_model_fns() + fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn + + snl = SNL(fns, make_model(2)) + n = 100 + data, _ = snl.simulate_data(jr.PRNGKey(1), n_simulations=n) + stacked_data = snl.stack_data(None, data) + + chex.assert_trees_all_equal(data[0], stacked_data[0]) + chex.assert_trees_all_equal(data[1], stacked_data[1]) + + def test_simulate_data_from_posterior_fail(): rng_seq = hk.PRNGSequence(0)