From 61763672e7a60826dc2c69a46a55ef1177aa85fb Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Sun, 18 Aug 2024 15:52:54 +0200 Subject: [PATCH] Deprecate distrax and update examples (#48) --- .gitattributes | 1 + .github/workflows/examples.yaml | 47 ++++++++++++ examples/bivariate_gaussian-sabc.py | 42 ---------- examples/bivariate_gaussian-smcabc.py | 10 ++- examples/mixture_model-cmpe.py | 50 ++++++++++++ examples/mixture_model-nle.py | 14 ++-- examples/mixture_model-npe.py | 50 ++++++++++++ examples/mixture_model-nre.py | 50 ++++++++++++ examples/slcp-fmpe.py | 102 +++++++++++++++++++++++++ examples/slcp-nass_nle.py | 19 +++-- examples/slcp-nass_smcabc.py | 17 +++-- examples/slcp-snle.py | 58 +++++--------- examples/two_moons-fmpe.py | 0 examples/two_moons-slice.py | 73 ------------------ mypy.ini | 2 - pyproject.toml | 5 -- sbijax/__init__.py | 2 +- sbijax/_src/nn/make_continuous_flow.py | 4 +- sbijax/_src/nn/make_resnet.py | 2 +- 19 files changed, 360 insertions(+), 188 deletions(-) create mode 100644 .gitattributes create mode 100644 .github/workflows/examples.yaml delete mode 100644 examples/bivariate_gaussian-sabc.py delete mode 100644 examples/two_moons-fmpe.py delete mode 100644 examples/two_moons-slice.py delete mode 100644 mypy.ini diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..5be91f9 --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +*.ipynb linguist-vendored diff --git a/.github/workflows/examples.yaml b/.github/workflows/examples.yaml new file mode 100644 index 0000000..d2fe4da --- /dev/null +++ b/.github/workflows/examples.yaml @@ -0,0 +1,47 @@ +name: examples + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +jobs: + precommit: + name: Pre-commit checks + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + + examples: + runs-on: ubuntu-latest + needs: + - precommit + strategy: + matrix: + python-version: [3.11] + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + pip install hatch matplotlib + - name: Build package + run: | + pip install jaxlib jax + pip install . + - name: Run tests + run: | + python examples/bivariate_gaussian-smcabc.py --n-rounds 1 + python examples/mixture_model-cmpe.py --n-iter 10 + python examples/mixture_model-nle.py --n-iter 10 + python examples/mixture_model-nle.py --n-iter 10 --use-spf + python examples/mixture_model-npe.py --n-iter 10 + python examples/mixture_model-nre.py --n-iter 10 + python examples/slcp-fmpe.py --n-iter 10 + python examples/slcp-nass_nle.py --n-iter 10 --n-rounds 1 + python examples/slcp-nass_smcabc.py --n-iter 10 --n-rounds 1 + python examples/slcp-snle.py --n-iter 10 --n-rounds 1 diff --git a/examples/bivariate_gaussian-sabc.py b/examples/bivariate_gaussian-sabc.py deleted file mode 100644 index d988088..0000000 --- a/examples/bivariate_gaussian-sabc.py +++ /dev/null @@ -1,42 +0,0 @@ -""" -Example using sequential neural likelihood estimation on a bivariate Gaussian -""" - -import matplotlib.pyplot as plt -from jax import numpy as jnp -from jax import random as jr -from tensorflow_probability.substrates.jax import distributions as tfd - -from sbijax import NLE, plot_posterior -from sbijax.nn import make_mdn - - -def prior_fn(): - prior = tfd.JointDistributionNamed(dict( - mean=tfd.Normal(jnp.zeros(2), 1.0), - scale=tfd.HalfNormal(jnp.ones(1)), - ), batch_ndims=0) - return prior - - -def simulator_fn(seed, theta): - p = tfd.Normal(jnp.zeros_like(theta["mean"]), 1.0) - y = theta["mean"] + theta["scale"] * p.sample(seed=seed) - return y - - -def run(): - y_observed = jnp.array([-2.0, 1.0]) - fns = prior_fn, simulator_fn - model = NLE(fns, make_mdn(2, 10)) - - data, _ = model.simulate_data(jr.PRNGKey(11)) - params, info = model.fit(jr.PRNGKey(2), data=data, n_early_stopping_patience=25) - inference_result, _ = model.sample_posterior(jr.PRNGKey(3), params, y_observed) - - plot_posterior(inference_result) - plt.show() - - -if __name__ == "__main__": - run() diff --git a/examples/bivariate_gaussian-smcabc.py b/examples/bivariate_gaussian-smcabc.py index 4910fc6..37a012f 100644 --- a/examples/bivariate_gaussian-smcabc.py +++ b/examples/bivariate_gaussian-smcabc.py @@ -2,6 +2,7 @@ Demonstrates sequential Monte Carlo ABC on a simple bivariate Gaussian example. """ +import argparse import jax import matplotlib.pyplot as plt @@ -35,18 +36,21 @@ def distance_fn(y_simulated, y_observed): return dist -def run(): +def run(n_rounds): y_observed = jnp.array([-2.0, 1.0]) fns = prior_fn, simulator_fn smc = SMCABC(fns, summary_fn, distance_fn) smc_samples, _ = smc.sample_posterior( - jr.PRNGKey(1), y_observed, 10, 1000, 0.85, 500 + jr.PRNGKey(1), y_observed, n_rounds=n_rounds, n_particles=1000, ess_min=500 ) plot_posterior(smc_samples) plt.show() if __name__ == "__main__": - run() + parser = argparse.ArgumentParser() + parser.add_argument("--n-rounds", type=int, default=10) + args = parser.parse_args() + run(args.n_rounds) diff --git a/examples/mixture_model-cmpe.py b/examples/mixture_model-cmpe.py index e69de29..ca41d46 100644 --- a/examples/mixture_model-cmpe.py +++ b/examples/mixture_model-cmpe.py @@ -0,0 +1,50 @@ +"""Consistency model posterior estimation example. + +Demonstrates CMPE on a simple mixture model. +""" +import argparse + +import matplotlib.pyplot as plt +from jax import numpy as jnp, random as jr +from tensorflow_probability.substrates.jax import distributions as tfd + +from sbijax import plot_posterior, CMPE +from sbijax.nn import make_cm + + +def prior_fn(): + prior = tfd.JointDistributionNamed(dict( + theta=tfd.Normal(jnp.zeros(2), 1) + ), batch_ndims=0) + return prior + + +def simulator_fn(seed, theta): + mean = theta["theta"].reshape(-1, 2) + n = mean.shape[0] + data_key, cat_key = jr.split(seed) + categories = tfd.Categorical(logits=jnp.zeros(2)).sample(seed=cat_key, sample_shape=(n,)) + scales = jnp.array([1.0, 0.1])[categories].reshape(-1, 1) + y = tfd.Normal(mean, scales).sample(seed=data_key) + return y + + +def run(n_iter): + y_observed = jnp.array([-2.0, 1.0]) + fns = prior_fn, simulator_fn + neural_network = make_cm(2, 64) + model = CMPE(fns, neural_network) + + data, _ = model.simulate_data(jr.PRNGKey(1), n_simulations=10_000) + params, info = model.fit(jr.PRNGKey(2), data=data, n_early_stopping_patience=25, n_iter=n_iter) + inference_result, _ = model.sample_posterior(jr.PRNGKey(3), params, y_observed) + + plot_posterior(inference_result) + plt.show() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--n-iter", type=int, default=10) + args = parser.parse_args() + run(args.n_iter) diff --git a/examples/mixture_model-nle.py b/examples/mixture_model-nle.py index a61d28d..0b78ca4 100644 --- a/examples/mixture_model-nle.py +++ b/examples/mixture_model-nle.py @@ -1,4 +1,7 @@ +"""Neural likelihood estimation example. +Demonstrates NLE on a simple mixture model. +""" import matplotlib.pyplot as plt from jax import numpy as jnp, random as jr from tensorflow_probability.substrates.jax import distributions as tfd @@ -24,14 +27,14 @@ def simulator_fn(seed, theta): return y -def run(use_spf): +def run(use_spf, n_iter): y_observed = jnp.array([-2.0, 1.0]) fns = prior_fn, simulator_fn - neural_network = make_spf(2, -5.0, 5.0, n_params=3) if use_spf else make_mdn(2, 10) + neural_network = make_spf(2, -5.0, 5.0, n_params=10) if use_spf else make_mdn(2, 10) model = NLE(fns, neural_network) data, _ = model.simulate_data(jr.PRNGKey(1), n_simulations=10_000) - params, info = model.fit(jr.PRNGKey(2), data=data, n_early_stopping_patience=25) + params, info = model.fit(jr.PRNGKey(2), data=data, n_early_stopping_patience=25, n_iter=n_iter) inference_result, _ = model.sample_posterior(jr.PRNGKey(3), params, y_observed) plot_posterior(inference_result) @@ -41,6 +44,7 @@ def run(use_spf): if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() - parser.add_argument("--use-spf", action="store_true", default=True) + parser.add_argument("--use-spf", action="store_true", default=False) + parser.add_argument("--n-iter", type=int, default=1_000) args = parser.parse_args() - run(False) + run(args.use_spf, args.n_iter) diff --git a/examples/mixture_model-npe.py b/examples/mixture_model-npe.py index e69de29..5251067 100644 --- a/examples/mixture_model-npe.py +++ b/examples/mixture_model-npe.py @@ -0,0 +1,50 @@ +"""Neural posterior estimation example. + +Demonstrates NPE on a simple mixture model. +""" +import argparse + +import matplotlib.pyplot as plt +from jax import numpy as jnp, random as jr +from tensorflow_probability.substrates.jax import distributions as tfd + +from sbijax import plot_posterior, NPE +from sbijax.nn import make_maf + + +def prior_fn(): + prior = tfd.JointDistributionNamed(dict( + theta=tfd.Normal(jnp.zeros(2), 1) + ), batch_ndims=0) + return prior + + +def simulator_fn(seed, theta): + mean = theta["theta"].reshape(-1, 2) + n = mean.shape[0] + data_key, cat_key = jr.split(seed) + categories = tfd.Categorical(logits=jnp.zeros(2)).sample(seed=cat_key, sample_shape=(n,)) + scales = jnp.array([1.0, 0.1])[categories].reshape(-1, 1) + y = tfd.Normal(mean, scales).sample(seed=data_key) + return y + + +def run(n_iter): + y_observed = jnp.array([-2.0, 1.0]) + fns = prior_fn, simulator_fn + neural_network = make_maf(2) + model = NPE(fns, neural_network) + + data, _ = model.simulate_data(jr.PRNGKey(1), n_simulations=10_000) + params, info = model.fit(jr.PRNGKey(2), data=data, n_early_stopping_patience=25, n_iter=n_iter) + inference_result, _ = model.sample_posterior(jr.PRNGKey(3), params, y_observed) + + plot_posterior(inference_result) + plt.show() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--n-iter", type=int, default=1_000) + args = parser.parse_args() + run(args.n_iter) diff --git a/examples/mixture_model-nre.py b/examples/mixture_model-nre.py index e69de29..8d276be 100644 --- a/examples/mixture_model-nre.py +++ b/examples/mixture_model-nre.py @@ -0,0 +1,50 @@ +"""Neural ratio estimation example. + +Demonstrates NRE on a simple mixture model. +""" +import argparse + +import matplotlib.pyplot as plt +from jax import numpy as jnp, random as jr +from tensorflow_probability.substrates.jax import distributions as tfd + +from sbijax import plot_posterior, NRE +from sbijax.nn import make_mlp + + +def prior_fn(): + prior = tfd.JointDistributionNamed(dict( + theta=tfd.Normal(jnp.zeros(2), 1) + ), batch_ndims=0) + return prior + + +def simulator_fn(seed, theta): + mean = theta["theta"].reshape(-1, 2) + n = mean.shape[0] + data_key, cat_key = jr.split(seed) + categories = tfd.Categorical(logits=jnp.zeros(2)).sample(seed=cat_key, sample_shape=(n,)) + scales = jnp.array([1.0, 0.1])[categories].reshape(-1, 1) + y = tfd.Normal(mean, scales).sample(seed=data_key) + return y + + +def run(n_iter): + y_observed = jnp.array([-2.0, 1.0]) + fns = prior_fn, simulator_fn + neural_network = make_mlp() + model = NRE(fns, neural_network) + + data, _ = model.simulate_data(jr.PRNGKey(1), n_simulations=10_000) + params, info = model.fit(jr.PRNGKey(2), data=data, n_early_stopping_patience=25, n_iter=n_iter) + inference_result, _ = model.sample_posterior(jr.PRNGKey(3), params, y_observed) + + plot_posterior(inference_result) + plt.show() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--n-iter", type=int, default=1_000) + args = parser.parse_args() + run(args.n_iter) diff --git a/examples/slcp-fmpe.py b/examples/slcp-fmpe.py index e69de29..d729bf5 100644 --- a/examples/slcp-fmpe.py +++ b/examples/slcp-fmpe.py @@ -0,0 +1,102 @@ +"""Flow matching posterior estimation. + +Demonstrates FMPE on the simple likelihood complex posterior model. +""" +import optax +from jax import numpy as jnp +from jax import random as jr +from matplotlib import pyplot as plt +from tensorflow_probability.substrates.jax import distributions as tfd + +from sbijax import FMPE +from sbijax import inference_data_as_dictionary +from sbijax.nn import make_cnf + + +def prior_fn(): + prior = tfd.JointDistributionNamed(dict( + theta=tfd.Uniform(jnp.full(5, -3.0), jnp.full(5, 3.0)) + ), batch_ndims=0) + return prior + + +def simulator_fn(seed, theta): + theta = theta["theta"] + theta = theta[:, None, :] + us_key, noise_key = jr.split(seed) + + def _unpack_params(ps): + m0 = ps[..., [0]] + m1 = ps[..., [1]] + s0 = ps[..., [2]] ** 2 + s1 = ps[..., [3]] ** 2 + r = jnp.tanh(ps[..., [4]]) + return m0, m1, s0, s1, r + + m0, m1, s0, s1, r = _unpack_params(theta) + us = tfd.Normal(0.0, 1.0).sample( + seed=us_key, sample_shape=(theta.shape[0], theta.shape[1], 4, 2) + ) + xs = jnp.empty_like(us) + xs = xs.at[:, :, :, 0].set(s0 * us[:, :, :, 0] + m0) + y = xs.at[:, :, :, 1].set( + s1 * (r * us[:, :, :, 0] + jnp.sqrt(1.0 - r**2) * us[:, :, :, 1]) + m1 + ) + y = y.reshape((*theta.shape[:1], 8)) + return y + + +def run(n_iter): + y_observed = jnp.array([[ + -0.9707123, + -2.9461224, + -0.4494722, + -3.4231849, + -0.13285634, + -3.364017, + -0.85367596, + -2.4271638, + ]]) + + n_dim_theta = 5 + n_layers, hidden_size = 5, 128 + neural_network = make_cnf(n_dim_theta, n_layers, hidden_size) + fns = prior_fn, simulator_fn + fmpe = FMPE(fns, neural_network) + + data, _ = fmpe.simulate_data( + jr.PRNGKey(1), + n_simulations=20_000, + ) + fmpe_params, info = fmpe.fit( + jr.PRNGKey(2), + data=data, + optimizer=optax.adam(0.001), + n_iter=n_iter, + n_early_stopping_delta=0.00001, + n_early_stopping_patience=30 + ) + inference_results, diagnostics = fmpe.sample_posterior( + jr.PRNGKey(5), fmpe_params, y_observed, n_samples=25_000 + ) + + samples = inference_data_as_dictionary(inference_results.posterior)["theta"] + _, axes = plt.subplots(figsize=(12, 10), nrows=5, ncols=5) + for i in range(0, 5): + for j in range(0, 5): + ax = axes[i, j] + if i < j: + ax.axis('off') + else: + ax.hexbin(samples[..., j], samples[..., i], gridsize=50, bins='log') + for i in range(5): + axes[i, i].hist(samples[..., i], color="black") + plt.show() + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--n-iter", type=int, default=1_000) + args = parser.parse_args() + run(args.n_iter) diff --git a/examples/slcp-nass_nle.py b/examples/slcp-nass_nle.py index 4378cc7..2301f79 100644 --- a/examples/slcp-nass_nle.py +++ b/examples/slcp-nass_nle.py @@ -51,7 +51,7 @@ def _unpack_params(ps): return y -def run(): +def run(n_rounds, n_iter): y_observed = jnp.array([[ -0.9707123, -2.9461224, @@ -63,11 +63,12 @@ def run(): -2.4271638, ]]) fns = prior_fn, simulator_fn - model_nass = NASS(fns, make_nass_net([64, 64, 5], [64, 64, 1])) + neural_network = make_nass_net(5, (64, 64)) + model_nass = NASS(fns, neural_network) model_nle = NLE(fns, make_maf(5)) data, params_nle, params_nass = None, {}, {} - for i in range(5): + for i in range(n_rounds): simulate_key, nass_key, nle_key = jr.split(jr.fold_in(jr.PRNGKey(1), i), 3) s_observed = model_nass.summarize(params_nass, y_observed) data, _ = model_nle.simulate_data_and_possibly_append( @@ -76,9 +77,9 @@ def run(): observable=s_observed, data=data, ) - params_nass, _ = model_nass.fit(nass_key, data=data) + params_nass, _ = model_nass.fit(nass_key, data=data, n_iter=n_iter) summaries = model_nass.summarize(params_nass, data) - params_nle, _ = model_nle.fit(nle_key, data=summaries) + params_nle, _ = model_nle.fit(nle_key, data=summaries, n_iter=n_iter) s_observed = model_nass.summarize(params_nass, y_observed) inference_results, _ = model_nle.sample_posterior(jr.PRNGKey(3), params_nle, s_observed) @@ -97,6 +98,10 @@ def run(): plt.show() - if __name__ == "__main__": - run() + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--n-iter", type=int, default=1_000) + parser.add_argument("--n-rounds", type=int, default=15) + args = parser.parse_args() + run(args.n_rounds, args.n_iter) diff --git a/examples/slcp-nass_smcabc.py b/examples/slcp-nass_smcabc.py index 6a71154..584747c 100644 --- a/examples/slcp-nass_smcabc.py +++ b/examples/slcp-nass_smcabc.py @@ -3,7 +3,7 @@ Demonstrates neural approximate sufficient statistics with SMCABC on the simple likelihood complex posterior model. """ - +import argparse import distrax import jax @@ -60,7 +60,7 @@ def distance_fn(y_simulated, y_observed): return dist -def run(): +def run(n_rounds, n_iter): y_observed = jnp.array([[ -0.9707123, -2.9461224, @@ -72,10 +72,10 @@ def run(): -2.4271638, ]]) fns = prior_fn, simulator_fn - model_nass = NASS(fns, make_nass_net([64, 64, 5], [64, 64, 1])) + model_nass = NASS(fns, make_nass_net(5, (64, 64))) data, _ = model_nass.simulate_data(jr.PRNGKey(1), n_simulations=20_000) - params_nass, _ = model_nass.fit(jr.PRNGKey(2), data=data, n_early_stopping_patience=25) + params_nass, _ = model_nass.fit(jr.PRNGKey(2), data=data, n_early_stopping_patience=25, n_iter=n_iter) def summary_fn(y): s = model_nass.summarize(params_nass, y) @@ -83,7 +83,7 @@ def summary_fn(y): model_smc = SMCABC(fns, summary_fn, distance_fn) inference_results, _ = model_smc.sample_posterior( - jr.PRNGKey(3), y_observed, n_rounds=10, n_particles=5_000, eps_step=0.825, ess_min=2_000 + jr.PRNGKey(3), y_observed, n_rounds=n_rounds, n_particles=5_000, eps_step=0.825, ess_min=2_000 ) samples = inference_data_as_dictionary(inference_results.posterior)["theta"] @@ -100,6 +100,9 @@ def summary_fn(y): plt.show() - if __name__ == "__main__": - run() + parser = argparse.ArgumentParser() + parser.add_argument("--n-rounds", type=int, default=15) + parser.add_argument("--n-iter", type=int, default=1_000) + args = parser.parse_args() + run(args.n_rounds, args.n_iter) diff --git a/examples/slcp-snle.py b/examples/slcp-snle.py index caf0f4b..df751be 100644 --- a/examples/slcp-snle.py +++ b/examples/slcp-snle.py @@ -1,18 +1,15 @@ -"""SNLE example. +"""Surjective neural likelihood estimation example. Demonstrates sequential surjective neural likelihood estimation on the simple likelihood complex posterior model. """ import argparse -from functools import partial import distrax import haiku as hk import jax import matplotlib.pyplot as plt import optax -import pandas as pd -import seaborn as sns from jax import numpy as jnp from jax import random as jr from jax import scipy as jsp @@ -28,7 +25,6 @@ from tensorflow_probability.substrates.jax import distributions as tfd from sbijax import SNLE, inference_data_as_dictionary -from sbijax.mcmc import sample_with_nuts from sbijax.nn import make_maf @@ -117,7 +113,7 @@ def _flow(method, **kwargs): _decoder_fn(n_dimension - n_latent), conditioner=MADE( n_latent, - [50, n_latent * 2], + [64, 64], 2, w_init=hk.initializers.TruncatedNormal(0.001), b_init=jnp.zeros, @@ -132,7 +128,7 @@ def _flow(method, **kwargs): bijector_fn=_bijector_fn, conditioner=MADE( n_dimension, - [50, n_dimension * 2], + [64, 64], 2, w_init=hk.initializers.TruncatedNormal(0.001), b_init=jnp.zeros, @@ -156,7 +152,7 @@ def _flow(method, **kwargs): return td -def run(n_iter): +def run(n_rounds, n_iter): y_obs = jnp.array([[ -0.9707123, -2.9461224, @@ -174,7 +170,7 @@ def run(n_iter): optimizer = optax.adam(1e-3) data, params = None, {} - for i in range(10): + for i in range(n_rounds): data, _ = snl.simulate_data_and_possibly_append( jr.fold_in(jr.PRNGKey(1), i), params=params, @@ -188,42 +184,24 @@ def run(n_iter): sample_key, rng_key = jr.split(jr.PRNGKey(3)) inference_results, _ = snl.sample_posterior(sample_key, params, y_obs) - sample_key, rng_key = jr.split(rng_key) - log_density_partial = partial(log_density_fn, y=y_obs) - log_density = lambda x: log_density_partial(**x) - slice_samples = sample_with_nuts( - sample_key, - log_density, - prior_fn().sample, - ) - slice_samples = slice_samples['theta'].reshape(-1, 5) - snl_samples = inference_data_as_dictionary(inference_results.posterior)["theta"] - - g = sns.PairGrid(pd.DataFrame(slice_samples)) - g.map_upper(sns.scatterplot, color="black", marker=".", edgecolor=None, s=2) - g.map_diag(plt.hist, color="black") - for ax in g.axes.flatten(): - ax.set_xlim(-5, 5) - ax.set_ylim(-5, 5) - g.fig.set_figheight(5) - g.fig.set_figwidth(5) - plt.show() - - fig, axes = plt.subplots(5, 2) + samples = inference_data_as_dictionary(inference_results.posterior)["theta"] + _, axes = plt.subplots(figsize=(12, 10), nrows=5, ncols=5) + for i in range(0, 5): + for j in range(0, 5): + ax = axes[i, j] + if i < j: + ax.axis('off') + else: + ax.hexbin(samples[..., j], samples[..., i], gridsize=50, + bins='log') for i in range(5): - sns.histplot(slice_samples[:, i], color="darkgrey", ax=axes[i, 0]) - sns.histplot(snl_samples[:, i], color="darkblue", ax=axes[i, 1]) - axes[i, 0].set_title(rf"Sampled posterior $\theta_{i}$") - axes[i, 1].set_title(rf"Approximated posterior $\theta_{i}$") - for j in range(2): - axes[i, j].set_xlim(-5, 5) - sns.despine() - plt.tight_layout() + axes[i, i].hist(samples[..., i], color="black") plt.show() if __name__ == "__main__": parser = argparse.ArgumentParser() + parser.add_argument("--n-rounds", type=int, default=15) parser.add_argument("--n-iter", type=int, default=1_000) args = parser.parse_args() - run(args.n_iter) + run(args.n_rounds, args.n_iter) diff --git a/examples/two_moons-fmpe.py b/examples/two_moons-fmpe.py deleted file mode 100644 index e69de29..0000000 diff --git a/examples/two_moons-slice.py b/examples/two_moons-slice.py deleted file mode 100644 index d5631cf..0000000 --- a/examples/two_moons-slice.py +++ /dev/null @@ -1,73 +0,0 @@ -""" -Example using consistency model posterior estimation on a bivariate Gaussian -""" -from functools import partial - -import jax -import numpy as np -from jax import numpy as jnp -from jax import random as jr -from matplotlib import pyplot as plt -from tensorflow_probability.substrates.jax import distributions as tfd - -from sbijax.mcmc import sample_with_slice - - -def prior_fn(): - prior = tfd.JointDistributionNamed(dict( - theta=tfd.Normal(jnp.full(2, 0.0), 1.0) - ), batch_ndims=0) - return prior - - -def _map_fun_inv(theta, y): - ang = jnp.array([-jnp.pi / 4.0]) - c = jnp.cos(ang) - s = jnp.sin(ang) - z0 = (c * theta[:, 0] - s * theta[:, 1]).reshape(-1, 1) - z1 = (s * theta[:, 0] + c * theta[:, 1]).reshape(-1, 1) - return y - jnp.concatenate([-jnp.abs(z0), z1], axis=1) - - -def likelihood_fn(y, theta): - theta = theta.reshape(1, 2) - p = _map_fun_inv(theta, y).reshape(1, -1) - - u = p[:, 0] - 0.25 - v = p[:, 1] - - r = jnp.sqrt(u ** 2 + v ** 2) - ll = -0.5 * ((r - 0.1) / 0.01) ** 2 - 0.5 * jnp.log(2 * jnp.array([jnp.pi]) * 0.01 ** 2) - ll = jnp.where( - u < 0.0, - jnp.array(-jnp.inf), - ll - ) - return ll - - -def log_density_fn(theta, y): - prior_lp = prior_fn().log_prob(theta) - likelihood_lp = likelihood_fn(y, theta) - lp = jnp.sum(prior_lp) + jnp.sum(likelihood_lp) - #jax.debug.print("🤯 {x} 🤯", x=lp) - return lp - - -def run(): - y_observed = jnp.array([-0.6396706, 0.16234657]) - - log_density_partial = partial(log_density_fn, y=y_observed) - log_density = lambda x: jax.vmap(log_density_partial)(x) - samples = sample_with_slice( - jr.PRNGKey(0), - log_density, - prior_fn().sample, n_chains=4, n_samples=10000, n_warmup=5000 - ) - samples = np.array(samples.reshape(-1, 2)) - plt.scatter(samples[:, 0], samples[:, 1]) - plt.show() - - -if __name__ == "__main__": - run() diff --git a/mypy.ini b/mypy.ini deleted file mode 100644 index 5ce758b..0000000 --- a/mypy.ini +++ /dev/null @@ -1,2 +0,0 @@ -[mypy] -exclude = manuscript+supplement diff --git a/pyproject.toml b/pyproject.toml index 9dbc761..7219c23 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,6 @@ requires-python = ">=3.9" dependencies = [ "arviz>=0.17.1", "blackjax-nightly>=1.0.0.post17", - "distrax>=0.1.2", "dm-haiku>=0.0.9", "matplotlib>=3.6.2", "optax>=0.1.3", @@ -67,10 +66,6 @@ test = 'pytest -v --doctest-modules --cov=./sbijax --cov-report=xml sbijax' [tool.bandit] skips = ["B101"] -[tool.mypy] -exclude = ["manuscript+supplement"] - - [tool.ruff] line-length = 80 exclude = ["*_test.py", "docs/**", "examples/**", "manuscript+supplement/**"] diff --git a/sbijax/__init__.py b/sbijax/__init__.py index 6ee5291..97c0352 100644 --- a/sbijax/__init__.py +++ b/sbijax/__init__.py @@ -1,6 +1,6 @@ """sbijax: Simulation-based inference in JAX.""" -__version__ = "0.3.0" +__version__ = "0.3.1" import os diff --git a/sbijax/_src/nn/make_continuous_flow.py b/sbijax/_src/nn/make_continuous_flow.py index e16d936..a49ae9e 100644 --- a/sbijax/_src/nn/make_continuous_flow.py +++ b/sbijax/_src/nn/make_continuous_flow.py @@ -150,8 +150,8 @@ def make_cnf( n_dimension: int, n_layers: int = 2, hidden_size: int = 64, - activation: Callable = jax.nn.tanh, - dropout_rate: float = 0.2, + activation: Callable = jax.nn.relu, + dropout_rate: float = 0.1, do_batch_norm: bool = False, batch_norm_decay: float = 0.2, ): diff --git a/sbijax/_src/nn/make_resnet.py b/sbijax/_src/nn/make_resnet.py index 3a3a043..27f9844 100644 --- a/sbijax/_src/nn/make_resnet.py +++ b/sbijax/_src/nn/make_resnet.py @@ -53,7 +53,7 @@ def __init__( n_layers: int, hidden_size: int, activation: Callable = jax.nn.relu, - dropout_rate: float = 0.2, + dropout_rate: float = 0.1, do_batch_norm: bool = True, batch_norm_decay: float = 0.1, ):