diff --git a/sbijax/_src/nn/consistency_model.py b/sbijax/_src/nn/consistency_model.py index a871d2b..24fc43d 100644 --- a/sbijax/_src/nn/consistency_model.py +++ b/sbijax/_src/nn/consistency_model.py @@ -96,7 +96,7 @@ def vector_field(self, theta, time, context, **kwargs): return self._network(theta=theta, time=time, context=context, **kwargs) -# pylint: disable=too-many-arguments +# pylint: disable=too-many-arguments,too-many-instance-attributes class _CMResnet(hk.Module): """A simplified 1-d residual network.""" diff --git a/sbijax/_src/nn/continuous_normalizing_flow.py b/sbijax/_src/nn/continuous_normalizing_flow.py index 1c405e0..741d865 100644 --- a/sbijax/_src/nn/continuous_normalizing_flow.py +++ b/sbijax/_src/nn/continuous_normalizing_flow.py @@ -47,7 +47,7 @@ def __call__(self, method, **kwargs): """ return getattr(self, method)(**kwargs) - def sample(self, context): + def sample(self, context, **kwargs): """Sample from the pushforward. Args: @@ -61,7 +61,7 @@ def ode_func(time, theta_t): theta_t = theta_t.reshape(-1, self._n_dimension) time = jnp.full((theta_t.shape[0], 1), time) ret = self.vector_field( - theta=theta_t, time=time, context=context, is_training=False + theta=theta_t, time=time, context=context, **kwargs ) return ret.reshape(-1) diff --git a/sbijax/_src/scmpe_test.py b/sbijax/_src/scmpe_test.py new file mode 100644 index 0000000..8f6a38e --- /dev/null +++ b/sbijax/_src/scmpe_test.py @@ -0,0 +1,60 @@ +# pylint: skip-file + +import distrax +import haiku as hk +from jax import numpy as jnp + +from sbijax import SCMPE +from sbijax.nn import make_consistency_model + + +def prior_model_fns(): + p = distrax.Independent(distrax.Normal(jnp.zeros(2), jnp.ones(2)), 1) + return p.sample, p.log_prob + + +def simulator_fn(seed, theta): + p = distrax.MultivariateNormalDiag(theta, 0.1 * jnp.ones_like(theta)) + y = p.sample(seed=seed) + return y + + +def log_density_fn(theta, y): + prior = distrax.Uniform(jnp.full(2, -3.0), jnp.full(2, 3.0)) + likelihood = distrax.MultivariateNormalDiag( + theta, 0.1 * jnp.ones_like(theta) + ) + + lp = jnp.sum(prior.log_prob(theta)) + jnp.sum(likelihood.log_prob(y)) + return lp + + +def test_scmpe(): + rng_seq = hk.PRNGSequence(0) + y_observed = jnp.array([-1.0, 1.0]) + + prior_simulator_fn, prior_logdensity_fn = prior_model_fns() + fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn + + estim = SCMPE(fns, make_consistency_model(2)) + data, params = None, {} + for i in range(2): + data, _ = estim.simulate_data_and_possibly_append( + next(rng_seq), + params=params, + observable=y_observed, + data=data, + n_simulations=100, + n_chains=2, + n_samples=200, + n_warmup=100, + ) + params, info = estim.fit(next(rng_seq), data=data, n_iter=2) + _ = estim.sample_posterior( + next(rng_seq), + params, + y_observed, + n_chains=2, + n_samples=200, + n_warmup=100, + ) diff --git a/sbijax/_src/sfmpe_test.py b/sbijax/_src/sfmpe_test.py new file mode 100644 index 0000000..102ca07 --- /dev/null +++ b/sbijax/_src/sfmpe_test.py @@ -0,0 +1,60 @@ +# pylint: skip-file + +import distrax +import haiku as hk +from jax import numpy as jnp + +from sbijax import SFMPE +from sbijax.nn import make_ccnf + + +def prior_model_fns(): + p = distrax.Independent(distrax.Normal(jnp.zeros(2), jnp.ones(2)), 1) + return p.sample, p.log_prob + + +def simulator_fn(seed, theta): + p = distrax.MultivariateNormalDiag(theta, 0.1 * jnp.ones_like(theta)) + y = p.sample(seed=seed) + return y + + +def log_density_fn(theta, y): + prior = distrax.Uniform(jnp.full(2, -3.0), jnp.full(2, 3.0)) + likelihood = distrax.MultivariateNormalDiag( + theta, 0.1 * jnp.ones_like(theta) + ) + + lp = jnp.sum(prior.log_prob(theta)) + jnp.sum(likelihood.log_prob(y)) + return lp + + +def test_sfmpe(): + rng_seq = hk.PRNGSequence(0) + y_observed = jnp.array([-1.0, 1.0]) + + prior_simulator_fn, prior_logdensity_fn = prior_model_fns() + fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn + + estim = SFMPE(fns, make_ccnf(2)) + data, params = None, {} + for i in range(2): + data, _ = estim.simulate_data_and_possibly_append( + next(rng_seq), + params=params, + observable=y_observed, + data=data, + n_simulations=100, + n_chains=2, + n_samples=200, + n_warmup=100, + ) + params, info = estim.fit(next(rng_seq), data=data, n_iter=2) + _ = estim.sample_posterior( + next(rng_seq), + params, + y_observed, + n_chains=2, + n_samples=200, + n_warmup=100, + )