Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dirmeier committed Feb 29, 2024
1 parent 218a4b4 commit 0323cb2
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 3 deletions.
2 changes: 1 addition & 1 deletion sbijax/_src/nn/consistency_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def vector_field(self, theta, time, context, **kwargs):
return self._network(theta=theta, time=time, context=context, **kwargs)


# pylint: disable=too-many-arguments
# pylint: disable=too-many-arguments,too-many-instance-attributes
class _CMResnet(hk.Module):
"""A simplified 1-d residual network."""

Expand Down
4 changes: 2 additions & 2 deletions sbijax/_src/nn/continuous_normalizing_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __call__(self, method, **kwargs):
"""
return getattr(self, method)(**kwargs)

def sample(self, context):
def sample(self, context, **kwargs):
"""Sample from the pushforward.
Args:
Expand All @@ -61,7 +61,7 @@ def ode_func(time, theta_t):
theta_t = theta_t.reshape(-1, self._n_dimension)
time = jnp.full((theta_t.shape[0], 1), time)
ret = self.vector_field(
theta=theta_t, time=time, context=context, is_training=False
theta=theta_t, time=time, context=context, **kwargs
)
return ret.reshape(-1)

Expand Down
60 changes: 60 additions & 0 deletions sbijax/_src/scmpe_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# pylint: skip-file

import distrax
import haiku as hk
from jax import numpy as jnp

from sbijax import SCMPE
from sbijax.nn import make_consistency_model


def prior_model_fns():
p = distrax.Independent(distrax.Normal(jnp.zeros(2), jnp.ones(2)), 1)
return p.sample, p.log_prob


def simulator_fn(seed, theta):
p = distrax.MultivariateNormalDiag(theta, 0.1 * jnp.ones_like(theta))
y = p.sample(seed=seed)
return y


def log_density_fn(theta, y):
prior = distrax.Uniform(jnp.full(2, -3.0), jnp.full(2, 3.0))
likelihood = distrax.MultivariateNormalDiag(
theta, 0.1 * jnp.ones_like(theta)
)

lp = jnp.sum(prior.log_prob(theta)) + jnp.sum(likelihood.log_prob(y))
return lp


def test_scmpe():
rng_seq = hk.PRNGSequence(0)
y_observed = jnp.array([-1.0, 1.0])

prior_simulator_fn, prior_logdensity_fn = prior_model_fns()
fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn

estim = SCMPE(fns, make_consistency_model(2))
data, params = None, {}
for i in range(2):
data, _ = estim.simulate_data_and_possibly_append(
next(rng_seq),
params=params,
observable=y_observed,
data=data,
n_simulations=100,
n_chains=2,
n_samples=200,
n_warmup=100,
)
params, info = estim.fit(next(rng_seq), data=data, n_iter=2)
_ = estim.sample_posterior(
next(rng_seq),
params,
y_observed,
n_chains=2,
n_samples=200,
n_warmup=100,
)
60 changes: 60 additions & 0 deletions sbijax/_src/sfmpe_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# pylint: skip-file

import distrax
import haiku as hk
from jax import numpy as jnp

from sbijax import SFMPE
from sbijax.nn import make_ccnf


def prior_model_fns():
p = distrax.Independent(distrax.Normal(jnp.zeros(2), jnp.ones(2)), 1)
return p.sample, p.log_prob


def simulator_fn(seed, theta):
p = distrax.MultivariateNormalDiag(theta, 0.1 * jnp.ones_like(theta))
y = p.sample(seed=seed)
return y


def log_density_fn(theta, y):
prior = distrax.Uniform(jnp.full(2, -3.0), jnp.full(2, 3.0))
likelihood = distrax.MultivariateNormalDiag(
theta, 0.1 * jnp.ones_like(theta)
)

lp = jnp.sum(prior.log_prob(theta)) + jnp.sum(likelihood.log_prob(y))
return lp


def test_sfmpe():
rng_seq = hk.PRNGSequence(0)
y_observed = jnp.array([-1.0, 1.0])

prior_simulator_fn, prior_logdensity_fn = prior_model_fns()
fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn

estim = SFMPE(fns, make_ccnf(2))
data, params = None, {}
for i in range(2):
data, _ = estim.simulate_data_and_possibly_append(
next(rng_seq),
params=params,
observable=y_observed,
data=data,
n_simulations=100,
n_chains=2,
n_samples=200,
n_warmup=100,
)
params, info = estim.fit(next(rng_seq), data=data, n_iter=2)
_ = estim.sample_posterior(
next(rng_seq),
params,
y_observed,
n_chains=2,
n_samples=200,
n_warmup=100,
)

0 comments on commit 0323cb2

Please sign in to comment.