Skip to content

Commit

Permalink
DRAFT: Move to TF iterators (#23)
Browse files Browse the repository at this point in the history
* Move all classes to use TF iterators
* Update unit tests
  • Loading branch information
dirmeier authored Feb 26, 2024
1 parent f4f43d1 commit db2a588
Show file tree
Hide file tree
Showing 19 changed files with 304 additions and 382 deletions.
9 changes: 4 additions & 5 deletions examples/bivariate_gaussian_snl.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from surjectors.util import unstack

from sbijax import SNL
from sbijax.mcmc import sample_with_slice
from sbijax.mcmc import sample_with_nuts


def prior_model_fns():
Expand Down Expand Up @@ -84,9 +84,6 @@ def _flow(method, **kwargs):
def run():
y_observed = jnp.array([-2.0, 1.0])

log_density_partial = partial(log_density_fn, y=y_observed)
log_density = lambda x: jax.vmap(log_density_partial)(x)

prior_simulator_fn, prior_logdensity_fn = prior_model_fns()
fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn

Expand All @@ -107,7 +104,9 @@ def run():
)

sample_key, rng_key = jr.split(jr.PRNGKey(123))
slice_samples = sample_with_slice(
log_density_partial = partial(log_density_fn, y=y_observed)
log_density = lambda x: log_density_partial(**x)
slice_samples = sample_with_nuts(
sample_key, log_density, prior_simulator_fn
)
slice_samples = slice_samples.reshape(-1, 2)
Expand Down
3 changes: 1 addition & 2 deletions examples/bivariate_gaussian_snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def run():
optimizer = optax.adam(1e-3)

data, params = None, {}
for i in range(5):
for i in range(2):
data, _ = snr.simulate_data_and_possibly_append(
jr.fold_in(jr.PRNGKey(1), i),
params=params,
Expand All @@ -54,7 +54,6 @@ def run():
jr.fold_in(jr.PRNGKey(2), i),
data=data,
optimizer=optimizer,
batch_size=100,
)

rng_key = jr.PRNGKey(23)
Expand Down
11 changes: 5 additions & 6 deletions examples/slcp_snass.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from surjectors.nn import MADE
from surjectors.util import unstack

from sbijax import SNASSS
from sbijax.nn import make_snasss_net
from sbijax import SNASS
from sbijax._src.nn.make_snass_networks import make_snass_net


def prior_model_fns():
Expand Down Expand Up @@ -125,15 +125,15 @@ def run():
prior_simulator_fn, prior_logdensity_fn = prior_model_fns()
fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn

estim = SNASSS(
estim = SNASS(
fns,
make_model(5),
make_snasss_net([64, 64, 5], [64, 64, 1], [64, 64, 1]),
make_snass_net([64, 64, 5], [64, 64, 1]),
)
optimizer = optax.adam(1e-3)

data, params = None, {}
for i in range(5):
for i in range(2):
data, _ = estim.simulate_data_and_possibly_append(
jr.fold_in(jr.PRNGKey(12), i),
params=params,
Expand All @@ -144,7 +144,6 @@ def run():
jr.fold_in(jr.PRNGKey(23), i),
data=data,
optimizer=optimizer,
batch_size=100,
)

rng_key = jr.PRNGKey(23)
Expand Down
35 changes: 10 additions & 25 deletions examples/slcp_ssnl.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from jax import numpy as jnp
from jax import random as jr
from jax import scipy as jsp
from jax import vmap
from surjectors import (
AffineMaskedAutoregressiveInferenceFunnel,
Chain,
Expand All @@ -27,7 +26,7 @@
from surjectors.util import unstack

from sbijax import SNL
from sbijax.mcmc import sample_with_slice
from sbijax.mcmc import sample_with_nuts


def prior_model_fns():
Expand Down Expand Up @@ -160,50 +159,36 @@ def _flow(method, **kwargs):

def run(use_surjectors):
len_theta = 5
# this is the thetas used in SNL
# thetas = jnp.array([-0.7, -2.9, -1.0, -0.9, 0.6])
y_observed = jnp.array(
[
[
-0.9707123,
-2.9461224,
-0.4494722,
-3.4231849,
-0.13285634,
-3.364017,
-0.85367596,
-2.4271638,
]
]
thetas = jnp.linspace(-2.0, 2.0, len_theta)
y_0 = simulator_fn(jr.PRNGKey(0), thetas.reshape(-1, len_theta)).reshape(
-1, 8
)

log_density_partial = partial(log_density_fn, y=y_observed)
log_density = lambda x: vmap(log_density_partial)(x)

prior_simulator_fn, prior_fn = prior_model_fns()
fns = (prior_simulator_fn, prior_fn), simulator_fn

snl = SNL(fns, make_model(y_observed.shape[1], use_surjectors))
snl = SNL(fns, make_model(y_0.shape[1], use_surjectors))
optimizer = optax.adam(1e-3)

data, params = None, {}
for i in range(5):
data, _ = snl.simulate_data_and_possibly_append(
jr.fold_in(jr.PRNGKey(12), i),
params=params,
observable=y_observed,
observable=y_0,
data=data,
sampler="slice",
)
params, info = snl.fit(
jr.fold_in(jr.PRNGKey(23), i), data=data, optimizer=optimizer
)

sample_key, rng_key = jr.split(jr.PRNGKey(123))
snl_samples, _ = snl.sample_posterior(sample_key, params, y_observed)
snl_samples, _ = snl.sample_posterior(sample_key, params, y_0)

sample_key, rng_key = jr.split(rng_key)
slice_samples = sample_with_slice(
log_density_partial = partial(log_density_fn, y=y_0)
log_density = lambda x: log_density_partial(**x)
slice_samples = sample_with_nuts(
sample_key,
log_density,
prior_simulator_fn,
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ dependencies = [
"optax>=0.1.3",
"surjectors>=0.3.0",
"tfp-nightly>=0.20.0.dev20230404",
"tensorflow==2.15.0",
"tensorflow-datasets==4.9.3",
"tqdm>=4.64.1"
]
dynamic = ["version"]
Expand Down
2 changes: 1 addition & 1 deletion sbijax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
sbijax: Simulation-based inference in JAX
"""

__version__ = "0.1.7"
__version__ = "0.1.8"


from sbijax._src.abc.smc_abc import SMCABC
Expand Down
24 changes: 3 additions & 21 deletions sbijax/_src/_sne_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from jax import random as jr

from sbijax._src._sbi_base import SBI
from sbijax._src.generator import as_batch_iterators, named_dataset
from sbijax._src.util.data import stack_data
from sbijax._src.util.dataloader import as_batch_iterators, named_dataset


# pylint: disable=too-many-arguments,unused-argument
Expand Down Expand Up @@ -61,7 +62,7 @@ def simulate_data_and_possibly_append(
if data is None:
d_new = new_data
else:
d_new = self.stack_data(data, new_data)
d_new = stack_data(data, new_data)
return d_new, diagnostics

@abc.abstractmethod
Expand Down Expand Up @@ -135,25 +136,6 @@ def simulate_data(

return new_data, diagnostics

@staticmethod
def stack_data(data, also_data):
"""Stack two data sets.
Args:
data: one data set
also_data: another data set
Returns:
returns the stack of the two data sets
"""
if data is None:
return also_data
if also_data is None:
return data
return named_dataset(
*[jnp.vstack([a, b]) for a, b in zip(data, also_data)]
)

@staticmethod
def as_iterators(
rng_key, data, batch_size, percentage_data_as_validation_set
Expand Down
107 changes: 0 additions & 107 deletions sbijax/_src/generator.py

This file was deleted.

2 changes: 1 addition & 1 deletion sbijax/_src/mcmc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from sbijax._src.mcmc.diagnostics import mcmc_diagnostics
from sbijax._src.mcmc.irmh import sample_with_imh
from sbijax._src.mcmc.mala import sample_with_mala
from sbijax._src.mcmc.nuts import sample_with_nuts
from sbijax._src.mcmc.rmh import sample_with_rmh
from sbijax._src.mcmc.sample import mcmc_diagnostics
from sbijax._src.mcmc.slice import sample_with_slice
File renamed without changes.
Loading

0 comments on commit db2a588

Please sign in to comment.